From 9086b28bc392a928e7cd68aa03b69f3e60d73643 Mon Sep 17 00:00:00 2001 From: Fu Chen Date: Thu, 20 Apr 2023 17:58:27 +0800 Subject: [PATCH] [KYUUBI #4710] [ARROW] LocalTableScanExec should not trigger job MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### _Why are the changes needed?_ Before this PR: ![截屏2023-04-14 下午5 19 52](https://user-images.githubusercontent.com/8537877/232003579-95c56f56-1fd7-4c8a-a13f-58d4bc16fef1.png) After this PR: ![截屏2023-04-14 下午5 18 16](https://user-images.githubusercontent.com/8537877/232003652-77b38d08-c741-4977-bf69-6eb70f6d991a.png) ### _How was this patch tested?_ - [ ] Add some test cases that check the changes thoroughly including negative and positive cases if possible - [ ] Add screenshots for manual tests if appropriate - [x] [Run test](https://kyuubi.readthedocs.io/en/master/develop_tools/testing.html#running-tests) locally before make a pull request Closes #4710 from cfmcgrady/arrow-local-table-scan-exec. Closes #4710 e4c2891d1 [Fu Chen] fix ci 1049200ea [Fu Chen] fix style 4d45fe8b7 [Fu Chen] add assert b8bd5b5a7 [Fu Chen] LocalTableScanExec should not trigger job Authored-by: Fu Chen Signed-off-by: Cheng Pan --- .../arrow/KyuubiArrowConverters.scala | 8 +- .../spark/sql/kyuubi/SparkDatasetHelper.scala | 15 +- .../SparkArrowbasedOperationSuite.scala | 141 +++++++++++------- 3 files changed, 104 insertions(+), 60 deletions(-) diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/KyuubiArrowConverters.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/KyuubiArrowConverters.scala index 2feadbced..8a34943cc 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/KyuubiArrowConverters.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/KyuubiArrowConverters.scala @@ -203,7 +203,7 @@ object KyuubiArrowConverters extends SQLConfHelper with Logging { * Different from [[org.apache.spark.sql.execution.arrow.ArrowConverters.toBatchIterator]], * each output arrow batch contains this batch row count. */ - private def toBatchIterator( + def toBatchIterator( rowIter: Iterator[InternalRow], schema: StructType, maxRecordsPerBatch: Long, @@ -226,6 +226,7 @@ object KyuubiArrowConverters extends SQLConfHelper with Logging { * with two key differences: * 1. there is no requirement to write the schema at the batch header * 2. iteration halts when `rowCount` equals `limit` + * Note that `limit < 0` means no limit, and return all rows the in the iterator. */ private[sql] class ArrowBatchIterator( rowIter: Iterator[InternalRow], @@ -255,7 +256,7 @@ object KyuubiArrowConverters extends SQLConfHelper with Logging { } } - override def hasNext: Boolean = (rowIter.hasNext && rowCount < limit) || { + override def hasNext: Boolean = (rowIter.hasNext && (rowCount < limit || limit < 0)) || { root.close() allocator.close() false @@ -283,7 +284,8 @@ object KyuubiArrowConverters extends SQLConfHelper with Logging { // If the size of rows are 0 or negative, unlimit it. maxRecordsPerBatch <= 0 || rowCountInLastBatch < maxRecordsPerBatch || - rowCount < limit)) { + rowCount < limit || + limit < 0)) { val row = rowIter.next() arrowWriter.write(row) estimatedBatchSize += (row match { diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala index 1c8d32c48..10b178324 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala @@ -24,7 +24,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.network.util.{ByteUnit, JavaUtils} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} -import org.apache.spark.sql.execution.{CollectLimitExec, SparkPlan, SQLExecution} +import org.apache.spark.sql.execution.{CollectLimitExec, LocalTableScanExec, SparkPlan, SQLExecution} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec import org.apache.spark.sql.execution.arrow.{ArrowConverters, KyuubiArrowConverters} import org.apache.spark.sql.functions._ @@ -51,6 +51,8 @@ object SparkDatasetHelper extends Logging { doCollectLimit(collectLimit) case collectLimit: CollectLimitExec if collectLimit.limit < 0 => executeArrowBatchCollect(collectLimit.child) + case localTableScan: LocalTableScanExec => + doLocalTableScan(localTableScan) case plan: SparkPlan => toArrowBatchRdd(plan).collect() } @@ -175,6 +177,17 @@ object SparkDatasetHelper extends Logging { result.toArray } + def doLocalTableScan(localTableScan: LocalTableScanExec): Array[Array[Byte]] = { + localTableScan.longMetric("numOutputRows").add(localTableScan.rows.size) + KyuubiArrowConverters.toBatchIterator( + localTableScan.rows.iterator, + localTableScan.schema, + SparkSession.active.sessionState.conf.arrowMaxRecordsPerBatch, + maxBatchSize, + -1, + SparkSession.active.sessionState.conf.sessionLocalTimeZone).toArray + } + /** * This method provides a reflection-based implementation of * [[AdaptiveSparkPlanExec.finalPhysicalPlan]] that enables us to adapt to the Spark runtime diff --git a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala index 2ef29b398..27310992f 100644 --- a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala +++ b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkArrowbasedOperationSuite.scala @@ -23,8 +23,8 @@ import java.util.{Set => JSet} import org.apache.spark.KyuubiSparkContextHelper import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql.{QueryTest, Row, SparkSession} -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} -import org.apache.spark.sql.execution.{CollectLimitExec, QueryExecution, SparkPlan} +import org.apache.spark.sql.catalyst.plans.logical.Project +import org.apache.spark.sql.execution.{CollectLimitExec, LocalTableScanExec, QueryExecution, SparkPlan} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec import org.apache.spark.sql.execution.arrow.KyuubiArrowConverters import org.apache.spark.sql.execution.exchange.Exchange @@ -104,48 +104,29 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp } test("assign a new execution id for arrow-based result") { - var plan: LogicalPlan = null - - val listener = new QueryExecutionListener { - override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { - plan = qe.analyzed - } - override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {} - } + val listener = new SQLMetricsListener withJdbcStatement() { statement => - // since all the new sessions have their owner listener bus, we should register the listener - // in the current session. - registerListener(listener) - - val result = statement.executeQuery("select 1 as c1") - assert(result.next()) - assert(result.getInt("c1") == 1) + withSparkListener(listener) { + val result = statement.executeQuery("select 1 as c1") + assert(result.next()) + assert(result.getInt("c1") == 1) + } } - KyuubiSparkContextHelper.waitListenerBus(spark) - unregisterListener(listener) - assert(plan.isInstanceOf[Project]) + + assert(listener.queryExecution.analyzed.isInstanceOf[Project]) } test("arrow-based query metrics") { - var queryExecution: QueryExecution = null - - val listener = new QueryExecutionListener { - override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { - queryExecution = qe - } - override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {} - } + val listener = new SQLMetricsListener withJdbcStatement() { statement => - registerListener(listener) - val result = statement.executeQuery("select 1 as c1") - assert(result.next()) - assert(result.getInt("c1") == 1) + withSparkListener(listener) { + val result = statement.executeQuery("select 1 as c1") + assert(result.next()) + assert(result.getInt("c1") == 1) + } } - KyuubiSparkContextHelper.waitListenerBus(spark) - unregisterListener(listener) - - val metrics = queryExecution.executedPlan.collectLeaves().head.metrics + val metrics = listener.queryExecution.executedPlan.collectLeaves().head.metrics assert(metrics.contains("numOutputRows")) assert(metrics("numOutputRows").value === 1) } @@ -273,7 +254,6 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp withPartitionedTable("t_3") { statement.executeQuery("select * from t_3 limit 10 offset 10") } - KyuubiSparkContextHelper.waitListenerBus(spark) } } // the extra shuffle be introduced if the `offset` > 0 @@ -292,13 +272,49 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp withPartitionedTable("t_3") { statement.executeQuery("select * from t_3 limit 1000") } - KyuubiSparkContextHelper.waitListenerBus(spark) } } // Should be only one stage since there is no shuffle. assert(numStages == 1) } + test("LocalTableScanExec should not trigger job") { + val listener = new JobCountListener + withJdbcStatement("view_1") { statement => + withSparkListener(listener) { + withAllSessions { s => + import s.implicits._ + Seq((1, "a")).toDF("c1", "c2").createOrReplaceTempView("view_1") + val plan = s.sql("select * from view_1").queryExecution.executedPlan + assert(plan.isInstanceOf[LocalTableScanExec]) + } + val resultSet = statement.executeQuery("select * from view_1") + assert(resultSet.next()) + assert(!resultSet.next()) + } + } + assert(listener.numJobs == 0) + } + + test("LocalTableScanExec metrics") { + val listener = new SQLMetricsListener + withJdbcStatement("view_1") { statement => + withSparkListener(listener) { + withAllSessions { s => + import s.implicits._ + Seq((1, "a")).toDF("c1", "c2").createOrReplaceTempView("view_1") + } + val result = statement.executeQuery("select * from view_1") + assert(result.next()) + assert(!result.next()) + } + } + + val metrics = listener.queryExecution.executedPlan.collectLeaves().head.metrics + assert(metrics.contains("numOutputRows")) + assert(metrics("numOutputRows").value === 1) + } + private def checkResultSetFormat(statement: Statement, expectFormat: String): Unit = { val query = s""" @@ -321,32 +337,30 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp assert(resultSet.getString("col") === expect) } - private def registerListener(listener: QueryExecutionListener): Unit = { - // since all the new sessions have their owner listener bus, we should register the listener - // in the current session. - SparkSQLEngine.currentEngine.get - .backendService - .sessionManager - .allSessions() - .foreach(_.asInstanceOf[SparkSessionImpl].spark.listenerManager.register(listener)) - } - - private def unregisterListener(listener: QueryExecutionListener): Unit = { - SparkSQLEngine.currentEngine.get - .backendService - .sessionManager - .allSessions() - .foreach(_.asInstanceOf[SparkSessionImpl].spark.listenerManager.unregister(listener)) + // since all the new sessions have their owner listener bus, we should register the listener + // in the current session. + private def withSparkListener[T](listener: QueryExecutionListener)(body: => T): T = { + withAllSessions(s => s.listenerManager.register(listener)) + try { + val result = body + KyuubiSparkContextHelper.waitListenerBus(spark) + result + } finally { + withAllSessions(s => s.listenerManager.unregister(listener)) + } } + // since all the new sessions have their owner listener bus, we should register the listener + // in the current session. private def withSparkListener[T](listener: SparkListener)(body: => T): T = { withAllSessions(s => s.sparkContext.addSparkListener(listener)) try { - body + val result = body + KyuubiSparkContextHelper.waitListenerBus(spark) + result } finally { withAllSessions(s => s.sparkContext.removeSparkListener(listener)) } - } private def withPartitionedTable[T](viewName: String)(body: => T): T = { @@ -432,6 +446,21 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp .get() staticConfKeys.contains(key) } + + class JobCountListener extends SparkListener { + var numJobs = 0 + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + numJobs += 1 + } + } + + class SQLMetricsListener extends QueryExecutionListener { + var queryExecution: QueryExecution = null + override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { + queryExecution = qe + } + override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {} + } } case class TestData(key: Int, value: String)