diff --git a/externals/kyuubi-spark-sql-engine/src/main/resources/python/execute_python.py b/externals/kyuubi-spark-sql-engine/src/main/resources/python/execute_python.py index 2d7ce4e0f..e6fe7f92b 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/resources/python/execute_python.py +++ b/externals/kyuubi-spark-sql-engine/src/main/resources/python/execute_python.py @@ -240,7 +240,9 @@ def execute_request(content): # get or create spark session -spark_session = kyuubi_util.get_spark_session() +spark_session = kyuubi_util.get_spark_session( + os.environ.get("KYUUBI_SPARK_SESSION_UUID") +) global_dict["spark"] = spark_session diff --git a/externals/kyuubi-spark-sql-engine/src/main/resources/python/kyuubi_util.py b/externals/kyuubi-spark-sql-engine/src/main/resources/python/kyuubi_util.py index a76b3b726..35ab88511 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/resources/python/kyuubi_util.py +++ b/externals/kyuubi-spark-sql-engine/src/main/resources/python/kyuubi_util.py @@ -19,6 +19,7 @@ import os from py4j.clientserver import ClientServer, JavaParameters, PythonParameters from py4j.java_gateway import java_import, JavaGateway, GatewayParameters +from pyspark.conf import SparkConf from pyspark.context import SparkContext from pyspark.serializers import read_int, UTF8Deserializer from pyspark.sql import SparkSession @@ -61,18 +62,23 @@ def connect_to_exist_gateway() -> "JavaGateway": return gateway -def _get_exist_spark_context(self, jconf): - """ - Initialize SparkContext in function to allow subclass specific initialization - """ - return self._jvm.JavaSparkContext( - self._jvm.org.apache.spark.SparkContext.getOrCreate(jconf) - ) - - -def get_spark_session() -> "SparkSession": - SparkContext._initialize_context = _get_exist_spark_context +def get_spark_session(uuid=None) -> "SparkSession": gateway = connect_to_exist_gateway() - SparkContext._ensure_initialized(gateway=gateway) - spark = SparkSession.builder.master("local").appName("test").getOrCreate() - return spark + jjsc = gateway.jvm.JavaSparkContext( + gateway.jvm.org.apache.spark.SparkContext.getOrCreate() + ) + conf = SparkConf() + conf.setMaster("dummy").setAppName("kyuubi-python") + sc = SparkContext(conf=conf, gateway=gateway, jsc=jjsc) + if uuid is None: + # note that in this mode, all the python's spark sessions share the root spark session. + return ( + SparkSession.builder.master("dummy").appName("kyuubi-python").getOrCreate() + ) + else: + session = ( + gateway.jvm.org.apache.kyuubi.engine.spark.SparkSQLEngine.getSparkSession( + uuid + ) + ) + return SparkSession(sparkContext=sc, jsparkSession=session) diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/SparkSQLEngine.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/SparkSQLEngine.scala index aed0cb68c..a5fc5cfac 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/SparkSQLEngine.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/SparkSQLEngine.scala @@ -39,10 +39,12 @@ import org.apache.kyuubi.config.KyuubiConf._ import org.apache.kyuubi.config.KyuubiReservedKeys.KYUUBI_ENGINE_SUBMIT_TIME_KEY import org.apache.kyuubi.engine.spark.SparkSQLEngine.{countDownLatch, currentEngine} import org.apache.kyuubi.engine.spark.events.{EngineEvent, EngineEventsStore, SparkEventHandlerRegister} +import org.apache.kyuubi.engine.spark.session.SparkSessionImpl import org.apache.kyuubi.events.EventBus import org.apache.kyuubi.ha.HighAvailabilityConf._ import org.apache.kyuubi.ha.client.RetryPolicies import org.apache.kyuubi.service.Serverable +import org.apache.kyuubi.session.SessionHandle import org.apache.kyuubi.util.{SignalRegister, ThreadUtils} case class SparkSQLEngine(spark: SparkSession) extends Serverable("SparkSQLEngine") { @@ -166,6 +168,22 @@ object SparkSQLEngine extends Logging { SignalRegister.registerLogger(logger) setupConf() + /** + * get the SparkSession by the session identifier, it was used for the initial PySpark session + * now, see + * externals/kyuubi-spark-sql-engine/src/main/resources/python/kyuubi_util.py::get_spark_session + * for details + */ + def getSparkSession(uuid: String): SparkSession = { + assert(currentEngine.isDefined) + currentEngine.get + .backendService + .sessionManager + .getSession(SessionHandle.fromUUID(uuid)) + .asInstanceOf[SparkSessionImpl] + .spark + } + def setupConf(): Unit = { _sparkConf = new SparkConf() _kyuubiConf = KyuubiConf() diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecutePython.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecutePython.scala index 35c172b96..a23a9f36f 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecutePython.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecutePython.scala @@ -243,6 +243,7 @@ object ExecutePython extends Logging { "SPARK_HOME", getSparkPythonHomeFromArchive(spark, session).getOrElse(defaultSparkHome))) } + env.put("KYUUBI_SPARK_SESSION_UUID", session.handle.identifier.toString) env.put("PYTHON_GATEWAY_CONNECTION_INFO", KyuubiPythonGatewayServer.CONNECTION_FILE_PATH) logger.info( s""" diff --git a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/PySparkTests.scala b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/PySparkTests.scala index ee6450ed7..d47c64fb0 100644 --- a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/PySparkTests.scala +++ b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/PySparkTests.scala @@ -90,6 +90,50 @@ class PySparkTests extends WithSparkSQLEngine with HiveJDBCTestHelper { } } + test("binding python/sql spark session") { + checkPythonRuntimeAndVersion() + withMultipleConnectionJdbcStatement()({ statement => + statement.executeQuery("SET kyuubi.operation.language=PYTHON") + + // set hello=kyuubi in python + val set1 = + """ + |spark.sql("set hello=kyuubi").show() + |""".stripMargin + val output1 = + """|+-----+------+ + || key| value| + |+-----+------+ + ||hello|kyuubi| + |+-----+------+""".stripMargin + val resultSet1 = statement.executeQuery(set1) + assert(resultSet1.next()) + assert(resultSet1.getString("status") === "ok") + assert(resultSet1.getString("output") === output1) + + val set2 = + """ + |spark.sql("SET kyuubi.operation.language=SQL").show(truncate = False) + |""".stripMargin + val output2 = + """|+-------------------------+-----+ + ||key |value| + |+-------------------------+-----+ + ||kyuubi.operation.language|SQL | + |+-------------------------+-----+""".stripMargin + val resultSet2 = statement.executeQuery(set2) + assert(resultSet2.next()) + assert(resultSet2.getString("status") === "ok") + assert(resultSet2.getString("output") === output2) + + // get hello value in sql + val resultSet3 = statement.executeQuery("set hello") + assert(resultSet3.next()) + assert(resultSet3.getString("key") === "hello") + assert(resultSet3.getString("value") === "kyuubi") + }) + } + private def runPySparkTest( pyCode: String, output: String): Unit = {