[KYUUBI #3820] [Subtask] [PySpark] Skip missing MagicNode and code improvements

### _Why are the changes needed?_

to close #3820 .

To improve pyspark script support,
1. skip missing MagicNode implementation, since Jupyter and sparkmagic are not yet supported
2. add missing execute_reply_internal_error method
3. fix by calling clearOutputs before loop
4. ident lines and optimze unsed imports to conform python code style
5. Check Python major version , and exit on Python 2.x
6. fix name typo of `PythonResponse`

### _How was this patch tested?_
- [ ] Add some test cases that check the changes thoroughly including negative and positive cases if possible

- [ ] Add screenshots for manual tests if appropriate

- [x] [Run test](https://kyuubi.apache.org/docs/latest/develop_tools/testing.html#running-tests) locally before make a pull request

Closes #3819 from bowenliang123/imrove-pyspark.

Closes #3820

473b9952 [liangbowen] add return type to `connect_to_existed_gateway`
66927821 [liangbowen] remove unnecessary comments for magic code
21e1d7a2 [liangbowen] move pyspark path preparing to the top of exeuction_python
9751e094 [liangbowen] revert to use SparkSessionBuilder for session creation
c4f3ef55 [liangbowen] use `SparkSession._create_shell_session()` to create spark session
c2f65630 [liangbowen] delay importing kyuubi_util
5ed893cc [liangbowen] adding Exception to except, to prevent PEP 8: E203
029361a9 [liangbowen] ast module adaptation for >=3.8
00c75fda [liangbowen] remove legacy code for importing unicode
9f56a4f4 [liangbowen] add todo
1da708ed [liangbowen] fix typo for PythonResponse, and minor declaration improvement
910c62fb [liangbowen] remove MagicNode implementation since Jupyter and sparkmagic are not yet supported
5f15c257 [liangbowen] exit on python 2.x
86ff7d06 [liangbowen] ident lines to conform python code style
5634c5e0 [liangbowen] rename get_spark to get_spark_session, and optimize unused imports in kyuubi_util.py
9d3e1d0c [liangbowen] add missing MagicNode implementation
0ade1dbe [liangbowen] add missing execute_reply_internal_error method
aee205a5 [liangbowen] import cStringIO for fix package resolving problem
acdd4b16 [liangbowen] fix by calling clearOutputs before loop

Authored-by: liangbowen <liangbowen@gf.com.cn>
Signed-off-by: Cheng Pan <chengpan@apache.org>
This commit is contained in:
liangbowen 2022-11-22 10:16:52 +08:00 committed by Cheng Pan
parent 47e1cfdf08
commit 3fae1845e7
No known key found for this signature in database
GPG Key ID: 8001952629BCC75D
3 changed files with 83 additions and 85 deletions

View File

@ -15,28 +15,57 @@
# limitations under the License.
#
from glob import glob
import ast
import sys
import io
import json
import traceback
import re
import os
import re
import sys
import traceback
from glob import glob
if sys.version_info[0] < 3:
sys.exit('Python < 3 is unsupported.')
spark_home = os.environ.get("SPARK_HOME", "")
os.environ["PYSPARK_PYTHON"] = os.environ.get("PYSPARK_PYTHON", sys.executable)
# add pyspark to sys.path
if "pyspark" not in sys.modules:
spark_python = os.path.join(spark_home, "python")
try:
py4j = glob(os.path.join(spark_python, "lib", "py4j-*.zip"))[0]
except IndexError:
raise Exception(
"Unable to find py4j in {}, your SPARK_HOME may not be configured correctly".format(
spark_python
)
)
sys.path[:0] = sys_path = [spark_python, py4j]
else:
# already imported, no need to patch sys.path
sys_path = None
# import kyuubi_util after preparing sys.path
import kyuubi_util
# ast api is changed after python 3.8, see https://github.com/ipython/ipython/pull/11593
if sys.version_info > (3,8):
from ast import Module
else :
# mock the new API, ignore second argument
# see https://github.com/ipython/ipython/issues/11590
from ast import Module as OriginalModule
Module = lambda nodelist, type_ignores: OriginalModule(nodelist)
if sys.version_info >= (3, 8):
from ast import Module
else:
# mock the new API, ignore second argument
# see https://github.com/ipython/ipython/issues/11590
from ast import Module as OriginalModule
Module = lambda nodelist, type_ignores: OriginalModule(nodelist)
TOP_FRAME_REGEX = re.compile(r'\s*File "<stdin>".*in <module>')
global_dict = {}
class NormalNode(object):
def __init__(self, code):
self.code = compile(code, '<stdin>', 'exec', ast.PyCF_ONLY_AST, 1)
@ -54,21 +83,24 @@ class NormalNode(object):
mod = ast.Interactive([node])
code = compile(mod, '<stdin>', 'single')
exec(code, global_dict)
except:
except Exception:
# We don't need to log the exception because we're just executing user
# code and passing the error along.
raise ExecutionError(sys.exc_info())
class ExecutionError(Exception):
def __init__(self, exc_info):
self.exc_info = exc_info
class UnicodeDecodingStringIO(io.StringIO):
def write(self, s):
if isinstance(s, bytes):
s = s.decode("utf-8")
super(UnicodeDecodingStringIO, self).write(s)
def clearOutputs():
sys.stdout.close()
sys.stderr.close()
@ -81,16 +113,6 @@ def parse_code_into_nodes(code):
try:
nodes.append(NormalNode(code))
except SyntaxError:
# It's possible we hit a syntax error because of a magic command. Split the code groups
# of 'normal code', and code that starts with a '%'. possibly magic code
# lines, and see if any of the lines
# Remove lines until we find a node that parses, then check if the next line is a magic
# line
# .
# Split the code into chunks of normal code, and possibly magic code, which starts with
# a '%'.
normal = []
chunks = []
for i, line in enumerate(code.rstrip().split('\n')):
@ -108,13 +130,15 @@ def parse_code_into_nodes(code):
# Convert the chunks into AST nodes. Let exceptions propagate.
for chunk in chunks:
if chunk.startswith('%'):
nodes.append(MagicNode(chunk))
else:
nodes.append(NormalNode(chunk))
# TODO: look back here when Jupyter and sparkmagic are supported
# if chunk.startswith('%'):
# nodes.append(MagicNode(chunk))
nodes.append(NormalNode(chunk))
return nodes
def execute_reply(status, content):
msg = {
'msg_type': 'execute_reply',
@ -125,17 +149,15 @@ def execute_reply(status, content):
}
return json.dumps(msg)
def execute_reply_ok(data):
return execute_reply("ok", {
"data": data,
})
def execute_reply_error(exc_type, exc_value, tb):
# LOG.error('execute_reply', exc_info=True)
if sys.version >= '3':
formatted_tb = traceback.format_exception(exc_type, exc_value, tb, chain=False)
else:
formatted_tb = traceback.format_exception(exc_type, exc_value, tb)
formatted_tb = traceback.format_exception(exc_type, exc_value, tb, chain=False)
for i in range(len(formatted_tb)):
if TOP_FRAME_REGEX.match(formatted_tb[i]):
formatted_tb = formatted_tb[:1] + formatted_tb[i + 1:]
@ -147,6 +169,15 @@ def execute_reply_error(exc_type, exc_value, tb):
'traceback': formatted_tb,
})
def execute_reply_internal_error(message, exc_info=None):
return execute_reply('error', {
'ename': 'InternalError',
'evalue': message,
'traceback': [],
})
def execute_request(content):
try:
code = content['code']
@ -193,49 +224,25 @@ def execute_request(content):
return execute_reply_ok(result)
# import findspark
# findspark.init()
spark_home = os.environ.get("SPARK_HOME", "")
os.environ["PYSPARK_PYTHON"] = os.environ.get("PYSPARK_PYTHON", sys.executable)
# get or create spark session
spark_session = kyuubi_util.get_spark_session()
global_dict['spark'] = spark_session
# add pyspark to sys.path
if "pyspark" not in sys.modules:
spark_python = os.path.join(spark_home, "python")
try:
py4j = glob(os.path.join(spark_python, "lib", "py4j-*.zip"))[0]
except IndexError:
raise Exception(
"Unable to find py4j in {}, your SPARK_HOME may not be configured correctly".format(
spark_python
)
)
sys.path[:0] = sys_path = [spark_python, py4j]
else:
# already imported, no need to patch sys.path
sys_path = None
import kyuubi_util
spark = kyuubi_util.get_spark()
global_dict['spark'] = spark
def main():
sys_stdin = sys.stdin
sys_stdout = sys.stdout
sys_stderr = sys.stderr
if sys.version >= '3':
sys.stdin = io.StringIO()
else:
sys.stdin = cStringIO.StringIO()
sys.stdin = io.StringIO()
sys.stdout = UnicodeDecodingStringIO()
sys.stderr = UnicodeDecodingStringIO()
stderr = sys.stderr.getvalue()
print(stderr, file=sys_stderr)
clearOutputs
clearOutputs()
try:
while True:
@ -249,7 +256,6 @@ def main():
try:
content = json.loads(line)
except ValueError:
# LOG.error('failed to parse message', exc_info=True)
continue
if content['cmd'] == 'exit_worker':
@ -265,5 +271,6 @@ def main():
sys.stdout = sys_stdout
sys.stderr = sys_stderr
if __name__ == '__main__':
sys.exit(main())

View File

@ -15,29 +15,19 @@
# limitations under the License.
#
import atexit
import os
import sys
import signal
import shlex
import shutil
import socket
import platform
import tempfile
import time
from subprocess import Popen, PIPE
from py4j.java_gateway import java_import, JavaGateway, JavaObject, GatewayParameters
from py4j.clientserver import ClientServer, JavaParameters, PythonParameters
from py4j.java_gateway import java_import, JavaGateway, GatewayParameters
from pyspark.context import SparkContext
from pyspark.serializers import read_int, write_with_length, UTF8Deserializer
from pyspark.serializers import read_int, UTF8Deserializer
from pyspark.sql import SparkSession
def connect_to_exist_gateway():
def connect_to_exist_gateway() -> "JavaGateway":
conn_info_file = os.environ.get("PYTHON_GATEWAY_CONNECTION_INFO")
if conn_info_file is None:
raise SystemExit("the python gateway connection information file not found!")
raise SystemExit("the python gateway connection information file not found!")
with open(conn_info_file, "rb") as info:
gateway_port = read_int(info)
gateway_secret = UTF8Deserializer().loads(info)
@ -72,16 +62,17 @@ def connect_to_exist_gateway():
return gateway
def _get_exist_spark_context(self, jconf):
"""
Initialize SparkContext in function to allow subclass specific initialization
"""
return self._jvm.JavaSparkContext(self._jvm.org.apache.spark.SparkContext.getOrCreate(jconf))
def get_spark():
def get_spark_session() -> "SparkSession":
SparkContext._initialize_context = _get_exist_spark_context
gateway = connect_to_exist_gateway()
SparkContext._ensure_initialized(gateway=gateway)
spark = SparkSession.builder.master('local').appName('test').getOrCreate()
return spark

View File

@ -98,14 +98,14 @@ case class SessionPythonWorker(
private val stdout: BufferedReader =
new BufferedReader(new InputStreamReader(workerProcess.getInputStream), 1)
def runCode(code: String): Option[PythonReponse] = {
def runCode(code: String): Option[PythonResponse] = {
val input = ExecutePython.toJson(Map("code" -> code, "cmd" -> "run_code"))
// scalastyle:off println
stdin.println(input)
// scalastyle:on
stdin.flush()
Option(stdout.readLine())
.map(ExecutePython.fromJson[PythonReponse](_))
.map(ExecutePython.fromJson[PythonResponse](_))
}
def close(): Unit = {
@ -125,7 +125,7 @@ case class SessionPythonWorker(
object ExecutePython extends Logging {
private val isPythonGatewayStart = new AtomicBoolean(false)
val kyuubiPythonPath = Files.createTempDirectory("")
private val kyuubiPythonPath = Files.createTempDirectory("")
def init(): Unit = {
if (!isPythonGatewayStart.get()) {
synchronized {
@ -186,7 +186,7 @@ object ExecutePython extends Logging {
private def startStderrSteamReader(process: Process): Thread = {
val stderrThread = new Thread("process stderr thread") {
override def run() = {
override def run(): Unit = {
val lines = scala.io.Source.fromInputStream(process.getErrorStream).getLines()
lines.foreach(logger.error)
}
@ -198,7 +198,7 @@ object ExecutePython extends Logging {
def startWatcher(process: Process): Thread = {
val processWatcherThread = new Thread("process watcher thread") {
override def run() = {
override def run(): Unit = {
val exitCode = process.waitFor()
if (exitCode != 0) {
logger.error(f"Process has died with $exitCode")
@ -229,7 +229,7 @@ object ExecutePython extends Logging {
file
}
val mapper = new ObjectMapper().registerModule(DefaultScalaModule)
val mapper: ObjectMapper = new ObjectMapper().registerModule(DefaultScalaModule)
def toJson[T](obj: T): String = {
mapper.writeValueAsString(obj)
}
@ -243,7 +243,7 @@ object ExecutePython extends Logging {
}
case class PythonReponse(
case class PythonResponse(
msg_type: String,
content: PythonResponseContent)