[KYUUBI #4392] [ARROW] Assign a new execution id for arrow-based result
### _Why are the changes needed?_ assign a new execution id for arrow-based result, so that we can track the arrow-based queries on the UI tab. ```sql set kyuubi.operation.result.format=arrow; select 1; ``` Before this PR:   After this PR:   ### _How was this patch tested?_ - [ ] Add some test cases that check the changes thoroughly including negative and positive cases if possible - [ ] Add screenshots for manual tests if appropriate - [x] [Run test](https://kyuubi.readthedocs.io/en/master/develop_tools/testing.html#running-tests) locally before make a pull request Closes #4392 from cfmcgrady/arrow-execution-id-2. Closes #4392 481118a4 [Fu Chen] enable ut c90674ee [Fu Chen] address comment 6cc7af44 [Fu Chen] address comment 3f8a3ab8 [Fu Chen] fix ut 223a2469 [Fu Chen] add KyuubiSparkContextHelper bb7b28f5 [Fu Chen] fix style 879a1502 [Fu Chen] unnecessary changes a2b04f83 [Fu Chen] fix Authored-by: Fu Chen <cfmcgrady@gmail.com> Signed-off-by: Cheng Pan <chengpan@apache.org>
This commit is contained in:
parent
3191fa22fc
commit
f0acff315c
@ -22,13 +22,14 @@ import java.util.concurrent.RejectedExecutionException
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import org.apache.spark.sql.DataFrame
|
||||
import org.apache.spark.sql.execution.SQLExecution
|
||||
import org.apache.spark.sql.kyuubi.SparkDatasetHelper
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
import org.apache.kyuubi.{KyuubiSQLException, Logging}
|
||||
import org.apache.kyuubi.config.KyuubiConf.OPERATION_RESULT_MAX_ROWS
|
||||
import org.apache.kyuubi.engine.spark.KyuubiSparkUtil._
|
||||
import org.apache.kyuubi.operation.{ArrayFetchIterator, IterableFetchIterator, OperationState}
|
||||
import org.apache.kyuubi.operation.{ArrayFetchIterator, FetchIterator, IterableFetchIterator, OperationState}
|
||||
import org.apache.kyuubi.operation.log.OperationLog
|
||||
import org.apache.kyuubi.session.Session
|
||||
|
||||
@ -62,49 +63,49 @@ class ExecuteStatement(
|
||||
OperationLog.removeCurrentOperationLog()
|
||||
}
|
||||
|
||||
private def executeStatement(): Unit = withLocalProperties {
|
||||
protected def incrementalCollectResult(resultDF: DataFrame): Iterator[Any] = {
|
||||
resultDF.toLocalIterator().asScala
|
||||
}
|
||||
|
||||
protected def fullCollectResult(resultDF: DataFrame): Array[_] = {
|
||||
resultDF.collect()
|
||||
}
|
||||
|
||||
protected def takeResult(resultDF: DataFrame, maxRows: Int): Array[_] = {
|
||||
resultDF.take(maxRows)
|
||||
}
|
||||
|
||||
protected def collectAsIterator(resultDF: DataFrame): FetchIterator[_] = {
|
||||
val resultMaxRows = spark.conf.getOption(OPERATION_RESULT_MAX_ROWS.key).map(_.toInt)
|
||||
.getOrElse(session.sessionManager.getConf.get(OPERATION_RESULT_MAX_ROWS))
|
||||
if (incrementalCollect) {
|
||||
if (resultMaxRows > 0) {
|
||||
warn(s"Ignore ${OPERATION_RESULT_MAX_ROWS.key} on incremental collect mode.")
|
||||
}
|
||||
info("Execute in incremental collect mode")
|
||||
new IterableFetchIterator[Any](new Iterable[Any] {
|
||||
override def iterator: Iterator[Any] = incrementalCollectResult(resultDF)
|
||||
})
|
||||
} else {
|
||||
val internalArray = if (resultMaxRows <= 0) {
|
||||
info("Execute in full collect mode")
|
||||
fullCollectResult(resultDF)
|
||||
} else {
|
||||
info(s"Execute with max result rows[$resultMaxRows]")
|
||||
takeResult(resultDF, resultMaxRows)
|
||||
}
|
||||
new ArrayFetchIterator(internalArray)
|
||||
}
|
||||
}
|
||||
|
||||
protected def executeStatement(): Unit = withLocalProperties {
|
||||
try {
|
||||
setState(OperationState.RUNNING)
|
||||
info(diagnostics)
|
||||
Thread.currentThread().setContextClassLoader(spark.sharedState.jarClassLoader)
|
||||
addOperationListener()
|
||||
result = spark.sql(statement)
|
||||
|
||||
val resultMaxRows = spark.conf.getOption(OPERATION_RESULT_MAX_ROWS.key).map(_.toInt)
|
||||
.getOrElse(session.sessionManager.getConf.get(OPERATION_RESULT_MAX_ROWS))
|
||||
iter = if (incrementalCollect) {
|
||||
if (resultMaxRows > 0) {
|
||||
warn(s"Ignore ${OPERATION_RESULT_MAX_ROWS.key} on incremental collect mode.")
|
||||
}
|
||||
info("Execute in incremental collect mode")
|
||||
def internalIterator(): Iterator[Any] = if (arrowEnabled) {
|
||||
SparkDatasetHelper.toArrowBatchRdd(convertComplexType(result)).toLocalIterator
|
||||
} else {
|
||||
result.toLocalIterator().asScala
|
||||
}
|
||||
new IterableFetchIterator[Any](new Iterable[Any] {
|
||||
override def iterator: Iterator[Any] = internalIterator()
|
||||
})
|
||||
} else {
|
||||
val internalArray = if (resultMaxRows <= 0) {
|
||||
info("Execute in full collect mode")
|
||||
if (arrowEnabled) {
|
||||
SparkDatasetHelper.toArrowBatchRdd(convertComplexType(result)).collect()
|
||||
} else {
|
||||
result.collect()
|
||||
}
|
||||
} else {
|
||||
info(s"Execute with max result rows[$resultMaxRows]")
|
||||
if (arrowEnabled) {
|
||||
// this will introduce shuffle and hurt performance
|
||||
val limitedResult = result.limit(resultMaxRows)
|
||||
SparkDatasetHelper.toArrowBatchRdd(convertComplexType(limitedResult)).collect()
|
||||
} else {
|
||||
result.take(resultMaxRows)
|
||||
}
|
||||
}
|
||||
new ArrayFetchIterator(internalArray)
|
||||
}
|
||||
iter = collectAsIterator(result)
|
||||
setCompiledStateIfNeeded()
|
||||
setState(OperationState.FINISHED)
|
||||
} catch {
|
||||
@ -171,3 +172,40 @@ class ExecuteStatement(
|
||||
s"__kyuubi_operation_result_format__=$resultFormat",
|
||||
s"__kyuubi_operation_result_arrow_timestampAsString__=$timestampAsString")
|
||||
}
|
||||
|
||||
class ArrowBasedExecuteStatement(
|
||||
session: Session,
|
||||
override val statement: String,
|
||||
override val shouldRunAsync: Boolean,
|
||||
queryTimeout: Long,
|
||||
incrementalCollect: Boolean)
|
||||
extends ExecuteStatement(session, statement, shouldRunAsync, queryTimeout, incrementalCollect) {
|
||||
|
||||
override protected def incrementalCollectResult(resultDF: DataFrame): Iterator[Any] = {
|
||||
SparkDatasetHelper.toArrowBatchRdd(convertComplexType(resultDF)).toLocalIterator
|
||||
}
|
||||
|
||||
override protected def fullCollectResult(resultDF: DataFrame): Array[_] = {
|
||||
SparkDatasetHelper.toArrowBatchRdd(convertComplexType(resultDF)).collect()
|
||||
}
|
||||
|
||||
override protected def takeResult(resultDF: DataFrame, maxRows: Int): Array[_] = {
|
||||
// this will introduce shuffle and hurt performance
|
||||
val limitedResult = resultDF.limit(maxRows)
|
||||
SparkDatasetHelper.toArrowBatchRdd(convertComplexType(limitedResult)).collect()
|
||||
}
|
||||
|
||||
/**
|
||||
* assign a new execution id for arrow-based operation.
|
||||
*/
|
||||
override protected def collectAsIterator(resultDF: DataFrame): FetchIterator[_] = {
|
||||
SQLExecution.withNewExecutionId(resultDF.queryExecution, Some("collectAsArrow")) {
|
||||
resultDF.queryExecution.executedPlan.resetMetrics()
|
||||
super.collectAsIterator(resultDF)
|
||||
}
|
||||
}
|
||||
|
||||
override protected def isArrowBasedOperation: Boolean = true
|
||||
|
||||
override val resultFormat = "arrow"
|
||||
}
|
||||
|
||||
@ -245,7 +245,7 @@ abstract class SparkOperation(session: Session)
|
||||
case FETCH_FIRST => iter.fetchAbsolute(0);
|
||||
}
|
||||
resultRowSet =
|
||||
if (arrowEnabled) {
|
||||
if (isArrowBasedOperation) {
|
||||
if (iter.hasNext) {
|
||||
val taken = iter.next().asInstanceOf[Array[Byte]]
|
||||
RowSet.toTRowSet(taken, getProtocolVersion)
|
||||
@ -257,8 +257,7 @@ abstract class SparkOperation(session: Session)
|
||||
RowSet.toTRowSet(
|
||||
taken.toSeq.asInstanceOf[Seq[Row]],
|
||||
resultSchema,
|
||||
getProtocolVersion,
|
||||
timeZone)
|
||||
getProtocolVersion)
|
||||
}
|
||||
resultRowSet.setStartRowOffset(iter.getPosition)
|
||||
} catch onError(cancel = true)
|
||||
@ -268,16 +267,9 @@ abstract class SparkOperation(session: Session)
|
||||
|
||||
override def shouldRunAsync: Boolean = false
|
||||
|
||||
protected def arrowEnabled: Boolean = {
|
||||
resultFormat.equalsIgnoreCase("arrow") &&
|
||||
// TODO: (fchen) make all operation support arrow
|
||||
getClass.getCanonicalName == classOf[ExecuteStatement].getCanonicalName
|
||||
}
|
||||
protected def isArrowBasedOperation: Boolean = false
|
||||
|
||||
protected def resultFormat: String = {
|
||||
// TODO: respect the config of the operation ExecuteStatement, if it was set.
|
||||
spark.conf.get("kyuubi.operation.result.format", "thrift")
|
||||
}
|
||||
protected def resultFormat: String = "thrift"
|
||||
|
||||
protected def timestampAsString: Boolean = {
|
||||
spark.conf.get("kyuubi.operation.result.arrow.timestampAsString", "false").toBoolean
|
||||
|
||||
@ -82,7 +82,24 @@ class SparkSQLOperationManager private (name: String) extends OperationManager(n
|
||||
case NoneMode =>
|
||||
val incrementalCollect = spark.conf.getOption(OPERATION_INCREMENTAL_COLLECT.key)
|
||||
.map(_.toBoolean).getOrElse(operationIncrementalCollectDefault)
|
||||
new ExecuteStatement(session, statement, runAsync, queryTimeout, incrementalCollect)
|
||||
// TODO: respect the config of the operation ExecuteStatement, if it was set.
|
||||
val resultFormat = spark.conf.get("kyuubi.operation.result.format", "thrift")
|
||||
resultFormat.toLowerCase match {
|
||||
case "arrow" =>
|
||||
new ArrowBasedExecuteStatement(
|
||||
session,
|
||||
statement,
|
||||
runAsync,
|
||||
queryTimeout,
|
||||
incrementalCollect)
|
||||
case _ =>
|
||||
new ExecuteStatement(
|
||||
session,
|
||||
statement,
|
||||
runAsync,
|
||||
queryTimeout,
|
||||
incrementalCollect)
|
||||
}
|
||||
case mode =>
|
||||
new PlanOnlyStatement(session, statement, mode)
|
||||
}
|
||||
|
||||
@ -18,7 +18,6 @@
|
||||
package org.apache.kyuubi.engine.spark.schema
|
||||
|
||||
import java.nio.ByteBuffer
|
||||
import java.time.ZoneId
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
@ -61,16 +60,15 @@ object RowSet {
|
||||
def toTRowSet(
|
||||
rows: Seq[Row],
|
||||
schema: StructType,
|
||||
protocolVersion: TProtocolVersion,
|
||||
timeZone: ZoneId): TRowSet = {
|
||||
protocolVersion: TProtocolVersion): TRowSet = {
|
||||
if (protocolVersion.getValue < TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V6.getValue) {
|
||||
toRowBasedSet(rows, schema, timeZone)
|
||||
toRowBasedSet(rows, schema)
|
||||
} else {
|
||||
toColumnBasedSet(rows, schema, timeZone)
|
||||
toColumnBasedSet(rows, schema)
|
||||
}
|
||||
}
|
||||
|
||||
def toRowBasedSet(rows: Seq[Row], schema: StructType, timeZone: ZoneId): TRowSet = {
|
||||
def toRowBasedSet(rows: Seq[Row], schema: StructType): TRowSet = {
|
||||
val rowSize = rows.length
|
||||
val tRows = new java.util.ArrayList[TRow](rowSize)
|
||||
var i = 0
|
||||
@ -80,7 +78,7 @@ object RowSet {
|
||||
var j = 0
|
||||
val columnSize = row.length
|
||||
while (j < columnSize) {
|
||||
val columnValue = toTColumnValue(j, row, schema, timeZone)
|
||||
val columnValue = toTColumnValue(j, row, schema)
|
||||
tRow.addToColVals(columnValue)
|
||||
j += 1
|
||||
}
|
||||
@ -90,21 +88,21 @@ object RowSet {
|
||||
new TRowSet(0, tRows)
|
||||
}
|
||||
|
||||
def toColumnBasedSet(rows: Seq[Row], schema: StructType, timeZone: ZoneId): TRowSet = {
|
||||
def toColumnBasedSet(rows: Seq[Row], schema: StructType): TRowSet = {
|
||||
val rowSize = rows.length
|
||||
val tRowSet = new TRowSet(0, new java.util.ArrayList[TRow](rowSize))
|
||||
var i = 0
|
||||
val columnSize = schema.length
|
||||
while (i < columnSize) {
|
||||
val field = schema(i)
|
||||
val tColumn = toTColumn(rows, i, field.dataType, timeZone)
|
||||
val tColumn = toTColumn(rows, i, field.dataType)
|
||||
tRowSet.addToColumns(tColumn)
|
||||
i += 1
|
||||
}
|
||||
tRowSet
|
||||
}
|
||||
|
||||
private def toTColumn(rows: Seq[Row], ordinal: Int, typ: DataType, timeZone: ZoneId): TColumn = {
|
||||
private def toTColumn(rows: Seq[Row], ordinal: Int, typ: DataType): TColumn = {
|
||||
val nulls = new java.util.BitSet()
|
||||
typ match {
|
||||
case BooleanType =>
|
||||
@ -186,8 +184,7 @@ object RowSet {
|
||||
private def toTColumnValue(
|
||||
ordinal: Int,
|
||||
row: Row,
|
||||
types: StructType,
|
||||
timeZone: ZoneId): TColumnValue = {
|
||||
types: StructType): TColumnValue = {
|
||||
types(ordinal).dataType match {
|
||||
case BooleanType =>
|
||||
val boolValue = new TBoolValue
|
||||
|
||||
@ -19,8 +19,14 @@ package org.apache.kyuubi.engine.spark.operation
|
||||
|
||||
import java.sql.Statement
|
||||
|
||||
import org.apache.spark.KyuubiSparkContextHelper
|
||||
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
|
||||
import org.apache.spark.sql.execution.QueryExecution
|
||||
import org.apache.spark.sql.util.QueryExecutionListener
|
||||
|
||||
import org.apache.kyuubi.config.KyuubiConf
|
||||
import org.apache.kyuubi.engine.spark.WithSparkSQLEngine
|
||||
import org.apache.kyuubi.engine.spark.{SparkSQLEngine, WithSparkSQLEngine}
|
||||
import org.apache.kyuubi.engine.spark.session.SparkSessionImpl
|
||||
import org.apache.kyuubi.operation.SparkDataTypeTests
|
||||
|
||||
class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTypeTests {
|
||||
@ -85,6 +91,34 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp
|
||||
}
|
||||
}
|
||||
|
||||
test("assign a new execution id for arrow-based result") {
|
||||
var plan: LogicalPlan = null
|
||||
|
||||
val listener = new QueryExecutionListener {
|
||||
override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = {
|
||||
plan = qe.analyzed
|
||||
}
|
||||
override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {}
|
||||
}
|
||||
withJdbcStatement() { statement =>
|
||||
// since all the new sessions have their owner listener bus, we should register the listener
|
||||
// in the current session.
|
||||
SparkSQLEngine.currentEngine.get
|
||||
.backendService
|
||||
.sessionManager
|
||||
.allSessions()
|
||||
.foreach(_.asInstanceOf[SparkSessionImpl].spark.listenerManager.register(listener))
|
||||
|
||||
val result = statement.executeQuery("select 1 as c1")
|
||||
assert(result.next())
|
||||
assert(result.getInt("c1") == 1)
|
||||
}
|
||||
|
||||
KyuubiSparkContextHelper.waitListenerBus(spark)
|
||||
spark.listenerManager.unregister(listener)
|
||||
assert(plan.isInstanceOf[Project])
|
||||
}
|
||||
|
||||
private def checkResultSetFormat(statement: Statement, expectFormat: String): Unit = {
|
||||
val query =
|
||||
s"""
|
||||
|
||||
@ -20,7 +20,7 @@ package org.apache.kyuubi.engine.spark.schema
|
||||
import java.nio.ByteBuffer
|
||||
import java.nio.charset.StandardCharsets
|
||||
import java.sql.{Date, Timestamp}
|
||||
import java.time.{Instant, LocalDate, ZoneId}
|
||||
import java.time.{Instant, LocalDate}
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
@ -96,10 +96,9 @@ class RowSetSuite extends KyuubiFunSuite {
|
||||
.add("q", "timestamp")
|
||||
|
||||
private val rows: Seq[Row] = (0 to 10).map(genRow) ++ Seq(Row.fromSeq(Seq.fill(17)(null)))
|
||||
private val zoneId: ZoneId = ZoneId.systemDefault()
|
||||
|
||||
test("column based set") {
|
||||
val tRowSet = RowSet.toColumnBasedSet(rows, schema, zoneId)
|
||||
val tRowSet = RowSet.toColumnBasedSet(rows, schema)
|
||||
assert(tRowSet.getColumns.size() === schema.size)
|
||||
assert(tRowSet.getRowsSize === 0)
|
||||
|
||||
@ -204,7 +203,7 @@ class RowSetSuite extends KyuubiFunSuite {
|
||||
}
|
||||
|
||||
test("row based set") {
|
||||
val tRowSet = RowSet.toRowBasedSet(rows, schema, zoneId)
|
||||
val tRowSet = RowSet.toRowBasedSet(rows, schema)
|
||||
assert(tRowSet.getColumnCount === 0)
|
||||
assert(tRowSet.getRowsSize === rows.size)
|
||||
val iter = tRowSet.getRowsIterator
|
||||
@ -250,7 +249,7 @@ class RowSetSuite extends KyuubiFunSuite {
|
||||
|
||||
test("to row set") {
|
||||
TProtocolVersion.values().foreach { proto =>
|
||||
val set = RowSet.toTRowSet(rows, schema, proto, zoneId)
|
||||
val set = RowSet.toTRowSet(rows, schema, proto)
|
||||
if (proto.getValue < TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V6.getValue) {
|
||||
assert(!set.isSetColumns, proto.toString)
|
||||
assert(set.isSetRows, proto.toString)
|
||||
|
||||
30
externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/spark/KyuubiSparkContextHelper.scala
vendored
Normal file
30
externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/spark/KyuubiSparkContextHelper.scala
vendored
Normal file
@ -0,0 +1,30 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark
|
||||
|
||||
import org.apache.spark.sql.SparkSession
|
||||
|
||||
/**
|
||||
* A place to invoke non-public APIs of [[SparkContext]], for test only.
|
||||
*/
|
||||
object KyuubiSparkContextHelper {
|
||||
|
||||
def waitListenerBus(spark: SparkSession): Unit = {
|
||||
spark.sparkContext.listenerBus.waitUntilEmpty()
|
||||
}
|
||||
}
|
||||
@ -433,13 +433,13 @@ trait SparkQueryTests extends SparkDataTypeTests with HiveJDBCTestHelper {
|
||||
expectedFormat = "thrift")
|
||||
checkStatusAndResultSetFormatHint(
|
||||
sql = "set kyuubi.operation.result.format=arrow",
|
||||
expectedFormat = "arrow")
|
||||
expectedFormat = "thrift")
|
||||
checkStatusAndResultSetFormatHint(
|
||||
sql = "SELECT 1",
|
||||
expectedFormat = "arrow")
|
||||
checkStatusAndResultSetFormatHint(
|
||||
sql = "set kyuubi.operation.result.format=thrift",
|
||||
expectedFormat = "thrift")
|
||||
expectedFormat = "arrow")
|
||||
checkStatusAndResultSetFormatHint(
|
||||
sql = "set kyuubi.operation.result.format",
|
||||
expectedFormat = "thrift")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user