From 5ea0964e26c6bbcad58e407cf52287d43565cd9c Mon Sep 17 00:00:00 2001 From: odone Date: Wed, 26 Jan 2022 09:32:54 +0800 Subject: [PATCH] [KYUUBI #1832] Fixed: forcedMaxOutputRows extension for subquery ### _Why are the changes needed?_ Fixed #1832 ### _How was this patch tested?_ - [x] Add some test cases that check the changes thoroughly including negative and positive cases if possible - [ ] Add screenshots for manual tests if appropriate - [ ] [Run test](https://kyuubi.apache.org/docs/latest/develop_tools/testing.html#running-tests) locally before make a pull request Closes #1833 from iodone/1832-bug. Closes #1832 52769746 [odone] [KYUUBI #1832] Added: Add some test case 61d733dc [odone] [KYUUBI #1832] fixed: spotless d8ee657a [odone] [KYUUBI #1832] fixed: forcedMaxOutputRows extension for subquery efd87a75 [odone] [KYUUBI #1832] fixed: forcedMaxOutputRows extension for subquery Authored-by: odone Signed-off-by: ulysses-you --- .../kyuubi/sql/KyuubiSparkSQLExtension.scala | 6 +- .../watchdog/ForcedMaxOutputRowsRule.scala | 55 +--------------- .../kyuubi/sql/KyuubiSparkSQLExtension.scala | 2 +- .../watchdog/ForcedMaxOutputRowsRule.scala | 12 +++- .../watchdog/ForcedMaxOutputRowsBase.scala | 3 +- .../apache/spark/sql/WatchDogSuiteBase.scala | 65 ++++++++++++++----- 6 files changed, 65 insertions(+), 78 deletions(-) diff --git a/dev/kyuubi-extension-spark-3-1/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala b/dev/kyuubi-extension-spark-3-1/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala index f6b1ef0f7..cd312de95 100644 --- a/dev/kyuubi-extension-spark-3-1/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala +++ b/dev/kyuubi-extension-spark-3-1/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala @@ -20,7 +20,7 @@ package org.apache.kyuubi.sql import org.apache.spark.sql.SparkSessionExtensions import org.apache.kyuubi.sql.sqlclassification.KyuubiSqlClassification -import org.apache.kyuubi.sql.watchdog.{ForcedMaxOutputRowsRule, MarkAggregateOrderRule, MaxPartitionStrategy} +import org.apache.kyuubi.sql.watchdog.{ForcedMaxOutputRowsRule, MaxPartitionStrategy} // scalastyle:off line.size.limit /** @@ -39,9 +39,7 @@ class KyuubiSparkSQLExtension extends (SparkSessionExtensions => Unit) { extensions.injectPostHocResolutionRule(DropIgnoreNonexistent) // watchdog extension - // a help rule for ForcedMaxOutputRowsRule - extensions.injectResolutionRule(MarkAggregateOrderRule) - extensions.injectPostHocResolutionRule(ForcedMaxOutputRowsRule) + extensions.injectOptimizerRule(ForcedMaxOutputRowsRule) extensions.injectPlannerStrategy(MaxPartitionStrategy) } } diff --git a/dev/kyuubi-extension-spark-3-1/src/main/scala/org/apache/kyuubi/sql/watchdog/ForcedMaxOutputRowsRule.scala b/dev/kyuubi-extension-spark-3-1/src/main/scala/org/apache/kyuubi/sql/watchdog/ForcedMaxOutputRowsRule.scala index e92a69f71..15d1a47d7 100644 --- a/dev/kyuubi-extension-spark-3-1/src/main/scala/org/apache/kyuubi/sql/watchdog/ForcedMaxOutputRowsRule.scala +++ b/dev/kyuubi-extension-spark-3-1/src/main/scala/org/apache/kyuubi/sql/watchdog/ForcedMaxOutputRowsRule.scala @@ -19,60 +19,7 @@ package org.apache.kyuubi.sql.watchdog import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.trees.TreeNodeTag - -import org.apache.kyuubi.sql.KyuubiSQLConf - -object ForcedMaxOutputRowsConstraint { - val CHILD_AGGREGATE: TreeNodeTag[String] = TreeNodeTag[String]("__kyuubi_child_agg__") - val CHILD_AGGREGATE_FLAG: String = "__kyuubi_child_agg__" -} case class ForcedMaxOutputRowsRule(sparkSession: SparkSession) extends ForcedMaxOutputRowsBase { - override protected def isChildAggregate(a: Aggregate): Boolean = - a.aggregateExpressions.exists(p => - p.getTagValue(ForcedMaxOutputRowsConstraint.CHILD_AGGREGATE) - .contains(ForcedMaxOutputRowsConstraint.CHILD_AGGREGATE_FLAG)) -} - -/** - * After SPARK-35712, we don't need mark child aggregate for spark 3.2.x or higher version, - * for more detail, please see https://github.com/apache/spark/pull/32470 - */ -case class MarkAggregateOrderRule(sparkSession: SparkSession) extends Rule[LogicalPlan] { - - private def markChildAggregate(a: Aggregate): Unit = { - // mark child aggregate - a.aggregateExpressions.filter(_.resolved).foreach(_.setTagValue( - ForcedMaxOutputRowsConstraint.CHILD_AGGREGATE, - ForcedMaxOutputRowsConstraint.CHILD_AGGREGATE_FLAG)) - } - - protected def findAndMarkChildAggregate(plan: LogicalPlan): LogicalPlan = plan match { - /* - * The case mainly process order not aggregate column but grouping column as below - * SELECT c1, COUNT(*) as cnt - * FROM t1 - * GROUP BY c1 - * ORDER BY c1 - * */ - case a: Aggregate - if a.aggregateExpressions - .exists(x => x.resolved && x.name.equals("aggOrder")) => - markChildAggregate(a) - plan - case _ => - plan.children.foreach(_.foreach { - case agg: Aggregate => markChildAggregate(agg) - case _ => Unit - }) - plan - } - - override def apply(plan: LogicalPlan): LogicalPlan = conf.getConf( - KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS) match { - case Some(_) => findAndMarkChildAggregate(plan) - case _ => plan - } + override protected def isChildAggregate(a: Aggregate): Boolean = false } diff --git a/dev/kyuubi-extension-spark-3-2/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala b/dev/kyuubi-extension-spark-3-2/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala index 38426b8e6..ef9da41be 100644 --- a/dev/kyuubi-extension-spark-3-2/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala +++ b/dev/kyuubi-extension-spark-3-2/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala @@ -37,7 +37,7 @@ class KyuubiSparkSQLExtension extends (SparkSessionExtensions => Unit) { extensions.injectPostHocResolutionRule(DropIgnoreNonexistent) // watchdog extension - extensions.injectPostHocResolutionRule(ForcedMaxOutputRowsRule) + extensions.injectOptimizerRule(ForcedMaxOutputRowsRule) extensions.injectPlannerStrategy(MaxPartitionStrategy) } } diff --git a/dev/kyuubi-extension-spark-3-2/src/main/scala/org/apache/kyuubi/sql/watchdog/ForcedMaxOutputRowsRule.scala b/dev/kyuubi-extension-spark-3-2/src/main/scala/org/apache/kyuubi/sql/watchdog/ForcedMaxOutputRowsRule.scala index 03d243f85..a3d990b10 100644 --- a/dev/kyuubi-extension-spark-3-2/src/main/scala/org/apache/kyuubi/sql/watchdog/ForcedMaxOutputRowsRule.scala +++ b/dev/kyuubi-extension-spark-3-2/src/main/scala/org/apache/kyuubi/sql/watchdog/ForcedMaxOutputRowsRule.scala @@ -18,7 +18,8 @@ package org.apache.kyuubi.sql.watchdog import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, WithCTE} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, CommandResult, LogicalPlan, Union, WithCTE} +import org.apache.spark.sql.execution.command.DataWritingCommand case class ForcedMaxOutputRowsRule(sparkSession: SparkSession) extends ForcedMaxOutputRowsBase { @@ -26,7 +27,14 @@ case class ForcedMaxOutputRowsRule(sparkSession: SparkSession) extends ForcedMax override protected def canInsertLimitInner(p: LogicalPlan): Boolean = p match { case WithCTE(plan, _) => this.canInsertLimitInner(plan) - case plan: LogicalPlan => super.canInsertLimitInner(plan) + case plan: LogicalPlan => plan match { + case Union(children, _, _) => !children.exists { + case _: DataWritingCommand => true + case p: CommandResult if p.commandLogicalPlan.isInstanceOf[DataWritingCommand] => true + case _ => false + } + case _ => super.canInsertLimitInner(plan) + } } override protected def canInsertLimit(p: LogicalPlan, maxOutputRowsOpt: Option[Int]): Boolean = { diff --git a/dev/kyuubi-extension-spark-common/src/main/scala/org/apache/kyuubi/sql/watchdog/ForcedMaxOutputRowsBase.scala b/dev/kyuubi-extension-spark-common/src/main/scala/org/apache/kyuubi/sql/watchdog/ForcedMaxOutputRowsBase.scala index 7f846d961..e9fd9fe06 100644 --- a/dev/kyuubi-extension-spark-common/src/main/scala/org/apache/kyuubi/sql/watchdog/ForcedMaxOutputRowsBase.scala +++ b/dev/kyuubi-extension-spark-common/src/main/scala/org/apache/kyuubi/sql/watchdog/ForcedMaxOutputRowsBase.scala @@ -17,7 +17,7 @@ package org.apache.kyuubi.sql.watchdog -import org.apache.spark.sql.catalyst.analysis.AnalysisContext +import org.apache.spark.sql.catalyst.analysis.{AnalysisContext, MultiInstanceRelation} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.Alias import org.apache.spark.sql.catalyst.plans.logical._ @@ -69,6 +69,7 @@ trait ForcedMaxOutputRowsBase extends Rule[LogicalPlan] { } else { true } + case _: MultiInstanceRelation => true case _ => false } diff --git a/dev/kyuubi-extension-spark-common/src/test/scala/org/apache/spark/sql/WatchDogSuiteBase.scala b/dev/kyuubi-extension-spark-common/src/test/scala/org/apache/spark/sql/WatchDogSuiteBase.scala index a604308a5..07dfb4cec 100644 --- a/dev/kyuubi-extension-spark-common/src/test/scala/org/apache/spark/sql/WatchDogSuiteBase.scala +++ b/dev/kyuubi-extension-spark-common/src/test/scala/org/apache/spark/sql/WatchDogSuiteBase.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import org.apache.spark.sql.catalyst.plans.logical.GlobalLimit +import org.apache.spark.sql.catalyst.plans.logical.{GlobalLimit, LogicalPlan} import org.apache.kyuubi.sql.KyuubiSQLConf import org.apache.kyuubi.sql.watchdog.MaxPartitionExceedException @@ -29,6 +29,7 @@ trait WatchDogSuiteBase extends KyuubiSparkSQLExtensionTest { } case class LimitAndExpected(limit: Int, expected: Int) + val limitAndExpecteds = List(LimitAndExpected(1, 1), LimitAndExpected(11, 10)) private def checkMaxPartition: Unit = { @@ -100,7 +101,7 @@ trait WatchDogSuiteBase extends KyuubiSparkSQLExtensionTest { |SELECT $distinct * |FROM t1 |$sort - |""".stripMargin).queryExecution.analyzed.isInstanceOf[GlobalLimit]) + |""".stripMargin).queryExecution.optimizedPlan.isInstanceOf[GlobalLimit]) } } @@ -113,7 +114,7 @@ trait WatchDogSuiteBase extends KyuubiSparkSQLExtensionTest { |FROM t1 |$sort |LIMIT $limit - |""".stripMargin).queryExecution.analyzed.maxRows.contains(expected)) + |""".stripMargin).queryExecution.optimizedPlan.maxRows.contains(expected)) } } } @@ -125,7 +126,7 @@ trait WatchDogSuiteBase extends KyuubiSparkSQLExtensionTest { withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") { assert(!sql("SELECT count(*) FROM t1") - .queryExecution.analyzed.isInstanceOf[GlobalLimit]) + .queryExecution.optimizedPlan.isInstanceOf[GlobalLimit]) val sorts = List("", "ORDER BY cnt", "ORDER BY c1", "ORDER BY cnt, c1", "ORDER BY c1, cnt") val havingConditions = List("", "HAVING cnt > 1") @@ -139,7 +140,7 @@ trait WatchDogSuiteBase extends KyuubiSparkSQLExtensionTest { |GROUP BY c1 |$having |$sort - |""".stripMargin).queryExecution.analyzed.isInstanceOf[GlobalLimit]) + |""".stripMargin).queryExecution.optimizedPlan.isInstanceOf[GlobalLimit]) } } @@ -154,7 +155,7 @@ trait WatchDogSuiteBase extends KyuubiSparkSQLExtensionTest { |$having |$sort |LIMIT $limit - |""".stripMargin).queryExecution.analyzed.maxRows.contains(expected)) + |""".stripMargin).queryExecution.optimizedPlan.maxRows.contains(expected)) } } } @@ -191,7 +192,7 @@ trait WatchDogSuiteBase extends KyuubiSparkSQLExtensionTest { |$withQuery |SELECT * FROM t2 |$sort - |""".stripMargin).queryExecution.analyzed.isInstanceOf[GlobalLimit]) + |""".stripMargin).queryExecution.optimizedPlan.isInstanceOf[GlobalLimit]) } } @@ -242,7 +243,7 @@ trait WatchDogSuiteBase extends KyuubiSparkSQLExtensionTest { |GROUP BY c1 |$having |$sort - |""".stripMargin).queryExecution.analyzed.isInstanceOf[GlobalLimit]) + |""".stripMargin).queryExecution.optimizedPlan.isInstanceOf[GlobalLimit]) } } @@ -281,7 +282,7 @@ trait WatchDogSuiteBase extends KyuubiSparkSQLExtensionTest { |UNION $x |SELECT c1, c2 FROM t3 |""".stripMargin) - .queryExecution.analyzed.isInstanceOf[GlobalLimit]) + .queryExecution.optimizedPlan.isInstanceOf[GlobalLimit]) } val sorts = List("", "ORDER BY cnt", "ORDER BY c1", "ORDER BY cnt, c1", "ORDER BY c1, cnt") @@ -308,7 +309,7 @@ trait WatchDogSuiteBase extends KyuubiSparkSQLExtensionTest { |$having |$sort |""".stripMargin) - .queryExecution.analyzed.isInstanceOf[GlobalLimit]) + .queryExecution.optimizedPlan.isInstanceOf[GlobalLimit]) } } } @@ -323,7 +324,7 @@ trait WatchDogSuiteBase extends KyuubiSparkSQLExtensionTest { |SELECT c1, c2 FROM t3 |LIMIT $limit |""".stripMargin) - .queryExecution.analyzed.maxRows.contains(expected)) + .queryExecution.optimizedPlan.maxRows.contains(expected)) } } } @@ -344,14 +345,14 @@ trait WatchDogSuiteBase extends KyuubiSparkSQLExtensionTest { |SELECT * FROM |tmp_table |""".stripMargin) - .queryExecution.analyzed.isInstanceOf[GlobalLimit]) + .queryExecution.optimizedPlan.isInstanceOf[GlobalLimit]) assert(sql( s""" |SELECT * FROM |tmp_view |""".stripMargin) - .queryExecution.analyzed.maxRows.contains(3)) + .queryExecution.optimizedPlan.maxRows.contains(3)) assert(sql( s""" @@ -359,7 +360,7 @@ trait WatchDogSuiteBase extends KyuubiSparkSQLExtensionTest { |tmp_view |limit 11 |""".stripMargin) - .queryExecution.analyzed.maxRows.contains(3)) + .queryExecution.optimizedPlan.maxRows.contains(3)) assert(sql( s""" @@ -394,7 +395,7 @@ trait WatchDogSuiteBase extends KyuubiSparkSQLExtensionTest { |insert into $multiInsertTableName1 select * limit 2 |insert into $multiInsertTableName2 select * |""".stripMargin) - .queryExecution.analyzed.isInstanceOf[GlobalLimit]) + .queryExecution.optimizedPlan.isInstanceOf[GlobalLimit]) } } } @@ -411,8 +412,40 @@ trait WatchDogSuiteBase extends KyuubiSparkSQLExtensionTest { |FROM tmp_table |DISTRIBUTE BY a |""".stripMargin) - .queryExecution.analyzed.isInstanceOf[GlobalLimit]) + .queryExecution.optimizedPlan.isInstanceOf[GlobalLimit]) } } } + + test("test watchdog: Subquery for forceMaxOutputRows") { + withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "1") { + withTable("tmp_table1") { + sql("CREATE TABLE spark_catalog.`default`.tmp_table1(KEY INT, VALUE STRING) USING PARQUET") + sql("INSERT INTO TABLE spark_catalog.`default`.tmp_table1 " + + "VALUES (1, 'aa'),(2,'bb'),(3, 'cc'),(4,'aa'),(5,'cc'),(6, 'aa')") + assert( + sql("select * from tmp_table1").queryExecution.optimizedPlan.isInstanceOf[GlobalLimit]) + val testSqlText = + """ + |select count(*) + |from tmp_table1 + |where tmp_table1.key in ( + |select distinct tmp_table1.key + |from tmp_table1 + |where tmp_table1.value = "aa" + |) + |""".stripMargin + val plan = sql(testSqlText).queryExecution.optimizedPlan + assert(!findGlobalLimit(plan)) + checkAnswer(sql(testSqlText), Row(3) :: Nil) + } + + def findGlobalLimit(plan: LogicalPlan): Boolean = plan match { + case _: GlobalLimit => true + case p if p.children.isEmpty => false + case p => p.children.exists(findGlobalLimit) + } + + } + } }