[KYUUBI #1832] Fixed: forcedMaxOutputRows extension for subquery

<!--
Thanks for sending a pull request!

Here are some tips for you:
  1. If this is your first time, please read our contributor guidelines: https://kyuubi.readthedocs.io/en/latest/community/contributions.html
  2. If the PR is related to an issue in https://github.com/apache/incubator-kyuubi/issues, add '[KYUUBI #XXXX]' in your PR title, e.g., '[KYUUBI #XXXX] Your PR title ...'.
  3. If the PR is unfinished, add '[WIP]' in your PR title, e.g., '[WIP][KYUUBI #XXXX] Your PR title ...'.
-->

### _Why are the changes needed?_
<!--
Please clarify why the changes are needed. For instance,
  1. If you add a feature, you can talk about the use case of it.
  2. If you fix a bug, you can clarify why it is a bug.
-->
Fixed #1832

### _How was this patch tested?_
- [x] Add some test cases that check the changes thoroughly including negative and positive cases if possible

- [ ] Add screenshots for manual tests if appropriate

- [ ] [Run test](https://kyuubi.apache.org/docs/latest/develop_tools/testing.html#running-tests) locally before make a pull request

Closes #1833 from iodone/1832-bug.

Closes #1832

52769746 [odone] [KYUUBI #1832] Added: Add some test case
61d733dc [odone] [KYUUBI #1832] fixed: spotless
d8ee657a [odone] [KYUUBI #1832] fixed: forcedMaxOutputRows extension for subquery
efd87a75 [odone] [KYUUBI #1832] fixed: forcedMaxOutputRows extension for subquery

Authored-by: odone <odone.zhang@gmail.com>
Signed-off-by: ulysses-you <ulyssesyou@apache.org>
This commit is contained in:
odone 2022-01-26 09:32:54 +08:00 committed by ulysses-you
parent 0337da6701
commit 5ea0964e26
No known key found for this signature in database
GPG Key ID: 4C500BC62D576766
6 changed files with 65 additions and 78 deletions

View File

@ -20,7 +20,7 @@ package org.apache.kyuubi.sql
import org.apache.spark.sql.SparkSessionExtensions
import org.apache.kyuubi.sql.sqlclassification.KyuubiSqlClassification
import org.apache.kyuubi.sql.watchdog.{ForcedMaxOutputRowsRule, MarkAggregateOrderRule, MaxPartitionStrategy}
import org.apache.kyuubi.sql.watchdog.{ForcedMaxOutputRowsRule, MaxPartitionStrategy}
// scalastyle:off line.size.limit
/**
@ -39,9 +39,7 @@ class KyuubiSparkSQLExtension extends (SparkSessionExtensions => Unit) {
extensions.injectPostHocResolutionRule(DropIgnoreNonexistent)
// watchdog extension
// a help rule for ForcedMaxOutputRowsRule
extensions.injectResolutionRule(MarkAggregateOrderRule)
extensions.injectPostHocResolutionRule(ForcedMaxOutputRowsRule)
extensions.injectOptimizerRule(ForcedMaxOutputRowsRule)
extensions.injectPlannerStrategy(MaxPartitionStrategy)
}
}

View File

@ -19,60 +19,7 @@ package org.apache.kyuubi.sql.watchdog
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.kyuubi.sql.KyuubiSQLConf
object ForcedMaxOutputRowsConstraint {
val CHILD_AGGREGATE: TreeNodeTag[String] = TreeNodeTag[String]("__kyuubi_child_agg__")
val CHILD_AGGREGATE_FLAG: String = "__kyuubi_child_agg__"
}
case class ForcedMaxOutputRowsRule(sparkSession: SparkSession) extends ForcedMaxOutputRowsBase {
override protected def isChildAggregate(a: Aggregate): Boolean =
a.aggregateExpressions.exists(p =>
p.getTagValue(ForcedMaxOutputRowsConstraint.CHILD_AGGREGATE)
.contains(ForcedMaxOutputRowsConstraint.CHILD_AGGREGATE_FLAG))
}
/**
* After SPARK-35712, we don't need mark child aggregate for spark 3.2.x or higher version,
* for more detail, please see https://github.com/apache/spark/pull/32470
*/
case class MarkAggregateOrderRule(sparkSession: SparkSession) extends Rule[LogicalPlan] {
private def markChildAggregate(a: Aggregate): Unit = {
// mark child aggregate
a.aggregateExpressions.filter(_.resolved).foreach(_.setTagValue(
ForcedMaxOutputRowsConstraint.CHILD_AGGREGATE,
ForcedMaxOutputRowsConstraint.CHILD_AGGREGATE_FLAG))
}
protected def findAndMarkChildAggregate(plan: LogicalPlan): LogicalPlan = plan match {
/*
* The case mainly process order not aggregate column but grouping column as below
* SELECT c1, COUNT(*) as cnt
* FROM t1
* GROUP BY c1
* ORDER BY c1
* */
case a: Aggregate
if a.aggregateExpressions
.exists(x => x.resolved && x.name.equals("aggOrder")) =>
markChildAggregate(a)
plan
case _ =>
plan.children.foreach(_.foreach {
case agg: Aggregate => markChildAggregate(agg)
case _ => Unit
})
plan
}
override def apply(plan: LogicalPlan): LogicalPlan = conf.getConf(
KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS) match {
case Some(_) => findAndMarkChildAggregate(plan)
case _ => plan
}
override protected def isChildAggregate(a: Aggregate): Boolean = false
}

View File

@ -37,7 +37,7 @@ class KyuubiSparkSQLExtension extends (SparkSessionExtensions => Unit) {
extensions.injectPostHocResolutionRule(DropIgnoreNonexistent)
// watchdog extension
extensions.injectPostHocResolutionRule(ForcedMaxOutputRowsRule)
extensions.injectOptimizerRule(ForcedMaxOutputRowsRule)
extensions.injectPlannerStrategy(MaxPartitionStrategy)
}
}

View File

@ -18,7 +18,8 @@
package org.apache.kyuubi.sql.watchdog
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, WithCTE}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, CommandResult, LogicalPlan, Union, WithCTE}
import org.apache.spark.sql.execution.command.DataWritingCommand
case class ForcedMaxOutputRowsRule(sparkSession: SparkSession) extends ForcedMaxOutputRowsBase {
@ -26,7 +27,14 @@ case class ForcedMaxOutputRowsRule(sparkSession: SparkSession) extends ForcedMax
override protected def canInsertLimitInner(p: LogicalPlan): Boolean = p match {
case WithCTE(plan, _) => this.canInsertLimitInner(plan)
case plan: LogicalPlan => super.canInsertLimitInner(plan)
case plan: LogicalPlan => plan match {
case Union(children, _, _) => !children.exists {
case _: DataWritingCommand => true
case p: CommandResult if p.commandLogicalPlan.isInstanceOf[DataWritingCommand] => true
case _ => false
}
case _ => super.canInsertLimitInner(plan)
}
}
override protected def canInsertLimit(p: LogicalPlan, maxOutputRowsOpt: Option[Int]): Boolean = {

View File

@ -17,7 +17,7 @@
package org.apache.kyuubi.sql.watchdog
import org.apache.spark.sql.catalyst.analysis.AnalysisContext
import org.apache.spark.sql.catalyst.analysis.{AnalysisContext, MultiInstanceRelation}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.Alias
import org.apache.spark.sql.catalyst.plans.logical._
@ -69,6 +69,7 @@ trait ForcedMaxOutputRowsBase extends Rule[LogicalPlan] {
} else {
true
}
case _: MultiInstanceRelation => true
case _ => false
}

View File

@ -17,7 +17,7 @@
package org.apache.spark.sql
import org.apache.spark.sql.catalyst.plans.logical.GlobalLimit
import org.apache.spark.sql.catalyst.plans.logical.{GlobalLimit, LogicalPlan}
import org.apache.kyuubi.sql.KyuubiSQLConf
import org.apache.kyuubi.sql.watchdog.MaxPartitionExceedException
@ -29,6 +29,7 @@ trait WatchDogSuiteBase extends KyuubiSparkSQLExtensionTest {
}
case class LimitAndExpected(limit: Int, expected: Int)
val limitAndExpecteds = List(LimitAndExpected(1, 1), LimitAndExpected(11, 10))
private def checkMaxPartition: Unit = {
@ -100,7 +101,7 @@ trait WatchDogSuiteBase extends KyuubiSparkSQLExtensionTest {
|SELECT $distinct *
|FROM t1
|$sort
|""".stripMargin).queryExecution.analyzed.isInstanceOf[GlobalLimit])
|""".stripMargin).queryExecution.optimizedPlan.isInstanceOf[GlobalLimit])
}
}
@ -113,7 +114,7 @@ trait WatchDogSuiteBase extends KyuubiSparkSQLExtensionTest {
|FROM t1
|$sort
|LIMIT $limit
|""".stripMargin).queryExecution.analyzed.maxRows.contains(expected))
|""".stripMargin).queryExecution.optimizedPlan.maxRows.contains(expected))
}
}
}
@ -125,7 +126,7 @@ trait WatchDogSuiteBase extends KyuubiSparkSQLExtensionTest {
withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") {
assert(!sql("SELECT count(*) FROM t1")
.queryExecution.analyzed.isInstanceOf[GlobalLimit])
.queryExecution.optimizedPlan.isInstanceOf[GlobalLimit])
val sorts = List("", "ORDER BY cnt", "ORDER BY c1", "ORDER BY cnt, c1", "ORDER BY c1, cnt")
val havingConditions = List("", "HAVING cnt > 1")
@ -139,7 +140,7 @@ trait WatchDogSuiteBase extends KyuubiSparkSQLExtensionTest {
|GROUP BY c1
|$having
|$sort
|""".stripMargin).queryExecution.analyzed.isInstanceOf[GlobalLimit])
|""".stripMargin).queryExecution.optimizedPlan.isInstanceOf[GlobalLimit])
}
}
@ -154,7 +155,7 @@ trait WatchDogSuiteBase extends KyuubiSparkSQLExtensionTest {
|$having
|$sort
|LIMIT $limit
|""".stripMargin).queryExecution.analyzed.maxRows.contains(expected))
|""".stripMargin).queryExecution.optimizedPlan.maxRows.contains(expected))
}
}
}
@ -191,7 +192,7 @@ trait WatchDogSuiteBase extends KyuubiSparkSQLExtensionTest {
|$withQuery
|SELECT * FROM t2
|$sort
|""".stripMargin).queryExecution.analyzed.isInstanceOf[GlobalLimit])
|""".stripMargin).queryExecution.optimizedPlan.isInstanceOf[GlobalLimit])
}
}
@ -242,7 +243,7 @@ trait WatchDogSuiteBase extends KyuubiSparkSQLExtensionTest {
|GROUP BY c1
|$having
|$sort
|""".stripMargin).queryExecution.analyzed.isInstanceOf[GlobalLimit])
|""".stripMargin).queryExecution.optimizedPlan.isInstanceOf[GlobalLimit])
}
}
@ -281,7 +282,7 @@ trait WatchDogSuiteBase extends KyuubiSparkSQLExtensionTest {
|UNION $x
|SELECT c1, c2 FROM t3
|""".stripMargin)
.queryExecution.analyzed.isInstanceOf[GlobalLimit])
.queryExecution.optimizedPlan.isInstanceOf[GlobalLimit])
}
val sorts = List("", "ORDER BY cnt", "ORDER BY c1", "ORDER BY cnt, c1", "ORDER BY c1, cnt")
@ -308,7 +309,7 @@ trait WatchDogSuiteBase extends KyuubiSparkSQLExtensionTest {
|$having
|$sort
|""".stripMargin)
.queryExecution.analyzed.isInstanceOf[GlobalLimit])
.queryExecution.optimizedPlan.isInstanceOf[GlobalLimit])
}
}
}
@ -323,7 +324,7 @@ trait WatchDogSuiteBase extends KyuubiSparkSQLExtensionTest {
|SELECT c1, c2 FROM t3
|LIMIT $limit
|""".stripMargin)
.queryExecution.analyzed.maxRows.contains(expected))
.queryExecution.optimizedPlan.maxRows.contains(expected))
}
}
}
@ -344,14 +345,14 @@ trait WatchDogSuiteBase extends KyuubiSparkSQLExtensionTest {
|SELECT * FROM
|tmp_table
|""".stripMargin)
.queryExecution.analyzed.isInstanceOf[GlobalLimit])
.queryExecution.optimizedPlan.isInstanceOf[GlobalLimit])
assert(sql(
s"""
|SELECT * FROM
|tmp_view
|""".stripMargin)
.queryExecution.analyzed.maxRows.contains(3))
.queryExecution.optimizedPlan.maxRows.contains(3))
assert(sql(
s"""
@ -359,7 +360,7 @@ trait WatchDogSuiteBase extends KyuubiSparkSQLExtensionTest {
|tmp_view
|limit 11
|""".stripMargin)
.queryExecution.analyzed.maxRows.contains(3))
.queryExecution.optimizedPlan.maxRows.contains(3))
assert(sql(
s"""
@ -394,7 +395,7 @@ trait WatchDogSuiteBase extends KyuubiSparkSQLExtensionTest {
|insert into $multiInsertTableName1 select * limit 2
|insert into $multiInsertTableName2 select *
|""".stripMargin)
.queryExecution.analyzed.isInstanceOf[GlobalLimit])
.queryExecution.optimizedPlan.isInstanceOf[GlobalLimit])
}
}
}
@ -411,8 +412,40 @@ trait WatchDogSuiteBase extends KyuubiSparkSQLExtensionTest {
|FROM tmp_table
|DISTRIBUTE BY a
|""".stripMargin)
.queryExecution.analyzed.isInstanceOf[GlobalLimit])
.queryExecution.optimizedPlan.isInstanceOf[GlobalLimit])
}
}
}
test("test watchdog: Subquery for forceMaxOutputRows") {
withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "1") {
withTable("tmp_table1") {
sql("CREATE TABLE spark_catalog.`default`.tmp_table1(KEY INT, VALUE STRING) USING PARQUET")
sql("INSERT INTO TABLE spark_catalog.`default`.tmp_table1 " +
"VALUES (1, 'aa'),(2,'bb'),(3, 'cc'),(4,'aa'),(5,'cc'),(6, 'aa')")
assert(
sql("select * from tmp_table1").queryExecution.optimizedPlan.isInstanceOf[GlobalLimit])
val testSqlText =
"""
|select count(*)
|from tmp_table1
|where tmp_table1.key in (
|select distinct tmp_table1.key
|from tmp_table1
|where tmp_table1.value = "aa"
|)
|""".stripMargin
val plan = sql(testSqlText).queryExecution.optimizedPlan
assert(!findGlobalLimit(plan))
checkAnswer(sql(testSqlText), Row(3) :: Nil)
}
def findGlobalLimit(plan: LogicalPlan): Boolean = plan match {
case _: GlobalLimit => true
case p if p.children.isEmpty => false
case p => p.children.exists(findGlobalLimit)
}
}
}
}