[KYUUBI #1085][FOLLOWUP] Fix-Enforce maxOutputRows for aggregate with having statement

<!--
Thanks for sending a pull request!

Here are some tips for you:
  1. If this is your first time, please read our contributor guidelines: https://kyuubi.readthedocs.io/en/latest/community/contributions.html
  2. If the PR is related to an issue in https://github.com/apache/incubator-kyuubi/issues, add '[KYUUBI #XXXX]' in your PR title, e.g., '[KYUUBI #XXXX] Your PR title ...'.
  3. If the PR is unfinished, add '[WIP]' in your PR title, e.g., '[WIP][KYUUBI #XXXX] Your PR title ...'.
-->

### _Why are the changes needed?_
<!--
Please clarify why the changes are needed. For instance,
  1. If you add a feature, you can talk about the use case of it.
  2. If you fix a bug, you can clarify why it is a bug.
-->
Support `Union` case as below
```
SELECT * FROM t1
UNION [ALL]
SELECT * FROM t2
```

Support `Distinct` case as below
```
SELECT DISTINCT * FROM t1
```

Fix The bug of watchdog with maxOutputRows happens in this situation as below

Having and Sort
```
SELECT c1, COUNT(c2) AS cnt
FROM t1
GROUP BY c1
HAVING cnt > 0
[ORDER BY c1, [c2 ...]]
```

It throws Exception as
```
org.apache.spark.sql.catalyst.plans.logical.GlobalLimit cannot be cast to org.apache.spark.sql.catalyst.plans.logical.Aggregate
java.lang.ClassCastException: org.apache.spark.sql.catalyst.plans.logical.GlobalLimit cannot be cast to org.apache.spark.sql.catalyst.plans.logical.Aggregate
	at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveAggregateFunctions$.resolvedAggregateFilter$1(Analyzer.scala:2451)
	at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveAggregateFunctions$.resolveFilterCondInAggregate(Analyzer.scala:2460)
	at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveAggregateFunctions$.resolveHaving(Analyzer.scala:2496)
	at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveAggregateFunctions$$anonfun$apply$21.applyOrElse(Analyzer.scala:2353)
	at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveAggregateFunctions$$anonfun$apply$21.applyOrElse(Analyzer.scala:2345)
	at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper.$anonfun$resolveOperatorsUp$3(AnalysisHelper.scala:90)
	at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:74)
	at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper.$anonfun$resolveOperatorsUp$1(AnalysisHelper.scala:90)
	at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper$.allowInvokingTransformsInAnalyzer(AnalysisHelper.scala:221)
	at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper.resolveOperatorsUp(AnalysisHelper.scala:86)
	at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper.resolveOperatorsUp$(AnalysisHelper.scala:84)
	at org.apache.spark.sql.catalyst.plans.logical.LogicalPlan.resolveOperatorsUp(LogicalPlan.scala:29)
	at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveAggregateFunctions$.apply(Analyzer.scala:2345)
	at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveAggregateFunctions$.apply(Analyzer.scala:2344)
	at org.apache.spark.sql.catalyst.rules.RuleExecutor.$anonfun$execute$2(RuleExecutor.scala:216)
	at scala.collection.LinearSeqOptimized.foldLeft(LinearSeqOptimized.scala:126)
	at scala.collection.LinearSeqOptimized.foldLeft$(LinearSeqOptimized.scala:122)
	at scala.collection.immutable.List.foldLeft(List.scala:91)
	at org.apache.spark.sql.catalyst.rules.RuleExecutor.$anonfun$execute$1(RuleExecutor.scala:213)
	at org.apache.spark.sql.catalyst.rules.RuleExecutor.$anonfun$execute$1$adapted(RuleExecutor.scala:205)
	at scala.collection.immutable.List.foreach(List.scala:431)
	at org.apache.spark.sql.catalyst.rules.RuleExecutor.execute(RuleExecutor.scala:205)
	at org.apache.spark.sql.catalyst.analysis.Analyzer.org$apache$spark$sql$catalyst$analysis$Analyzer$$executeSameContext(Analyzer.scala:196)
	at org.apache.spark.sql.catalyst.analysis.Analyzer.execute(Analyzer.scala:190)
	at org.apache.spark.sql.catalyst.analysis.Analyzer.execute(Analyzer.scala:155)
	at org.apache.spark.sql.catalyst.rules.RuleExecutor.$anonfun$executeAndTrack$1(RuleExecutor.scala:183)
	at org.apache.spark.sql.catalyst.QueryPlanningTracker$.withTracker(QueryPlanningTracker.scala:88)
	at org.apache.spark.sql.catalyst.rules.RuleExecutor.executeAndTrack(RuleExecutor.scala:183)
	at org.apache.spark.sql.catalyst.analysis.Analyzer.$anonfun$executeAndCheck$1(Analyzer.scala:174)
	at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper$.markInAnalyzer(AnalysisHelper.scala:228)
	at org.apache.spark.sql.catalyst.analysis.Analyzer.executeAndCheck(Analyzer.scala:173)
	at org.apache.spark.sql.execution.QueryExecution.$anonfun$analyzed$1(QueryExecution.scala:73)
	at org.apache.spark.sql.catalyst.QueryPlanningTracker.measurePhase(QueryPlanningTracker.scala:111)
	at org.apache.spark.sql.execution.QueryExecution.$anonfun$executePhase$1(QueryExecution.scala:143)
	at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:775)
	at org.apache.spark.sql.execution.QueryExecution.executePhase(QueryExecution.scala:143)
	at org.apache.spark.sql.execution.QueryExecution.analyzed$lzycompute(QueryExecution.scala:73)
	at org.apache.spark.sql.execution.QueryExecution.analyzed(QueryExecution.scala:71)
	at org.apache.spark.sql.execution.QueryExecution.assertAnalyzed(QueryExecution.scala:63)
	at org.apache.spark.sql.Dataset$.$anonfun$ofRows$2(Dataset.scala:98)
	at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:775)
	at org.apache.spark.sql.Dataset$.ofRows(Dataset.scala:96)
	at org.apache.spark.sql.SparkSession.$anonfun$sql$1(SparkSession.scala:618)
	at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:775)
	at org.apache.spark.sql.SparkSession.sql(SparkSession.scala:613)
	at org.apache.spark.sql.test.SQLTestUtilsBase.$anonfun$sql$1(SQLTestUtils.scala:231)
	at org.apache.spark.sql.KyuubiExtensionSuite.$anonfun$new$55(KyuubiExtensionSuite.scala:1331)
	at org.apache.spark.sql.catalyst.plans.SQLHelper.withSQLConf(SQLHelper.scala:54)
	at org.apache.spark.sql.catalyst.plans.SQLHelper.withSQLConf$(SQLHelper.scala:38)
	at
```
Reference related issue: https://issues.apache.org/jira/browse/SPARK-31519
Spark SQL Ananlyzer transform aggregate with having to
```
Filter
+- Aggregate
```

Solution:

1. Skip the aggregate with havingCondition
2. Match Filter for Adding Limit

### _How was this patch tested?_
- [x] Add some test cases that check the changes thoroughly including negative and positive cases if possible

- [x] Add screenshots for manual tests if appropriate
<img width="1440" alt="截屏2021-09-21 下午8 35 16" src="https://user-images.githubusercontent.com/635169/134171308-2842f0d4-acfa-4817-a03c-a7ef5e38df12.png">

- [x] [Run test](https://kyuubi.readthedocs.io/en/latest/develop_tools/testing.html#running-tests) locally before make a pull request

Closes #1129 from i7xh/fixAggWithHavingInMaxOutput.

Closes #1085

7577f4d3 [h] update
5955c89e [h] update
384b2333 [h] fix issue
5b0af156 [h] update
46327fc6 [h] Fix issue
6119a039 [h] fix issue
a7b87dd7 [h] Fix issue
2570444e [h] BugFix: Aggregate with having statement

Authored-by: h <h@zhihu.com>
Signed-off-by: ulysses-you <ulyssesyou@apache.org>
This commit is contained in:
h 2021-10-11 09:49:33 +08:00 committed by ulysses-you
parent 1b1b898123
commit fdff2b6240
No known key found for this signature in database
GPG Key ID: 4C500BC62D576766
3 changed files with 274 additions and 78 deletions

View File

@ -20,7 +20,7 @@ package org.apache.kyuubi.sql
import org.apache.spark.sql.SparkSessionExtensions
import org.apache.kyuubi.sql.sqlclassification.KyuubiSqlClassification
import org.apache.kyuubi.sql.watchdog.{ForcedMaxOutputRowsRule, MaxHivePartitionStrategy}
import org.apache.kyuubi.sql.watchdog.{ForcedMaxOutputRowsRule, MarkAggregateOrderRule, MaxHivePartitionStrategy}
import org.apache.kyuubi.sql.zorder.{InsertZorderBeforeWritingDatasource, InsertZorderBeforeWritingHive}
import org.apache.kyuubi.sql.zorder.ResolveZorder
import org.apache.kyuubi.sql.zorder.ZorderSparkSqlExtensionsParser
@ -43,16 +43,18 @@ class KyuubiSparkSQLExtension extends (SparkSessionExtensions => Unit) {
// should be applied before
// RepartitionBeforeWrite and RepartitionBeforeWriteHive
// because we can only apply one of them (i.e. Global Sort or Repartition)
extensions.injectResolutionRule(MarkAggregateOrderRule)
extensions.injectPostHocResolutionRule(InsertZorderBeforeWritingDatasource)
extensions.injectPostHocResolutionRule(InsertZorderBeforeWritingHive)
extensions.injectPostHocResolutionRule(KyuubiSqlClassification)
extensions.injectPostHocResolutionRule(RepartitionBeforeWrite)
extensions.injectPostHocResolutionRule(RepartitionBeforeWriteHive)
extensions.injectPostHocResolutionRule(FinalStageConfigIsolationCleanRule)
extensions.injectPostHocResolutionRule(ForcedMaxOutputRowsRule)
extensions.injectQueryStagePrepRule(_ => InsertShuffleNodeBeforeJoin)
extensions.injectQueryStagePrepRule(FinalStageConfigIsolation(_))
extensions.injectPlannerStrategy(MaxHivePartitionStrategy)
extensions.injectPostHocResolutionRule(ForcedMaxOutputRowsRule)
}
}

View File

@ -19,11 +19,18 @@ package org.apache.kyuubi.sql.watchdog
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Limit, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.expressions.Alias
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Distinct, Filter, Limit, LogicalPlan, Project, Sort, Union}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.kyuubi.sql.KyuubiSQLConf
object ForcedMaxOutputRowsConstraint {
val CHILD_AGGREGATE: TreeNodeTag[String] = TreeNodeTag[String]("__kyuubi_child_agg__")
val CHILD_AGGREGATE_FLAG: String = "__kyuubi_child_agg__"
}
/*
* Add ForcedMaxOutputRows rule for output rows limitation
* to avoid huge output rows of non_limit query unexpectedly
@ -45,19 +52,31 @@ import org.apache.kyuubi.sql.KyuubiSQLConf
* */
case class ForcedMaxOutputRowsRule(session: SparkSession) extends Rule[LogicalPlan] {
private def isChildAggregate(a: Aggregate): Boolean = a
.aggregateExpressions.exists(p => p.getTagValue(ForcedMaxOutputRowsConstraint.CHILD_AGGREGATE)
.contains(ForcedMaxOutputRowsConstraint.CHILD_AGGREGATE_FLAG))
private def canInsertLimitInner(p: LogicalPlan): Boolean = p match {
case Aggregate(_, Alias(_, "havingCondition")::Nil, _) => false
case agg: Aggregate => !isChildAggregate(agg)
case _: Distinct => true
case _: Filter => true
case _: Project => true
case Limit(_, _) => true
case _: Sort => true
case _: Union => true
case _ => false
}
private def canInsertLimit(p: LogicalPlan, maxOutputRowsOpt: Option[Int]): Boolean = {
maxOutputRowsOpt match {
case Some(forcedMaxOutputRows) => val supported = p match {
case _: Project => true
case _: Aggregate => true
case Limit(_, _) => true
case _ => false
}
supported && !p.maxRows.exists(_ <= forcedMaxOutputRows)
case Some(forcedMaxOutputRows) => canInsertLimitInner(p) &&
!p.maxRows.exists(_ <= forcedMaxOutputRows)
case None => false
}
}
override def apply(plan: LogicalPlan): LogicalPlan = {
@ -70,3 +89,41 @@ case class ForcedMaxOutputRowsRule(session: SparkSession) extends Rule[LogicalPl
}
}
case class MarkAggregateOrderRule(session: SparkSession) extends Rule[LogicalPlan] {
private def markChildAggregate(a: Aggregate): Unit = {
// mark child aggregate
a.aggregateExpressions.filter(_.resolved).foreach(_.setTagValue(
ForcedMaxOutputRowsConstraint.CHILD_AGGREGATE,
ForcedMaxOutputRowsConstraint.CHILD_AGGREGATE_FLAG)
)
}
private def findAndMarkChildAggregate(plan: LogicalPlan): LogicalPlan = plan match {
/*
* The case mainly process order not aggregate column but grouping column as below
* SELECT c1, COUNT(*) as cnt
* FROM t1
* GROUP BY c1
* ORDER BY c1
* */
case a: Aggregate if a.aggregateExpressions
.exists(x => x.resolved && x.name.equals("aggOrder")) => markChildAggregate(a)
plan
case _ => plan.children.foreach(_.foreach {
case agg: Aggregate => markChildAggregate(agg)
case _ => Unit
}
)
plan
}
override def apply(plan: LogicalPlan): LogicalPlan = conf.getConf(
KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS
) match {
case Some(_) => findAndMarkChildAggregate(plan)
case _ => plan
}
}

View File

@ -23,6 +23,10 @@ import org.apache.kyuubi.sql.KyuubiSQLConf
import org.apache.kyuubi.sql.watchdog.MaxHivePartitionExceedException
class WatchDogSuite extends KyuubiSparkSQLExtensionTest {
case class LimitAndExpected(limit: Int, expected: Int)
val limitAndExpecteds = List(LimitAndExpected(1, 1), LimitAndExpected(11, 10))
test("test watchdog with scan maxHivePartitions") {
withTable("test", "temp") {
sql(
@ -59,60 +63,117 @@ class WatchDogSuite extends KyuubiSparkSQLExtensionTest {
}
}
test("test watchdog with query forceMaxOutputRows") {
test("test watchdog: simple SELECT STATEMENT") {
withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") {
assert(sql("SELECT * FROM t1")
.queryExecution.analyzed.isInstanceOf[GlobalLimit])
List("", "ORDER BY c1", "ORDER BY c2").foreach { sort =>
List("", " DISTINCT").foreach{ distinct =>
assert(sql(
s"""
|SELECT $distinct *
|FROM t1
|$sort
|""".stripMargin).queryExecution.analyzed.isInstanceOf[GlobalLimit])
}
}
assert(sql("SELECT * FROM t1 LIMIT 1")
.queryExecution.analyzed.asInstanceOf[GlobalLimit].maxRows.contains(1))
limitAndExpecteds.foreach { case LimitAndExpected(limit, expected) =>
List("", "ORDER BY c1", "ORDER BY c2").foreach { sort =>
List("", "DISTINCT").foreach{ distinct =>
assert(sql(
s"""
|SELECT $distinct *
|FROM t1
|$sort
|LIMIT $limit
|""".stripMargin).queryExecution.analyzed.maxRows.contains(expected)
)
}
}
}
}
}
assert(sql("SELECT * FROM t1 LIMIT 11")
.queryExecution.analyzed.asInstanceOf[GlobalLimit].maxRows.contains(10))
test("test watchdog: SELECT ... WITH AGGREGATE STATEMENT ") {
withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") {
assert(!sql("SELECT count(*) FROM t1")
.queryExecution.analyzed.isInstanceOf[GlobalLimit])
assert(sql(
"""
|SELECT c1, COUNT(*)
|FROM t1
|GROUP BY c1
|""".stripMargin).queryExecution.analyzed.isInstanceOf[GlobalLimit])
val sorts = List("", "ORDER BY cnt", "ORDER BY c1", "ORDER BY cnt, c1", "ORDER BY c1, cnt")
val havingConditions = List("", "HAVING cnt > 1")
assert(sql(
"""
|WITH custom_cte AS (
|SELECT * FROM t1
|)
|
|SELECT * FROM custom_cte
|""".stripMargin).queryExecution
.analyzed.isInstanceOf[GlobalLimit])
havingConditions.foreach { having =>
sorts.foreach { sort =>
assert(sql(
s"""
|SELECT c1, COUNT(*) as cnt
|FROM t1
|GROUP BY c1
|$having
|$sort
|""".stripMargin).queryExecution.analyzed.isInstanceOf[GlobalLimit])
}
}
assert(sql(
"""
|WITH custom_cte AS (
|SELECT * FROM t1
|)
|
|SELECT * FROM custom_cte
|LIMIT 1
|""".stripMargin).queryExecution
.analyzed.asInstanceOf[GlobalLimit].maxRows.contains(1))
limitAndExpecteds.foreach{ case LimitAndExpected(limit, expected) =>
havingConditions.foreach { having =>
sorts.foreach { sort =>
assert(sql(
s"""
|SELECT c1, COUNT(*) as cnt
|FROM t1
|GROUP BY c1
|$having
|$sort
|LIMIT $limit
|""".stripMargin).queryExecution.analyzed.maxRows.contains(expected))
}
}
}
}
}
assert(sql(
"""
|WITH custom_cte AS (
|SELECT * FROM t1
|)
|
|SELECT * FROM custom_cte
|LIMIT 11
|""".stripMargin).queryExecution
.analyzed.asInstanceOf[GlobalLimit].maxRows.contains(10))
test("test watchdog: SELECT with CTE forceMaxOutputRows") {
withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") {
val sorts = List("", "ORDER BY c1", "ORDER BY c2")
sorts.foreach { sort =>
assert(sql(
s"""
|WITH custom_cte AS (
|SELECT * FROM t1
|)
|SELECT *
|FROM custom_cte
|$sort
|""".stripMargin).queryExecution.analyzed.isInstanceOf[GlobalLimit])
}
limitAndExpecteds.foreach { case LimitAndExpected(limit, expected) =>
sorts.foreach { sort =>
assert(sql(
s"""
|WITH custom_cte AS (
|SELECT * FROM t1
|)
|SELECT *
|FROM custom_cte
|$sort
|LIMIT $limit
|""".stripMargin).queryExecution.analyzed.maxRows.contains(expected))
}
}
}
}
test("test watchdog: SELECT AGGREGATE WITH CTE forceMaxOutputRows") {
withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") {
assert(!sql(
"""
@ -120,34 +181,110 @@ class WatchDogSuite extends KyuubiSparkSQLExtensionTest {
|SELECT * FROM t1
|)
|
|SELECT COUNT(*) FROM custom_cte
|SELECT COUNT(*)
|FROM custom_cte
|""".stripMargin).queryExecution
.analyzed.isInstanceOf[GlobalLimit])
assert(sql(
"""
|WITH custom_cte AS (
|SELECT * FROM t1
|)
|
|SELECT c1, COUNT(*)
|FROM custom_cte
|GROUP BY c1
|""".stripMargin).queryExecution
.analyzed.isInstanceOf[GlobalLimit])
val sorts = List("", "ORDER BY cnt", "ORDER BY c1", "ORDER BY cnt, c1", "ORDER BY c1, cnt")
val havingConditions = List("", "HAVING cnt > 1")
assert(sql(
"""
|WITH custom_cte AS (
|SELECT * FROM t1
|)
|
|SELECT c1, COUNT(*)
|FROM custom_cte
|GROUP BY c1
|LIMIT 11
|""".stripMargin).queryExecution
.analyzed.asInstanceOf[GlobalLimit].maxRows.contains(10))
havingConditions.foreach { having =>
sorts.foreach { sort =>
assert(sql(
s"""
|WITH custom_cte AS (
|SELECT * FROM t1
|)
|
|SELECT c1, COUNT(*) as cnt
|FROM custom_cte
|GROUP BY c1
|$having
|$sort
|""".stripMargin).queryExecution.analyzed.isInstanceOf[GlobalLimit])
}
}
limitAndExpecteds.foreach { case LimitAndExpected(limit, expected) =>
havingConditions.foreach { having =>
sorts.foreach { sort =>
assert(sql(
s"""
|WITH custom_cte AS (
|SELECT * FROM t1
|)
|
|SELECT c1, COUNT(*) as cnt
|FROM custom_cte
|GROUP BY c1
|$having
|$sort
|LIMIT $limit
|""".stripMargin).queryExecution.analyzed.maxRows.contains(expected))
}
}
}
}
}
test("test watchdog: UNION Statement for forceMaxOutputRows") {
withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") {
List("", "ALL").foreach { x =>
assert(sql(
s"""
|SELECT c1, c2 FROM t1
|UNION $x
|SELECT c1, c2 FROM t2
|UNION $x
|SELECT c1, c2 FROM t3
|""".stripMargin)
.queryExecution.analyzed.isInstanceOf[GlobalLimit])
}
val sorts = List("", "ORDER BY cnt", "ORDER BY c1", "ORDER BY cnt, c1", "ORDER BY c1, cnt")
val havingConditions = List("", "HAVING cnt > 1")
List("", "ALL").foreach { x =>
havingConditions.foreach{ having =>
sorts.foreach { sort =>
assert(sql(
s"""
|SELECT c1, count(c2) as cnt
|FROM t1
|GROUP BY c1
|$having
|UNION $x
|SELECT c1, COUNT(c2) as cnt
|FROM t2
|GROUP BY c1
|$having
|UNION $x
|SELECT c1, COUNT(c2) as cnt
|FROM t3
|GROUP BY c1
|$having
|$sort
|""".stripMargin)
.queryExecution.analyzed.isInstanceOf[GlobalLimit])
}
}
}
limitAndExpecteds.foreach { case LimitAndExpected(limit, expected) =>
assert(sql(
s"""
|SELECT c1, c2 FROM t1
|UNION
|SELECT c1, c2 FROM t2
|UNION
|SELECT c1, c2 FROM t3
|LIMIT $limit
|""".stripMargin)
.queryExecution.analyzed.maxRows.contains(expected))
}
}
}
}