diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/SparkOperation.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/SparkOperation.scala index 27c525948..f6966e141 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/SparkOperation.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/SparkOperation.scala @@ -92,10 +92,7 @@ abstract class SparkOperation(spark: SparkSession, opType: OperationType, sessio warn(s"Ignore exception in terminal state with $statementId: $e") } else { setState(OperationState.ERROR) - e match { - case k: KyuubiSQLException => throw k - case _ => throw KyuubiSQLException(s"Error operating $opType: ${e.getMessage}", e) - } + throw KyuubiSQLException(s"Error operating $opType: ${e.getMessage}", e) } } } diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/session/SparkSQLSessionManager.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/session/SparkSQLSessionManager.scala index 8300ddc7a..772af9a61 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/session/SparkSQLSessionManager.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/session/SparkSQLSessionManager.scala @@ -18,6 +18,7 @@ package org.apache.kyuubi.engine.spark.session import scala.util.control.NonFatal +import scala.util.matching.Regex import org.apache.hive.service.rpc.thrift.TProtocolVersion import org.apache.spark.sql.SparkSession @@ -37,6 +38,8 @@ import org.apache.kyuubi.session.{SessionHandle, SessionManager} class SparkSQLSessionManager private (name: String, spark: SparkSession) extends SessionManager(name) { + import SparkSQLSessionManager._ + def this(spark: SparkSession) = this(classOf[SparkSQLSessionManager].getSimpleName, spark) val operationManager = new SparkSQLOperationManager() @@ -51,12 +54,17 @@ class SparkSQLSessionManager private (name: String, spark: SparkSession) val handle = sessionImpl.handle try { val sparkSession = spark.newSession() - conf.foreach { case (key, value) => spark.conf.set(key, value)} - operationManager.setSparkSession(handle, sparkSession) + conf.foreach { + case (HIVE_VAR_PREFIX(key), value) => sparkSession.conf.set(key, value) + case (HIVE_CONF_PREFIX(key), value) => sparkSession.conf.set(key, value) + case ("use:database", database) => sparkSession.catalog.setCurrentDatabase(database) + case (key, value) => sparkSession.conf.set(key, value) + } sessionImpl.open() - info(s"$user's session with $handle is opened, current opening sessions" + - s" $getOpenSessionCount") + operationManager.setSparkSession(handle, sparkSession) setSession(handle, sessionImpl) + info(s"$user's session with $handle is opened, current opening sessions" + + s" $getOpenSessionCount") handle } catch { case NonFatal(e) => @@ -65,7 +73,12 @@ class SparkSQLSessionManager private (name: String, spark: SparkSession) } catch { case t: Throwable => warn(s"Error closing session $handle for $user", t) } - throw KyuubiSQLException(s"Error opening session $handle for $user", e) + throw KyuubiSQLException(s"Error opening session $handle for $user: ${e.getMessage}", e) } } } + +object SparkSQLSessionManager { + val HIVE_VAR_PREFIX: Regex = """set:hivevar:([^=]+)""".r + val HIVE_CONF_PREFIX: Regex = """set:hiveconf:([^=]+)""".r +} diff --git a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkOperationSuite.scala b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkOperationSuite.scala index b2b95a990..c25473e25 100644 --- a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkOperationSuite.scala +++ b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkOperationSuite.scala @@ -19,8 +19,10 @@ package org.apache.kyuubi.engine.spark.operation import java.sql.{Date, SQLException, Timestamp} +import scala.collection.JavaConverters._ + import org.apache.hive.service.cli.HiveSQLException -import org.apache.hive.service.rpc.thrift.TOpenSessionReq +import org.apache.hive.service.rpc.thrift.{TCloseSessionReq, TExecuteStatementReq, TFetchResultsReq, TOpenSessionReq, TStatus, TStatusCode} import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.catalog.CatalogTableType import org.apache.spark.sql.types._ @@ -698,17 +700,104 @@ class SparkOperationSuite extends WithSparkSQLEngine { } } - test("get functions operation") { + test("basic open | execute | close") { withThriftClient { client => val req = new TOpenSessionReq() req.setUsername("kentyao") req.setPassword("anonymous") -// req.setClient_protocol(BackendService.SERVER_VERSION) - req - val resp = client.OpenSession(req) - val sessionHandle = resp.getSessionHandle - } + val tOpenSessionResp = client.OpenSession(req) + val tExecuteStatementReq = new TExecuteStatementReq() + tExecuteStatementReq.setSessionHandle( tOpenSessionResp.getSessionHandle) + tExecuteStatementReq.setStatement("set -v") + val tExecuteStatementResp = client.ExecuteStatement(tExecuteStatementReq) + + val tFetchResultsReq = new TFetchResultsReq() + tFetchResultsReq.setOperationHandle(tExecuteStatementResp.getOperationHandle) + tFetchResultsReq.setFetchType(1) + tFetchResultsReq.setMaxRows(1000) + val tFetchResultsResp = client.FetchResults(tFetchResultsReq) + val logs = tFetchResultsResp.getResults.getColumns.get(0).getStringVal.getValues.asScala + assert(logs.exists(_.contains(classOf[ExecuteStatement].getCanonicalName))) + + tFetchResultsReq.setFetchType(0) + val tFetchResultsResp1 = client.FetchResults(tFetchResultsReq) + val rs = tFetchResultsResp1.getResults.getColumns.get(0).getStringVal.getValues.asScala + assert(rs.contains("spark.sql.shuffle.partitions")) + + val tCloseSessionReq = new TCloseSessionReq() + tCloseSessionReq.setSessionHandle(tOpenSessionResp.getSessionHandle) + val tCloseSessionResp = client.CloseSession(tCloseSessionReq) + assert(tCloseSessionResp.getStatus.getStatusCode === TStatusCode.SUCCESS_STATUS) + } } + test("set session conf") { + withThriftClient { client => + val req = new TOpenSessionReq() + req.setUsername("kentyao") + req.setPassword("anonymous") + val conf = Map( + "use:database" -> "default", + "spark.sql.shuffle.partitions" -> "4", + "set:hiveconf:spark.sql.autoBroadcastJoinThreshold" -> "-1", + "set:hivevar:spark.sql.adaptive.enabled" -> "true") + req.setConfiguration(conf.asJava) + val tOpenSessionResp = client.OpenSession(req) + + val tExecuteStatementReq = new TExecuteStatementReq() + tExecuteStatementReq.setSessionHandle( tOpenSessionResp.getSessionHandle) + tExecuteStatementReq.setStatement("set") + val tExecuteStatementResp = client.ExecuteStatement(tExecuteStatementReq) + + val tFetchResultsReq = new TFetchResultsReq() + tFetchResultsReq.setOperationHandle(tExecuteStatementResp.getOperationHandle) + tFetchResultsReq.setFetchType(0) + tFetchResultsReq.setMaxRows(1000) + val tFetchResultsResp1 = client.FetchResults(tFetchResultsReq) + val columns = tFetchResultsResp1.getResults.getColumns + val rs = columns.get(0).getStringVal.getValues.asScala.zip( + columns.get(1).getStringVal.getValues.asScala) + rs foreach { + case ("spark.sql.shuffle.partitions", v) => assert(v === "4") + case ("spark.sql.autoBroadcastJoinThreshold", v) => assert(v === "-1") + case ("spark.sql.adaptive.enabled", v) => assert(v.toBoolean) + case _ => + } + assert(spark.conf.get("spark.sql.shuffle.partitions") === "200") + + val tCloseSessionReq = new TCloseSessionReq() + tCloseSessionReq.setSessionHandle(tOpenSessionResp.getSessionHandle) + val tCloseSessionResp = client.CloseSession(tCloseSessionReq) + assert(tCloseSessionResp.getStatus.getStatusCode === TStatusCode.SUCCESS_STATUS) + } + } + + test("set session conf - static") { + withThriftClient { client => + val req = new TOpenSessionReq() + req.setUsername("kentyao") + req.setPassword("anonymous") + val conf = Map("use:database" -> "default", "spark.sql.globalTempDatabase" -> "temp") + req.setConfiguration(conf.asJava) + val tOpenSessionResp = client.OpenSession(req) + val status = tOpenSessionResp.getStatus + assert(status.getStatusCode === TStatusCode.ERROR_STATUS) + assert(status.getErrorMessage.contains("spark.sql.globalTempDatabase")) + } + } + + test("set session conf - wrong database") { + withThriftClient { client => + val req = new TOpenSessionReq() + req.setUsername("kentyao") + req.setPassword("anonymous") + val conf = Map("use:database" -> "default2") + req.setConfiguration(conf.asJava) + val tOpenSessionResp = client.OpenSession(req) + val status = tOpenSessionResp.getStatus + assert(status.getStatusCode === TStatusCode.ERROR_STATUS) + assert(status.getErrorMessage.contains("Database 'default2' does not exist")) + } + } }