spark-sql-engine: enable session configuration

This commit is contained in:
Kent Yao 2020-09-14 18:05:08 +08:00
parent 07c9d9578d
commit d8a27c914f
3 changed files with 115 additions and 16 deletions

View File

@ -92,10 +92,7 @@ abstract class SparkOperation(spark: SparkSession, opType: OperationType, sessio
warn(s"Ignore exception in terminal state with $statementId: $e")
} else {
setState(OperationState.ERROR)
e match {
case k: KyuubiSQLException => throw k
case _ => throw KyuubiSQLException(s"Error operating $opType: ${e.getMessage}", e)
}
throw KyuubiSQLException(s"Error operating $opType: ${e.getMessage}", e)
}
}
}

View File

@ -18,6 +18,7 @@
package org.apache.kyuubi.engine.spark.session
import scala.util.control.NonFatal
import scala.util.matching.Regex
import org.apache.hive.service.rpc.thrift.TProtocolVersion
import org.apache.spark.sql.SparkSession
@ -37,6 +38,8 @@ import org.apache.kyuubi.session.{SessionHandle, SessionManager}
class SparkSQLSessionManager private (name: String, spark: SparkSession)
extends SessionManager(name) {
import SparkSQLSessionManager._
def this(spark: SparkSession) = this(classOf[SparkSQLSessionManager].getSimpleName, spark)
val operationManager = new SparkSQLOperationManager()
@ -51,12 +54,17 @@ class SparkSQLSessionManager private (name: String, spark: SparkSession)
val handle = sessionImpl.handle
try {
val sparkSession = spark.newSession()
conf.foreach { case (key, value) => spark.conf.set(key, value)}
operationManager.setSparkSession(handle, sparkSession)
conf.foreach {
case (HIVE_VAR_PREFIX(key), value) => sparkSession.conf.set(key, value)
case (HIVE_CONF_PREFIX(key), value) => sparkSession.conf.set(key, value)
case ("use:database", database) => sparkSession.catalog.setCurrentDatabase(database)
case (key, value) => sparkSession.conf.set(key, value)
}
sessionImpl.open()
info(s"$user's session with $handle is opened, current opening sessions" +
s" $getOpenSessionCount")
operationManager.setSparkSession(handle, sparkSession)
setSession(handle, sessionImpl)
info(s"$user's session with $handle is opened, current opening sessions" +
s" $getOpenSessionCount")
handle
} catch {
case NonFatal(e) =>
@ -65,7 +73,12 @@ class SparkSQLSessionManager private (name: String, spark: SparkSession)
} catch {
case t: Throwable => warn(s"Error closing session $handle for $user", t)
}
throw KyuubiSQLException(s"Error opening session $handle for $user", e)
throw KyuubiSQLException(s"Error opening session $handle for $user: ${e.getMessage}", e)
}
}
}
object SparkSQLSessionManager {
val HIVE_VAR_PREFIX: Regex = """set:hivevar:([^=]+)""".r
val HIVE_CONF_PREFIX: Regex = """set:hiveconf:([^=]+)""".r
}

View File

@ -19,8 +19,10 @@ package org.apache.kyuubi.engine.spark.operation
import java.sql.{Date, SQLException, Timestamp}
import scala.collection.JavaConverters._
import org.apache.hive.service.cli.HiveSQLException
import org.apache.hive.service.rpc.thrift.TOpenSessionReq
import org.apache.hive.service.rpc.thrift.{TCloseSessionReq, TExecuteStatementReq, TFetchResultsReq, TOpenSessionReq, TStatus, TStatusCode}
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
import org.apache.spark.sql.catalyst.catalog.CatalogTableType
import org.apache.spark.sql.types._
@ -698,17 +700,104 @@ class SparkOperationSuite extends WithSparkSQLEngine {
}
}
test("get functions operation") {
test("basic open | execute | close") {
withThriftClient { client =>
val req = new TOpenSessionReq()
req.setUsername("kentyao")
req.setPassword("anonymous")
// req.setClient_protocol(BackendService.SERVER_VERSION)
req
val resp = client.OpenSession(req)
val sessionHandle = resp.getSessionHandle
}
val tOpenSessionResp = client.OpenSession(req)
val tExecuteStatementReq = new TExecuteStatementReq()
tExecuteStatementReq.setSessionHandle( tOpenSessionResp.getSessionHandle)
tExecuteStatementReq.setStatement("set -v")
val tExecuteStatementResp = client.ExecuteStatement(tExecuteStatementReq)
val tFetchResultsReq = new TFetchResultsReq()
tFetchResultsReq.setOperationHandle(tExecuteStatementResp.getOperationHandle)
tFetchResultsReq.setFetchType(1)
tFetchResultsReq.setMaxRows(1000)
val tFetchResultsResp = client.FetchResults(tFetchResultsReq)
val logs = tFetchResultsResp.getResults.getColumns.get(0).getStringVal.getValues.asScala
assert(logs.exists(_.contains(classOf[ExecuteStatement].getCanonicalName)))
tFetchResultsReq.setFetchType(0)
val tFetchResultsResp1 = client.FetchResults(tFetchResultsReq)
val rs = tFetchResultsResp1.getResults.getColumns.get(0).getStringVal.getValues.asScala
assert(rs.contains("spark.sql.shuffle.partitions"))
val tCloseSessionReq = new TCloseSessionReq()
tCloseSessionReq.setSessionHandle(tOpenSessionResp.getSessionHandle)
val tCloseSessionResp = client.CloseSession(tCloseSessionReq)
assert(tCloseSessionResp.getStatus.getStatusCode === TStatusCode.SUCCESS_STATUS)
}
}
test("set session conf") {
withThriftClient { client =>
val req = new TOpenSessionReq()
req.setUsername("kentyao")
req.setPassword("anonymous")
val conf = Map(
"use:database" -> "default",
"spark.sql.shuffle.partitions" -> "4",
"set:hiveconf:spark.sql.autoBroadcastJoinThreshold" -> "-1",
"set:hivevar:spark.sql.adaptive.enabled" -> "true")
req.setConfiguration(conf.asJava)
val tOpenSessionResp = client.OpenSession(req)
val tExecuteStatementReq = new TExecuteStatementReq()
tExecuteStatementReq.setSessionHandle( tOpenSessionResp.getSessionHandle)
tExecuteStatementReq.setStatement("set")
val tExecuteStatementResp = client.ExecuteStatement(tExecuteStatementReq)
val tFetchResultsReq = new TFetchResultsReq()
tFetchResultsReq.setOperationHandle(tExecuteStatementResp.getOperationHandle)
tFetchResultsReq.setFetchType(0)
tFetchResultsReq.setMaxRows(1000)
val tFetchResultsResp1 = client.FetchResults(tFetchResultsReq)
val columns = tFetchResultsResp1.getResults.getColumns
val rs = columns.get(0).getStringVal.getValues.asScala.zip(
columns.get(1).getStringVal.getValues.asScala)
rs foreach {
case ("spark.sql.shuffle.partitions", v) => assert(v === "4")
case ("spark.sql.autoBroadcastJoinThreshold", v) => assert(v === "-1")
case ("spark.sql.adaptive.enabled", v) => assert(v.toBoolean)
case _ =>
}
assert(spark.conf.get("spark.sql.shuffle.partitions") === "200")
val tCloseSessionReq = new TCloseSessionReq()
tCloseSessionReq.setSessionHandle(tOpenSessionResp.getSessionHandle)
val tCloseSessionResp = client.CloseSession(tCloseSessionReq)
assert(tCloseSessionResp.getStatus.getStatusCode === TStatusCode.SUCCESS_STATUS)
}
}
test("set session conf - static") {
withThriftClient { client =>
val req = new TOpenSessionReq()
req.setUsername("kentyao")
req.setPassword("anonymous")
val conf = Map("use:database" -> "default", "spark.sql.globalTempDatabase" -> "temp")
req.setConfiguration(conf.asJava)
val tOpenSessionResp = client.OpenSession(req)
val status = tOpenSessionResp.getStatus
assert(status.getStatusCode === TStatusCode.ERROR_STATUS)
assert(status.getErrorMessage.contains("spark.sql.globalTempDatabase"))
}
}
test("set session conf - wrong database") {
withThriftClient { client =>
val req = new TOpenSessionReq()
req.setUsername("kentyao")
req.setPassword("anonymous")
val conf = Map("use:database" -> "default2")
req.setConfiguration(conf.asJava)
val tOpenSessionResp = client.OpenSession(req)
val status = tOpenSessionResp.getStatus
assert(status.getStatusCode === TStatusCode.ERROR_STATUS)
assert(status.getErrorMessage.contains("Database 'default2' does not exist"))
}
}
}