diff --git a/externals/kyuubi-spark-sql-engine/src/main/resources/python/execute_python.py b/externals/kyuubi-spark-sql-engine/src/main/resources/python/execute_python.py index 299be587f..67539b3b9 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/resources/python/execute_python.py +++ b/externals/kyuubi-spark-sql-engine/src/main/resources/python/execute_python.py @@ -15,28 +15,57 @@ # limitations under the License. # -from glob import glob import ast -import sys import io import json -import traceback -import re + import os +import re +import sys +import traceback +from glob import glob + +if sys.version_info[0] < 3: + sys.exit('Python < 3 is unsupported.') + +spark_home = os.environ.get("SPARK_HOME", "") +os.environ["PYSPARK_PYTHON"] = os.environ.get("PYSPARK_PYTHON", sys.executable) + +# add pyspark to sys.path + +if "pyspark" not in sys.modules: + spark_python = os.path.join(spark_home, "python") + try: + py4j = glob(os.path.join(spark_python, "lib", "py4j-*.zip"))[0] + except IndexError: + raise Exception( + "Unable to find py4j in {}, your SPARK_HOME may not be configured correctly".format( + spark_python + ) + ) + sys.path[:0] = sys_path = [spark_python, py4j] +else: + # already imported, no need to patch sys.path + sys_path = None + +# import kyuubi_util after preparing sys.path +import kyuubi_util # ast api is changed after python 3.8, see https://github.com/ipython/ipython/pull/11593 -if sys.version_info > (3,8): - from ast import Module -else : - # mock the new API, ignore second argument - # see https://github.com/ipython/ipython/issues/11590 - from ast import Module as OriginalModule - Module = lambda nodelist, type_ignores: OriginalModule(nodelist) +if sys.version_info >= (3, 8): + from ast import Module +else: + # mock the new API, ignore second argument + # see https://github.com/ipython/ipython/issues/11590 + from ast import Module as OriginalModule + + Module = lambda nodelist, type_ignores: OriginalModule(nodelist) TOP_FRAME_REGEX = re.compile(r'\s*File "".*in ') global_dict = {} + class NormalNode(object): def __init__(self, code): self.code = compile(code, '', 'exec', ast.PyCF_ONLY_AST, 1) @@ -54,21 +83,24 @@ class NormalNode(object): mod = ast.Interactive([node]) code = compile(mod, '', 'single') exec(code, global_dict) - except: + except Exception: # We don't need to log the exception because we're just executing user # code and passing the error along. raise ExecutionError(sys.exc_info()) + class ExecutionError(Exception): def __init__(self, exc_info): self.exc_info = exc_info + class UnicodeDecodingStringIO(io.StringIO): def write(self, s): if isinstance(s, bytes): s = s.decode("utf-8") super(UnicodeDecodingStringIO, self).write(s) + def clearOutputs(): sys.stdout.close() sys.stderr.close() @@ -81,16 +113,6 @@ def parse_code_into_nodes(code): try: nodes.append(NormalNode(code)) except SyntaxError: - # It's possible we hit a syntax error because of a magic command. Split the code groups - # of 'normal code', and code that starts with a '%'. possibly magic code - # lines, and see if any of the lines - # Remove lines until we find a node that parses, then check if the next line is a magic - # line - # . - - # Split the code into chunks of normal code, and possibly magic code, which starts with - # a '%'. - normal = [] chunks = [] for i, line in enumerate(code.rstrip().split('\n')): @@ -108,13 +130,15 @@ def parse_code_into_nodes(code): # Convert the chunks into AST nodes. Let exceptions propagate. for chunk in chunks: - if chunk.startswith('%'): - nodes.append(MagicNode(chunk)) - else: - nodes.append(NormalNode(chunk)) + # TODO: look back here when Jupyter and sparkmagic are supported + # if chunk.startswith('%'): + # nodes.append(MagicNode(chunk)) + + nodes.append(NormalNode(chunk)) return nodes + def execute_reply(status, content): msg = { 'msg_type': 'execute_reply', @@ -125,17 +149,15 @@ def execute_reply(status, content): } return json.dumps(msg) + def execute_reply_ok(data): return execute_reply("ok", { "data": data, }) + def execute_reply_error(exc_type, exc_value, tb): - # LOG.error('execute_reply', exc_info=True) - if sys.version >= '3': - formatted_tb = traceback.format_exception(exc_type, exc_value, tb, chain=False) - else: - formatted_tb = traceback.format_exception(exc_type, exc_value, tb) + formatted_tb = traceback.format_exception(exc_type, exc_value, tb, chain=False) for i in range(len(formatted_tb)): if TOP_FRAME_REGEX.match(formatted_tb[i]): formatted_tb = formatted_tb[:1] + formatted_tb[i + 1:] @@ -147,6 +169,15 @@ def execute_reply_error(exc_type, exc_value, tb): 'traceback': formatted_tb, }) + +def execute_reply_internal_error(message, exc_info=None): + return execute_reply('error', { + 'ename': 'InternalError', + 'evalue': message, + 'traceback': [], + }) + + def execute_request(content): try: code = content['code'] @@ -193,49 +224,25 @@ def execute_request(content): return execute_reply_ok(result) -# import findspark -# findspark.init() -spark_home = os.environ.get("SPARK_HOME", "") -os.environ["PYSPARK_PYTHON"] = os.environ.get("PYSPARK_PYTHON", sys.executable) +# get or create spark session +spark_session = kyuubi_util.get_spark_session() +global_dict['spark'] = spark_session -# add pyspark to sys.path - -if "pyspark" not in sys.modules: - spark_python = os.path.join(spark_home, "python") - try: - py4j = glob(os.path.join(spark_python, "lib", "py4j-*.zip"))[0] - except IndexError: - raise Exception( - "Unable to find py4j in {}, your SPARK_HOME may not be configured correctly".format( - spark_python - ) - ) - sys.path[:0] = sys_path = [spark_python, py4j] -else: - # already imported, no need to patch sys.path - sys_path = None - -import kyuubi_util -spark = kyuubi_util.get_spark() -global_dict['spark'] = spark def main(): sys_stdin = sys.stdin sys_stdout = sys.stdout sys_stderr = sys.stderr - if sys.version >= '3': - sys.stdin = io.StringIO() - else: - sys.stdin = cStringIO.StringIO() - + sys.stdin = io.StringIO() sys.stdout = UnicodeDecodingStringIO() sys.stderr = UnicodeDecodingStringIO() stderr = sys.stderr.getvalue() print(stderr, file=sys_stderr) - clearOutputs + clearOutputs() + try: while True: @@ -249,7 +256,6 @@ def main(): try: content = json.loads(line) except ValueError: - # LOG.error('failed to parse message', exc_info=True) continue if content['cmd'] == 'exit_worker': @@ -265,5 +271,6 @@ def main(): sys.stdout = sys_stdout sys.stderr = sys_stderr + if __name__ == '__main__': sys.exit(main()) diff --git a/externals/kyuubi-spark-sql-engine/src/main/resources/python/kyuubi_util.py b/externals/kyuubi-spark-sql-engine/src/main/resources/python/kyuubi_util.py index 4478060e9..8bbe6eb7c 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/resources/python/kyuubi_util.py +++ b/externals/kyuubi-spark-sql-engine/src/main/resources/python/kyuubi_util.py @@ -15,29 +15,19 @@ # limitations under the License. # -import atexit import os -import sys -import signal -import shlex -import shutil -import socket -import platform -import tempfile -import time -from subprocess import Popen, PIPE -from py4j.java_gateway import java_import, JavaGateway, JavaObject, GatewayParameters from py4j.clientserver import ClientServer, JavaParameters, PythonParameters +from py4j.java_gateway import java_import, JavaGateway, GatewayParameters from pyspark.context import SparkContext -from pyspark.serializers import read_int, write_with_length, UTF8Deserializer +from pyspark.serializers import read_int, UTF8Deserializer from pyspark.sql import SparkSession -def connect_to_exist_gateway(): +def connect_to_exist_gateway() -> "JavaGateway": conn_info_file = os.environ.get("PYTHON_GATEWAY_CONNECTION_INFO") if conn_info_file is None: - raise SystemExit("the python gateway connection information file not found!") + raise SystemExit("the python gateway connection information file not found!") with open(conn_info_file, "rb") as info: gateway_port = read_int(info) gateway_secret = UTF8Deserializer().loads(info) @@ -72,16 +62,17 @@ def connect_to_exist_gateway(): return gateway + def _get_exist_spark_context(self, jconf): """ Initialize SparkContext in function to allow subclass specific initialization """ return self._jvm.JavaSparkContext(self._jvm.org.apache.spark.SparkContext.getOrCreate(jconf)) -def get_spark(): + +def get_spark_session() -> "SparkSession": SparkContext._initialize_context = _get_exist_spark_context gateway = connect_to_exist_gateway() SparkContext._ensure_initialized(gateway=gateway) spark = SparkSession.builder.master('local').appName('test').getOrCreate() return spark - diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecutePython.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecutePython.scala index 3254f0e2c..4fb08105d 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecutePython.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecutePython.scala @@ -98,14 +98,14 @@ case class SessionPythonWorker( private val stdout: BufferedReader = new BufferedReader(new InputStreamReader(workerProcess.getInputStream), 1) - def runCode(code: String): Option[PythonReponse] = { + def runCode(code: String): Option[PythonResponse] = { val input = ExecutePython.toJson(Map("code" -> code, "cmd" -> "run_code")) // scalastyle:off println stdin.println(input) // scalastyle:on stdin.flush() Option(stdout.readLine()) - .map(ExecutePython.fromJson[PythonReponse](_)) + .map(ExecutePython.fromJson[PythonResponse](_)) } def close(): Unit = { @@ -125,7 +125,7 @@ case class SessionPythonWorker( object ExecutePython extends Logging { private val isPythonGatewayStart = new AtomicBoolean(false) - val kyuubiPythonPath = Files.createTempDirectory("") + private val kyuubiPythonPath = Files.createTempDirectory("") def init(): Unit = { if (!isPythonGatewayStart.get()) { synchronized { @@ -186,7 +186,7 @@ object ExecutePython extends Logging { private def startStderrSteamReader(process: Process): Thread = { val stderrThread = new Thread("process stderr thread") { - override def run() = { + override def run(): Unit = { val lines = scala.io.Source.fromInputStream(process.getErrorStream).getLines() lines.foreach(logger.error) } @@ -198,7 +198,7 @@ object ExecutePython extends Logging { def startWatcher(process: Process): Thread = { val processWatcherThread = new Thread("process watcher thread") { - override def run() = { + override def run(): Unit = { val exitCode = process.waitFor() if (exitCode != 0) { logger.error(f"Process has died with $exitCode") @@ -229,7 +229,7 @@ object ExecutePython extends Logging { file } - val mapper = new ObjectMapper().registerModule(DefaultScalaModule) + val mapper: ObjectMapper = new ObjectMapper().registerModule(DefaultScalaModule) def toJson[T](obj: T): String = { mapper.writeValueAsString(obj) } @@ -243,7 +243,7 @@ object ExecutePython extends Logging { } -case class PythonReponse( +case class PythonResponse( msg_type: String, content: PythonResponseContent)