diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecutePython.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecutePython.scala index a23a9f36f..a27a8a023 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecutePython.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecutePython.scala @@ -23,6 +23,7 @@ import java.net.URI import java.nio.file.{Files, Path, Paths} import java.util.concurrent.RejectedExecutionException import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.locks.ReentrantLock import javax.ws.rs.core.UriBuilder import scala.collection.JavaConverters._ @@ -99,7 +100,7 @@ class ExecutePython( } } - override protected def runInternal(): Unit = withLocalProperties { + override protected def runInternal(): Unit = { addTimeoutMonitor(queryTimeout) if (shouldRunAsync) { val asyncOperation = new Runnable { @@ -129,14 +130,20 @@ class ExecutePython( override def setSparkLocalProperty: (String, String) => Unit = (key: String, value: String) => { val valueStr = if (value == null) "None" else s"'$value'" - worker.runCode(s"spark.sparkContext.setLocalProperty('$key', $valueStr)") + worker.runCode(s"spark.sparkContext.setLocalProperty('$key', $valueStr)", internal = true) () } override protected def withLocalProperties[T](f: => T): T = { try { - worker.runCode("spark.sparkContext.setJobGroup" + - s"($statementId, $redactedStatement, $forceCancel)") + // to prevent the transferred set job group python code broken + val jobDesc = s"Python statement: $statementId" + // for python, the boolean value is capitalized + val pythonForceCancel = if (forceCancel) "True" else "False" + worker.runCode( + "spark.sparkContext.setJobGroup" + + s"('$statementId', '$jobDesc', $pythonForceCancel)", + internal = true) setSparkLocalProperty(KYUUBI_SESSION_USER_KEY, session.user) setSparkLocalProperty(KYUUBI_STATEMENT_ID_KEY, statementId) schedulerPool match { @@ -153,7 +160,8 @@ class ExecutePython( setSparkLocalProperty(KYUUBI_SESSION_USER_KEY, "") setSparkLocalProperty(KYUUBI_STATEMENT_ID_KEY, "") setSparkLocalProperty(SPARK_SCHEDULER_POOL_KEY, "") - worker.runCode("spark.sparkContext.clearJobGroup()") + // using cancelJobGroup for pyspark, see details in pyspark/context.py + worker.runCode(s"spark.sparkContext.cancelJobGroup('$statementId')", internal = true) if (isSessionUserSignEnabled) { clearSessionUserSign() } @@ -168,15 +176,38 @@ case class SessionPythonWorker( private val stdin: PrintWriter = new PrintWriter(workerProcess.getOutputStream) private val stdout: BufferedReader = new BufferedReader(new InputStreamReader(workerProcess.getInputStream), 1) + private val lock = new ReentrantLock() - def runCode(code: String): Option[PythonResponse] = { + private def withLockRequired[T](block: => T): T = { + try { + lock.lock() + block + } finally lock.unlock() + } + + /** + * Run the python code and return the response. This method maybe invoked internally, + * such as setJobGroup and cancelJobGroup, if the internal python code is not formatted correctly, + * it might impact the correctness and even cause result out of sequence. To prevent that, + * please make sure the internal python code simple and set internal flag, to be aware of the + * internal python code failure. + * + * @param code the python code + * @param internal whether is internal python code + * @return the python response + */ + def runCode(code: String, internal: Boolean = false): Option[PythonResponse] = withLockRequired { val input = ExecutePython.toJson(Map("code" -> code, "cmd" -> "run_code")) // scalastyle:off println stdin.println(input) // scalastyle:on stdin.flush() - Option(stdout.readLine()) - .map(ExecutePython.fromJson[PythonResponse](_)) + val pythonResponse = Option(stdout.readLine()).map(ExecutePython.fromJson[PythonResponse](_)) + // throw exception if internal python code fail + if (internal && pythonResponse.map(_.content.status) != Some(PythonResponse.OK_STATUS)) { + throw KyuubiSQLException(s"Internal python code $code failure: $pythonResponse") + } + pythonResponse } def close(): Unit = { diff --git a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/PySparkTests.scala b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/PySparkTests.scala index d47c64fb0..e2dd2609d 100644 --- a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/PySparkTests.scala +++ b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/PySparkTests.scala @@ -72,11 +72,7 @@ class PySparkTests extends WithSparkSQLEngine with HiveJDBCTestHelper { val statement = connection.createStatement().asInstanceOf[KyuubiStatement] statement.setQueryTimeout(5) try { - var code = - """ - |import time - |time.sleep(10) - |""".stripMargin + var code = "spark.sql(\"select java_method('java.lang.Thread', 'sleep', 10000L)\").show()" var e = intercept[SQLTimeoutException] { statement.executePython(code) }.getMessage