[KYUUBI #4019] Binding python/sql spark session
### _Why are the changes needed?_
Bind python and SQL spark session, then the variables we set on the python side can be visited on the SQL side
After this PR, we can change the execution mode from python to sql by running
```python
spark.sql("SET kyuubi.operation.language=SQL").show()
```

### _How was this patch tested?_
- [x] Add some test cases that check the changes thoroughly including negative and positive cases if possible
- [ ] Add screenshots for manual tests if appropriate
- [ ] [Run test](https://kyuubi.apache.org/docs/latest/develop_tools/testing.html#running-tests) locally before make a pull request
Closes #4019 from cfmcgrady/binding-sql.
Closes #4019
2fd16a8e2 [Fu Chen] address comment
2136dfd64 [Fu Chen] fix style
cf8a612ee [Fu Chen] fix ut
57c592ed6 [Fu Chen] fix ut
fed7614dd [Fu Chen] binding python/sql spark session
Authored-by: Fu Chen <cfmcgrady@gmail.com>
Signed-off-by: Fu Chen <cfmcgrady@gmail.com>
This commit is contained in:
parent
182227bd16
commit
c28cc6b3b3
@ -240,7 +240,9 @@ def execute_request(content):
|
||||
|
||||
|
||||
# get or create spark session
|
||||
spark_session = kyuubi_util.get_spark_session()
|
||||
spark_session = kyuubi_util.get_spark_session(
|
||||
os.environ.get("KYUUBI_SPARK_SESSION_UUID")
|
||||
)
|
||||
global_dict["spark"] = spark_session
|
||||
|
||||
|
||||
|
||||
@ -19,6 +19,7 @@ import os
|
||||
|
||||
from py4j.clientserver import ClientServer, JavaParameters, PythonParameters
|
||||
from py4j.java_gateway import java_import, JavaGateway, GatewayParameters
|
||||
from pyspark.conf import SparkConf
|
||||
from pyspark.context import SparkContext
|
||||
from pyspark.serializers import read_int, UTF8Deserializer
|
||||
from pyspark.sql import SparkSession
|
||||
@ -61,18 +62,23 @@ def connect_to_exist_gateway() -> "JavaGateway":
|
||||
return gateway
|
||||
|
||||
|
||||
def _get_exist_spark_context(self, jconf):
|
||||
"""
|
||||
Initialize SparkContext in function to allow subclass specific initialization
|
||||
"""
|
||||
return self._jvm.JavaSparkContext(
|
||||
self._jvm.org.apache.spark.SparkContext.getOrCreate(jconf)
|
||||
)
|
||||
|
||||
|
||||
def get_spark_session() -> "SparkSession":
|
||||
SparkContext._initialize_context = _get_exist_spark_context
|
||||
def get_spark_session(uuid=None) -> "SparkSession":
|
||||
gateway = connect_to_exist_gateway()
|
||||
SparkContext._ensure_initialized(gateway=gateway)
|
||||
spark = SparkSession.builder.master("local").appName("test").getOrCreate()
|
||||
return spark
|
||||
jjsc = gateway.jvm.JavaSparkContext(
|
||||
gateway.jvm.org.apache.spark.SparkContext.getOrCreate()
|
||||
)
|
||||
conf = SparkConf()
|
||||
conf.setMaster("dummy").setAppName("kyuubi-python")
|
||||
sc = SparkContext(conf=conf, gateway=gateway, jsc=jjsc)
|
||||
if uuid is None:
|
||||
# note that in this mode, all the python's spark sessions share the root spark session.
|
||||
return (
|
||||
SparkSession.builder.master("dummy").appName("kyuubi-python").getOrCreate()
|
||||
)
|
||||
else:
|
||||
session = (
|
||||
gateway.jvm.org.apache.kyuubi.engine.spark.SparkSQLEngine.getSparkSession(
|
||||
uuid
|
||||
)
|
||||
)
|
||||
return SparkSession(sparkContext=sc, jsparkSession=session)
|
||||
|
||||
@ -39,10 +39,12 @@ import org.apache.kyuubi.config.KyuubiConf._
|
||||
import org.apache.kyuubi.config.KyuubiReservedKeys.KYUUBI_ENGINE_SUBMIT_TIME_KEY
|
||||
import org.apache.kyuubi.engine.spark.SparkSQLEngine.{countDownLatch, currentEngine}
|
||||
import org.apache.kyuubi.engine.spark.events.{EngineEvent, EngineEventsStore, SparkEventHandlerRegister}
|
||||
import org.apache.kyuubi.engine.spark.session.SparkSessionImpl
|
||||
import org.apache.kyuubi.events.EventBus
|
||||
import org.apache.kyuubi.ha.HighAvailabilityConf._
|
||||
import org.apache.kyuubi.ha.client.RetryPolicies
|
||||
import org.apache.kyuubi.service.Serverable
|
||||
import org.apache.kyuubi.session.SessionHandle
|
||||
import org.apache.kyuubi.util.{SignalRegister, ThreadUtils}
|
||||
|
||||
case class SparkSQLEngine(spark: SparkSession) extends Serverable("SparkSQLEngine") {
|
||||
@ -166,6 +168,22 @@ object SparkSQLEngine extends Logging {
|
||||
SignalRegister.registerLogger(logger)
|
||||
setupConf()
|
||||
|
||||
/**
|
||||
* get the SparkSession by the session identifier, it was used for the initial PySpark session
|
||||
* now, see
|
||||
* externals/kyuubi-spark-sql-engine/src/main/resources/python/kyuubi_util.py::get_spark_session
|
||||
* for details
|
||||
*/
|
||||
def getSparkSession(uuid: String): SparkSession = {
|
||||
assert(currentEngine.isDefined)
|
||||
currentEngine.get
|
||||
.backendService
|
||||
.sessionManager
|
||||
.getSession(SessionHandle.fromUUID(uuid))
|
||||
.asInstanceOf[SparkSessionImpl]
|
||||
.spark
|
||||
}
|
||||
|
||||
def setupConf(): Unit = {
|
||||
_sparkConf = new SparkConf()
|
||||
_kyuubiConf = KyuubiConf()
|
||||
|
||||
@ -243,6 +243,7 @@ object ExecutePython extends Logging {
|
||||
"SPARK_HOME",
|
||||
getSparkPythonHomeFromArchive(spark, session).getOrElse(defaultSparkHome)))
|
||||
}
|
||||
env.put("KYUUBI_SPARK_SESSION_UUID", session.handle.identifier.toString)
|
||||
env.put("PYTHON_GATEWAY_CONNECTION_INFO", KyuubiPythonGatewayServer.CONNECTION_FILE_PATH)
|
||||
logger.info(
|
||||
s"""
|
||||
|
||||
@ -90,6 +90,50 @@ class PySparkTests extends WithSparkSQLEngine with HiveJDBCTestHelper {
|
||||
}
|
||||
}
|
||||
|
||||
test("binding python/sql spark session") {
|
||||
checkPythonRuntimeAndVersion()
|
||||
withMultipleConnectionJdbcStatement()({ statement =>
|
||||
statement.executeQuery("SET kyuubi.operation.language=PYTHON")
|
||||
|
||||
// set hello=kyuubi in python
|
||||
val set1 =
|
||||
"""
|
||||
|spark.sql("set hello=kyuubi").show()
|
||||
|""".stripMargin
|
||||
val output1 =
|
||||
"""|+-----+------+
|
||||
|| key| value|
|
||||
|+-----+------+
|
||||
||hello|kyuubi|
|
||||
|+-----+------+""".stripMargin
|
||||
val resultSet1 = statement.executeQuery(set1)
|
||||
assert(resultSet1.next())
|
||||
assert(resultSet1.getString("status") === "ok")
|
||||
assert(resultSet1.getString("output") === output1)
|
||||
|
||||
val set2 =
|
||||
"""
|
||||
|spark.sql("SET kyuubi.operation.language=SQL").show(truncate = False)
|
||||
|""".stripMargin
|
||||
val output2 =
|
||||
"""|+-------------------------+-----+
|
||||
||key |value|
|
||||
|+-------------------------+-----+
|
||||
||kyuubi.operation.language|SQL |
|
||||
|+-------------------------+-----+""".stripMargin
|
||||
val resultSet2 = statement.executeQuery(set2)
|
||||
assert(resultSet2.next())
|
||||
assert(resultSet2.getString("status") === "ok")
|
||||
assert(resultSet2.getString("output") === output2)
|
||||
|
||||
// get hello value in sql
|
||||
val resultSet3 = statement.executeQuery("set hello")
|
||||
assert(resultSet3.next())
|
||||
assert(resultSet3.getString("key") === "hello")
|
||||
assert(resultSet3.getString("value") === "kyuubi")
|
||||
})
|
||||
}
|
||||
|
||||
private def runPySparkTest(
|
||||
pyCode: String,
|
||||
output: String): Unit = {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user