[KYUUBI #3820] [Subtask] [PySpark] Skip missing MagicNode and code improvements
### _Why are the changes needed?_ to close #3820 . To improve pyspark script support, 1. skip missing MagicNode implementation, since Jupyter and sparkmagic are not yet supported 2. add missing execute_reply_internal_error method 3. fix by calling clearOutputs before loop 4. ident lines and optimze unsed imports to conform python code style 5. Check Python major version , and exit on Python 2.x 6. fix name typo of `PythonResponse` ### _How was this patch tested?_ - [ ] Add some test cases that check the changes thoroughly including negative and positive cases if possible - [ ] Add screenshots for manual tests if appropriate - [x] [Run test](https://kyuubi.apache.org/docs/latest/develop_tools/testing.html#running-tests) locally before make a pull request Closes #3819 from bowenliang123/imrove-pyspark. Closes #3820 473b9952 [liangbowen] add return type to `connect_to_existed_gateway` 66927821 [liangbowen] remove unnecessary comments for magic code 21e1d7a2 [liangbowen] move pyspark path preparing to the top of exeuction_python 9751e094 [liangbowen] revert to use SparkSessionBuilder for session creation c4f3ef55 [liangbowen] use `SparkSession._create_shell_session()` to create spark session c2f65630 [liangbowen] delay importing kyuubi_util 5ed893cc [liangbowen] adding Exception to except, to prevent PEP 8: E203 029361a9 [liangbowen] ast module adaptation for >=3.8 00c75fda [liangbowen] remove legacy code for importing unicode 9f56a4f4 [liangbowen] add todo 1da708ed [liangbowen] fix typo for PythonResponse, and minor declaration improvement 910c62fb [liangbowen] remove MagicNode implementation since Jupyter and sparkmagic are not yet supported 5f15c257 [liangbowen] exit on python 2.x 86ff7d06 [liangbowen] ident lines to conform python code style 5634c5e0 [liangbowen] rename get_spark to get_spark_session, and optimize unused imports in kyuubi_util.py 9d3e1d0c [liangbowen] add missing MagicNode implementation 0ade1dbe [liangbowen] add missing execute_reply_internal_error method aee205a5 [liangbowen] import cStringIO for fix package resolving problem acdd4b16 [liangbowen] fix by calling clearOutputs before loop Authored-by: liangbowen <liangbowen@gf.com.cn> Signed-off-by: Cheng Pan <chengpan@apache.org>
This commit is contained in:
parent
47e1cfdf08
commit
3fae1845e7
@ -15,28 +15,57 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from glob import glob
|
||||
import ast
|
||||
import sys
|
||||
import io
|
||||
import json
|
||||
import traceback
|
||||
import re
|
||||
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import traceback
|
||||
from glob import glob
|
||||
|
||||
if sys.version_info[0] < 3:
|
||||
sys.exit('Python < 3 is unsupported.')
|
||||
|
||||
spark_home = os.environ.get("SPARK_HOME", "")
|
||||
os.environ["PYSPARK_PYTHON"] = os.environ.get("PYSPARK_PYTHON", sys.executable)
|
||||
|
||||
# add pyspark to sys.path
|
||||
|
||||
if "pyspark" not in sys.modules:
|
||||
spark_python = os.path.join(spark_home, "python")
|
||||
try:
|
||||
py4j = glob(os.path.join(spark_python, "lib", "py4j-*.zip"))[0]
|
||||
except IndexError:
|
||||
raise Exception(
|
||||
"Unable to find py4j in {}, your SPARK_HOME may not be configured correctly".format(
|
||||
spark_python
|
||||
)
|
||||
)
|
||||
sys.path[:0] = sys_path = [spark_python, py4j]
|
||||
else:
|
||||
# already imported, no need to patch sys.path
|
||||
sys_path = None
|
||||
|
||||
# import kyuubi_util after preparing sys.path
|
||||
import kyuubi_util
|
||||
|
||||
# ast api is changed after python 3.8, see https://github.com/ipython/ipython/pull/11593
|
||||
if sys.version_info > (3,8):
|
||||
from ast import Module
|
||||
else :
|
||||
# mock the new API, ignore second argument
|
||||
# see https://github.com/ipython/ipython/issues/11590
|
||||
from ast import Module as OriginalModule
|
||||
Module = lambda nodelist, type_ignores: OriginalModule(nodelist)
|
||||
if sys.version_info >= (3, 8):
|
||||
from ast import Module
|
||||
else:
|
||||
# mock the new API, ignore second argument
|
||||
# see https://github.com/ipython/ipython/issues/11590
|
||||
from ast import Module as OriginalModule
|
||||
|
||||
Module = lambda nodelist, type_ignores: OriginalModule(nodelist)
|
||||
|
||||
TOP_FRAME_REGEX = re.compile(r'\s*File "<stdin>".*in <module>')
|
||||
|
||||
global_dict = {}
|
||||
|
||||
|
||||
class NormalNode(object):
|
||||
def __init__(self, code):
|
||||
self.code = compile(code, '<stdin>', 'exec', ast.PyCF_ONLY_AST, 1)
|
||||
@ -54,21 +83,24 @@ class NormalNode(object):
|
||||
mod = ast.Interactive([node])
|
||||
code = compile(mod, '<stdin>', 'single')
|
||||
exec(code, global_dict)
|
||||
except:
|
||||
except Exception:
|
||||
# We don't need to log the exception because we're just executing user
|
||||
# code and passing the error along.
|
||||
raise ExecutionError(sys.exc_info())
|
||||
|
||||
|
||||
class ExecutionError(Exception):
|
||||
def __init__(self, exc_info):
|
||||
self.exc_info = exc_info
|
||||
|
||||
|
||||
class UnicodeDecodingStringIO(io.StringIO):
|
||||
def write(self, s):
|
||||
if isinstance(s, bytes):
|
||||
s = s.decode("utf-8")
|
||||
super(UnicodeDecodingStringIO, self).write(s)
|
||||
|
||||
|
||||
def clearOutputs():
|
||||
sys.stdout.close()
|
||||
sys.stderr.close()
|
||||
@ -81,16 +113,6 @@ def parse_code_into_nodes(code):
|
||||
try:
|
||||
nodes.append(NormalNode(code))
|
||||
except SyntaxError:
|
||||
# It's possible we hit a syntax error because of a magic command. Split the code groups
|
||||
# of 'normal code', and code that starts with a '%'. possibly magic code
|
||||
# lines, and see if any of the lines
|
||||
# Remove lines until we find a node that parses, then check if the next line is a magic
|
||||
# line
|
||||
# .
|
||||
|
||||
# Split the code into chunks of normal code, and possibly magic code, which starts with
|
||||
# a '%'.
|
||||
|
||||
normal = []
|
||||
chunks = []
|
||||
for i, line in enumerate(code.rstrip().split('\n')):
|
||||
@ -108,13 +130,15 @@ def parse_code_into_nodes(code):
|
||||
|
||||
# Convert the chunks into AST nodes. Let exceptions propagate.
|
||||
for chunk in chunks:
|
||||
if chunk.startswith('%'):
|
||||
nodes.append(MagicNode(chunk))
|
||||
else:
|
||||
nodes.append(NormalNode(chunk))
|
||||
# TODO: look back here when Jupyter and sparkmagic are supported
|
||||
# if chunk.startswith('%'):
|
||||
# nodes.append(MagicNode(chunk))
|
||||
|
||||
nodes.append(NormalNode(chunk))
|
||||
|
||||
return nodes
|
||||
|
||||
|
||||
def execute_reply(status, content):
|
||||
msg = {
|
||||
'msg_type': 'execute_reply',
|
||||
@ -125,17 +149,15 @@ def execute_reply(status, content):
|
||||
}
|
||||
return json.dumps(msg)
|
||||
|
||||
|
||||
def execute_reply_ok(data):
|
||||
return execute_reply("ok", {
|
||||
"data": data,
|
||||
})
|
||||
|
||||
|
||||
def execute_reply_error(exc_type, exc_value, tb):
|
||||
# LOG.error('execute_reply', exc_info=True)
|
||||
if sys.version >= '3':
|
||||
formatted_tb = traceback.format_exception(exc_type, exc_value, tb, chain=False)
|
||||
else:
|
||||
formatted_tb = traceback.format_exception(exc_type, exc_value, tb)
|
||||
formatted_tb = traceback.format_exception(exc_type, exc_value, tb, chain=False)
|
||||
for i in range(len(formatted_tb)):
|
||||
if TOP_FRAME_REGEX.match(formatted_tb[i]):
|
||||
formatted_tb = formatted_tb[:1] + formatted_tb[i + 1:]
|
||||
@ -147,6 +169,15 @@ def execute_reply_error(exc_type, exc_value, tb):
|
||||
'traceback': formatted_tb,
|
||||
})
|
||||
|
||||
|
||||
def execute_reply_internal_error(message, exc_info=None):
|
||||
return execute_reply('error', {
|
||||
'ename': 'InternalError',
|
||||
'evalue': message,
|
||||
'traceback': [],
|
||||
})
|
||||
|
||||
|
||||
def execute_request(content):
|
||||
try:
|
||||
code = content['code']
|
||||
@ -193,49 +224,25 @@ def execute_request(content):
|
||||
|
||||
return execute_reply_ok(result)
|
||||
|
||||
# import findspark
|
||||
# findspark.init()
|
||||
|
||||
spark_home = os.environ.get("SPARK_HOME", "")
|
||||
os.environ["PYSPARK_PYTHON"] = os.environ.get("PYSPARK_PYTHON", sys.executable)
|
||||
# get or create spark session
|
||||
spark_session = kyuubi_util.get_spark_session()
|
||||
global_dict['spark'] = spark_session
|
||||
|
||||
# add pyspark to sys.path
|
||||
|
||||
if "pyspark" not in sys.modules:
|
||||
spark_python = os.path.join(spark_home, "python")
|
||||
try:
|
||||
py4j = glob(os.path.join(spark_python, "lib", "py4j-*.zip"))[0]
|
||||
except IndexError:
|
||||
raise Exception(
|
||||
"Unable to find py4j in {}, your SPARK_HOME may not be configured correctly".format(
|
||||
spark_python
|
||||
)
|
||||
)
|
||||
sys.path[:0] = sys_path = [spark_python, py4j]
|
||||
else:
|
||||
# already imported, no need to patch sys.path
|
||||
sys_path = None
|
||||
|
||||
import kyuubi_util
|
||||
spark = kyuubi_util.get_spark()
|
||||
global_dict['spark'] = spark
|
||||
|
||||
def main():
|
||||
sys_stdin = sys.stdin
|
||||
sys_stdout = sys.stdout
|
||||
sys_stderr = sys.stderr
|
||||
|
||||
if sys.version >= '3':
|
||||
sys.stdin = io.StringIO()
|
||||
else:
|
||||
sys.stdin = cStringIO.StringIO()
|
||||
|
||||
sys.stdin = io.StringIO()
|
||||
sys.stdout = UnicodeDecodingStringIO()
|
||||
sys.stderr = UnicodeDecodingStringIO()
|
||||
|
||||
stderr = sys.stderr.getvalue()
|
||||
print(stderr, file=sys_stderr)
|
||||
clearOutputs
|
||||
clearOutputs()
|
||||
|
||||
try:
|
||||
|
||||
while True:
|
||||
@ -249,7 +256,6 @@ def main():
|
||||
try:
|
||||
content = json.loads(line)
|
||||
except ValueError:
|
||||
# LOG.error('failed to parse message', exc_info=True)
|
||||
continue
|
||||
|
||||
if content['cmd'] == 'exit_worker':
|
||||
@ -265,5 +271,6 @@ def main():
|
||||
sys.stdout = sys_stdout
|
||||
sys.stderr = sys_stderr
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(main())
|
||||
|
||||
@ -15,29 +15,19 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import atexit
|
||||
import os
|
||||
import sys
|
||||
import signal
|
||||
import shlex
|
||||
import shutil
|
||||
import socket
|
||||
import platform
|
||||
import tempfile
|
||||
import time
|
||||
from subprocess import Popen, PIPE
|
||||
|
||||
from py4j.java_gateway import java_import, JavaGateway, JavaObject, GatewayParameters
|
||||
from py4j.clientserver import ClientServer, JavaParameters, PythonParameters
|
||||
from py4j.java_gateway import java_import, JavaGateway, GatewayParameters
|
||||
from pyspark.context import SparkContext
|
||||
from pyspark.serializers import read_int, write_with_length, UTF8Deserializer
|
||||
from pyspark.serializers import read_int, UTF8Deserializer
|
||||
from pyspark.sql import SparkSession
|
||||
|
||||
|
||||
def connect_to_exist_gateway():
|
||||
def connect_to_exist_gateway() -> "JavaGateway":
|
||||
conn_info_file = os.environ.get("PYTHON_GATEWAY_CONNECTION_INFO")
|
||||
if conn_info_file is None:
|
||||
raise SystemExit("the python gateway connection information file not found!")
|
||||
raise SystemExit("the python gateway connection information file not found!")
|
||||
with open(conn_info_file, "rb") as info:
|
||||
gateway_port = read_int(info)
|
||||
gateway_secret = UTF8Deserializer().loads(info)
|
||||
@ -72,16 +62,17 @@ def connect_to_exist_gateway():
|
||||
|
||||
return gateway
|
||||
|
||||
|
||||
def _get_exist_spark_context(self, jconf):
|
||||
"""
|
||||
Initialize SparkContext in function to allow subclass specific initialization
|
||||
"""
|
||||
return self._jvm.JavaSparkContext(self._jvm.org.apache.spark.SparkContext.getOrCreate(jconf))
|
||||
|
||||
def get_spark():
|
||||
|
||||
def get_spark_session() -> "SparkSession":
|
||||
SparkContext._initialize_context = _get_exist_spark_context
|
||||
gateway = connect_to_exist_gateway()
|
||||
SparkContext._ensure_initialized(gateway=gateway)
|
||||
spark = SparkSession.builder.master('local').appName('test').getOrCreate()
|
||||
return spark
|
||||
|
||||
|
||||
@ -98,14 +98,14 @@ case class SessionPythonWorker(
|
||||
private val stdout: BufferedReader =
|
||||
new BufferedReader(new InputStreamReader(workerProcess.getInputStream), 1)
|
||||
|
||||
def runCode(code: String): Option[PythonReponse] = {
|
||||
def runCode(code: String): Option[PythonResponse] = {
|
||||
val input = ExecutePython.toJson(Map("code" -> code, "cmd" -> "run_code"))
|
||||
// scalastyle:off println
|
||||
stdin.println(input)
|
||||
// scalastyle:on
|
||||
stdin.flush()
|
||||
Option(stdout.readLine())
|
||||
.map(ExecutePython.fromJson[PythonReponse](_))
|
||||
.map(ExecutePython.fromJson[PythonResponse](_))
|
||||
}
|
||||
|
||||
def close(): Unit = {
|
||||
@ -125,7 +125,7 @@ case class SessionPythonWorker(
|
||||
object ExecutePython extends Logging {
|
||||
|
||||
private val isPythonGatewayStart = new AtomicBoolean(false)
|
||||
val kyuubiPythonPath = Files.createTempDirectory("")
|
||||
private val kyuubiPythonPath = Files.createTempDirectory("")
|
||||
def init(): Unit = {
|
||||
if (!isPythonGatewayStart.get()) {
|
||||
synchronized {
|
||||
@ -186,7 +186,7 @@ object ExecutePython extends Logging {
|
||||
|
||||
private def startStderrSteamReader(process: Process): Thread = {
|
||||
val stderrThread = new Thread("process stderr thread") {
|
||||
override def run() = {
|
||||
override def run(): Unit = {
|
||||
val lines = scala.io.Source.fromInputStream(process.getErrorStream).getLines()
|
||||
lines.foreach(logger.error)
|
||||
}
|
||||
@ -198,7 +198,7 @@ object ExecutePython extends Logging {
|
||||
|
||||
def startWatcher(process: Process): Thread = {
|
||||
val processWatcherThread = new Thread("process watcher thread") {
|
||||
override def run() = {
|
||||
override def run(): Unit = {
|
||||
val exitCode = process.waitFor()
|
||||
if (exitCode != 0) {
|
||||
logger.error(f"Process has died with $exitCode")
|
||||
@ -229,7 +229,7 @@ object ExecutePython extends Logging {
|
||||
file
|
||||
}
|
||||
|
||||
val mapper = new ObjectMapper().registerModule(DefaultScalaModule)
|
||||
val mapper: ObjectMapper = new ObjectMapper().registerModule(DefaultScalaModule)
|
||||
def toJson[T](obj: T): String = {
|
||||
mapper.writeValueAsString(obj)
|
||||
}
|
||||
@ -243,7 +243,7 @@ object ExecutePython extends Logging {
|
||||
|
||||
}
|
||||
|
||||
case class PythonReponse(
|
||||
case class PythonResponse(
|
||||
msg_type: String,
|
||||
content: PythonResponseContent)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user