[KYUUBI #4019] Binding python/sql spark session

### _Why are the changes needed?_

Bind python and SQL spark session, then the variables we set on the python side can be visited on the SQL side

After this PR, we can change the execution mode from python to sql by running

```python
spark.sql("SET kyuubi.operation.language=SQL").show()
```

![5091671606960_ pic_hd](https://user-images.githubusercontent.com/8537877/208873580-bf6d8a09-63ad-4788-bce7-c1fe2705f0b2.jpg)

### _How was this patch tested?_
- [x] Add some test cases that check the changes thoroughly including negative and positive cases if possible

- [ ] Add screenshots for manual tests if appropriate

- [ ] [Run test](https://kyuubi.apache.org/docs/latest/develop_tools/testing.html#running-tests) locally before make a pull request

Closes #4019 from cfmcgrady/binding-sql.

Closes #4019

2fd16a8e2 [Fu Chen] address comment
2136dfd64 [Fu Chen] fix style
cf8a612ee [Fu Chen] fix ut
57c592ed6 [Fu Chen] fix ut
fed7614dd [Fu Chen] binding python/sql spark session

Authored-by: Fu Chen <cfmcgrady@gmail.com>
Signed-off-by: Fu Chen <cfmcgrady@gmail.com>
This commit is contained in:
Fu Chen 2022-12-26 10:16:49 +08:00
parent 182227bd16
commit c28cc6b3b3
5 changed files with 86 additions and 15 deletions

View File

@ -240,7 +240,9 @@ def execute_request(content):
# get or create spark session
spark_session = kyuubi_util.get_spark_session()
spark_session = kyuubi_util.get_spark_session(
os.environ.get("KYUUBI_SPARK_SESSION_UUID")
)
global_dict["spark"] = spark_session

View File

@ -19,6 +19,7 @@ import os
from py4j.clientserver import ClientServer, JavaParameters, PythonParameters
from py4j.java_gateway import java_import, JavaGateway, GatewayParameters
from pyspark.conf import SparkConf
from pyspark.context import SparkContext
from pyspark.serializers import read_int, UTF8Deserializer
from pyspark.sql import SparkSession
@ -61,18 +62,23 @@ def connect_to_exist_gateway() -> "JavaGateway":
return gateway
def _get_exist_spark_context(self, jconf):
"""
Initialize SparkContext in function to allow subclass specific initialization
"""
return self._jvm.JavaSparkContext(
self._jvm.org.apache.spark.SparkContext.getOrCreate(jconf)
)
def get_spark_session() -> "SparkSession":
SparkContext._initialize_context = _get_exist_spark_context
def get_spark_session(uuid=None) -> "SparkSession":
gateway = connect_to_exist_gateway()
SparkContext._ensure_initialized(gateway=gateway)
spark = SparkSession.builder.master("local").appName("test").getOrCreate()
return spark
jjsc = gateway.jvm.JavaSparkContext(
gateway.jvm.org.apache.spark.SparkContext.getOrCreate()
)
conf = SparkConf()
conf.setMaster("dummy").setAppName("kyuubi-python")
sc = SparkContext(conf=conf, gateway=gateway, jsc=jjsc)
if uuid is None:
# note that in this mode, all the python's spark sessions share the root spark session.
return (
SparkSession.builder.master("dummy").appName("kyuubi-python").getOrCreate()
)
else:
session = (
gateway.jvm.org.apache.kyuubi.engine.spark.SparkSQLEngine.getSparkSession(
uuid
)
)
return SparkSession(sparkContext=sc, jsparkSession=session)

View File

@ -39,10 +39,12 @@ import org.apache.kyuubi.config.KyuubiConf._
import org.apache.kyuubi.config.KyuubiReservedKeys.KYUUBI_ENGINE_SUBMIT_TIME_KEY
import org.apache.kyuubi.engine.spark.SparkSQLEngine.{countDownLatch, currentEngine}
import org.apache.kyuubi.engine.spark.events.{EngineEvent, EngineEventsStore, SparkEventHandlerRegister}
import org.apache.kyuubi.engine.spark.session.SparkSessionImpl
import org.apache.kyuubi.events.EventBus
import org.apache.kyuubi.ha.HighAvailabilityConf._
import org.apache.kyuubi.ha.client.RetryPolicies
import org.apache.kyuubi.service.Serverable
import org.apache.kyuubi.session.SessionHandle
import org.apache.kyuubi.util.{SignalRegister, ThreadUtils}
case class SparkSQLEngine(spark: SparkSession) extends Serverable("SparkSQLEngine") {
@ -166,6 +168,22 @@ object SparkSQLEngine extends Logging {
SignalRegister.registerLogger(logger)
setupConf()
/**
* get the SparkSession by the session identifier, it was used for the initial PySpark session
* now, see
* externals/kyuubi-spark-sql-engine/src/main/resources/python/kyuubi_util.py::get_spark_session
* for details
*/
def getSparkSession(uuid: String): SparkSession = {
assert(currentEngine.isDefined)
currentEngine.get
.backendService
.sessionManager
.getSession(SessionHandle.fromUUID(uuid))
.asInstanceOf[SparkSessionImpl]
.spark
}
def setupConf(): Unit = {
_sparkConf = new SparkConf()
_kyuubiConf = KyuubiConf()

View File

@ -243,6 +243,7 @@ object ExecutePython extends Logging {
"SPARK_HOME",
getSparkPythonHomeFromArchive(spark, session).getOrElse(defaultSparkHome)))
}
env.put("KYUUBI_SPARK_SESSION_UUID", session.handle.identifier.toString)
env.put("PYTHON_GATEWAY_CONNECTION_INFO", KyuubiPythonGatewayServer.CONNECTION_FILE_PATH)
logger.info(
s"""

View File

@ -90,6 +90,50 @@ class PySparkTests extends WithSparkSQLEngine with HiveJDBCTestHelper {
}
}
test("binding python/sql spark session") {
checkPythonRuntimeAndVersion()
withMultipleConnectionJdbcStatement()({ statement =>
statement.executeQuery("SET kyuubi.operation.language=PYTHON")
// set hello=kyuubi in python
val set1 =
"""
|spark.sql("set hello=kyuubi").show()
|""".stripMargin
val output1 =
"""|+-----+------+
|| key| value|
|+-----+------+
||hello|kyuubi|
|+-----+------+""".stripMargin
val resultSet1 = statement.executeQuery(set1)
assert(resultSet1.next())
assert(resultSet1.getString("status") === "ok")
assert(resultSet1.getString("output") === output1)
val set2 =
"""
|spark.sql("SET kyuubi.operation.language=SQL").show(truncate = False)
|""".stripMargin
val output2 =
"""|+-------------------------+-----+
||key |value|
|+-------------------------+-----+
||kyuubi.operation.language|SQL |
|+-------------------------+-----+""".stripMargin
val resultSet2 = statement.executeQuery(set2)
assert(resultSet2.next())
assert(resultSet2.getString("status") === "ok")
assert(resultSet2.getString("output") === output2)
// get hello value in sql
val resultSet3 = statement.executeQuery("set hello")
assert(resultSet3.next())
assert(resultSet3.getString("key") === "hello")
assert(resultSet3.getString("value") === "kyuubi")
})
}
private def runPySparkTest(
pyCode: String,
output: String): Unit = {