[KYUUBI #4710] [ARROW] LocalTableScanExec should not trigger job
### _Why are the changes needed?_ Before this PR:  After this PR:  ### _How was this patch tested?_ - [ ] Add some test cases that check the changes thoroughly including negative and positive cases if possible - [ ] Add screenshots for manual tests if appropriate - [x] [Run test](https://kyuubi.readthedocs.io/en/master/develop_tools/testing.html#running-tests) locally before make a pull request Closes #4710 from cfmcgrady/arrow-local-table-scan-exec. Closes #4710 e4c2891d1 [Fu Chen] fix ci 1049200ea [Fu Chen] fix style 4d45fe8b7 [Fu Chen] add assert b8bd5b5a7 [Fu Chen] LocalTableScanExec should not trigger job Authored-by: Fu Chen <cfmcgrady@gmail.com> Signed-off-by: Cheng Pan <chengpan@apache.org>
This commit is contained in:
parent
f4a56efec2
commit
9086b28bc3
@ -203,7 +203,7 @@ object KyuubiArrowConverters extends SQLConfHelper with Logging {
|
||||
* Different from [[org.apache.spark.sql.execution.arrow.ArrowConverters.toBatchIterator]],
|
||||
* each output arrow batch contains this batch row count.
|
||||
*/
|
||||
private def toBatchIterator(
|
||||
def toBatchIterator(
|
||||
rowIter: Iterator[InternalRow],
|
||||
schema: StructType,
|
||||
maxRecordsPerBatch: Long,
|
||||
@ -226,6 +226,7 @@ object KyuubiArrowConverters extends SQLConfHelper with Logging {
|
||||
* with two key differences:
|
||||
* 1. there is no requirement to write the schema at the batch header
|
||||
* 2. iteration halts when `rowCount` equals `limit`
|
||||
* Note that `limit < 0` means no limit, and return all rows the in the iterator.
|
||||
*/
|
||||
private[sql] class ArrowBatchIterator(
|
||||
rowIter: Iterator[InternalRow],
|
||||
@ -255,7 +256,7 @@ object KyuubiArrowConverters extends SQLConfHelper with Logging {
|
||||
}
|
||||
}
|
||||
|
||||
override def hasNext: Boolean = (rowIter.hasNext && rowCount < limit) || {
|
||||
override def hasNext: Boolean = (rowIter.hasNext && (rowCount < limit || limit < 0)) || {
|
||||
root.close()
|
||||
allocator.close()
|
||||
false
|
||||
@ -283,7 +284,8 @@ object KyuubiArrowConverters extends SQLConfHelper with Logging {
|
||||
// If the size of rows are 0 or negative, unlimit it.
|
||||
maxRecordsPerBatch <= 0 ||
|
||||
rowCountInLastBatch < maxRecordsPerBatch ||
|
||||
rowCount < limit)) {
|
||||
rowCount < limit ||
|
||||
limit < 0)) {
|
||||
val row = rowIter.next()
|
||||
arrowWriter.write(row)
|
||||
estimatedBatchSize += (row match {
|
||||
|
||||
@ -24,7 +24,7 @@ import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.network.util.{ByteUnit, JavaUtils}
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
|
||||
import org.apache.spark.sql.execution.{CollectLimitExec, SparkPlan, SQLExecution}
|
||||
import org.apache.spark.sql.execution.{CollectLimitExec, LocalTableScanExec, SparkPlan, SQLExecution}
|
||||
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
|
||||
import org.apache.spark.sql.execution.arrow.{ArrowConverters, KyuubiArrowConverters}
|
||||
import org.apache.spark.sql.functions._
|
||||
@ -51,6 +51,8 @@ object SparkDatasetHelper extends Logging {
|
||||
doCollectLimit(collectLimit)
|
||||
case collectLimit: CollectLimitExec if collectLimit.limit < 0 =>
|
||||
executeArrowBatchCollect(collectLimit.child)
|
||||
case localTableScan: LocalTableScanExec =>
|
||||
doLocalTableScan(localTableScan)
|
||||
case plan: SparkPlan =>
|
||||
toArrowBatchRdd(plan).collect()
|
||||
}
|
||||
@ -175,6 +177,17 @@ object SparkDatasetHelper extends Logging {
|
||||
result.toArray
|
||||
}
|
||||
|
||||
def doLocalTableScan(localTableScan: LocalTableScanExec): Array[Array[Byte]] = {
|
||||
localTableScan.longMetric("numOutputRows").add(localTableScan.rows.size)
|
||||
KyuubiArrowConverters.toBatchIterator(
|
||||
localTableScan.rows.iterator,
|
||||
localTableScan.schema,
|
||||
SparkSession.active.sessionState.conf.arrowMaxRecordsPerBatch,
|
||||
maxBatchSize,
|
||||
-1,
|
||||
SparkSession.active.sessionState.conf.sessionLocalTimeZone).toArray
|
||||
}
|
||||
|
||||
/**
|
||||
* This method provides a reflection-based implementation of
|
||||
* [[AdaptiveSparkPlanExec.finalPhysicalPlan]] that enables us to adapt to the Spark runtime
|
||||
|
||||
@ -23,8 +23,8 @@ import java.util.{Set => JSet}
|
||||
import org.apache.spark.KyuubiSparkContextHelper
|
||||
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
|
||||
import org.apache.spark.sql.{QueryTest, Row, SparkSession}
|
||||
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
|
||||
import org.apache.spark.sql.execution.{CollectLimitExec, QueryExecution, SparkPlan}
|
||||
import org.apache.spark.sql.catalyst.plans.logical.Project
|
||||
import org.apache.spark.sql.execution.{CollectLimitExec, LocalTableScanExec, QueryExecution, SparkPlan}
|
||||
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
|
||||
import org.apache.spark.sql.execution.arrow.KyuubiArrowConverters
|
||||
import org.apache.spark.sql.execution.exchange.Exchange
|
||||
@ -104,48 +104,29 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp
|
||||
}
|
||||
|
||||
test("assign a new execution id for arrow-based result") {
|
||||
var plan: LogicalPlan = null
|
||||
|
||||
val listener = new QueryExecutionListener {
|
||||
override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = {
|
||||
plan = qe.analyzed
|
||||
}
|
||||
override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {}
|
||||
}
|
||||
val listener = new SQLMetricsListener
|
||||
withJdbcStatement() { statement =>
|
||||
// since all the new sessions have their owner listener bus, we should register the listener
|
||||
// in the current session.
|
||||
registerListener(listener)
|
||||
|
||||
val result = statement.executeQuery("select 1 as c1")
|
||||
assert(result.next())
|
||||
assert(result.getInt("c1") == 1)
|
||||
withSparkListener(listener) {
|
||||
val result = statement.executeQuery("select 1 as c1")
|
||||
assert(result.next())
|
||||
assert(result.getInt("c1") == 1)
|
||||
}
|
||||
}
|
||||
KyuubiSparkContextHelper.waitListenerBus(spark)
|
||||
unregisterListener(listener)
|
||||
assert(plan.isInstanceOf[Project])
|
||||
|
||||
assert(listener.queryExecution.analyzed.isInstanceOf[Project])
|
||||
}
|
||||
|
||||
test("arrow-based query metrics") {
|
||||
var queryExecution: QueryExecution = null
|
||||
|
||||
val listener = new QueryExecutionListener {
|
||||
override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = {
|
||||
queryExecution = qe
|
||||
}
|
||||
override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {}
|
||||
}
|
||||
val listener = new SQLMetricsListener
|
||||
withJdbcStatement() { statement =>
|
||||
registerListener(listener)
|
||||
val result = statement.executeQuery("select 1 as c1")
|
||||
assert(result.next())
|
||||
assert(result.getInt("c1") == 1)
|
||||
withSparkListener(listener) {
|
||||
val result = statement.executeQuery("select 1 as c1")
|
||||
assert(result.next())
|
||||
assert(result.getInt("c1") == 1)
|
||||
}
|
||||
}
|
||||
|
||||
KyuubiSparkContextHelper.waitListenerBus(spark)
|
||||
unregisterListener(listener)
|
||||
|
||||
val metrics = queryExecution.executedPlan.collectLeaves().head.metrics
|
||||
val metrics = listener.queryExecution.executedPlan.collectLeaves().head.metrics
|
||||
assert(metrics.contains("numOutputRows"))
|
||||
assert(metrics("numOutputRows").value === 1)
|
||||
}
|
||||
@ -273,7 +254,6 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp
|
||||
withPartitionedTable("t_3") {
|
||||
statement.executeQuery("select * from t_3 limit 10 offset 10")
|
||||
}
|
||||
KyuubiSparkContextHelper.waitListenerBus(spark)
|
||||
}
|
||||
}
|
||||
// the extra shuffle be introduced if the `offset` > 0
|
||||
@ -292,13 +272,49 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp
|
||||
withPartitionedTable("t_3") {
|
||||
statement.executeQuery("select * from t_3 limit 1000")
|
||||
}
|
||||
KyuubiSparkContextHelper.waitListenerBus(spark)
|
||||
}
|
||||
}
|
||||
// Should be only one stage since there is no shuffle.
|
||||
assert(numStages == 1)
|
||||
}
|
||||
|
||||
test("LocalTableScanExec should not trigger job") {
|
||||
val listener = new JobCountListener
|
||||
withJdbcStatement("view_1") { statement =>
|
||||
withSparkListener(listener) {
|
||||
withAllSessions { s =>
|
||||
import s.implicits._
|
||||
Seq((1, "a")).toDF("c1", "c2").createOrReplaceTempView("view_1")
|
||||
val plan = s.sql("select * from view_1").queryExecution.executedPlan
|
||||
assert(plan.isInstanceOf[LocalTableScanExec])
|
||||
}
|
||||
val resultSet = statement.executeQuery("select * from view_1")
|
||||
assert(resultSet.next())
|
||||
assert(!resultSet.next())
|
||||
}
|
||||
}
|
||||
assert(listener.numJobs == 0)
|
||||
}
|
||||
|
||||
test("LocalTableScanExec metrics") {
|
||||
val listener = new SQLMetricsListener
|
||||
withJdbcStatement("view_1") { statement =>
|
||||
withSparkListener(listener) {
|
||||
withAllSessions { s =>
|
||||
import s.implicits._
|
||||
Seq((1, "a")).toDF("c1", "c2").createOrReplaceTempView("view_1")
|
||||
}
|
||||
val result = statement.executeQuery("select * from view_1")
|
||||
assert(result.next())
|
||||
assert(!result.next())
|
||||
}
|
||||
}
|
||||
|
||||
val metrics = listener.queryExecution.executedPlan.collectLeaves().head.metrics
|
||||
assert(metrics.contains("numOutputRows"))
|
||||
assert(metrics("numOutputRows").value === 1)
|
||||
}
|
||||
|
||||
private def checkResultSetFormat(statement: Statement, expectFormat: String): Unit = {
|
||||
val query =
|
||||
s"""
|
||||
@ -321,32 +337,30 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp
|
||||
assert(resultSet.getString("col") === expect)
|
||||
}
|
||||
|
||||
private def registerListener(listener: QueryExecutionListener): Unit = {
|
||||
// since all the new sessions have their owner listener bus, we should register the listener
|
||||
// in the current session.
|
||||
SparkSQLEngine.currentEngine.get
|
||||
.backendService
|
||||
.sessionManager
|
||||
.allSessions()
|
||||
.foreach(_.asInstanceOf[SparkSessionImpl].spark.listenerManager.register(listener))
|
||||
}
|
||||
|
||||
private def unregisterListener(listener: QueryExecutionListener): Unit = {
|
||||
SparkSQLEngine.currentEngine.get
|
||||
.backendService
|
||||
.sessionManager
|
||||
.allSessions()
|
||||
.foreach(_.asInstanceOf[SparkSessionImpl].spark.listenerManager.unregister(listener))
|
||||
// since all the new sessions have their owner listener bus, we should register the listener
|
||||
// in the current session.
|
||||
private def withSparkListener[T](listener: QueryExecutionListener)(body: => T): T = {
|
||||
withAllSessions(s => s.listenerManager.register(listener))
|
||||
try {
|
||||
val result = body
|
||||
KyuubiSparkContextHelper.waitListenerBus(spark)
|
||||
result
|
||||
} finally {
|
||||
withAllSessions(s => s.listenerManager.unregister(listener))
|
||||
}
|
||||
}
|
||||
|
||||
// since all the new sessions have their owner listener bus, we should register the listener
|
||||
// in the current session.
|
||||
private def withSparkListener[T](listener: SparkListener)(body: => T): T = {
|
||||
withAllSessions(s => s.sparkContext.addSparkListener(listener))
|
||||
try {
|
||||
body
|
||||
val result = body
|
||||
KyuubiSparkContextHelper.waitListenerBus(spark)
|
||||
result
|
||||
} finally {
|
||||
withAllSessions(s => s.sparkContext.removeSparkListener(listener))
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
private def withPartitionedTable[T](viewName: String)(body: => T): T = {
|
||||
@ -432,6 +446,21 @@ class SparkArrowbasedOperationSuite extends WithSparkSQLEngine with SparkDataTyp
|
||||
.get()
|
||||
staticConfKeys.contains(key)
|
||||
}
|
||||
|
||||
class JobCountListener extends SparkListener {
|
||||
var numJobs = 0
|
||||
override def onJobStart(jobStart: SparkListenerJobStart): Unit = {
|
||||
numJobs += 1
|
||||
}
|
||||
}
|
||||
|
||||
class SQLMetricsListener extends QueryExecutionListener {
|
||||
var queryExecution: QueryExecution = null
|
||||
override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = {
|
||||
queryExecution = qe
|
||||
}
|
||||
override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {}
|
||||
}
|
||||
}
|
||||
|
||||
case class TestData(key: Int, value: String)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user