[KYUUBI #4710] [ARROW] LocalTableScanExec should not trigger job

### _Why are the changes needed?_

Before this PR:

![截屏2023-04-14 下午5 19 52](https://user-images.githubusercontent.com/8537877/232003579-95c56f56-1fd7-4c8a-a13f-58d4bc16fef1.png)

After this PR:

![截屏2023-04-14 下午5 18 16](https://user-images.githubusercontent.com/8537877/232003652-77b38d08-c741-4977-bf69-6eb70f6d991a.png)

### _How was this patch tested?_
- [ ] Add some test cases that check the changes thoroughly including negative and positive cases if possible

- [ ] Add screenshots for manual tests if appropriate

- [x] [Run test](https://kyuubi.readthedocs.io/en/master/develop_tools/testing.html#running-tests) locally before make a pull request

Closes #4710 from cfmcgrady/arrow-local-table-scan-exec.

Closes #4710

e4c2891d1 [Fu Chen] fix ci
1049200ea [Fu Chen] fix style
4d45fe8b7 [Fu Chen] add assert
b8bd5b5a7 [Fu Chen] LocalTableScanExec should not trigger job

Authored-by: Fu Chen <cfmcgrady@gmail.com>
Signed-off-by: Cheng Pan <chengpan@apache.org>
This commit is contained in:
Fu Chen 2023-04-20 17:58:27 +08:00 committed by Cheng Pan
parent f4a56efec2
commit 9086b28bc3
No known key found for this signature in database
GPG Key ID: 8001952629BCC75D
3 changed files with 104 additions and 60 deletions

View File

@ -203,7 +203,7 @@ object KyuubiArrowConverters extends SQLConfHelper with Logging {
* Different from [[org.apache.spark.sql.execution.arrow.ArrowConverters.toBatchIterator]],
* each output arrow batch contains this batch row count.
*/
private def toBatchIterator(
def toBatchIterator(
rowIter: Iterator[InternalRow],
schema: StructType,
maxRecordsPerBatch: Long,
@ -226,6 +226,7 @@ object KyuubiArrowConverters extends SQLConfHelper with Logging {
* with two key differences:
* 1. there is no requirement to write the schema at the batch header
* 2. iteration halts when `rowCount` equals `limit`
* Note that `limit < 0` means no limit, and return all rows the in the iterator.
*/
private[sql] class ArrowBatchIterator(
rowIter: Iterator[InternalRow],
@ -255,7 +256,7 @@ object KyuubiArrowConverters extends SQLConfHelper with Logging {
}
}
override def hasNext: Boolean = (rowIter.hasNext && rowCount < limit) || {
override def hasNext: Boolean = (rowIter.hasNext && (rowCount < limit || limit < 0)) || {
root.close()
allocator.close()
false
@ -283,7 +284,8 @@ object KyuubiArrowConverters extends SQLConfHelper with Logging {
// If the size of rows are 0 or negative, unlimit it.
maxRecordsPerBatch <= 0 ||
rowCountInLastBatch < maxRecordsPerBatch ||
rowCount < limit)) {
rowCount < limit ||
limit < 0)) {
val row = rowIter.next()
arrowWriter.write(row)
estimatedBatchSize += (row match {

View File

@ -24,7 +24,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.network.util.{ByteUnit, JavaUtils}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.execution.{CollectLimitExec, SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.{CollectLimitExec, LocalTableScanExec, SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
import org.apache.spark.sql.execution.arrow.{ArrowConverters, KyuubiArrowConverters}
import org.apache.spark.sql.functions._
@ -51,6 +51,8 @@ object SparkDatasetHelper extends Logging {
doCollectLimit(collectLimit)
case collectLimit: CollectLimitExec if collectLimit.limit < 0 =>
executeArrowBatchCollect(collectLimit.child)
case localTableScan: LocalTableScanExec =>
doLocalTableScan(localTableScan)
case plan: SparkPlan =>
toArrowBatchRdd(plan).collect()
}
@ -175,6 +177,17 @@ object SparkDatasetHelper extends Logging {
result.toArray
}
def doLocalTableScan(localTableScan: LocalTableScanExec): Array[Array[Byte]] = {
localTableScan.longMetric("numOutputRows").add(localTableScan.rows.size)
KyuubiArrowConverters.toBatchIterator(
localTableScan.rows.iterator,
localTableScan.schema,
SparkSession.active.sessionState.conf.arrowMaxRecordsPerBatch,
maxBatchSize,
-1,
SparkSession.active.sessionState.conf.sessionLocalTimeZone).toArray
}
/**
* This method provides a reflection-based implementation of
* [[AdaptiveSparkPlanExec.finalPhysicalPlan]] that enables us to adapt to the Spark runtime

View File

@ -23,8 +23,8 @@ import java.util.{Set => JSet}
import org.apache.spark.KyuubiSparkContextHelper
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
import org.apache.spark.sql.{QueryTest, Row, SparkSession}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
import org.apache.spark.sql.execution.{CollectLimitExec, QueryExecution, SparkPlan}
import org.apache.spark.sql.catalyst.plans.logical.Project
import org.apache.spark.sql.execution.{CollectLimitExec, LocalTableScanExec, QueryExecution, SparkPlan}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
import org.apache.spark.sql.execution.arrow.KyuubiArrowConverters
import org.apache.spark.sql.execution.exchange.Exchange
@ -104,48 +104,29 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp
}
test("assign a new execution id for arrow-based result") {
var plan: LogicalPlan = null
val listener = new QueryExecutionListener {
override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = {
plan = qe.analyzed
}
override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {}
}
val listener = new SQLMetricsListener
withJdbcStatement() { statement =>
// since all the new sessions have their owner listener bus, we should register the listener
// in the current session.
registerListener(listener)
val result = statement.executeQuery("select 1 as c1")
assert(result.next())
assert(result.getInt("c1") == 1)
withSparkListener(listener) {
val result = statement.executeQuery("select 1 as c1")
assert(result.next())
assert(result.getInt("c1") == 1)
}
}
KyuubiSparkContextHelper.waitListenerBus(spark)
unregisterListener(listener)
assert(plan.isInstanceOf[Project])
assert(listener.queryExecution.analyzed.isInstanceOf[Project])
}
test("arrow-based query metrics") {
var queryExecution: QueryExecution = null
val listener = new QueryExecutionListener {
override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = {
queryExecution = qe
}
override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {}
}
val listener = new SQLMetricsListener
withJdbcStatement() { statement =>
registerListener(listener)
val result = statement.executeQuery("select 1 as c1")
assert(result.next())
assert(result.getInt("c1") == 1)
withSparkListener(listener) {
val result = statement.executeQuery("select 1 as c1")
assert(result.next())
assert(result.getInt("c1") == 1)
}
}
KyuubiSparkContextHelper.waitListenerBus(spark)
unregisterListener(listener)
val metrics = queryExecution.executedPlan.collectLeaves().head.metrics
val metrics = listener.queryExecution.executedPlan.collectLeaves().head.metrics
assert(metrics.contains("numOutputRows"))
assert(metrics("numOutputRows").value === 1)
}
@ -273,7 +254,6 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp
withPartitionedTable("t_3") {
statement.executeQuery("select * from t_3 limit 10 offset 10")
}
KyuubiSparkContextHelper.waitListenerBus(spark)
}
}
// the extra shuffle be introduced if the `offset` > 0
@ -292,13 +272,49 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp
withPartitionedTable("t_3") {
statement.executeQuery("select * from t_3 limit 1000")
}
KyuubiSparkContextHelper.waitListenerBus(spark)
}
}
// Should be only one stage since there is no shuffle.
assert(numStages == 1)
}
test("LocalTableScanExec should not trigger job") {
val listener = new JobCountListener
withJdbcStatement("view_1") { statement =>
withSparkListener(listener) {
withAllSessions { s =>
import s.implicits._
Seq((1, "a")).toDF("c1", "c2").createOrReplaceTempView("view_1")
val plan = s.sql("select * from view_1").queryExecution.executedPlan
assert(plan.isInstanceOf[LocalTableScanExec])
}
val resultSet = statement.executeQuery("select * from view_1")
assert(resultSet.next())
assert(!resultSet.next())
}
}
assert(listener.numJobs == 0)
}
test("LocalTableScanExec metrics") {
val listener = new SQLMetricsListener
withJdbcStatement("view_1") { statement =>
withSparkListener(listener) {
withAllSessions { s =>
import s.implicits._
Seq((1, "a")).toDF("c1", "c2").createOrReplaceTempView("view_1")
}
val result = statement.executeQuery("select * from view_1")
assert(result.next())
assert(!result.next())
}
}
val metrics = listener.queryExecution.executedPlan.collectLeaves().head.metrics
assert(metrics.contains("numOutputRows"))
assert(metrics("numOutputRows").value === 1)
}
private def checkResultSetFormat(statement: Statement, expectFormat: String): Unit = {
val query =
s"""
@ -321,32 +337,30 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp
assert(resultSet.getString("col") === expect)
}
private def registerListener(listener: QueryExecutionListener): Unit = {
// since all the new sessions have their owner listener bus, we should register the listener
// in the current session.
SparkSQLEngine.currentEngine.get
.backendService
.sessionManager
.allSessions()
.foreach(_.asInstanceOf[SparkSessionImpl].spark.listenerManager.register(listener))
}
private def unregisterListener(listener: QueryExecutionListener): Unit = {
SparkSQLEngine.currentEngine.get
.backendService
.sessionManager
.allSessions()
.foreach(_.asInstanceOf[SparkSessionImpl].spark.listenerManager.unregister(listener))
// since all the new sessions have their owner listener bus, we should register the listener
// in the current session.
private def withSparkListener[T](listener: QueryExecutionListener)(body: => T): T = {
withAllSessions(s => s.listenerManager.register(listener))
try {
val result = body
KyuubiSparkContextHelper.waitListenerBus(spark)
result
} finally {
withAllSessions(s => s.listenerManager.unregister(listener))
}
}
// since all the new sessions have their owner listener bus, we should register the listener
// in the current session.
private def withSparkListener[T](listener: SparkListener)(body: => T): T = {
withAllSessions(s => s.sparkContext.addSparkListener(listener))
try {
body
val result = body
KyuubiSparkContextHelper.waitListenerBus(spark)
result
} finally {
withAllSessions(s => s.sparkContext.removeSparkListener(listener))
}
}
private def withPartitionedTable[T](viewName: String)(body: => T): T = {
@ -432,6 +446,21 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp
.get()
staticConfKeys.contains(key)
}
class JobCountListener extends SparkListener {
var numJobs = 0
override def onJobStart(jobStart: SparkListenerJobStart): Unit = {
numJobs += 1
}
}
class SQLMetricsListener extends QueryExecutionListener {
var queryExecution: QueryExecution = null
override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = {
queryExecution = qe
}
override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {}
}
}
case class TestData(key: Int, value: String)