diff --git a/.github/workflows/master.yml b/.github/workflows/master.yml index 85f67fdcd..38915b560 100644 --- a/.github/workflows/master.yml +++ b/.github/workflows/master.yml @@ -75,6 +75,10 @@ jobs: java-version: ${{ matrix.java }} cache: 'maven' check-latest: false + - name: Setup Python + uses: actions/setup-python@v4 + with: + python-version: '3.9' - name: Build and test Kyuubi and Spark with maven w/o linters run: | TEST_MODULES="dev/kyuubi-codecov" diff --git a/externals/kyuubi-spark-sql-engine/src/main/resources/python/execute_python.py b/externals/kyuubi-spark-sql-engine/src/main/resources/python/execute_python.py new file mode 100644 index 000000000..8f8dfca96 --- /dev/null +++ b/externals/kyuubi-spark-sql-engine/src/main/resources/python/execute_python.py @@ -0,0 +1,260 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from glob import glob +import ast +import sys +import io +import json +import traceback +import re +import os + +TOP_FRAME_REGEX = re.compile(r'\s*File "".*in ') + +global_dict = {} + +class NormalNode(object): + def __init__(self, code): + self.code = compile(code, '', 'exec', ast.PyCF_ONLY_AST, 1) + + def execute(self): + to_run_exec, to_run_single = self.code.body[:-1], self.code.body[-1:] + + try: + for node in to_run_exec: + mod = ast.Module([node]) + code = compile(mod, '', 'exec') + exec(code, global_dict) + + for node in to_run_single: + mod = ast.Interactive([node]) + code = compile(mod, '', 'single') + exec(code, global_dict) + except: + # We don't need to log the exception because we're just executing user + # code and passing the error along. + raise ExecutionError(sys.exc_info()) + +class ExecutionError(Exception): + def __init__(self, exc_info): + self.exc_info = exc_info + +class UnicodeDecodingStringIO(io.StringIO): + def write(self, s): + if isinstance(s, bytes): + s = s.decode("utf-8") + super(UnicodeDecodingStringIO, self).write(s) + +def clearOutputs(): + sys.stdout.close() + sys.stderr.close() + sys.stdout = UnicodeDecodingStringIO() + sys.stderr = UnicodeDecodingStringIO() + + +def parse_code_into_nodes(code): + nodes = [] + try: + nodes.append(NormalNode(code)) + except SyntaxError: + # It's possible we hit a syntax error because of a magic command. Split the code groups + # of 'normal code', and code that starts with a '%'. possibly magic code + # lines, and see if any of the lines + # Remove lines until we find a node that parses, then check if the next line is a magic + # line + # . + + # Split the code into chunks of normal code, and possibly magic code, which starts with + # a '%'. + + normal = [] + chunks = [] + for i, line in enumerate(code.rstrip().split('\n')): + if line.startswith('%'): + if normal: + chunks.append('\n'.join(normal)) + normal = [] + + chunks.append(line) + else: + normal.append(line) + + if normal: + chunks.append('\n'.join(normal)) + + # Convert the chunks into AST nodes. Let exceptions propagate. + for chunk in chunks: + if chunk.startswith('%'): + nodes.append(MagicNode(chunk)) + else: + nodes.append(NormalNode(chunk)) + + return nodes + +def execute_reply(status, content): + msg = { + 'msg_type': 'execute_reply', + 'content': dict( + content, + status=status, + ) + } + return json.dumps(msg) + +def execute_reply_ok(data): + return execute_reply("ok", { + "data": data, + }) + +def execute_reply_error(exc_type, exc_value, tb): + # LOG.error('execute_reply', exc_info=True) + if sys.version >= '3': + formatted_tb = traceback.format_exception(exc_type, exc_value, tb, chain=False) + else: + formatted_tb = traceback.format_exception(exc_type, exc_value, tb) + for i in range(len(formatted_tb)): + if TOP_FRAME_REGEX.match(formatted_tb[i]): + formatted_tb = formatted_tb[:1] + formatted_tb[i + 1:] + break + + return execute_reply('error', { + 'ename': str(exc_type.__name__), + 'evalue': str(exc_value), + 'traceback': formatted_tb, + }) + +def execute_request(content): + try: + code = content['code'] + except KeyError: + return execute_reply_internal_error( + 'Malformed message: content object missing "code"', sys.exc_info() + ) + + try: + nodes = parse_code_into_nodes(code) + except SyntaxError: + exc_type, exc_value, tb = sys.exc_info() + return execute_reply_error(exc_type, exc_value, None) + + result = None + + try: + for node in nodes: + result = node.execute() + except ExecutionError as e: + return execute_reply_error(*e.exc_info) + + if result is None: + result = {} + + stdout = sys.stdout.getvalue() + stderr = sys.stderr.getvalue() + + clearOutputs() + + output = result.pop('text/plain', '') + + if stdout: + output += stdout + + if stderr: + output += stderr + + output = output.rstrip() + + # Only add the output if it exists, or if there are no other mimetypes in the result. + if output or not result: + result['text/plain'] = output.rstrip() + + return execute_reply_ok(result) + +# import findspark +# findspark.init() + +spark_home = os.environ.get("SPARK_HOME", "") +os.environ["PYSPARK_PYTHON"] = os.environ.get("PYSPARK_PYTHON", sys.executable) + +# add pyspark to sys.path + +if "pyspark" not in sys.modules: + spark_python = os.path.join(spark_home, "python") + try: + py4j = glob(os.path.join(spark_python, "lib", "py4j-*.zip"))[0] + except IndexError: + raise Exception( + "Unable to find py4j in {}, your SPARK_HOME may not be configured correctly".format( + spark_python + ) + ) + sys.path[:0] = sys_path = [spark_python, py4j] +else: + # already imported, no need to patch sys.path + sys_path = None + +import kyuubi_util +spark = kyuubi_util.get_spark() +global_dict['spark'] = spark + +def main(): + sys_stdin = sys.stdin + sys_stdout = sys.stdout + sys_stderr = sys.stderr + + if sys.version >= '3': + sys.stdin = io.StringIO() + else: + sys.stdin = cStringIO.StringIO() + + sys.stdout = UnicodeDecodingStringIO() + sys.stderr = UnicodeDecodingStringIO() + + stderr = sys.stderr.getvalue() + print(stderr, file=sys_stderr) + clearOutputs + try: + + while True: + line = sys_stdin.readline() + + if line == '': + break + elif line == '\n': + continue + + try: + content = json.loads(line) + except ValueError: + # LOG.error('failed to parse message', exc_info=True) + continue + + if content['cmd'] == 'exit_worker': + break + + result = execute_request(content) + print(result, file=sys_stdout) + sys_stdout.flush() + clearOutputs() + finally: + print("python worker exit", file=sys_stderr) + sys.stdin = sys_stdin + sys.stdout = sys_stdout + sys.stderr = sys_stderr + +if __name__ == '__main__': + sys.exit(main()) diff --git a/externals/kyuubi-spark-sql-engine/src/main/resources/python/kyuubi_util.py b/externals/kyuubi-spark-sql-engine/src/main/resources/python/kyuubi_util.py new file mode 100644 index 000000000..4478060e9 --- /dev/null +++ b/externals/kyuubi-spark-sql-engine/src/main/resources/python/kyuubi_util.py @@ -0,0 +1,87 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import atexit +import os +import sys +import signal +import shlex +import shutil +import socket +import platform +import tempfile +import time +from subprocess import Popen, PIPE + +from py4j.java_gateway import java_import, JavaGateway, JavaObject, GatewayParameters +from py4j.clientserver import ClientServer, JavaParameters, PythonParameters +from pyspark.context import SparkContext +from pyspark.serializers import read_int, write_with_length, UTF8Deserializer +from pyspark.sql import SparkSession + + +def connect_to_exist_gateway(): + conn_info_file = os.environ.get("PYTHON_GATEWAY_CONNECTION_INFO") + if conn_info_file is None: + raise SystemExit("the python gateway connection information file not found!") + with open(conn_info_file, "rb") as info: + gateway_port = read_int(info) + gateway_secret = UTF8Deserializer().loads(info) + if os.environ.get("PYSPARK_PIN_THREAD", "true").lower() == "true": + gateway = ClientServer( + java_parameters=JavaParameters( + port=gateway_port, + auth_token=gateway_secret, + auto_convert=True), + python_parameters=PythonParameters( + port=0, + eager_load=False)) + else: + gateway = JavaGateway( + gateway_parameters=GatewayParameters( + port=gateway_port, + auth_token=gateway_secret, + auto_convert=True)) + # gateway.proc = proc + + # Import the classes used by PySpark + java_import(gateway.jvm, "org.apache.spark.SparkConf") + java_import(gateway.jvm, "org.apache.spark.api.java.*") + java_import(gateway.jvm, "org.apache.spark.api.python.*") + java_import(gateway.jvm, "org.apache.spark.ml.python.*") + java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*") + java_import(gateway.jvm, "org.apache.spark.resource.*") + java_import(gateway.jvm, "org.apache.spark.sql.*") + java_import(gateway.jvm, "org.apache.spark.sql.api.python.*") + java_import(gateway.jvm, "org.apache.spark.sql.hive.*") + java_import(gateway.jvm, "scala.Tuple2") + + return gateway + +def _get_exist_spark_context(self, jconf): + """ + Initialize SparkContext in function to allow subclass specific initialization + """ + return self._jvm.JavaSparkContext(self._jvm.org.apache.spark.SparkContext.getOrCreate(jconf)) + +def get_spark(): + SparkContext._initialize_context = _get_exist_spark_context + gateway = connect_to_exist_gateway() + SparkContext._ensure_initialized(gateway=gateway) + spark = SparkSession.builder.master('local').appName('test').getOrCreate() + return spark + diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecutePython.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecutePython.scala new file mode 100644 index 000000000..f980fda9a --- /dev/null +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecutePython.scala @@ -0,0 +1,240 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.spark.operation + +import java.io.{BufferedReader, File, FilenameFilter, FileOutputStream, InputStreamReader, PrintWriter} +import java.lang.ProcessBuilder.Redirect +import java.nio.file.{Files, Path, Paths} +import java.util.concurrent.atomic.AtomicBoolean + +import scala.collection.JavaConverters._ + +import com.fasterxml.jackson.databind.ObjectMapper +import com.fasterxml.jackson.module.scala.DefaultScalaModule +import org.apache.spark.api.python.KyuubiPythonGatewayServer +import org.apache.spark.sql.Row +import org.apache.spark.sql.types.StructType + +import org.apache.kyuubi.Logging +import org.apache.kyuubi.operation.ArrayFetchIterator +import org.apache.kyuubi.session.Session + +class ExecutePython( + session: Session, + override val statement: String, + worker: SessionPythonWorker) extends SparkOperation(session) { + + override protected def resultSchema: StructType = { + if (result == null || result.schema.isEmpty) { + new StructType().add("output", "string") + .add("status", "string") + .add("ename", "string") + .add("evalue", "string") + .add("traceback", "array") + } else { + result.schema + } + } + + override protected def runInternal(): Unit = { + val response = worker.runCode(statement) + val output = response.map(_.content.getOutput()).getOrElse("") + val status = response.map(_.content.status).getOrElse("UNKNOWN_STATUS") + val ename = response.map(_.content.getEname()).getOrElse("") + val evalue = response.map(_.content.getEvalue()).getOrElse("") + val traceback = response.map(_.content.getTraceback()).getOrElse(Array.empty) + iter = + new ArrayFetchIterator[Row](Array(Row(output, status, ename, evalue, Row(traceback: _*)))) + } + +} + +case class SessionPythonWorker( + errorReader: Thread, + pythonWorkerMonitor: Thread, + workerProcess: Process) { + private val stdin: PrintWriter = new PrintWriter(workerProcess.getOutputStream) + private val stdout: BufferedReader = + new BufferedReader(new InputStreamReader(workerProcess.getInputStream), 1) + + def runCode(code: String): Option[PythonReponse] = { + val input = ExecutePython.toJson(Map("code" -> code, "cmd" -> "run_code")) + // scalastyle:off println + stdin.println(input) + // scalastyle:on + stdin.flush() + Option(stdout.readLine()) + .map(ExecutePython.fromJson[PythonReponse](_)) + } + + def close(): Unit = { + val exitCmd = ExecutePython.toJson(Map("cmd" -> "exit_worker")) + // scalastyle:off println + stdin.println(exitCmd) + // scalastyle:on + stdin.flush() + stdin.close() + stdout.close() + errorReader.interrupt() + pythonWorkerMonitor.interrupt() + workerProcess.destroy() + } +} + +object ExecutePython extends Logging { + + // TODO:(fchen) get from conf + val pythonExec = + sys.env.getOrElse("PYSPARK_PYTHON", sys.env.getOrElse("PYSPARK_DRIVER_PYTHON", "python3")) + private val isPythonGatewayStart = new AtomicBoolean(false) + val kyuubiPythonPath = Files.createTempDirectory("") + def init(): Unit = { + if (!isPythonGatewayStart.get()) { + synchronized { + if (!isPythonGatewayStart.get()) { + KyuubiPythonGatewayServer.start() + writeTempPyFile(kyuubiPythonPath, "execute_python.py") + writeTempPyFile(kyuubiPythonPath, "kyuubi_util.py") + isPythonGatewayStart.set(true) + } + } + } + } + + def createSessionPythonWorker(): SessionPythonWorker = { + val builder = new ProcessBuilder(Seq( + pythonExec, + s"${ExecutePython.kyuubiPythonPath}/execute_python.py").asJava) + val env = builder.environment() + val pythonPath = sys.env.getOrElse("PYTHONPATH", "") + .split(File.pathSeparator) + .++(ExecutePython.kyuubiPythonPath.toString) + env.put("PYTHONPATH", pythonPath.mkString(File.pathSeparator)) + env.put("SPARK_HOME", sys.env.getOrElse("SPARK_HOME", defaultSparkHome())) + env.put("PYTHON_GATEWAY_CONNECTION_INFO", KyuubiPythonGatewayServer.CONNECTION_FILE_PATH) + logger.info( + s""" + |launch python worker command: ${builder.command().asScala.mkString(" ")} + |environment: + |${builder.environment().asScala.map(kv => kv._1 + "=" + kv._2).mkString("\n")} + |""".stripMargin) + builder.redirectError(Redirect.PIPE) + val process = builder.start() + SessionPythonWorker(startStderrSteamReader(process), startWatcher(process), process) + } + + // for test + def defaultSparkHome(): String = { + val homeDirFilter: FilenameFilter = (dir: File, name: String) => + dir.isDirectory && name.contains("spark-") && !name.contains("-engine") + // get from kyuubi-server/../externals/kyuubi-download/target + new File(getClass.getProtectionDomain.getCodeSource.getLocation.toURI).getPath + .split("kyuubi-spark-sql-engine").flatMap { cwd => + val candidates = Paths.get(cwd, "kyuubi-download", "target") + .toFile.listFiles(homeDirFilter) + if (candidates == null) None else candidates.map(_.toPath).headOption + }.find(Files.exists(_)).map(_.toAbsolutePath.toFile.getCanonicalPath) + .getOrElse { + throw new IllegalStateException("SPARK_HOME not found!") + } + } + + private def startStderrSteamReader(process: Process): Thread = { + val stderrThread = new Thread("process stderr thread") { + override def run() = { + val lines = scala.io.Source.fromInputStream(process.getErrorStream).getLines() + lines.foreach(logger.error) + } + } + stderrThread.setDaemon(true) + stderrThread.start() + stderrThread + } + + def startWatcher(process: Process): Thread = { + val processWatcherThread = new Thread("process watcher thread") { + override def run() = { + val exitCode = process.waitFor() + if (exitCode != 0) { + logger.error(f"Process has died with $exitCode") + } + } + } + processWatcherThread.setDaemon(true) + processWatcherThread.start() + processWatcherThread + } + + private def writeTempPyFile(pythonPath: Path, pyfile: String): File = { + val source = getClass.getClassLoader.getResourceAsStream(s"python/$pyfile") + + val file = new File(pythonPath.toFile, pyfile) + file.deleteOnExit() + + val sink = new FileOutputStream(file) + val buf = new Array[Byte](1024) + var n = source.read(buf) + + while (n > 0) { + sink.write(buf, 0, n) + n = source.read(buf) + } + source.close() + sink.close() + file + } + + val mapper = new ObjectMapper().registerModule(DefaultScalaModule) + def toJson[T](obj: T): String = { + mapper.writeValueAsString(obj) + } + def fromJson[T](json: String, clz: Class[T]): T = { + mapper.readValue(json, clz) + } + + def fromJson[T](json: String)(implicit m: Manifest[T]): T = { + mapper.readValue(json, m.runtimeClass).asInstanceOf[T] + } + +} + +case class PythonReponse( + msg_type: String, + content: PythonResponseContent) + +case class PythonResponseContent( + data: Map[String, String], + ename: String, + evalue: String, + traceback: Array[String], + status: String) { + def getOutput(): String = { + Option(data) + .map(_.getOrElse("text/plain", "")) + .getOrElse("") + } + def getEname(): String = { + Option(ename).getOrElse("") + } + def getEvalue(): String = { + Option(evalue).getOrElse("") + } + def getTraceback(): Array[String] = { + Option(traceback).getOrElse(Array.empty) + } +} diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/SparkSQLOperationManager.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/SparkSQLOperationManager.scala index 7a3f25eaa..4166d4902 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/SparkSQLOperationManager.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/SparkSQLOperationManager.scala @@ -40,12 +40,19 @@ class SparkSQLOperationManager private (name: String) extends OperationManager(n getConf.get(ENGINE_OPERATION_CONVERT_CATALOG_DATABASE_ENABLED) private val sessionToRepl = new ConcurrentHashMap[SessionHandle, KyuubiSparkILoop]().asScala + private val sessionToPythonProcess = + new ConcurrentHashMap[SessionHandle, SessionPythonWorker]().asScala def closeILoop(session: SessionHandle): Unit = { val maybeRepl = sessionToRepl.remove(session) maybeRepl.foreach(_.close()) } + def closePythonProcess(session: SessionHandle): Unit = { + val maybeProcess = sessionToPythonProcess.remove(session) + maybeProcess.foreach(_.close) + } + override def newExecuteStatementOperation( session: Session, statement: String, @@ -82,6 +89,12 @@ class SparkSQLOperationManager private (name: String) extends OperationManager(n case OperationLanguages.SCALA => val repl = sessionToRepl.getOrElseUpdate(session.handle, KyuubiSparkILoop(spark)) new ExecuteScala(session, repl, statement) + case OperationLanguages.PYTHON => + ExecutePython.init() + val worker = sessionToPythonProcess.getOrElseUpdate( + session.handle, + ExecutePython.createSessionPythonWorker()) + new ExecutePython(session, statement, worker) case OperationLanguages.UNKNOWN => spark.conf.unset(OPERATION_LANGUAGE.key) throw KyuubiSQLException(s"The operation language $lang" + diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/session/SparkSessionImpl.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/session/SparkSessionImpl.scala index eb4c84e24..5bf1ec084 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/session/SparkSessionImpl.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/session/SparkSessionImpl.scala @@ -97,5 +97,7 @@ class SparkSessionImpl( super.close() spark.sessionState.catalog.getTempViewNames().foreach(spark.catalog.uncacheTable) sessionManager.operationManager.asInstanceOf[SparkSQLOperationManager].closeILoop(handle) + sessionManager.operationManager.asInstanceOf[SparkSQLOperationManager].closePythonProcess( + handle) } } diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/api/python/KyuubiPythonGatewayServer.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/api/python/KyuubiPythonGatewayServer.scala new file mode 100644 index 000000000..ff3ba089f --- /dev/null +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/api/python/KyuubiPythonGatewayServer.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.python + +import java.io.{DataOutputStream, File, FileOutputStream} +import java.nio.charset.StandardCharsets.UTF_8 +import java.nio.file.Files + +import org.apache.spark.SparkConf +import org.apache.spark.internal.Logging + +object KyuubiPythonGatewayServer extends Logging { + + val CONNECTION_FILE_PATH = Files.createTempDirectory("") + "/connection.info" + + def start(): Unit = { + + val sparkConf = new SparkConf() + val gatewayServer: Py4JServer = new Py4JServer(sparkConf) + + gatewayServer.start() + val boundPort: Int = gatewayServer.getListeningPort + if (boundPort == -1) { + logError(s"${gatewayServer.server.getClass} failed to bind; exiting") + System.exit(1) + } else { + logDebug(s"Started PythonGatewayServer on port $boundPort") + } + + // Communicate the connection information back to the python process by writing the + // information in the requested file. This needs to match the read side in java_gateway.py. + val connectionInfoPath = new File(CONNECTION_FILE_PATH) + val tmpPath = Files.createTempFile( + connectionInfoPath.getParentFile().toPath(), + "connection", + ".info").toFile() + + val dos = new DataOutputStream(new FileOutputStream(tmpPath)) + dos.writeInt(boundPort) + + val secretBytes = gatewayServer.secret.getBytes(UTF_8) + dos.writeInt(secretBytes.length) + dos.write(secretBytes, 0, secretBytes.length) + dos.close() + + if (!tmpPath.renameTo(connectionInfoPath)) { + logError(s"Unable to write connection information to $connectionInfoPath.") + System.exit(1) + } + } +} diff --git a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/PySparkTests.scala b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/PySparkTests.scala new file mode 100644 index 000000000..c11705a99 --- /dev/null +++ b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/PySparkTests.scala @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.spark.operation + +import java.io.PrintWriter +import java.nio.file.Files + +import scala.sys.process._ + +import org.apache.kyuubi.engine.spark.WithSparkSQLEngine +import org.apache.kyuubi.operation.HiveJDBCTestHelper + +trait PySparkTests extends WithSparkSQLEngine with HiveJDBCTestHelper { + + test("pyspark support") { + val code = "print(1)" + val output = "1" + runPySparkTest(code, output) + } + + test("pyspark support - multi-line") { + val code = + """ + |for i in [1, 2, 3]: + | print(i) + |""".stripMargin + val output = "1\n2\n3" + runPySparkTest(code, output) + } + + test("pyspark support - call spark.sql") { + val code = + """ + |spark.sql("select 1").show() + |""".stripMargin + val output = + """|+---+ + || 1| + |+---+ + || 1| + |+---+""".stripMargin + runPySparkTest(code, output) + } + + private def runPySparkTest( + pyCode: String, + output: String): Unit = { + checkPythonRuntimeAndVersion() + withMultipleConnectionJdbcStatement()({ statement => + statement.executeQuery("SET kyuubi.operation.language=python") + val resultSet = statement.executeQuery(pyCode) + assert(resultSet.next()) + assert(resultSet.getString("output") === output) + assert(resultSet.getString("status") === "ok") + }) + } + + private def checkPythonRuntimeAndVersion(): Unit = { + val code = + """ + |import sys + |print(".".join(map(str, sys.version_info[:2]))) + |""".stripMargin + withTempPyFile(code) { + pyfile: String => + val pythonVersion = s"python3 $pyfile".!!.toDouble + assert(pythonVersion > 3.0, "required python version > 3.0") + } + } + + private def withTempPyFile(code: String)(op: (String) => Unit): Unit = { + val tempPyFile = Files.createTempFile("", ".py").toFile + try { + new PrintWriter(tempPyFile) { + write(code) + close + } + op(tempPyFile.getPath) + } finally { + Files.delete(tempPyFile.toPath) + } + } +} diff --git a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkOperationSuite.scala b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkOperationSuite.scala index 8d3e1d7ac..b58d39e73 100644 --- a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkOperationSuite.scala +++ b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkOperationSuite.scala @@ -40,7 +40,8 @@ import org.apache.kyuubi.operation.meta.ResultSetSchemaConstant._ import org.apache.kyuubi.util.KyuubiHadoopUtils import org.apache.kyuubi.util.SparkVersionUtil.isSparkVersionAtLeast -class SparkOperationSuite extends WithSparkSQLEngine with HiveMetadataTests with SparkQueryTests { +class SparkOperationSuite extends WithSparkSQLEngine with HiveMetadataTests with SparkQueryTests + with PySparkTests { override protected def jdbcUrl: String = getJdbcUrl override def withKyuubiConf: Map[String, String] = Map.empty diff --git a/kyuubi-common/src/main/scala/org/apache/kyuubi/config/KyuubiConf.scala b/kyuubi-common/src/main/scala/org/apache/kyuubi/config/KyuubiConf.scala index f0d1534c1..d58014570 100644 --- a/kyuubi-common/src/main/scala/org/apache/kyuubi/config/KyuubiConf.scala +++ b/kyuubi-common/src/main/scala/org/apache/kyuubi/config/KyuubiConf.scala @@ -1867,9 +1867,10 @@ object KyuubiConf { object OperationLanguages extends Enumeration with Logging { type OperationLanguage = Value - val SQL, SCALA, UNKNOWN = Value + val PYTHON, SQL, SCALA, UNKNOWN = Value def apply(language: String): OperationLanguage = { language.toUpperCase(Locale.ROOT) match { + case "PYTHON" => PYTHON case "SQL" => SQL case "SCALA" => SCALA case other => diff --git a/kyuubi-hive-beeline/src/main/java/org/apache/hive/beeline/KyuubiCommands.java b/kyuubi-hive-beeline/src/main/java/org/apache/hive/beeline/KyuubiCommands.java index e3a4b295e..3ffe7aee7 100644 --- a/kyuubi-hive-beeline/src/main/java/org/apache/hive/beeline/KyuubiCommands.java +++ b/kyuubi-hive-beeline/src/main/java/org/apache/hive/beeline/KyuubiCommands.java @@ -21,7 +21,6 @@ import java.io.*; import java.sql.*; import java.util.*; import org.apache.hive.beeline.logs.KyuubiBeelineInPlaceUpdateStream; -import org.apache.hive.common.util.HiveStringUtils; import org.apache.kyuubi.jdbc.hive.JdbcConnectionParams; import org.apache.kyuubi.jdbc.hive.KyuubiStatement; import org.apache.kyuubi.jdbc.hive.Utils; @@ -45,7 +44,7 @@ public class KyuubiCommands extends Commands { /** Extract and clean up the first command in the input. */ private String getFirstCmd(String cmd, int length) { - return cmd.substring(length).trim(); + return cmd.substring(length); } private String[] tokenizeCmd(String cmd) { @@ -97,7 +96,6 @@ public class KyuubiCommands extends Commands { } String[] cmds = lines.split(";"); for (String c : cmds) { - c = c.trim(); if (!executeInternal(c, false)) { return false; } @@ -261,10 +259,9 @@ public class KyuubiCommands extends Commands { beeLine.handleException(e); } - line = line.trim(); List cmdList = getCmdList(line, entireLineAsCommand); for (int i = 0; i < cmdList.size(); i++) { - String sql = cmdList.get(i).trim(); + String sql = cmdList.get(i); if (sql.length() != 0) { if (!executeInternal(sql, call)) { return false; @@ -511,7 +508,6 @@ public class KyuubiCommands extends Commands { @Override public String handleMultiLineCmd(String line) throws IOException { int[] startQuote = {-1}; - line = HiveStringUtils.removeComments(line, startQuote); Character mask = (System.getProperty("jline.terminal", "").equals("jline.UnsupportedTerminal")) ? null @@ -542,7 +538,6 @@ public class KyuubiCommands extends Commands { if (extra == null) { // it happens when using -f and the line of cmds does not end with ; break; } - extra = HiveStringUtils.removeComments(extra, startQuote); if (!extra.isEmpty()) { line += "\n" + extra; } @@ -554,13 +549,12 @@ public class KyuubiCommands extends Commands { // console. Used in handleMultiLineCmd method assumes line would never be null when this method is // called private boolean isMultiLine(String line) { - line = line.trim(); if (line.endsWith(beeLine.getOpts().getDelimiter()) || beeLine.isComment(line)) { return false; } // handles the case like line = show tables; --test comment List cmds = getCmdList(line, false); - return cmds.isEmpty() || !cmds.get(cmds.size() - 1).trim().startsWith("--"); + return cmds.isEmpty() || !cmds.get(cmds.size() - 1).startsWith("--"); } static class KyuubiLogRunnable implements Runnable { diff --git a/pom.xml b/pom.xml index 398468506..d709f278e 100644 --- a/pom.xml +++ b/pom.xml @@ -159,6 +159,7 @@ 4.1.73.Final 1.10.1 0.16.0 + 0.10.7 2.3.0 3.2.9.0 3.2.9 @@ -1621,6 +1622,11 @@ kudu-client ${kudu.version} + + net.sf.py4j + py4j + ${py4j.version} +