From c28cc6b3b3bb7a654a07afacec45dab3d2e4bf79 Mon Sep 17 00:00:00 2001 From: Fu Chen Date: Mon, 26 Dec 2022 10:16:49 +0800 Subject: [PATCH] [KYUUBI #4019] Binding python/sql spark session ### _Why are the changes needed?_ Bind python and SQL spark session, then the variables we set on the python side can be visited on the SQL side After this PR, we can change the execution mode from python to sql by running ```python spark.sql("SET kyuubi.operation.language=SQL").show() ``` ![5091671606960_ pic_hd](https://user-images.githubusercontent.com/8537877/208873580-bf6d8a09-63ad-4788-bce7-c1fe2705f0b2.jpg) ### _How was this patch tested?_ - [x] Add some test cases that check the changes thoroughly including negative and positive cases if possible - [ ] Add screenshots for manual tests if appropriate - [ ] [Run test](https://kyuubi.apache.org/docs/latest/develop_tools/testing.html#running-tests) locally before make a pull request Closes #4019 from cfmcgrady/binding-sql. Closes #4019 2fd16a8e2 [Fu Chen] address comment 2136dfd64 [Fu Chen] fix style cf8a612ee [Fu Chen] fix ut 57c592ed6 [Fu Chen] fix ut fed7614dd [Fu Chen] binding python/sql spark session Authored-by: Fu Chen Signed-off-by: Fu Chen --- .../main/resources/python/execute_python.py | 4 +- .../src/main/resources/python/kyuubi_util.py | 34 ++++++++------ .../kyuubi/engine/spark/SparkSQLEngine.scala | 18 ++++++++ .../spark/operation/ExecutePython.scala | 1 + .../engine/spark/operation/PySparkTests.scala | 44 +++++++++++++++++++ 5 files changed, 86 insertions(+), 15 deletions(-) diff --git a/externals/kyuubi-spark-sql-engine/src/main/resources/python/execute_python.py b/externals/kyuubi-spark-sql-engine/src/main/resources/python/execute_python.py index 2d7ce4e0f..e6fe7f92b 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/resources/python/execute_python.py +++ b/externals/kyuubi-spark-sql-engine/src/main/resources/python/execute_python.py @@ -240,7 +240,9 @@ def execute_request(content): # get or create spark session -spark_session = kyuubi_util.get_spark_session() +spark_session = kyuubi_util.get_spark_session( + os.environ.get("KYUUBI_SPARK_SESSION_UUID") +) global_dict["spark"] = spark_session diff --git a/externals/kyuubi-spark-sql-engine/src/main/resources/python/kyuubi_util.py b/externals/kyuubi-spark-sql-engine/src/main/resources/python/kyuubi_util.py index a76b3b726..35ab88511 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/resources/python/kyuubi_util.py +++ b/externals/kyuubi-spark-sql-engine/src/main/resources/python/kyuubi_util.py @@ -19,6 +19,7 @@ import os from py4j.clientserver import ClientServer, JavaParameters, PythonParameters from py4j.java_gateway import java_import, JavaGateway, GatewayParameters +from pyspark.conf import SparkConf from pyspark.context import SparkContext from pyspark.serializers import read_int, UTF8Deserializer from pyspark.sql import SparkSession @@ -61,18 +62,23 @@ def connect_to_exist_gateway() -> "JavaGateway": return gateway -def _get_exist_spark_context(self, jconf): - """ - Initialize SparkContext in function to allow subclass specific initialization - """ - return self._jvm.JavaSparkContext( - self._jvm.org.apache.spark.SparkContext.getOrCreate(jconf) - ) - - -def get_spark_session() -> "SparkSession": - SparkContext._initialize_context = _get_exist_spark_context +def get_spark_session(uuid=None) -> "SparkSession": gateway = connect_to_exist_gateway() - SparkContext._ensure_initialized(gateway=gateway) - spark = SparkSession.builder.master("local").appName("test").getOrCreate() - return spark + jjsc = gateway.jvm.JavaSparkContext( + gateway.jvm.org.apache.spark.SparkContext.getOrCreate() + ) + conf = SparkConf() + conf.setMaster("dummy").setAppName("kyuubi-python") + sc = SparkContext(conf=conf, gateway=gateway, jsc=jjsc) + if uuid is None: + # note that in this mode, all the python's spark sessions share the root spark session. + return ( + SparkSession.builder.master("dummy").appName("kyuubi-python").getOrCreate() + ) + else: + session = ( + gateway.jvm.org.apache.kyuubi.engine.spark.SparkSQLEngine.getSparkSession( + uuid + ) + ) + return SparkSession(sparkContext=sc, jsparkSession=session) diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/SparkSQLEngine.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/SparkSQLEngine.scala index aed0cb68c..a5fc5cfac 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/SparkSQLEngine.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/SparkSQLEngine.scala @@ -39,10 +39,12 @@ import org.apache.kyuubi.config.KyuubiConf._ import org.apache.kyuubi.config.KyuubiReservedKeys.KYUUBI_ENGINE_SUBMIT_TIME_KEY import org.apache.kyuubi.engine.spark.SparkSQLEngine.{countDownLatch, currentEngine} import org.apache.kyuubi.engine.spark.events.{EngineEvent, EngineEventsStore, SparkEventHandlerRegister} +import org.apache.kyuubi.engine.spark.session.SparkSessionImpl import org.apache.kyuubi.events.EventBus import org.apache.kyuubi.ha.HighAvailabilityConf._ import org.apache.kyuubi.ha.client.RetryPolicies import org.apache.kyuubi.service.Serverable +import org.apache.kyuubi.session.SessionHandle import org.apache.kyuubi.util.{SignalRegister, ThreadUtils} case class SparkSQLEngine(spark: SparkSession) extends Serverable("SparkSQLEngine") { @@ -166,6 +168,22 @@ object SparkSQLEngine extends Logging { SignalRegister.registerLogger(logger) setupConf() + /** + * get the SparkSession by the session identifier, it was used for the initial PySpark session + * now, see + * externals/kyuubi-spark-sql-engine/src/main/resources/python/kyuubi_util.py::get_spark_session + * for details + */ + def getSparkSession(uuid: String): SparkSession = { + assert(currentEngine.isDefined) + currentEngine.get + .backendService + .sessionManager + .getSession(SessionHandle.fromUUID(uuid)) + .asInstanceOf[SparkSessionImpl] + .spark + } + def setupConf(): Unit = { _sparkConf = new SparkConf() _kyuubiConf = KyuubiConf() diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecutePython.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecutePython.scala index 35c172b96..a23a9f36f 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecutePython.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecutePython.scala @@ -243,6 +243,7 @@ object ExecutePython extends Logging { "SPARK_HOME", getSparkPythonHomeFromArchive(spark, session).getOrElse(defaultSparkHome))) } + env.put("KYUUBI_SPARK_SESSION_UUID", session.handle.identifier.toString) env.put("PYTHON_GATEWAY_CONNECTION_INFO", KyuubiPythonGatewayServer.CONNECTION_FILE_PATH) logger.info( s""" diff --git a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/PySparkTests.scala b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/PySparkTests.scala index ee6450ed7..d47c64fb0 100644 --- a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/PySparkTests.scala +++ b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/PySparkTests.scala @@ -90,6 +90,50 @@ class PySparkTests extends WithSparkSQLEngine with HiveJDBCTestHelper { } } + test("binding python/sql spark session") { + checkPythonRuntimeAndVersion() + withMultipleConnectionJdbcStatement()({ statement => + statement.executeQuery("SET kyuubi.operation.language=PYTHON") + + // set hello=kyuubi in python + val set1 = + """ + |spark.sql("set hello=kyuubi").show() + |""".stripMargin + val output1 = + """|+-----+------+ + || key| value| + |+-----+------+ + ||hello|kyuubi| + |+-----+------+""".stripMargin + val resultSet1 = statement.executeQuery(set1) + assert(resultSet1.next()) + assert(resultSet1.getString("status") === "ok") + assert(resultSet1.getString("output") === output1) + + val set2 = + """ + |spark.sql("SET kyuubi.operation.language=SQL").show(truncate = False) + |""".stripMargin + val output2 = + """|+-------------------------+-----+ + ||key |value| + |+-------------------------+-----+ + ||kyuubi.operation.language|SQL | + |+-------------------------+-----+""".stripMargin + val resultSet2 = statement.executeQuery(set2) + assert(resultSet2.next()) + assert(resultSet2.getString("status") === "ok") + assert(resultSet2.getString("output") === output2) + + // get hello value in sql + val resultSet3 = statement.executeQuery("set hello") + assert(resultSet3.next()) + assert(resultSet3.getString("key") === "hello") + assert(resultSet3.getString("value") === "kyuubi") + }) + } + private def runPySparkTest( pyCode: String, output: String): Unit = {