[KYUUBI #1832] Fixed: forcedMaxOutputRows extension for subquery
<!-- Thanks for sending a pull request! Here are some tips for you: 1. If this is your first time, please read our contributor guidelines: https://kyuubi.readthedocs.io/en/latest/community/contributions.html 2. If the PR is related to an issue in https://github.com/apache/incubator-kyuubi/issues, add '[KYUUBI #XXXX]' in your PR title, e.g., '[KYUUBI #XXXX] Your PR title ...'. 3. If the PR is unfinished, add '[WIP]' in your PR title, e.g., '[WIP][KYUUBI #XXXX] Your PR title ...'. --> ### _Why are the changes needed?_ <!-- Please clarify why the changes are needed. For instance, 1. If you add a feature, you can talk about the use case of it. 2. If you fix a bug, you can clarify why it is a bug. --> Fixed #1832 ### _How was this patch tested?_ - [x] Add some test cases that check the changes thoroughly including negative and positive cases if possible - [ ] Add screenshots for manual tests if appropriate - [ ] [Run test](https://kyuubi.apache.org/docs/latest/develop_tools/testing.html#running-tests) locally before make a pull request Closes #1833 from iodone/1832-bug. Closes #1832 52769746 [odone] [KYUUBI #1832] Added: Add some test case 61d733dc [odone] [KYUUBI #1832] fixed: spotless d8ee657a [odone] [KYUUBI #1832] fixed: forcedMaxOutputRows extension for subquery efd87a75 [odone] [KYUUBI #1832] fixed: forcedMaxOutputRows extension for subquery Authored-by: odone <odone.zhang@gmail.com> Signed-off-by: ulysses-you <ulyssesyou@apache.org>
This commit is contained in:
parent
0337da6701
commit
5ea0964e26
@ -20,7 +20,7 @@ package org.apache.kyuubi.sql
|
||||
import org.apache.spark.sql.SparkSessionExtensions
|
||||
|
||||
import org.apache.kyuubi.sql.sqlclassification.KyuubiSqlClassification
|
||||
import org.apache.kyuubi.sql.watchdog.{ForcedMaxOutputRowsRule, MarkAggregateOrderRule, MaxPartitionStrategy}
|
||||
import org.apache.kyuubi.sql.watchdog.{ForcedMaxOutputRowsRule, MaxPartitionStrategy}
|
||||
|
||||
// scalastyle:off line.size.limit
|
||||
/**
|
||||
@ -39,9 +39,7 @@ class KyuubiSparkSQLExtension extends (SparkSessionExtensions => Unit) {
|
||||
extensions.injectPostHocResolutionRule(DropIgnoreNonexistent)
|
||||
|
||||
// watchdog extension
|
||||
// a help rule for ForcedMaxOutputRowsRule
|
||||
extensions.injectResolutionRule(MarkAggregateOrderRule)
|
||||
extensions.injectPostHocResolutionRule(ForcedMaxOutputRowsRule)
|
||||
extensions.injectOptimizerRule(ForcedMaxOutputRowsRule)
|
||||
extensions.injectPlannerStrategy(MaxPartitionStrategy)
|
||||
}
|
||||
}
|
||||
|
||||
@ -19,60 +19,7 @@ package org.apache.kyuubi.sql.watchdog
|
||||
|
||||
import org.apache.spark.sql.SparkSession
|
||||
import org.apache.spark.sql.catalyst.plans.logical._
|
||||
import org.apache.spark.sql.catalyst.rules.Rule
|
||||
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
|
||||
|
||||
import org.apache.kyuubi.sql.KyuubiSQLConf
|
||||
|
||||
object ForcedMaxOutputRowsConstraint {
|
||||
val CHILD_AGGREGATE: TreeNodeTag[String] = TreeNodeTag[String]("__kyuubi_child_agg__")
|
||||
val CHILD_AGGREGATE_FLAG: String = "__kyuubi_child_agg__"
|
||||
}
|
||||
|
||||
case class ForcedMaxOutputRowsRule(sparkSession: SparkSession) extends ForcedMaxOutputRowsBase {
|
||||
override protected def isChildAggregate(a: Aggregate): Boolean =
|
||||
a.aggregateExpressions.exists(p =>
|
||||
p.getTagValue(ForcedMaxOutputRowsConstraint.CHILD_AGGREGATE)
|
||||
.contains(ForcedMaxOutputRowsConstraint.CHILD_AGGREGATE_FLAG))
|
||||
}
|
||||
|
||||
/**
|
||||
* After SPARK-35712, we don't need mark child aggregate for spark 3.2.x or higher version,
|
||||
* for more detail, please see https://github.com/apache/spark/pull/32470
|
||||
*/
|
||||
case class MarkAggregateOrderRule(sparkSession: SparkSession) extends Rule[LogicalPlan] {
|
||||
|
||||
private def markChildAggregate(a: Aggregate): Unit = {
|
||||
// mark child aggregate
|
||||
a.aggregateExpressions.filter(_.resolved).foreach(_.setTagValue(
|
||||
ForcedMaxOutputRowsConstraint.CHILD_AGGREGATE,
|
||||
ForcedMaxOutputRowsConstraint.CHILD_AGGREGATE_FLAG))
|
||||
}
|
||||
|
||||
protected def findAndMarkChildAggregate(plan: LogicalPlan): LogicalPlan = plan match {
|
||||
/*
|
||||
* The case mainly process order not aggregate column but grouping column as below
|
||||
* SELECT c1, COUNT(*) as cnt
|
||||
* FROM t1
|
||||
* GROUP BY c1
|
||||
* ORDER BY c1
|
||||
* */
|
||||
case a: Aggregate
|
||||
if a.aggregateExpressions
|
||||
.exists(x => x.resolved && x.name.equals("aggOrder")) =>
|
||||
markChildAggregate(a)
|
||||
plan
|
||||
case _ =>
|
||||
plan.children.foreach(_.foreach {
|
||||
case agg: Aggregate => markChildAggregate(agg)
|
||||
case _ => Unit
|
||||
})
|
||||
plan
|
||||
}
|
||||
|
||||
override def apply(plan: LogicalPlan): LogicalPlan = conf.getConf(
|
||||
KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS) match {
|
||||
case Some(_) => findAndMarkChildAggregate(plan)
|
||||
case _ => plan
|
||||
}
|
||||
override protected def isChildAggregate(a: Aggregate): Boolean = false
|
||||
}
|
||||
|
||||
@ -37,7 +37,7 @@ class KyuubiSparkSQLExtension extends (SparkSessionExtensions => Unit) {
|
||||
extensions.injectPostHocResolutionRule(DropIgnoreNonexistent)
|
||||
|
||||
// watchdog extension
|
||||
extensions.injectPostHocResolutionRule(ForcedMaxOutputRowsRule)
|
||||
extensions.injectOptimizerRule(ForcedMaxOutputRowsRule)
|
||||
extensions.injectPlannerStrategy(MaxPartitionStrategy)
|
||||
}
|
||||
}
|
||||
|
||||
@ -18,7 +18,8 @@
|
||||
package org.apache.kyuubi.sql.watchdog
|
||||
|
||||
import org.apache.spark.sql.SparkSession
|
||||
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, WithCTE}
|
||||
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, CommandResult, LogicalPlan, Union, WithCTE}
|
||||
import org.apache.spark.sql.execution.command.DataWritingCommand
|
||||
|
||||
case class ForcedMaxOutputRowsRule(sparkSession: SparkSession) extends ForcedMaxOutputRowsBase {
|
||||
|
||||
@ -26,7 +27,14 @@ case class ForcedMaxOutputRowsRule(sparkSession: SparkSession) extends ForcedMax
|
||||
|
||||
override protected def canInsertLimitInner(p: LogicalPlan): Boolean = p match {
|
||||
case WithCTE(plan, _) => this.canInsertLimitInner(plan)
|
||||
case plan: LogicalPlan => super.canInsertLimitInner(plan)
|
||||
case plan: LogicalPlan => plan match {
|
||||
case Union(children, _, _) => !children.exists {
|
||||
case _: DataWritingCommand => true
|
||||
case p: CommandResult if p.commandLogicalPlan.isInstanceOf[DataWritingCommand] => true
|
||||
case _ => false
|
||||
}
|
||||
case _ => super.canInsertLimitInner(plan)
|
||||
}
|
||||
}
|
||||
|
||||
override protected def canInsertLimit(p: LogicalPlan, maxOutputRowsOpt: Option[Int]): Boolean = {
|
||||
|
||||
@ -17,7 +17,7 @@
|
||||
|
||||
package org.apache.kyuubi.sql.watchdog
|
||||
|
||||
import org.apache.spark.sql.catalyst.analysis.AnalysisContext
|
||||
import org.apache.spark.sql.catalyst.analysis.{AnalysisContext, MultiInstanceRelation}
|
||||
import org.apache.spark.sql.catalyst.dsl.expressions._
|
||||
import org.apache.spark.sql.catalyst.expressions.Alias
|
||||
import org.apache.spark.sql.catalyst.plans.logical._
|
||||
@ -69,6 +69,7 @@ trait ForcedMaxOutputRowsBase extends Rule[LogicalPlan] {
|
||||
} else {
|
||||
true
|
||||
}
|
||||
case _: MultiInstanceRelation => true
|
||||
case _ => false
|
||||
}
|
||||
|
||||
|
||||
@ -17,7 +17,7 @@
|
||||
|
||||
package org.apache.spark.sql
|
||||
|
||||
import org.apache.spark.sql.catalyst.plans.logical.GlobalLimit
|
||||
import org.apache.spark.sql.catalyst.plans.logical.{GlobalLimit, LogicalPlan}
|
||||
|
||||
import org.apache.kyuubi.sql.KyuubiSQLConf
|
||||
import org.apache.kyuubi.sql.watchdog.MaxPartitionExceedException
|
||||
@ -29,6 +29,7 @@ trait WatchDogSuiteBase extends KyuubiSparkSQLExtensionTest {
|
||||
}
|
||||
|
||||
case class LimitAndExpected(limit: Int, expected: Int)
|
||||
|
||||
val limitAndExpecteds = List(LimitAndExpected(1, 1), LimitAndExpected(11, 10))
|
||||
|
||||
private def checkMaxPartition: Unit = {
|
||||
@ -100,7 +101,7 @@ trait WatchDogSuiteBase extends KyuubiSparkSQLExtensionTest {
|
||||
|SELECT $distinct *
|
||||
|FROM t1
|
||||
|$sort
|
||||
|""".stripMargin).queryExecution.analyzed.isInstanceOf[GlobalLimit])
|
||||
|""".stripMargin).queryExecution.optimizedPlan.isInstanceOf[GlobalLimit])
|
||||
}
|
||||
}
|
||||
|
||||
@ -113,7 +114,7 @@ trait WatchDogSuiteBase extends KyuubiSparkSQLExtensionTest {
|
||||
|FROM t1
|
||||
|$sort
|
||||
|LIMIT $limit
|
||||
|""".stripMargin).queryExecution.analyzed.maxRows.contains(expected))
|
||||
|""".stripMargin).queryExecution.optimizedPlan.maxRows.contains(expected))
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -125,7 +126,7 @@ trait WatchDogSuiteBase extends KyuubiSparkSQLExtensionTest {
|
||||
withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") {
|
||||
|
||||
assert(!sql("SELECT count(*) FROM t1")
|
||||
.queryExecution.analyzed.isInstanceOf[GlobalLimit])
|
||||
.queryExecution.optimizedPlan.isInstanceOf[GlobalLimit])
|
||||
|
||||
val sorts = List("", "ORDER BY cnt", "ORDER BY c1", "ORDER BY cnt, c1", "ORDER BY c1, cnt")
|
||||
val havingConditions = List("", "HAVING cnt > 1")
|
||||
@ -139,7 +140,7 @@ trait WatchDogSuiteBase extends KyuubiSparkSQLExtensionTest {
|
||||
|GROUP BY c1
|
||||
|$having
|
||||
|$sort
|
||||
|""".stripMargin).queryExecution.analyzed.isInstanceOf[GlobalLimit])
|
||||
|""".stripMargin).queryExecution.optimizedPlan.isInstanceOf[GlobalLimit])
|
||||
}
|
||||
}
|
||||
|
||||
@ -154,7 +155,7 @@ trait WatchDogSuiteBase extends KyuubiSparkSQLExtensionTest {
|
||||
|$having
|
||||
|$sort
|
||||
|LIMIT $limit
|
||||
|""".stripMargin).queryExecution.analyzed.maxRows.contains(expected))
|
||||
|""".stripMargin).queryExecution.optimizedPlan.maxRows.contains(expected))
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -191,7 +192,7 @@ trait WatchDogSuiteBase extends KyuubiSparkSQLExtensionTest {
|
||||
|$withQuery
|
||||
|SELECT * FROM t2
|
||||
|$sort
|
||||
|""".stripMargin).queryExecution.analyzed.isInstanceOf[GlobalLimit])
|
||||
|""".stripMargin).queryExecution.optimizedPlan.isInstanceOf[GlobalLimit])
|
||||
}
|
||||
}
|
||||
|
||||
@ -242,7 +243,7 @@ trait WatchDogSuiteBase extends KyuubiSparkSQLExtensionTest {
|
||||
|GROUP BY c1
|
||||
|$having
|
||||
|$sort
|
||||
|""".stripMargin).queryExecution.analyzed.isInstanceOf[GlobalLimit])
|
||||
|""".stripMargin).queryExecution.optimizedPlan.isInstanceOf[GlobalLimit])
|
||||
}
|
||||
}
|
||||
|
||||
@ -281,7 +282,7 @@ trait WatchDogSuiteBase extends KyuubiSparkSQLExtensionTest {
|
||||
|UNION $x
|
||||
|SELECT c1, c2 FROM t3
|
||||
|""".stripMargin)
|
||||
.queryExecution.analyzed.isInstanceOf[GlobalLimit])
|
||||
.queryExecution.optimizedPlan.isInstanceOf[GlobalLimit])
|
||||
}
|
||||
|
||||
val sorts = List("", "ORDER BY cnt", "ORDER BY c1", "ORDER BY cnt, c1", "ORDER BY c1, cnt")
|
||||
@ -308,7 +309,7 @@ trait WatchDogSuiteBase extends KyuubiSparkSQLExtensionTest {
|
||||
|$having
|
||||
|$sort
|
||||
|""".stripMargin)
|
||||
.queryExecution.analyzed.isInstanceOf[GlobalLimit])
|
||||
.queryExecution.optimizedPlan.isInstanceOf[GlobalLimit])
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -323,7 +324,7 @@ trait WatchDogSuiteBase extends KyuubiSparkSQLExtensionTest {
|
||||
|SELECT c1, c2 FROM t3
|
||||
|LIMIT $limit
|
||||
|""".stripMargin)
|
||||
.queryExecution.analyzed.maxRows.contains(expected))
|
||||
.queryExecution.optimizedPlan.maxRows.contains(expected))
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -344,14 +345,14 @@ trait WatchDogSuiteBase extends KyuubiSparkSQLExtensionTest {
|
||||
|SELECT * FROM
|
||||
|tmp_table
|
||||
|""".stripMargin)
|
||||
.queryExecution.analyzed.isInstanceOf[GlobalLimit])
|
||||
.queryExecution.optimizedPlan.isInstanceOf[GlobalLimit])
|
||||
|
||||
assert(sql(
|
||||
s"""
|
||||
|SELECT * FROM
|
||||
|tmp_view
|
||||
|""".stripMargin)
|
||||
.queryExecution.analyzed.maxRows.contains(3))
|
||||
.queryExecution.optimizedPlan.maxRows.contains(3))
|
||||
|
||||
assert(sql(
|
||||
s"""
|
||||
@ -359,7 +360,7 @@ trait WatchDogSuiteBase extends KyuubiSparkSQLExtensionTest {
|
||||
|tmp_view
|
||||
|limit 11
|
||||
|""".stripMargin)
|
||||
.queryExecution.analyzed.maxRows.contains(3))
|
||||
.queryExecution.optimizedPlan.maxRows.contains(3))
|
||||
|
||||
assert(sql(
|
||||
s"""
|
||||
@ -394,7 +395,7 @@ trait WatchDogSuiteBase extends KyuubiSparkSQLExtensionTest {
|
||||
|insert into $multiInsertTableName1 select * limit 2
|
||||
|insert into $multiInsertTableName2 select *
|
||||
|""".stripMargin)
|
||||
.queryExecution.analyzed.isInstanceOf[GlobalLimit])
|
||||
.queryExecution.optimizedPlan.isInstanceOf[GlobalLimit])
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -411,8 +412,40 @@ trait WatchDogSuiteBase extends KyuubiSparkSQLExtensionTest {
|
||||
|FROM tmp_table
|
||||
|DISTRIBUTE BY a
|
||||
|""".stripMargin)
|
||||
.queryExecution.analyzed.isInstanceOf[GlobalLimit])
|
||||
.queryExecution.optimizedPlan.isInstanceOf[GlobalLimit])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("test watchdog: Subquery for forceMaxOutputRows") {
|
||||
withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "1") {
|
||||
withTable("tmp_table1") {
|
||||
sql("CREATE TABLE spark_catalog.`default`.tmp_table1(KEY INT, VALUE STRING) USING PARQUET")
|
||||
sql("INSERT INTO TABLE spark_catalog.`default`.tmp_table1 " +
|
||||
"VALUES (1, 'aa'),(2,'bb'),(3, 'cc'),(4,'aa'),(5,'cc'),(6, 'aa')")
|
||||
assert(
|
||||
sql("select * from tmp_table1").queryExecution.optimizedPlan.isInstanceOf[GlobalLimit])
|
||||
val testSqlText =
|
||||
"""
|
||||
|select count(*)
|
||||
|from tmp_table1
|
||||
|where tmp_table1.key in (
|
||||
|select distinct tmp_table1.key
|
||||
|from tmp_table1
|
||||
|where tmp_table1.value = "aa"
|
||||
|)
|
||||
|""".stripMargin
|
||||
val plan = sql(testSqlText).queryExecution.optimizedPlan
|
||||
assert(!findGlobalLimit(plan))
|
||||
checkAnswer(sql(testSqlText), Row(3) :: Nil)
|
||||
}
|
||||
|
||||
def findGlobalLimit(plan: LogicalPlan): Boolean = plan match {
|
||||
case _: GlobalLimit => true
|
||||
case p if p.children.isEmpty => false
|
||||
case p => p.children.exists(findGlobalLimit)
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user