spark-sql-engine: enable session configuration
This commit is contained in:
parent
07c9d9578d
commit
d8a27c914f
@ -92,10 +92,7 @@ abstract class SparkOperation(spark: SparkSession, opType: OperationType, sessio
|
||||
warn(s"Ignore exception in terminal state with $statementId: $e")
|
||||
} else {
|
||||
setState(OperationState.ERROR)
|
||||
e match {
|
||||
case k: KyuubiSQLException => throw k
|
||||
case _ => throw KyuubiSQLException(s"Error operating $opType: ${e.getMessage}", e)
|
||||
}
|
||||
throw KyuubiSQLException(s"Error operating $opType: ${e.getMessage}", e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
package org.apache.kyuubi.engine.spark.session
|
||||
|
||||
import scala.util.control.NonFatal
|
||||
import scala.util.matching.Regex
|
||||
|
||||
import org.apache.hive.service.rpc.thrift.TProtocolVersion
|
||||
import org.apache.spark.sql.SparkSession
|
||||
@ -37,6 +38,8 @@ import org.apache.kyuubi.session.{SessionHandle, SessionManager}
|
||||
class SparkSQLSessionManager private (name: String, spark: SparkSession)
|
||||
extends SessionManager(name) {
|
||||
|
||||
import SparkSQLSessionManager._
|
||||
|
||||
def this(spark: SparkSession) = this(classOf[SparkSQLSessionManager].getSimpleName, spark)
|
||||
|
||||
val operationManager = new SparkSQLOperationManager()
|
||||
@ -51,12 +54,17 @@ class SparkSQLSessionManager private (name: String, spark: SparkSession)
|
||||
val handle = sessionImpl.handle
|
||||
try {
|
||||
val sparkSession = spark.newSession()
|
||||
conf.foreach { case (key, value) => spark.conf.set(key, value)}
|
||||
operationManager.setSparkSession(handle, sparkSession)
|
||||
conf.foreach {
|
||||
case (HIVE_VAR_PREFIX(key), value) => sparkSession.conf.set(key, value)
|
||||
case (HIVE_CONF_PREFIX(key), value) => sparkSession.conf.set(key, value)
|
||||
case ("use:database", database) => sparkSession.catalog.setCurrentDatabase(database)
|
||||
case (key, value) => sparkSession.conf.set(key, value)
|
||||
}
|
||||
sessionImpl.open()
|
||||
info(s"$user's session with $handle is opened, current opening sessions" +
|
||||
s" $getOpenSessionCount")
|
||||
operationManager.setSparkSession(handle, sparkSession)
|
||||
setSession(handle, sessionImpl)
|
||||
info(s"$user's session with $handle is opened, current opening sessions" +
|
||||
s" $getOpenSessionCount")
|
||||
handle
|
||||
} catch {
|
||||
case NonFatal(e) =>
|
||||
@ -65,7 +73,12 @@ class SparkSQLSessionManager private (name: String, spark: SparkSession)
|
||||
} catch {
|
||||
case t: Throwable => warn(s"Error closing session $handle for $user", t)
|
||||
}
|
||||
throw KyuubiSQLException(s"Error opening session $handle for $user", e)
|
||||
throw KyuubiSQLException(s"Error opening session $handle for $user: ${e.getMessage}", e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
object SparkSQLSessionManager {
|
||||
val HIVE_VAR_PREFIX: Regex = """set:hivevar:([^=]+)""".r
|
||||
val HIVE_CONF_PREFIX: Regex = """set:hiveconf:([^=]+)""".r
|
||||
}
|
||||
|
||||
@ -19,8 +19,10 @@ package org.apache.kyuubi.engine.spark.operation
|
||||
|
||||
import java.sql.{Date, SQLException, Timestamp}
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import org.apache.hive.service.cli.HiveSQLException
|
||||
import org.apache.hive.service.rpc.thrift.TOpenSessionReq
|
||||
import org.apache.hive.service.rpc.thrift.{TCloseSessionReq, TExecuteStatementReq, TFetchResultsReq, TOpenSessionReq, TStatus, TStatusCode}
|
||||
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
|
||||
import org.apache.spark.sql.catalyst.catalog.CatalogTableType
|
||||
import org.apache.spark.sql.types._
|
||||
@ -698,17 +700,104 @@ class SparkOperationSuite extends WithSparkSQLEngine {
|
||||
}
|
||||
}
|
||||
|
||||
test("get functions operation") {
|
||||
test("basic open | execute | close") {
|
||||
withThriftClient { client =>
|
||||
val req = new TOpenSessionReq()
|
||||
req.setUsername("kentyao")
|
||||
req.setPassword("anonymous")
|
||||
// req.setClient_protocol(BackendService.SERVER_VERSION)
|
||||
req
|
||||
val resp = client.OpenSession(req)
|
||||
val sessionHandle = resp.getSessionHandle
|
||||
}
|
||||
val tOpenSessionResp = client.OpenSession(req)
|
||||
|
||||
val tExecuteStatementReq = new TExecuteStatementReq()
|
||||
tExecuteStatementReq.setSessionHandle( tOpenSessionResp.getSessionHandle)
|
||||
tExecuteStatementReq.setStatement("set -v")
|
||||
val tExecuteStatementResp = client.ExecuteStatement(tExecuteStatementReq)
|
||||
|
||||
val tFetchResultsReq = new TFetchResultsReq()
|
||||
tFetchResultsReq.setOperationHandle(tExecuteStatementResp.getOperationHandle)
|
||||
tFetchResultsReq.setFetchType(1)
|
||||
tFetchResultsReq.setMaxRows(1000)
|
||||
val tFetchResultsResp = client.FetchResults(tFetchResultsReq)
|
||||
val logs = tFetchResultsResp.getResults.getColumns.get(0).getStringVal.getValues.asScala
|
||||
assert(logs.exists(_.contains(classOf[ExecuteStatement].getCanonicalName)))
|
||||
|
||||
tFetchResultsReq.setFetchType(0)
|
||||
val tFetchResultsResp1 = client.FetchResults(tFetchResultsReq)
|
||||
val rs = tFetchResultsResp1.getResults.getColumns.get(0).getStringVal.getValues.asScala
|
||||
assert(rs.contains("spark.sql.shuffle.partitions"))
|
||||
|
||||
val tCloseSessionReq = new TCloseSessionReq()
|
||||
tCloseSessionReq.setSessionHandle(tOpenSessionResp.getSessionHandle)
|
||||
val tCloseSessionResp = client.CloseSession(tCloseSessionReq)
|
||||
assert(tCloseSessionResp.getStatus.getStatusCode === TStatusCode.SUCCESS_STATUS)
|
||||
}
|
||||
}
|
||||
|
||||
test("set session conf") {
|
||||
withThriftClient { client =>
|
||||
val req = new TOpenSessionReq()
|
||||
req.setUsername("kentyao")
|
||||
req.setPassword("anonymous")
|
||||
val conf = Map(
|
||||
"use:database" -> "default",
|
||||
"spark.sql.shuffle.partitions" -> "4",
|
||||
"set:hiveconf:spark.sql.autoBroadcastJoinThreshold" -> "-1",
|
||||
"set:hivevar:spark.sql.adaptive.enabled" -> "true")
|
||||
req.setConfiguration(conf.asJava)
|
||||
val tOpenSessionResp = client.OpenSession(req)
|
||||
|
||||
val tExecuteStatementReq = new TExecuteStatementReq()
|
||||
tExecuteStatementReq.setSessionHandle( tOpenSessionResp.getSessionHandle)
|
||||
tExecuteStatementReq.setStatement("set")
|
||||
val tExecuteStatementResp = client.ExecuteStatement(tExecuteStatementReq)
|
||||
|
||||
val tFetchResultsReq = new TFetchResultsReq()
|
||||
tFetchResultsReq.setOperationHandle(tExecuteStatementResp.getOperationHandle)
|
||||
tFetchResultsReq.setFetchType(0)
|
||||
tFetchResultsReq.setMaxRows(1000)
|
||||
val tFetchResultsResp1 = client.FetchResults(tFetchResultsReq)
|
||||
val columns = tFetchResultsResp1.getResults.getColumns
|
||||
val rs = columns.get(0).getStringVal.getValues.asScala.zip(
|
||||
columns.get(1).getStringVal.getValues.asScala)
|
||||
rs foreach {
|
||||
case ("spark.sql.shuffle.partitions", v) => assert(v === "4")
|
||||
case ("spark.sql.autoBroadcastJoinThreshold", v) => assert(v === "-1")
|
||||
case ("spark.sql.adaptive.enabled", v) => assert(v.toBoolean)
|
||||
case _ =>
|
||||
}
|
||||
assert(spark.conf.get("spark.sql.shuffle.partitions") === "200")
|
||||
|
||||
val tCloseSessionReq = new TCloseSessionReq()
|
||||
tCloseSessionReq.setSessionHandle(tOpenSessionResp.getSessionHandle)
|
||||
val tCloseSessionResp = client.CloseSession(tCloseSessionReq)
|
||||
assert(tCloseSessionResp.getStatus.getStatusCode === TStatusCode.SUCCESS_STATUS)
|
||||
}
|
||||
}
|
||||
|
||||
test("set session conf - static") {
|
||||
withThriftClient { client =>
|
||||
val req = new TOpenSessionReq()
|
||||
req.setUsername("kentyao")
|
||||
req.setPassword("anonymous")
|
||||
val conf = Map("use:database" -> "default", "spark.sql.globalTempDatabase" -> "temp")
|
||||
req.setConfiguration(conf.asJava)
|
||||
val tOpenSessionResp = client.OpenSession(req)
|
||||
val status = tOpenSessionResp.getStatus
|
||||
assert(status.getStatusCode === TStatusCode.ERROR_STATUS)
|
||||
assert(status.getErrorMessage.contains("spark.sql.globalTempDatabase"))
|
||||
}
|
||||
}
|
||||
|
||||
test("set session conf - wrong database") {
|
||||
withThriftClient { client =>
|
||||
val req = new TOpenSessionReq()
|
||||
req.setUsername("kentyao")
|
||||
req.setPassword("anonymous")
|
||||
val conf = Map("use:database" -> "default2")
|
||||
req.setConfiguration(conf.asJava)
|
||||
val tOpenSessionResp = client.OpenSession(req)
|
||||
val status = tOpenSessionResp.getStatus
|
||||
assert(status.getStatusCode === TStatusCode.ERROR_STATUS)
|
||||
assert(status.getErrorMessage.contains("Database 'default2' does not exist"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user