diff --git a/dev/kyuubi-extension-spark-3-1/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala b/dev/kyuubi-extension-spark-3-1/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala index c11d65c39..7e4c78053 100644 --- a/dev/kyuubi-extension-spark-3-1/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala +++ b/dev/kyuubi-extension-spark-3-1/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala @@ -20,7 +20,7 @@ package org.apache.kyuubi.sql import org.apache.spark.sql.SparkSessionExtensions import org.apache.kyuubi.sql.sqlclassification.KyuubiSqlClassification -import org.apache.kyuubi.sql.watchdog.{ForcedMaxOutputRowsRule, MaxHivePartitionStrategy} +import org.apache.kyuubi.sql.watchdog.{ForcedMaxOutputRowsRule, MarkAggregateOrderRule, MaxHivePartitionStrategy} import org.apache.kyuubi.sql.zorder.{InsertZorderBeforeWritingDatasource, InsertZorderBeforeWritingHive} import org.apache.kyuubi.sql.zorder.ResolveZorder import org.apache.kyuubi.sql.zorder.ZorderSparkSqlExtensionsParser @@ -43,16 +43,18 @@ class KyuubiSparkSQLExtension extends (SparkSessionExtensions => Unit) { // should be applied before // RepartitionBeforeWrite and RepartitionBeforeWriteHive // because we can only apply one of them (i.e. Global Sort or Repartition) + extensions.injectResolutionRule(MarkAggregateOrderRule) + extensions.injectPostHocResolutionRule(InsertZorderBeforeWritingDatasource) extensions.injectPostHocResolutionRule(InsertZorderBeforeWritingHive) - extensions.injectPostHocResolutionRule(KyuubiSqlClassification) extensions.injectPostHocResolutionRule(RepartitionBeforeWrite) extensions.injectPostHocResolutionRule(RepartitionBeforeWriteHive) extensions.injectPostHocResolutionRule(FinalStageConfigIsolationCleanRule) + extensions.injectPostHocResolutionRule(ForcedMaxOutputRowsRule) + extensions.injectQueryStagePrepRule(_ => InsertShuffleNodeBeforeJoin) extensions.injectQueryStagePrepRule(FinalStageConfigIsolation(_)) extensions.injectPlannerStrategy(MaxHivePartitionStrategy) - extensions.injectPostHocResolutionRule(ForcedMaxOutputRowsRule) } } diff --git a/dev/kyuubi-extension-spark-3-1/src/main/scala/org/apache/kyuubi/sql/watchdog/ForcedMaxOutputRowsRule.scala b/dev/kyuubi-extension-spark-3-1/src/main/scala/org/apache/kyuubi/sql/watchdog/ForcedMaxOutputRowsRule.scala index c2e3ee4fc..d82eead64 100644 --- a/dev/kyuubi-extension-spark-3-1/src/main/scala/org/apache/kyuubi/sql/watchdog/ForcedMaxOutputRowsRule.scala +++ b/dev/kyuubi-extension-spark-3-1/src/main/scala/org/apache/kyuubi/sql/watchdog/ForcedMaxOutputRowsRule.scala @@ -19,11 +19,18 @@ package org.apache.kyuubi.sql.watchdog import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Limit, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.expressions.Alias +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Distinct, Filter, Limit, LogicalPlan, Project, Sort, Union} import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.kyuubi.sql.KyuubiSQLConf +object ForcedMaxOutputRowsConstraint { + val CHILD_AGGREGATE: TreeNodeTag[String] = TreeNodeTag[String]("__kyuubi_child_agg__") + val CHILD_AGGREGATE_FLAG: String = "__kyuubi_child_agg__" +} + /* * Add ForcedMaxOutputRows rule for output rows limitation * to avoid huge output rows of non_limit query unexpectedly @@ -45,19 +52,31 @@ import org.apache.kyuubi.sql.KyuubiSQLConf * */ case class ForcedMaxOutputRowsRule(session: SparkSession) extends Rule[LogicalPlan] { + private def isChildAggregate(a: Aggregate): Boolean = a + .aggregateExpressions.exists(p => p.getTagValue(ForcedMaxOutputRowsConstraint.CHILD_AGGREGATE) + .contains(ForcedMaxOutputRowsConstraint.CHILD_AGGREGATE_FLAG)) + + private def canInsertLimitInner(p: LogicalPlan): Boolean = p match { + + case Aggregate(_, Alias(_, "havingCondition")::Nil, _) => false + case agg: Aggregate => !isChildAggregate(agg) + case _: Distinct => true + case _: Filter => true + case _: Project => true + case Limit(_, _) => true + case _: Sort => true + case _: Union => true + case _ => false + + } + private def canInsertLimit(p: LogicalPlan, maxOutputRowsOpt: Option[Int]): Boolean = { maxOutputRowsOpt match { - case Some(forcedMaxOutputRows) => val supported = p match { - case _: Project => true - case _: Aggregate => true - case Limit(_, _) => true - case _ => false - } - supported && !p.maxRows.exists(_ <= forcedMaxOutputRows) + case Some(forcedMaxOutputRows) => canInsertLimitInner(p) && + !p.maxRows.exists(_ <= forcedMaxOutputRows) case None => false } - } override def apply(plan: LogicalPlan): LogicalPlan = { @@ -70,3 +89,41 @@ case class ForcedMaxOutputRowsRule(session: SparkSession) extends Rule[LogicalPl } } + +case class MarkAggregateOrderRule(session: SparkSession) extends Rule[LogicalPlan] { + + private def markChildAggregate(a: Aggregate): Unit = { + // mark child aggregate + a.aggregateExpressions.filter(_.resolved).foreach(_.setTagValue( + ForcedMaxOutputRowsConstraint.CHILD_AGGREGATE, + ForcedMaxOutputRowsConstraint.CHILD_AGGREGATE_FLAG) + ) + } + + private def findAndMarkChildAggregate(plan: LogicalPlan): LogicalPlan = plan match { + /* + * The case mainly process order not aggregate column but grouping column as below + * SELECT c1, COUNT(*) as cnt + * FROM t1 + * GROUP BY c1 + * ORDER BY c1 + * */ + case a: Aggregate if a.aggregateExpressions + .exists(x => x.resolved && x.name.equals("aggOrder")) => markChildAggregate(a) + plan + + case _ => plan.children.foreach(_.foreach { + case agg: Aggregate => markChildAggregate(agg) + case _ => Unit + } + ) + plan + } + + override def apply(plan: LogicalPlan): LogicalPlan = conf.getConf( + KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS + ) match { + case Some(_) => findAndMarkChildAggregate(plan) + case _ => plan + } +} diff --git a/dev/kyuubi-extension-spark-3-1/src/test/scala/org/apache/spark/sql/WatchDogSuite.scala b/dev/kyuubi-extension-spark-3-1/src/test/scala/org/apache/spark/sql/WatchDogSuite.scala index 483a2bf54..656e5bcbb 100644 --- a/dev/kyuubi-extension-spark-3-1/src/test/scala/org/apache/spark/sql/WatchDogSuite.scala +++ b/dev/kyuubi-extension-spark-3-1/src/test/scala/org/apache/spark/sql/WatchDogSuite.scala @@ -23,6 +23,10 @@ import org.apache.kyuubi.sql.KyuubiSQLConf import org.apache.kyuubi.sql.watchdog.MaxHivePartitionExceedException class WatchDogSuite extends KyuubiSparkSQLExtensionTest { + + case class LimitAndExpected(limit: Int, expected: Int) + val limitAndExpecteds = List(LimitAndExpected(1, 1), LimitAndExpected(11, 10)) + test("test watchdog with scan maxHivePartitions") { withTable("test", "temp") { sql( @@ -59,60 +63,117 @@ class WatchDogSuite extends KyuubiSparkSQLExtensionTest { } } - test("test watchdog with query forceMaxOutputRows") { + test("test watchdog: simple SELECT STATEMENT") { withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") { - assert(sql("SELECT * FROM t1") - .queryExecution.analyzed.isInstanceOf[GlobalLimit]) + List("", "ORDER BY c1", "ORDER BY c2").foreach { sort => + List("", " DISTINCT").foreach{ distinct => + assert(sql( + s""" + |SELECT $distinct * + |FROM t1 + |$sort + |""".stripMargin).queryExecution.analyzed.isInstanceOf[GlobalLimit]) + } + } - assert(sql("SELECT * FROM t1 LIMIT 1") - .queryExecution.analyzed.asInstanceOf[GlobalLimit].maxRows.contains(1)) + limitAndExpecteds.foreach { case LimitAndExpected(limit, expected) => + List("", "ORDER BY c1", "ORDER BY c2").foreach { sort => + List("", "DISTINCT").foreach{ distinct => + assert(sql( + s""" + |SELECT $distinct * + |FROM t1 + |$sort + |LIMIT $limit + |""".stripMargin).queryExecution.analyzed.maxRows.contains(expected) + ) + } + } + } + } + } - assert(sql("SELECT * FROM t1 LIMIT 11") - .queryExecution.analyzed.asInstanceOf[GlobalLimit].maxRows.contains(10)) + test("test watchdog: SELECT ... WITH AGGREGATE STATEMENT ") { + + withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") { assert(!sql("SELECT count(*) FROM t1") .queryExecution.analyzed.isInstanceOf[GlobalLimit]) - assert(sql( - """ - |SELECT c1, COUNT(*) - |FROM t1 - |GROUP BY c1 - |""".stripMargin).queryExecution.analyzed.isInstanceOf[GlobalLimit]) + val sorts = List("", "ORDER BY cnt", "ORDER BY c1", "ORDER BY cnt, c1", "ORDER BY c1, cnt") + val havingConditions = List("", "HAVING cnt > 1") - assert(sql( - """ - |WITH custom_cte AS ( - |SELECT * FROM t1 - |) - | - |SELECT * FROM custom_cte - |""".stripMargin).queryExecution - .analyzed.isInstanceOf[GlobalLimit]) + havingConditions.foreach { having => + sorts.foreach { sort => + assert(sql( + s""" + |SELECT c1, COUNT(*) as cnt + |FROM t1 + |GROUP BY c1 + |$having + |$sort + |""".stripMargin).queryExecution.analyzed.isInstanceOf[GlobalLimit]) + } + } - assert(sql( - """ - |WITH custom_cte AS ( - |SELECT * FROM t1 - |) - | - |SELECT * FROM custom_cte - |LIMIT 1 - |""".stripMargin).queryExecution - .analyzed.asInstanceOf[GlobalLimit].maxRows.contains(1)) + limitAndExpecteds.foreach{ case LimitAndExpected(limit, expected) => + havingConditions.foreach { having => + sorts.foreach { sort => + assert(sql( + s""" + |SELECT c1, COUNT(*) as cnt + |FROM t1 + |GROUP BY c1 + |$having + |$sort + |LIMIT $limit + |""".stripMargin).queryExecution.analyzed.maxRows.contains(expected)) + } + } + } + } + } - assert(sql( - """ - |WITH custom_cte AS ( - |SELECT * FROM t1 - |) - | - |SELECT * FROM custom_cte - |LIMIT 11 - |""".stripMargin).queryExecution - .analyzed.asInstanceOf[GlobalLimit].maxRows.contains(10)) + test("test watchdog: SELECT with CTE forceMaxOutputRows") { + + withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") { + + val sorts = List("", "ORDER BY c1", "ORDER BY c2") + + sorts.foreach { sort => + assert(sql( + s""" + |WITH custom_cte AS ( + |SELECT * FROM t1 + |) + |SELECT * + |FROM custom_cte + |$sort + |""".stripMargin).queryExecution.analyzed.isInstanceOf[GlobalLimit]) + } + + limitAndExpecteds.foreach { case LimitAndExpected(limit, expected) => + sorts.foreach { sort => + assert(sql( + s""" + |WITH custom_cte AS ( + |SELECT * FROM t1 + |) + |SELECT * + |FROM custom_cte + |$sort + |LIMIT $limit + |""".stripMargin).queryExecution.analyzed.maxRows.contains(expected)) + } + } + } + } + + test("test watchdog: SELECT AGGREGATE WITH CTE forceMaxOutputRows") { + + withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") { assert(!sql( """ @@ -120,34 +181,110 @@ class WatchDogSuite extends KyuubiSparkSQLExtensionTest { |SELECT * FROM t1 |) | - |SELECT COUNT(*) FROM custom_cte + |SELECT COUNT(*) + |FROM custom_cte |""".stripMargin).queryExecution .analyzed.isInstanceOf[GlobalLimit]) - assert(sql( - """ - |WITH custom_cte AS ( - |SELECT * FROM t1 - |) - | - |SELECT c1, COUNT(*) - |FROM custom_cte - |GROUP BY c1 - |""".stripMargin).queryExecution - .analyzed.isInstanceOf[GlobalLimit]) + val sorts = List("", "ORDER BY cnt", "ORDER BY c1", "ORDER BY cnt, c1", "ORDER BY c1, cnt") + val havingConditions = List("", "HAVING cnt > 1") - assert(sql( - """ - |WITH custom_cte AS ( - |SELECT * FROM t1 - |) - | - |SELECT c1, COUNT(*) - |FROM custom_cte - |GROUP BY c1 - |LIMIT 11 - |""".stripMargin).queryExecution - .analyzed.asInstanceOf[GlobalLimit].maxRows.contains(10)) + havingConditions.foreach { having => + sorts.foreach { sort => + assert(sql( + s""" + |WITH custom_cte AS ( + |SELECT * FROM t1 + |) + | + |SELECT c1, COUNT(*) as cnt + |FROM custom_cte + |GROUP BY c1 + |$having + |$sort + |""".stripMargin).queryExecution.analyzed.isInstanceOf[GlobalLimit]) + } + } + + limitAndExpecteds.foreach { case LimitAndExpected(limit, expected) => + havingConditions.foreach { having => + sorts.foreach { sort => + assert(sql( + s""" + |WITH custom_cte AS ( + |SELECT * FROM t1 + |) + | + |SELECT c1, COUNT(*) as cnt + |FROM custom_cte + |GROUP BY c1 + |$having + |$sort + |LIMIT $limit + |""".stripMargin).queryExecution.analyzed.maxRows.contains(expected)) + } + } + } + } + } + + test("test watchdog: UNION Statement for forceMaxOutputRows") { + + withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") { + + List("", "ALL").foreach { x => + assert(sql( + s""" + |SELECT c1, c2 FROM t1 + |UNION $x + |SELECT c1, c2 FROM t2 + |UNION $x + |SELECT c1, c2 FROM t3 + |""".stripMargin) + .queryExecution.analyzed.isInstanceOf[GlobalLimit]) + } + + val sorts = List("", "ORDER BY cnt", "ORDER BY c1", "ORDER BY cnt, c1", "ORDER BY c1, cnt") + val havingConditions = List("", "HAVING cnt > 1") + + List("", "ALL").foreach { x => + havingConditions.foreach{ having => + sorts.foreach { sort => + assert(sql( + s""" + |SELECT c1, count(c2) as cnt + |FROM t1 + |GROUP BY c1 + |$having + |UNION $x + |SELECT c1, COUNT(c2) as cnt + |FROM t2 + |GROUP BY c1 + |$having + |UNION $x + |SELECT c1, COUNT(c2) as cnt + |FROM t3 + |GROUP BY c1 + |$having + |$sort + |""".stripMargin) + .queryExecution.analyzed.isInstanceOf[GlobalLimit]) + } + } + } + + limitAndExpecteds.foreach { case LimitAndExpected(limit, expected) => + assert(sql( + s""" + |SELECT c1, c2 FROM t1 + |UNION + |SELECT c1, c2 FROM t2 + |UNION + |SELECT c1, c2 FROM t3 + |LIMIT $limit + |""".stripMargin) + .queryExecution.analyzed.maxRows.contains(expected)) + } } } }