[KYUUBI #1499] Introduce DataFrameHolder for cli result fetching

<!--
Thanks for sending a pull request!

Here are some tips for you:
  1. If this is your first time, please read our contributor guidelines: https://kyuubi.readthedocs.io/en/latest/community/contributions.html
  2. If the PR is related to an issue in https://github.com/apache/incubator-kyuubi/issues, add '[KYUUBI #XXXX]' in your PR title, e.g., '[KYUUBI #XXXX] Your PR title ...'.
  3. If the PR is unfinished, add '[WIP]' in your PR title, e.g., '[WIP][KYUUBI #XXXX] Your PR title ...'.
-->

### _Why are the changes needed?_
<!--
Please clarify why the changes are needed. For instance,
  1. If you add a feature, you can talk about the use case of it.
  2. If you fix a bug, you can clarify why it is a bug.
-->
Replace ArrayList with DataFrameHolder

### _How was this patch tested?_
- [X] Add some test cases that check the changes thoroughly including negative and positive cases if possible

- [ ] Add screenshots for manual tests if appropriate

- [X] [Run test](https://kyuubi.readthedocs.io/en/latest/develop_tools/testing.html#running-tests) locally before make a pull request

Closes #1507 from yaooqinn/1499.

Closes #1499

7fd2b0ef [Kent Yao] fix it
6278009b [Kent Yao] root
de3c601d [Kent Yao] fi
358d7a68 [Kent Yao] refine
8373b9b3 [Kent Yao] refine
ab95f7dc [Kent Yao] refine
86d90b80 [Kent Yao] loader
a07117c4 [Kent Yao] nit
90a7dd4f [Kent Yao] [KYUUBI #1499] Introduce DataFrameHolder for cli result fetching
8d97b51d [Kent Yao] [KYUUBI #1499] Introduce DataFrameHolder for cli result fetching

Authored-by: Kent Yao <yao@apache.org>
Signed-off-by: Kent Yao <yao@apache.org>
This commit is contained in:
Kent Yao 2021-12-17 10:45:05 +08:00
parent 6e10eec4ca
commit b2e679d5c7
No known key found for this signature in database
GPG Key ID: F7051850A0AF904D
12 changed files with 153 additions and 60 deletions

View File

@ -26,6 +26,9 @@ import org.apache.kyuubi.Utils
object KyuubiSparkUtil {
final val SPARK_SCHEDULER_POOL_KEY = "spark.scheduler.pool"
final val SPARK_SQL_EXECUTION_ID_KEY = "spark.sql.execution.id"
def globalSparkContext: SparkContext = SparkSession.active.sparkContext
def engineId: String =

View File

@ -32,8 +32,8 @@ import org.apache.kyuubi.session.Session
* Support executing Scala Script with or without common Spark APIs, only support running in sync
* mode, as an operation may [[Incomplete]] and wait for others to make [[Success]].
*
* [[KyuubiSparkILoop.results]] is exposed as a [[org.apache.spark.sql.DataFrame]] to users in repl
* to transfer result they wanted to client side.
* [[KyuubiSparkILoop.result]] is exposed as a [[org.apache.spark.sql.DataFrame]] holder to users
* in repl to transfer result they wanted to client side.
*
* @param session parent session
* @param repl Scala Interpreter
@ -56,20 +56,26 @@ class ExecuteScala(
}
}
override protected def runInternal(): Unit = {
override protected def runInternal(): Unit = withLocalProperties {
try {
OperationLog.setCurrentOperationLog(operationLog)
spark.sparkContext.setJobGroup(statementId, statement)
Thread.currentThread().setContextClassLoader(spark.sharedState.jarClassLoader)
val legacyOutput = repl.getOutput
if (legacyOutput.nonEmpty) {
warn(s"Clearing legacy output from last interpreting:\n $legacyOutput")
}
repl.interpretWithRedirectOutError(statement) match {
case Success =>
iter =
if (repl.results.nonEmpty) {
result = repl.results.remove(0)
iter = {
result = repl.getResult(statementId)
if (result != null) {
new ArrayFetchIterator[Row](result.collect())
} else {
// TODO (#1498): Maybe we shall pass the output through operation log
// but some clients may not support operation log
new ArrayFetchIterator[Row](Array(Row(repl.getOutput)))
}
}
case Error =>
throw KyuubiSQLException(s"Interpret error:\n$statement\n ${repl.getOutput}")
case Incomplete =>
@ -78,7 +84,7 @@ class ExecuteScala(
} catch {
onError(cancel = true)
} finally {
spark.sparkContext.clearJobGroup()
repl.clearResult(statementId)
}
}
}

View File

@ -26,8 +26,7 @@ import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
import org.apache.kyuubi.{KyuubiSQLException, Logging}
import org.apache.kyuubi.config.KyuubiConf
import org.apache.kyuubi.engine.spark.KyuubiSparkUtil
import org.apache.kyuubi.engine.spark.KyuubiSparkUtil._
import org.apache.kyuubi.engine.spark.events.{EventLoggingService, SparkStatementEvent}
import org.apache.kyuubi.operation.{ArrayFetchIterator, IterableFetchIterator, OperationState, OperationType}
import org.apache.kyuubi.operation.OperationState.OperationState
@ -43,15 +42,6 @@ class ExecuteStatement(
incrementalCollect: Boolean)
extends SparkOperation(OperationType.EXECUTE_STATEMENT, session) with Logging {
import org.apache.kyuubi.KyuubiSparkUtils._
private val forceCancel =
session.sessionManager.getConf.get(KyuubiConf.OPERATION_FORCE_CANCEL)
private val schedulerPool =
spark.conf.getOption(KyuubiConf.OPERATION_SCHEDULER_POOL.key).orElse(
session.sessionManager.getConf.get(KyuubiConf.OPERATION_SCHEDULER_POOL))
private var statementTimeoutCleaner: Option[ScheduledExecutorService] = None
private val operationLog: OperationLog = OperationLog.createOperationLog(session, getHandle)
@ -91,7 +81,7 @@ class ExecuteStatement(
private def executeStatement(): Unit = withLocalProperties {
try {
setState(OperationState.RUNNING)
info(KyuubiSparkUtil.diagnostics)
info(diagnostics)
Thread.currentThread().setContextClassLoader(spark.sharedState.jarClassLoader)
// TODO: Make it configurable
spark.sparkContext.addSparkListener(operationListener)
@ -143,26 +133,6 @@ class ExecuteStatement(
}
}
private def withLocalProperties[T](f: => T): T = {
try {
spark.sparkContext.setJobGroup(statementId, statement, forceCancel)
spark.sparkContext.setLocalProperty(KYUUBI_SESSION_USER_KEY, session.user)
spark.sparkContext.setLocalProperty(KYUUBI_STATEMENT_ID_KEY, statementId)
schedulerPool match {
case Some(pool) =>
spark.sparkContext.setLocalProperty(SPARK_SCHEDULER_POOL_KEY, pool)
case None =>
}
f
} finally {
spark.sparkContext.setLocalProperty(SPARK_SCHEDULER_POOL_KEY, null)
spark.sparkContext.setLocalProperty(KYUUBI_SESSION_USER_KEY, null)
spark.sparkContext.setLocalProperty(KYUUBI_STATEMENT_ID_KEY, null)
spark.sparkContext.clearJobGroup()
}
}
private def addTimeoutMonitor(): Unit = {
if (queryTimeout > 0) {
val timeoutExecutor =

View File

@ -26,6 +26,9 @@ import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.types.StructType
import org.apache.kyuubi.{KyuubiSQLException, Utils}
import org.apache.kyuubi.config.KyuubiConf
import org.apache.kyuubi.config.KyuubiReservedKeys.{KYUUBI_SESSION_USER_KEY, KYUUBI_STATEMENT_ID_KEY}
import org.apache.kyuubi.engine.spark.KyuubiSparkUtil.SPARK_SCHEDULER_POOL_KEY
import org.apache.kyuubi.engine.spark.operation.SparkOperation.TIMEZONE_KEY
import org.apache.kyuubi.engine.spark.schema.RowSet
import org.apache.kyuubi.engine.spark.schema.SchemaHelper
@ -72,7 +75,7 @@ abstract class SparkOperation(opType: OperationType, session: Session)
* @param input the SQL pattern to convert
* @return the equivalent Java regular expression of the pattern
*/
def toJavaRegex(input: String): String = {
protected def toJavaRegex(input: String): String = {
val res =
if (StringUtils.isEmpty(input) || input == "*") {
"%"
@ -85,6 +88,33 @@ abstract class SparkOperation(opType: OperationType, session: Session)
.replaceAll("([^\\\\])_", "$1.").replaceAll("\\\\_", "_").replaceAll("^_", ".")
}
private val forceCancel =
session.sessionManager.getConf.get(KyuubiConf.OPERATION_FORCE_CANCEL)
private val schedulerPool =
spark.conf.getOption(KyuubiConf.OPERATION_SCHEDULER_POOL.key).orElse(
session.sessionManager.getConf.get(KyuubiConf.OPERATION_SCHEDULER_POOL))
protected def withLocalProperties[T](f: => T): T = {
try {
spark.sparkContext.setJobGroup(statementId, statement, forceCancel)
spark.sparkContext.setLocalProperty(KYUUBI_SESSION_USER_KEY, session.user)
spark.sparkContext.setLocalProperty(KYUUBI_STATEMENT_ID_KEY, statementId)
schedulerPool match {
case Some(pool) =>
spark.sparkContext.setLocalProperty(SPARK_SCHEDULER_POOL_KEY, pool)
case None =>
}
f
} finally {
spark.sparkContext.setLocalProperty(SPARK_SCHEDULER_POOL_KEY, null)
spark.sparkContext.setLocalProperty(KYUUBI_SESSION_USER_KEY, null)
spark.sparkContext.setLocalProperty(KYUUBI_STATEMENT_ID_KEY, null)
spark.sparkContext.clearJobGroup()
}
}
protected def onError(cancel: Boolean = false): PartialFunction[Throwable, Unit] = {
// We should use Throwable instead of Exception since `java.lang.NoClassDefFoundError`
// could be thrown.

View File

@ -0,0 +1,61 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kyuubi.engine.spark.repl
import java.util.HashMap
import org.apache.spark.kyuubi.SparkContextHelper
import org.apache.spark.sql.{DataFrame, SparkSession}
/**
* Helper class to wrap a [[DataFrame]] and pass its results to clients
*
* @since 1.5.0
*/
class DataFrameHolder(spark: SparkSession) {
private val results = new HashMap[String, DataFrame]()
private def currentId: String = {
SparkContextHelper.getCurrentStatementId(spark.sparkContext)
}
/**
* Set Results
* @param df a DataFrame for pass result to to clients
*/
def set(df: DataFrame): Unit = {
results.put(currentId, df)
}
/**
* Get Result
* @param statementId kyuubi statement id
*/
def get(statementId: String): DataFrame = {
results.get(statementId)
}
/**
* Clear Result
* @param statementId kyuubi statement id
*/
def unset(statementId: String): Unit = {
results.remove(statementId)
}
}

View File

@ -19,13 +19,13 @@ package org.apache.kyuubi.engine.spark.repl
import java.io.{ByteArrayOutputStream, File}
import scala.collection.mutable.ArrayBuffer
import scala.tools.nsc.Settings
import scala.tools.nsc.interpreter.IR
import scala.tools.nsc.interpreter.JPrintWriter
import org.apache.spark.SparkContext
import org.apache.spark.repl.{Main, SparkILoop}
import org.apache.spark.sql.{Dataset, Row, SparkSession}
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.util.MutableURLClassLoader
private[spark] case class KyuubiSparkILoop private (
@ -33,8 +33,7 @@ private[spark] case class KyuubiSparkILoop private (
output: ByteArrayOutputStream)
extends SparkILoop(None, new JPrintWriter(output)) {
// TODO: this is a little hacky
val results = new ArrayBuffer[Dataset[Row]]()
val result = new DataFrameHolder(spark)
private def initialize(): Unit = {
settings = new Settings
@ -51,7 +50,7 @@ private[spark] case class KyuubiSparkILoop private (
try {
this.compilerClasspath
this.ensureClassLoader()
var classLoader = Thread.currentThread().getContextClassLoader
var classLoader: ClassLoader = Thread.currentThread().getContextClassLoader
while (classLoader != null) {
classLoader match {
case loader: MutableURLClassLoader =>
@ -66,6 +65,9 @@ private[spark] case class KyuubiSparkILoop private (
classLoader = classLoader.getParent
}
}
this.addUrlsToClassPath(
classOf[DataFrameHolder].getProtectionDomain.getCodeSource.getLocation)
} finally {
Thread.currentThread().setContextClassLoader(currentClassLoader)
}
@ -86,14 +88,17 @@ private[spark] case class KyuubiSparkILoop private (
// for feeding results to client, e.g. beeline
this.bind(
"results",
"scala.collection.mutable.ArrayBuffer[" +
"org.apache.spark.sql.Dataset[org.apache.spark.sql.Row]]",
results)
"result",
classOf[DataFrameHolder].getCanonicalName,
result)
}
}
def interpretWithRedirectOutError(statement: String): scala.tools.nsc.interpreter.IR.Result = {
def getResult(statementId: String): DataFrame = result.get(statementId)
def clearResult(statementId: String): Unit = result.unset(statementId)
def interpretWithRedirectOutError(statement: String): IR.Result = {
Console.withOut(output) {
Console.withErr(output) {
this.interpret(statement)

View File

@ -25,7 +25,7 @@ import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.functions.udf
import org.apache.kyuubi.{KYUUBI_VERSION, Utils}
import org.apache.kyuubi.KyuubiSparkUtils.KYUUBI_SESSION_USER_KEY
import org.apache.kyuubi.config.KyuubiReservedKeys.KYUUBI_SESSION_USER_KEY
object KDFRegistry {

View File

@ -23,8 +23,9 @@ import org.apache.spark.scheduler._
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionEnd
import org.apache.kyuubi.KyuubiSparkUtils._
import org.apache.kyuubi.Logging
import org.apache.kyuubi.config.KyuubiReservedKeys.KYUUBI_STATEMENT_ID_KEY
import org.apache.kyuubi.engine.spark.KyuubiSparkUtil.SPARK_SQL_EXECUTION_ID_KEY
import org.apache.kyuubi.operation.Operation
import org.apache.kyuubi.operation.log.OperationLog

View File

@ -26,6 +26,7 @@ import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend
import org.apache.spark.scheduler.local.LocalSchedulerBackend
import org.apache.kyuubi.Logging
import org.apache.kyuubi.config.KyuubiReservedKeys.KYUUBI_STATEMENT_ID_KEY
import org.apache.kyuubi.engine.spark.events.KyuubiSparkEvent
import org.apache.kyuubi.events.EventLogger
@ -52,6 +53,24 @@ object SparkContextHelper extends Logging {
}
}
/**
* Get a local property set in this thread, or null if it is missing. See
* `org.apache.spark.SparkContext.setLocalProperty`.
*/
private def getLocalProperty(sc: SparkContext, propertyKey: String): String = {
sc.getLocalProperty(propertyKey)
}
/**
* Get `KYUUBI_STATEMENT_ID_KEY` set in this thread, or null if it is missing.
*
* @param sc an active SparkContext
* @return the current statementId or null
*/
def getCurrentStatementId(sc: SparkContext): String = {
getLocalProperty(sc, KYUUBI_STATEMENT_ID_KEY)
}
}
/**

View File

@ -26,10 +26,10 @@ import scala.annotation.tailrec
import org.apache.spark.SparkException
import org.apache.spark.scheduler._
import org.apache.kyuubi.KyuubiSparkUtils.KYUUBI_STATEMENT_ID_KEY
import org.apache.kyuubi.Logging
import org.apache.kyuubi.Utils.stringifyException
import org.apache.kyuubi.config.KyuubiConf._
import org.apache.kyuubi.config.KyuubiReservedKeys.KYUUBI_STATEMENT_ID_KEY
import org.apache.kyuubi.engine.spark.events.{EngineEventsStore, SessionEvent, SparkStatementEvent}
import org.apache.kyuubi.service.{Serverable, ServiceState}

View File

@ -15,11 +15,9 @@
* limitations under the License.
*/
package org.apache.kyuubi
package org.apache.kyuubi.config
object KyuubiSparkUtils {
object KyuubiReservedKeys {
final val KYUUBI_SESSION_USER_KEY = "kyuubi.session.user"
final val KYUUBI_STATEMENT_ID_KEY = "kyuubi.statement.id"
final val SPARK_SCHEDULER_POOL_KEY = "spark.scheduler.pool"
final val SPARK_SQL_EXECUTION_ID_KEY = "spark.sql.execution.id"
}

View File

@ -444,7 +444,7 @@ trait SparkQueryTests extends HiveJDBCTestHelper {
assert(rs2.getString(1).endsWith("5"))
// continue
val rs3 = statement.executeQuery("results += df")
val rs3 = statement.executeQuery("result.set(df)")
for (i <- Range(0, 10, 2)) {
assert(rs3.next)
assert(rs3.getInt(1) === i)
@ -480,7 +480,7 @@ trait SparkQueryTests extends HiveJDBCTestHelper {
assert(rs5.getString(1) startsWith "df: org.apache.spark.sql.DataFrame")
// re-assign
val rs6 = statement.executeQuery("results += df")
val rs6 = statement.executeQuery("result.set(df)")
for (i <- Range(0, 10, 2)) {
assert(rs6.next)
assert(rs6.getInt(2) === i + 1)