[KYUUBI #4392] [ARROW] Assign a new execution id for arrow-based result

### _Why are the changes needed?_

assign a new execution id for arrow-based result, so that we can track the arrow-based queries on the UI tab.

```sql
set kyuubi.operation.result.format=arrow;
select 1;
```

Before this PR:

![截屏2023-02-21 下午5 23 08](https://user-images.githubusercontent.com/8537877/220303920-fbaf978b-ead7-4708-9094-bcc84e8fb47c.png)

![截屏2023-02-21 下午5 23 19](https://user-images.githubusercontent.com/8537877/220303966-cb8dfeae-cd10-4c4f-add6-2650619fc5f9.png)

After this PR:
![截屏2023-02-22 上午10 21 53](https://user-images.githubusercontent.com/8537877/220504608-f67a5f70-8c64-4e3b-89c2-c2ea54676217.png)

![截屏2023-02-21 下午5 20 50](https://user-images.githubusercontent.com/8537877/220304021-9b845f44-96c3-41f2-a48a-a428f8c4823f.png)

### _How was this patch tested?_
- [ ] Add some test cases that check the changes thoroughly including negative and positive cases if possible

- [ ] Add screenshots for manual tests if appropriate

- [x] [Run test](https://kyuubi.readthedocs.io/en/master/develop_tools/testing.html#running-tests) locally before make a pull request

Closes #4392 from cfmcgrady/arrow-execution-id-2.

Closes #4392

481118a4 [Fu Chen] enable ut
c90674ee [Fu Chen] address comment
6cc7af44 [Fu Chen] address comment
3f8a3ab8 [Fu Chen] fix ut
223a2469 [Fu Chen] add KyuubiSparkContextHelper
bb7b28f5 [Fu Chen] fix style
879a1502 [Fu Chen] unnecessary changes
a2b04f83 [Fu Chen] fix

Authored-by: Fu Chen <cfmcgrady@gmail.com>
Signed-off-by: Cheng Pan <chengpan@apache.org>
This commit is contained in:
Fu Chen 2023-02-22 23:00:30 +08:00 committed by Cheng Pan
parent 3191fa22fc
commit f0acff315c
No known key found for this signature in database
GPG Key ID: 8001952629BCC75D
8 changed files with 178 additions and 71 deletions

View File

@ -22,13 +22,14 @@ import java.util.concurrent.RejectedExecutionException
import scala.collection.JavaConverters._
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.kyuubi.SparkDatasetHelper
import org.apache.spark.sql.types._
import org.apache.kyuubi.{KyuubiSQLException, Logging}
import org.apache.kyuubi.config.KyuubiConf.OPERATION_RESULT_MAX_ROWS
import org.apache.kyuubi.engine.spark.KyuubiSparkUtil._
import org.apache.kyuubi.operation.{ArrayFetchIterator, IterableFetchIterator, OperationState}
import org.apache.kyuubi.operation.{ArrayFetchIterator, FetchIterator, IterableFetchIterator, OperationState}
import org.apache.kyuubi.operation.log.OperationLog
import org.apache.kyuubi.session.Session
@ -62,49 +63,49 @@ class ExecuteStatement(
OperationLog.removeCurrentOperationLog()
}
private def executeStatement(): Unit = withLocalProperties {
protected def incrementalCollectResult(resultDF: DataFrame): Iterator[Any] = {
resultDF.toLocalIterator().asScala
}
protected def fullCollectResult(resultDF: DataFrame): Array[_] = {
resultDF.collect()
}
protected def takeResult(resultDF: DataFrame, maxRows: Int): Array[_] = {
resultDF.take(maxRows)
}
protected def collectAsIterator(resultDF: DataFrame): FetchIterator[_] = {
val resultMaxRows = spark.conf.getOption(OPERATION_RESULT_MAX_ROWS.key).map(_.toInt)
.getOrElse(session.sessionManager.getConf.get(OPERATION_RESULT_MAX_ROWS))
if (incrementalCollect) {
if (resultMaxRows > 0) {
warn(s"Ignore ${OPERATION_RESULT_MAX_ROWS.key} on incremental collect mode.")
}
info("Execute in incremental collect mode")
new IterableFetchIterator[Any](new Iterable[Any] {
override def iterator: Iterator[Any] = incrementalCollectResult(resultDF)
})
} else {
val internalArray = if (resultMaxRows <= 0) {
info("Execute in full collect mode")
fullCollectResult(resultDF)
} else {
info(s"Execute with max result rows[$resultMaxRows]")
takeResult(resultDF, resultMaxRows)
}
new ArrayFetchIterator(internalArray)
}
}
protected def executeStatement(): Unit = withLocalProperties {
try {
setState(OperationState.RUNNING)
info(diagnostics)
Thread.currentThread().setContextClassLoader(spark.sharedState.jarClassLoader)
addOperationListener()
result = spark.sql(statement)
val resultMaxRows = spark.conf.getOption(OPERATION_RESULT_MAX_ROWS.key).map(_.toInt)
.getOrElse(session.sessionManager.getConf.get(OPERATION_RESULT_MAX_ROWS))
iter = if (incrementalCollect) {
if (resultMaxRows > 0) {
warn(s"Ignore ${OPERATION_RESULT_MAX_ROWS.key} on incremental collect mode.")
}
info("Execute in incremental collect mode")
def internalIterator(): Iterator[Any] = if (arrowEnabled) {
SparkDatasetHelper.toArrowBatchRdd(convertComplexType(result)).toLocalIterator
} else {
result.toLocalIterator().asScala
}
new IterableFetchIterator[Any](new Iterable[Any] {
override def iterator: Iterator[Any] = internalIterator()
})
} else {
val internalArray = if (resultMaxRows <= 0) {
info("Execute in full collect mode")
if (arrowEnabled) {
SparkDatasetHelper.toArrowBatchRdd(convertComplexType(result)).collect()
} else {
result.collect()
}
} else {
info(s"Execute with max result rows[$resultMaxRows]")
if (arrowEnabled) {
// this will introduce shuffle and hurt performance
val limitedResult = result.limit(resultMaxRows)
SparkDatasetHelper.toArrowBatchRdd(convertComplexType(limitedResult)).collect()
} else {
result.take(resultMaxRows)
}
}
new ArrayFetchIterator(internalArray)
}
iter = collectAsIterator(result)
setCompiledStateIfNeeded()
setState(OperationState.FINISHED)
} catch {
@ -171,3 +172,40 @@ class ExecuteStatement(
s"__kyuubi_operation_result_format__=$resultFormat",
s"__kyuubi_operation_result_arrow_timestampAsString__=$timestampAsString")
}
class ArrowBasedExecuteStatement(
session: Session,
override val statement: String,
override val shouldRunAsync: Boolean,
queryTimeout: Long,
incrementalCollect: Boolean)
extends ExecuteStatement(session, statement, shouldRunAsync, queryTimeout, incrementalCollect) {
override protected def incrementalCollectResult(resultDF: DataFrame): Iterator[Any] = {
SparkDatasetHelper.toArrowBatchRdd(convertComplexType(resultDF)).toLocalIterator
}
override protected def fullCollectResult(resultDF: DataFrame): Array[_] = {
SparkDatasetHelper.toArrowBatchRdd(convertComplexType(resultDF)).collect()
}
override protected def takeResult(resultDF: DataFrame, maxRows: Int): Array[_] = {
// this will introduce shuffle and hurt performance
val limitedResult = resultDF.limit(maxRows)
SparkDatasetHelper.toArrowBatchRdd(convertComplexType(limitedResult)).collect()
}
/**
* assign a new execution id for arrow-based operation.
*/
override protected def collectAsIterator(resultDF: DataFrame): FetchIterator[_] = {
SQLExecution.withNewExecutionId(resultDF.queryExecution, Some("collectAsArrow")) {
resultDF.queryExecution.executedPlan.resetMetrics()
super.collectAsIterator(resultDF)
}
}
override protected def isArrowBasedOperation: Boolean = true
override val resultFormat = "arrow"
}

View File

@ -245,7 +245,7 @@ abstract class SparkOperation(session: Session)
case FETCH_FIRST => iter.fetchAbsolute(0);
}
resultRowSet =
if (arrowEnabled) {
if (isArrowBasedOperation) {
if (iter.hasNext) {
val taken = iter.next().asInstanceOf[Array[Byte]]
RowSet.toTRowSet(taken, getProtocolVersion)
@ -257,8 +257,7 @@ abstract class SparkOperation(session: Session)
RowSet.toTRowSet(
taken.toSeq.asInstanceOf[Seq[Row]],
resultSchema,
getProtocolVersion,
timeZone)
getProtocolVersion)
}
resultRowSet.setStartRowOffset(iter.getPosition)
} catch onError(cancel = true)
@ -268,16 +267,9 @@ abstract class SparkOperation(session: Session)
override def shouldRunAsync: Boolean = false
protected def arrowEnabled: Boolean = {
resultFormat.equalsIgnoreCase("arrow") &&
// TODO: (fchen) make all operation support arrow
getClass.getCanonicalName == classOf[ExecuteStatement].getCanonicalName
}
protected def isArrowBasedOperation: Boolean = false
protected def resultFormat: String = {
// TODO: respect the config of the operation ExecuteStatement, if it was set.
spark.conf.get("kyuubi.operation.result.format", "thrift")
}
protected def resultFormat: String = "thrift"
protected def timestampAsString: Boolean = {
spark.conf.get("kyuubi.operation.result.arrow.timestampAsString", "false").toBoolean

View File

@ -82,7 +82,24 @@ class SparkSQLOperationManager private (name: String) extends OperationManager(n
case NoneMode =>
val incrementalCollect = spark.conf.getOption(OPERATION_INCREMENTAL_COLLECT.key)
.map(_.toBoolean).getOrElse(operationIncrementalCollectDefault)
new ExecuteStatement(session, statement, runAsync, queryTimeout, incrementalCollect)
// TODO: respect the config of the operation ExecuteStatement, if it was set.
val resultFormat = spark.conf.get("kyuubi.operation.result.format", "thrift")
resultFormat.toLowerCase match {
case "arrow" =>
new ArrowBasedExecuteStatement(
session,
statement,
runAsync,
queryTimeout,
incrementalCollect)
case _ =>
new ExecuteStatement(
session,
statement,
runAsync,
queryTimeout,
incrementalCollect)
}
case mode =>
new PlanOnlyStatement(session, statement, mode)
}

View File

@ -18,7 +18,6 @@
package org.apache.kyuubi.engine.spark.schema
import java.nio.ByteBuffer
import java.time.ZoneId
import scala.collection.JavaConverters._
@ -61,16 +60,15 @@ object RowSet {
def toTRowSet(
rows: Seq[Row],
schema: StructType,
protocolVersion: TProtocolVersion,
timeZone: ZoneId): TRowSet = {
protocolVersion: TProtocolVersion): TRowSet = {
if (protocolVersion.getValue < TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V6.getValue) {
toRowBasedSet(rows, schema, timeZone)
toRowBasedSet(rows, schema)
} else {
toColumnBasedSet(rows, schema, timeZone)
toColumnBasedSet(rows, schema)
}
}
def toRowBasedSet(rows: Seq[Row], schema: StructType, timeZone: ZoneId): TRowSet = {
def toRowBasedSet(rows: Seq[Row], schema: StructType): TRowSet = {
val rowSize = rows.length
val tRows = new java.util.ArrayList[TRow](rowSize)
var i = 0
@ -80,7 +78,7 @@ object RowSet {
var j = 0
val columnSize = row.length
while (j < columnSize) {
val columnValue = toTColumnValue(j, row, schema, timeZone)
val columnValue = toTColumnValue(j, row, schema)
tRow.addToColVals(columnValue)
j += 1
}
@ -90,21 +88,21 @@ object RowSet {
new TRowSet(0, tRows)
}
def toColumnBasedSet(rows: Seq[Row], schema: StructType, timeZone: ZoneId): TRowSet = {
def toColumnBasedSet(rows: Seq[Row], schema: StructType): TRowSet = {
val rowSize = rows.length
val tRowSet = new TRowSet(0, new java.util.ArrayList[TRow](rowSize))
var i = 0
val columnSize = schema.length
while (i < columnSize) {
val field = schema(i)
val tColumn = toTColumn(rows, i, field.dataType, timeZone)
val tColumn = toTColumn(rows, i, field.dataType)
tRowSet.addToColumns(tColumn)
i += 1
}
tRowSet
}
private def toTColumn(rows: Seq[Row], ordinal: Int, typ: DataType, timeZone: ZoneId): TColumn = {
private def toTColumn(rows: Seq[Row], ordinal: Int, typ: DataType): TColumn = {
val nulls = new java.util.BitSet()
typ match {
case BooleanType =>
@ -186,8 +184,7 @@ object RowSet {
private def toTColumnValue(
ordinal: Int,
row: Row,
types: StructType,
timeZone: ZoneId): TColumnValue = {
types: StructType): TColumnValue = {
types(ordinal).dataType match {
case BooleanType =>
val boolValue = new TBoolValue

View File

@ -19,8 +19,14 @@ package org.apache.kyuubi.engine.spark.operation
import java.sql.Statement
import org.apache.spark.KyuubiSparkContextHelper
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.util.QueryExecutionListener
import org.apache.kyuubi.config.KyuubiConf
import org.apache.kyuubi.engine.spark.WithSparkSQLEngine
import org.apache.kyuubi.engine.spark.{SparkSQLEngine, WithSparkSQLEngine}
import org.apache.kyuubi.engine.spark.session.SparkSessionImpl
import org.apache.kyuubi.operation.SparkDataTypeTests
class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTypeTests {
@ -85,6 +91,34 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp
}
}
test("assign a new execution id for arrow-based result") {
var plan: LogicalPlan = null
val listener = new QueryExecutionListener {
override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = {
plan = qe.analyzed
}
override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {}
}
withJdbcStatement() { statement =>
// since all the new sessions have their owner listener bus, we should register the listener
// in the current session.
SparkSQLEngine.currentEngine.get
.backendService
.sessionManager
.allSessions()
.foreach(_.asInstanceOf[SparkSessionImpl].spark.listenerManager.register(listener))
val result = statement.executeQuery("select 1 as c1")
assert(result.next())
assert(result.getInt("c1") == 1)
}
KyuubiSparkContextHelper.waitListenerBus(spark)
spark.listenerManager.unregister(listener)
assert(plan.isInstanceOf[Project])
}
private def checkResultSetFormat(statement: Statement, expectFormat: String): Unit = {
val query =
s"""

View File

@ -20,7 +20,7 @@ package org.apache.kyuubi.engine.spark.schema
import java.nio.ByteBuffer
import java.nio.charset.StandardCharsets
import java.sql.{Date, Timestamp}
import java.time.{Instant, LocalDate, ZoneId}
import java.time.{Instant, LocalDate}
import scala.collection.JavaConverters._
@ -96,10 +96,9 @@ class RowSetSuite extends KyuubiFunSuite {
.add("q", "timestamp")
private val rows: Seq[Row] = (0 to 10).map(genRow) ++ Seq(Row.fromSeq(Seq.fill(17)(null)))
private val zoneId: ZoneId = ZoneId.systemDefault()
test("column based set") {
val tRowSet = RowSet.toColumnBasedSet(rows, schema, zoneId)
val tRowSet = RowSet.toColumnBasedSet(rows, schema)
assert(tRowSet.getColumns.size() === schema.size)
assert(tRowSet.getRowsSize === 0)
@ -204,7 +203,7 @@ class RowSetSuite extends KyuubiFunSuite {
}
test("row based set") {
val tRowSet = RowSet.toRowBasedSet(rows, schema, zoneId)
val tRowSet = RowSet.toRowBasedSet(rows, schema)
assert(tRowSet.getColumnCount === 0)
assert(tRowSet.getRowsSize === rows.size)
val iter = tRowSet.getRowsIterator
@ -250,7 +249,7 @@ class RowSetSuite extends KyuubiFunSuite {
test("to row set") {
TProtocolVersion.values().foreach { proto =>
val set = RowSet.toTRowSet(rows, schema, proto, zoneId)
val set = RowSet.toTRowSet(rows, schema, proto)
if (proto.getValue < TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V6.getValue) {
assert(!set.isSetColumns, proto.toString)
assert(set.isSetRows, proto.toString)

View File

@ -0,0 +1,30 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark
import org.apache.spark.sql.SparkSession
/**
* A place to invoke non-public APIs of [[SparkContext]], for test only.
*/
object KyuubiSparkContextHelper {
def waitListenerBus(spark: SparkSession): Unit = {
spark.sparkContext.listenerBus.waitUntilEmpty()
}
}

View File

@ -433,13 +433,13 @@ trait SparkQueryTests extends SparkDataTypeTests with HiveJDBCTestHelper {
expectedFormat = "thrift")
checkStatusAndResultSetFormatHint(
sql = "set kyuubi.operation.result.format=arrow",
expectedFormat = "arrow")
expectedFormat = "thrift")
checkStatusAndResultSetFormatHint(
sql = "SELECT 1",
expectedFormat = "arrow")
checkStatusAndResultSetFormatHint(
sql = "set kyuubi.operation.result.format=thrift",
expectedFormat = "thrift")
expectedFormat = "arrow")
checkStatusAndResultSetFormatHint(
sql = "set kyuubi.operation.result.format",
expectedFormat = "thrift")