[KYUUBI #1591] Watchdog support for Spark-3.2

<!--
Thanks for sending a pull request!

Here are some tips for you:
  1. If this is your first time, please read our contributor guidelines: https://kyuubi.readthedocs.io/en/latest/community/contributions.html
  2. If the PR is related to an issue in https://github.com/apache/incubator-kyuubi/issues, add '[KYUUBI #XXXX]' in your PR title, e.g., '[KYUUBI #XXXX] Your PR title ...'.
  3. If the PR is unfinished, add '[WIP]' in your PR title, e.g., '[WIP][KYUUBI #XXXX] Your PR title ...'.
-->

### _Why are the changes needed?_
<!--
Please clarify why the changes are needed. For instance,
  1. If you add a feature, you can talk about the use case of it.
  2. If you fix a bug, you can clarify why it is a bug.
-->
<!-- 2. move spark-3.1 `MarkAggregateOrderRule` to spark-common and rename to `MarkAggregateOrderBase` -->

1. move spark-3.1 `ForcedMaxOutputRowsRule` to spark-common and rename to `ForcedMaxOutputRowsBase`
2. handle `WithCTE` logical plan in spark-3.2
3. move spark-3.1 `MaxPartitionStrategy` to spark-common
4. add netsted cte unit test for `ForcedMaxOutputRowsRule`

### _How was this patch tested?_
- [ ] Add some test cases that check the changes thoroughly including negative and positive cases if possible

- [ ] Add screenshots for manual tests if appropriate

- [x] [Run test](https://kyuubi.readthedocs.io/en/latest/develop_tools/testing.html#running-tests) locally before make a pull request

Closes #1591 from cfmcgrady/watchdog-spark32.

Closes #1591

5399a3f0 [Fu Chen] fix style
0ce83ba5 [Fu Chen] remove MarkAggregateOrderRule
364fc26e [Fu Chen] add license header
44726dee [Fu Chen] fix style
4847dbf3 [Fu Chen] watchdog support for spark-3.2

Authored-by: Fu Chen <cfmcgrady@gmail.com>
Signed-off-by: ulysses-you <ulyssesyou@apache.org>
This commit is contained in:
Fu Chen 2021-12-21 12:17:30 +08:00 committed by ulysses-you
parent d529402168
commit 2b8304d154
No known key found for this signature in database
GPG Key ID: 4C500BC62D576766
11 changed files with 592 additions and 466 deletions

View File

@ -32,15 +32,16 @@ import org.apache.kyuubi.sql.watchdog.{ForcedMaxOutputRowsRule, MarkAggregateOrd
class KyuubiSparkSQLExtension extends (SparkSessionExtensions => Unit) {
override def apply(extensions: SparkSessionExtensions): Unit = {
KyuubiSparkSQLCommonExtension.injectCommonExtensions(extensions)
// a help rule for ForcedMaxOutputRowsRule
extensions.injectResolutionRule(MarkAggregateOrderRule)
extensions.injectPostHocResolutionRule(KyuubiSqlClassification)
extensions.injectPostHocResolutionRule(RepartitionBeforeWritingDatasource)
extensions.injectPostHocResolutionRule(RepartitionBeforeWritingHive)
extensions.injectPostHocResolutionRule(ForcedMaxOutputRowsRule)
extensions.injectPostHocResolutionRule(DropIgnoreNonexistent)
// watchdog extension
// a help rule for ForcedMaxOutputRowsRule
extensions.injectResolutionRule(MarkAggregateOrderRule)
extensions.injectPostHocResolutionRule(ForcedMaxOutputRowsRule)
extensions.injectPlannerStrategy(MaxPartitionStrategy)
}
}

View File

@ -18,13 +18,9 @@
package org.apache.kyuubi.sql.watchdog
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.analysis.AnalysisContext
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.Alias
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Distinct, Filter, Limit, LogicalPlan, Project, RepartitionByExpression, Sort, Union}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.execution.command.DataWritingCommand
import org.apache.kyuubi.sql.KyuubiSQLConf
@ -33,81 +29,18 @@ object ForcedMaxOutputRowsConstraint {
val CHILD_AGGREGATE_FLAG: String = "__kyuubi_child_agg__"
}
/*
* Add ForcedMaxOutputRows rule for output rows limitation
* to avoid huge output rows of non_limit query unexpectedly
* mainly applied to cases as below:
*
* case 1:
* {{{
* SELECT [c1, c2, ...]
* }}}
*
* case 2:
* {{{
* WITH CTE AS (
* ...)
* SELECT [c1, c2, ...] FROM CTE ...
* }}}
*
* The Logical Rule add a GlobalLimit node before root project
* */
case class ForcedMaxOutputRowsRule(session: SparkSession) extends Rule[LogicalPlan] {
private def isChildAggregate(a: Aggregate): Boolean = a
.aggregateExpressions.exists(p =>
case class ForcedMaxOutputRowsRule(sparkSession: SparkSession) extends ForcedMaxOutputRowsBase {
override protected def isChildAggregate(a: Aggregate): Boolean =
a.aggregateExpressions.exists(p =>
p.getTagValue(ForcedMaxOutputRowsConstraint.CHILD_AGGREGATE)
.contains(ForcedMaxOutputRowsConstraint.CHILD_AGGREGATE_FLAG))
private def isView: Boolean = {
val nestedViewDepth = AnalysisContext.get.nestedViewDepth
nestedViewDepth > 0
}
private def canInsertLimitInner(p: LogicalPlan): Boolean = p match {
case Aggregate(_, Alias(_, "havingCondition") :: Nil, _) => false
case agg: Aggregate => !isChildAggregate(agg)
case _: RepartitionByExpression => true
case _: Distinct => true
case _: Filter => true
case _: Project => true
case Limit(_, _) => true
case _: Sort => true
case Union(children, _, _) =>
if (children.exists(_.isInstanceOf[DataWritingCommand])) {
false
} else {
true
}
case _ => false
}
private def canInsertLimit(p: LogicalPlan, maxOutputRowsOpt: Option[Int]): Boolean = {
maxOutputRowsOpt match {
case Some(forcedMaxOutputRows) => canInsertLimitInner(p) &&
!p.maxRows.exists(_ <= forcedMaxOutputRows) &&
!isView
case None => false
}
}
override def apply(plan: LogicalPlan): LogicalPlan = {
val maxOutputRowsOpt = conf.getConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS)
plan match {
case p if p.resolved && canInsertLimit(p, maxOutputRowsOpt) =>
Limit(
maxOutputRowsOpt.get,
plan)
case _ => plan
}
}
}
case class MarkAggregateOrderRule(session: SparkSession) extends Rule[LogicalPlan] {
/**
* After SPARK-35712, we don't need mark child aggregate for spark 3.2.x or higher version,
* for more detail, please see https://github.com/apache/spark/pull/32470
*/
case class MarkAggregateOrderRule(sparkSession: SparkSession) extends Rule[LogicalPlan] {
private def markChildAggregate(a: Aggregate): Unit = {
// mark child aggregate
@ -116,7 +49,7 @@ case class MarkAggregateOrderRule(session: SparkSession) extends Rule[LogicalPla
ForcedMaxOutputRowsConstraint.CHILD_AGGREGATE_FLAG))
}
private def findAndMarkChildAggregate(plan: LogicalPlan): LogicalPlan = plan match {
protected def findAndMarkChildAggregate(plan: LogicalPlan): LogicalPlan = plan match {
/*
* The case mainly process order not aggregate column but grouping column as below
* SELECT c1, COUNT(*) as cnt
@ -129,7 +62,6 @@ case class MarkAggregateOrderRule(session: SparkSession) extends Rule[LogicalPla
.exists(x => x.resolved && x.name.equals("aggOrder")) =>
markChildAggregate(a)
plan
case _ =>
plan.children.foreach(_.foreach {
case agg: Aggregate => markChildAggregate(agg)

View File

@ -17,387 +17,4 @@
package org.apache.spark.sql
import org.apache.spark.sql.catalyst.plans.logical.GlobalLimit
import org.apache.kyuubi.sql.KyuubiSQLConf
import org.apache.kyuubi.sql.watchdog.MaxPartitionExceedException
class WatchDogSuite extends KyuubiSparkSQLExtensionTest {
override protected def beforeAll(): Unit = {
super.beforeAll()
setupData()
}
case class LimitAndExpected(limit: Int, expected: Int)
val limitAndExpecteds = List(LimitAndExpected(1, 1), LimitAndExpected(11, 10))
private def checkMaxPartition: Unit = {
withSQLConf(KyuubiSQLConf.WATCHDOG_MAX_PARTITIONS.key -> "100") {
checkAnswer(sql("SELECT count(distinct(p)) FROM test"), Row(10) :: Nil)
}
withSQLConf(KyuubiSQLConf.WATCHDOG_MAX_PARTITIONS.key -> "5") {
sql("SELECT * FROM test where p=1").queryExecution.sparkPlan
sql(s"SELECT * FROM test WHERE p in (${Range(0, 5).toList.mkString(",")})")
.queryExecution.sparkPlan
intercept[MaxPartitionExceedException](
sql("SELECT * FROM test where p != 1").queryExecution.sparkPlan)
intercept[MaxPartitionExceedException](
sql("SELECT * FROM test").queryExecution.sparkPlan)
intercept[MaxPartitionExceedException](sql(
s"SELECT * FROM test WHERE p in (${Range(0, 6).toList.mkString(",")})")
.queryExecution.sparkPlan)
}
}
test("watchdog with scan maxPartitions -- hive") {
Seq("textfile", "parquet").foreach { format =>
withTable("test", "temp") {
sql(
s"""
|CREATE TABLE test(i int)
|PARTITIONED BY (p int)
|STORED AS $format""".stripMargin)
spark.range(0, 10, 1).selectExpr("id as col")
.createOrReplaceTempView("temp")
for (part <- Range(0, 10)) {
sql(
s"""
|INSERT OVERWRITE TABLE test PARTITION (p='$part')
|select col from temp""".stripMargin)
}
checkMaxPartition
}
}
}
test("watchdog with scan maxPartitions -- data source") {
withTempDir { dir =>
withTempView("test") {
spark.range(10).selectExpr("id", "id as p")
.write
.partitionBy("p")
.mode("overwrite")
.save(dir.getCanonicalPath)
spark.read.load(dir.getCanonicalPath).createOrReplaceTempView("test")
checkMaxPartition
}
}
}
test("test watchdog: simple SELECT STATEMENT") {
withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") {
List("", "ORDER BY c1", "ORDER BY c2").foreach { sort =>
List("", " DISTINCT").foreach { distinct =>
assert(sql(
s"""
|SELECT $distinct *
|FROM t1
|$sort
|""".stripMargin).queryExecution.analyzed.isInstanceOf[GlobalLimit])
}
}
limitAndExpecteds.foreach { case LimitAndExpected(limit, expected) =>
List("", "ORDER BY c1", "ORDER BY c2").foreach { sort =>
List("", "DISTINCT").foreach { distinct =>
assert(sql(
s"""
|SELECT $distinct *
|FROM t1
|$sort
|LIMIT $limit
|""".stripMargin).queryExecution.analyzed.maxRows.contains(expected))
}
}
}
}
}
test("test watchdog: SELECT ... WITH AGGREGATE STATEMENT ") {
withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") {
assert(!sql("SELECT count(*) FROM t1")
.queryExecution.analyzed.isInstanceOf[GlobalLimit])
val sorts = List("", "ORDER BY cnt", "ORDER BY c1", "ORDER BY cnt, c1", "ORDER BY c1, cnt")
val havingConditions = List("", "HAVING cnt > 1")
havingConditions.foreach { having =>
sorts.foreach { sort =>
assert(sql(
s"""
|SELECT c1, COUNT(*) as cnt
|FROM t1
|GROUP BY c1
|$having
|$sort
|""".stripMargin).queryExecution.analyzed.isInstanceOf[GlobalLimit])
}
}
limitAndExpecteds.foreach { case LimitAndExpected(limit, expected) =>
havingConditions.foreach { having =>
sorts.foreach { sort =>
assert(sql(
s"""
|SELECT c1, COUNT(*) as cnt
|FROM t1
|GROUP BY c1
|$having
|$sort
|LIMIT $limit
|""".stripMargin).queryExecution.analyzed.maxRows.contains(expected))
}
}
}
}
}
test("test watchdog: SELECT with CTE forceMaxOutputRows") {
withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") {
val sorts = List("", "ORDER BY c1", "ORDER BY c2")
sorts.foreach { sort =>
assert(sql(
s"""
|WITH custom_cte AS (
|SELECT * FROM t1
|)
|SELECT *
|FROM custom_cte
|$sort
|""".stripMargin).queryExecution.analyzed.isInstanceOf[GlobalLimit])
}
limitAndExpecteds.foreach { case LimitAndExpected(limit, expected) =>
sorts.foreach { sort =>
assert(sql(
s"""
|WITH custom_cte AS (
|SELECT * FROM t1
|)
|SELECT *
|FROM custom_cte
|$sort
|LIMIT $limit
|""".stripMargin).queryExecution.analyzed.maxRows.contains(expected))
}
}
}
}
test("test watchdog: SELECT AGGREGATE WITH CTE forceMaxOutputRows") {
withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") {
assert(!sql(
"""
|WITH custom_cte AS (
|SELECT * FROM t1
|)
|
|SELECT COUNT(*)
|FROM custom_cte
|""".stripMargin).queryExecution
.analyzed.isInstanceOf[GlobalLimit])
val sorts = List("", "ORDER BY cnt", "ORDER BY c1", "ORDER BY cnt, c1", "ORDER BY c1, cnt")
val havingConditions = List("", "HAVING cnt > 1")
havingConditions.foreach { having =>
sorts.foreach { sort =>
assert(sql(
s"""
|WITH custom_cte AS (
|SELECT * FROM t1
|)
|
|SELECT c1, COUNT(*) as cnt
|FROM custom_cte
|GROUP BY c1
|$having
|$sort
|""".stripMargin).queryExecution.analyzed.isInstanceOf[GlobalLimit])
}
}
limitAndExpecteds.foreach { case LimitAndExpected(limit, expected) =>
havingConditions.foreach { having =>
sorts.foreach { sort =>
assert(sql(
s"""
|WITH custom_cte AS (
|SELECT * FROM t1
|)
|
|SELECT c1, COUNT(*) as cnt
|FROM custom_cte
|GROUP BY c1
|$having
|$sort
|LIMIT $limit
|""".stripMargin).queryExecution.analyzed.maxRows.contains(expected))
}
}
}
}
}
test("test watchdog: UNION Statement for forceMaxOutputRows") {
withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") {
List("", "ALL").foreach { x =>
assert(sql(
s"""
|SELECT c1, c2 FROM t1
|UNION $x
|SELECT c1, c2 FROM t2
|UNION $x
|SELECT c1, c2 FROM t3
|""".stripMargin)
.queryExecution.analyzed.isInstanceOf[GlobalLimit])
}
val sorts = List("", "ORDER BY cnt", "ORDER BY c1", "ORDER BY cnt, c1", "ORDER BY c1, cnt")
val havingConditions = List("", "HAVING cnt > 1")
List("", "ALL").foreach { x =>
havingConditions.foreach { having =>
sorts.foreach { sort =>
assert(sql(
s"""
|SELECT c1, count(c2) as cnt
|FROM t1
|GROUP BY c1
|$having
|UNION $x
|SELECT c1, COUNT(c2) as cnt
|FROM t2
|GROUP BY c1
|$having
|UNION $x
|SELECT c1, COUNT(c2) as cnt
|FROM t3
|GROUP BY c1
|$having
|$sort
|""".stripMargin)
.queryExecution.analyzed.isInstanceOf[GlobalLimit])
}
}
}
limitAndExpecteds.foreach { case LimitAndExpected(limit, expected) =>
assert(sql(
s"""
|SELECT c1, c2 FROM t1
|UNION
|SELECT c1, c2 FROM t2
|UNION
|SELECT c1, c2 FROM t3
|LIMIT $limit
|""".stripMargin)
.queryExecution.analyzed.maxRows.contains(expected))
}
}
}
test("test watchdog: Select View Statement for forceMaxOutputRows") {
withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "3") {
withTable("tmp_table", "tmp_union") {
withView("tmp_view", "tmp_view2") {
sql(s"create table tmp_table (a int, b int)")
sql(s"insert into tmp_table values (1,10),(2,20),(3,30),(4,40),(5,50)")
sql(s"create table tmp_union (a int, b int)")
sql(s"insert into tmp_union values (6,60),(7,70),(8,80),(9,90),(10,100)")
sql(s"create view tmp_view2 as select * from tmp_union")
assert(!sql(
s"""
|CREATE VIEW tmp_view
|as
|SELECT * FROM
|tmp_table
|""".stripMargin)
.queryExecution.analyzed.isInstanceOf[GlobalLimit])
assert(sql(
s"""
|SELECT * FROM
|tmp_view
|""".stripMargin)
.queryExecution.analyzed.maxRows.contains(3))
assert(sql(
s"""
|SELECT * FROM
|tmp_view
|limit 11
|""".stripMargin)
.queryExecution.analyzed.maxRows.contains(3))
assert(sql(
s"""
|SELECT * FROM
|(select * from tmp_view
|UNION
|select * from tmp_view2)
|ORDER BY a
|DESC
|""".stripMargin)
.collect().head.get(0).equals(10))
}
}
}
}
test("test watchdog: Insert Statement for forceMaxOutputRows") {
withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") {
withTable("tmp_table", "tmp_insert") {
spark.sql(s"create table tmp_table (a int, b int)")
spark.sql(s"insert into tmp_table values (1,10),(2,20),(3,30),(4,40),(5,50)")
val multiInsertTableName1: String = "tmp_tbl1"
val multiInsertTableName2: String = "tmp_tbl2"
sql(s"drop table if exists $multiInsertTableName1")
sql(s"drop table if exists $multiInsertTableName2")
sql(s"create table $multiInsertTableName1 like tmp_table")
sql(s"create table $multiInsertTableName2 like tmp_table")
assert(!sql(
s"""
|FROM tmp_table
|insert into $multiInsertTableName1 select * limit 2
|insert into $multiInsertTableName2 select *
|""".stripMargin)
.queryExecution.analyzed.isInstanceOf[GlobalLimit])
}
}
}
test("test watchdog: Distribute by for forceMaxOutputRows") {
withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") {
withTable("tmp_table") {
spark.sql(s"create table tmp_table (a int, b int)")
spark.sql(s"insert into tmp_table values (1,10),(2,20),(3,30),(4,40),(5,50)")
assert(sql(
s"""
|SELECT *
|FROM tmp_table
|DISTRIBUTE BY a
|""".stripMargin)
.queryExecution.analyzed.isInstanceOf[GlobalLimit])
}
}
}
}
class WatchDogSuite extends WatchDogSuiteBase {}

View File

@ -19,6 +19,8 @@ package org.apache.kyuubi.sql
import org.apache.spark.sql.SparkSessionExtensions
import org.apache.kyuubi.sql.watchdog.{ForcedMaxOutputRowsRule, MaxPartitionStrategy}
// scalastyle:off line.size.limit
/**
* Depend on Spark SQL Extension framework, we can use this extension follow steps
@ -33,5 +35,9 @@ class KyuubiSparkSQLExtension extends (SparkSessionExtensions => Unit) {
extensions.injectPostHocResolutionRule(RebalanceBeforeWritingDatasource)
extensions.injectPostHocResolutionRule(RebalanceBeforeWritingHive)
extensions.injectPostHocResolutionRule(DropIgnoreNonexistent)
// watchdog extension
extensions.injectPostHocResolutionRule(ForcedMaxOutputRowsRule)
extensions.injectPlannerStrategy(MaxPartitionStrategy)
}
}

View File

@ -0,0 +1,38 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kyuubi.sql.watchdog
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, WithCTE}
case class ForcedMaxOutputRowsRule(sparkSession: SparkSession) extends ForcedMaxOutputRowsBase {
override protected def isChildAggregate(a: Aggregate): Boolean = false
override protected def canInsertLimitInner(p: LogicalPlan): Boolean = p match {
case WithCTE(plan, _) => this.canInsertLimitInner(plan)
case plan: LogicalPlan => super.canInsertLimitInner(plan)
}
override protected def canInsertLimit(p: LogicalPlan, maxOutputRowsOpt: Option[Int]): Boolean = {
p match {
case WithCTE(plan, _) => this.canInsertLimit(plan, maxOutputRowsOpt)
case _ => super.canInsertLimit(p, maxOutputRowsOpt)
}
}
}

View File

@ -0,0 +1,20 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql
class WatchDogSuite extends WatchDogSuiteBase {}

View File

@ -0,0 +1,94 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kyuubi.sql.watchdog
import org.apache.spark.sql.catalyst.analysis.AnalysisContext
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.Alias
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.command.DataWritingCommand
import org.apache.kyuubi.sql.KyuubiSQLConf
/*
* Add ForcedMaxOutputRows rule for output rows limitation
* to avoid huge output rows of non_limit query unexpectedly
* mainly applied to cases as below:
*
* case 1:
* {{{
* SELECT [c1, c2, ...]
* }}}
*
* case 2:
* {{{
* WITH CTE AS (
* ...)
* SELECT [c1, c2, ...] FROM CTE ...
* }}}
*
* The Logical Rule add a GlobalLimit node before root project
* */
trait ForcedMaxOutputRowsBase extends Rule[LogicalPlan] {
protected def isChildAggregate(a: Aggregate): Boolean
protected def isView: Boolean = {
val nestedViewDepth = AnalysisContext.get.nestedViewDepth
nestedViewDepth > 0
}
protected def canInsertLimitInner(p: LogicalPlan): Boolean = p match {
case Aggregate(_, Alias(_, "havingCondition") :: Nil, _) => false
case agg: Aggregate => !isChildAggregate(agg)
case _: RepartitionByExpression => true
case _: Distinct => true
case _: Filter => true
case _: Project => true
case Limit(_, _) => true
case _: Sort => true
case Union(children, _, _) =>
if (children.exists(_.isInstanceOf[DataWritingCommand])) {
false
} else {
true
}
case _ => false
}
protected def canInsertLimit(p: LogicalPlan, maxOutputRowsOpt: Option[Int]): Boolean = {
maxOutputRowsOpt match {
case Some(forcedMaxOutputRows) => canInsertLimitInner(p) &&
!p.maxRows.exists(_ <= forcedMaxOutputRows) &&
!isView
case None => false
}
}
override def apply(plan: LogicalPlan): LogicalPlan = {
val maxOutputRowsOpt = conf.getConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS)
plan match {
case p if p.resolved && canInsertLimit(p, maxOutputRowsOpt) =>
Limit(
maxOutputRowsOpt.get,
plan)
case _ => plan
}
}
}

View File

@ -19,7 +19,7 @@ package org.apache.spark.sql
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, ExpressionSet, PredicateHelper, SubqueryExpression}
import org.apache.spark.sql.catalyst.plans.logical.LeafNode
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
import org.apache.spark.sql.types.StructType
trait PruneFileSourcePartitionHelper extends PredicateHelper {

View File

@ -0,0 +1,418 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql
import org.apache.spark.sql.catalyst.plans.logical.GlobalLimit
import org.apache.kyuubi.sql.KyuubiSQLConf
import org.apache.kyuubi.sql.watchdog.MaxPartitionExceedException
trait WatchDogSuiteBase extends KyuubiSparkSQLExtensionTest {
override protected def beforeAll(): Unit = {
super.beforeAll()
setupData()
}
case class LimitAndExpected(limit: Int, expected: Int)
val limitAndExpecteds = List(LimitAndExpected(1, 1), LimitAndExpected(11, 10))
private def checkMaxPartition: Unit = {
withSQLConf(KyuubiSQLConf.WATCHDOG_MAX_PARTITIONS.key -> "100") {
checkAnswer(sql("SELECT count(distinct(p)) FROM test"), Row(10) :: Nil)
}
withSQLConf(KyuubiSQLConf.WATCHDOG_MAX_PARTITIONS.key -> "5") {
sql("SELECT * FROM test where p=1").queryExecution.sparkPlan
sql(s"SELECT * FROM test WHERE p in (${Range(0, 5).toList.mkString(",")})")
.queryExecution.sparkPlan
intercept[MaxPartitionExceedException](
sql("SELECT * FROM test where p != 1").queryExecution.sparkPlan)
intercept[MaxPartitionExceedException](
sql("SELECT * FROM test").queryExecution.sparkPlan)
intercept[MaxPartitionExceedException](sql(
s"SELECT * FROM test WHERE p in (${Range(0, 6).toList.mkString(",")})")
.queryExecution.sparkPlan)
}
}
test("watchdog with scan maxPartitions -- hive") {
Seq("textfile", "parquet").foreach { format =>
withTable("test", "temp") {
sql(
s"""
|CREATE TABLE test(i int)
|PARTITIONED BY (p int)
|STORED AS $format""".stripMargin)
spark.range(0, 10, 1).selectExpr("id as col")
.createOrReplaceTempView("temp")
for (part <- Range(0, 10)) {
sql(
s"""
|INSERT OVERWRITE TABLE test PARTITION (p='$part')
|select col from temp""".stripMargin)
}
checkMaxPartition
}
}
}
test("watchdog with scan maxPartitions -- data source") {
withTempDir { dir =>
withTempView("test") {
spark.range(10).selectExpr("id", "id as p")
.write
.partitionBy("p")
.mode("overwrite")
.save(dir.getCanonicalPath)
spark.read.load(dir.getCanonicalPath).createOrReplaceTempView("test")
checkMaxPartition
}
}
}
test("test watchdog: simple SELECT STATEMENT") {
withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") {
List("", "ORDER BY c1", "ORDER BY c2").foreach { sort =>
List("", " DISTINCT").foreach { distinct =>
assert(sql(
s"""
|SELECT $distinct *
|FROM t1
|$sort
|""".stripMargin).queryExecution.analyzed.isInstanceOf[GlobalLimit])
}
}
limitAndExpecteds.foreach { case LimitAndExpected(limit, expected) =>
List("", "ORDER BY c1", "ORDER BY c2").foreach { sort =>
List("", "DISTINCT").foreach { distinct =>
assert(sql(
s"""
|SELECT $distinct *
|FROM t1
|$sort
|LIMIT $limit
|""".stripMargin).queryExecution.analyzed.maxRows.contains(expected))
}
}
}
}
}
test("test watchdog: SELECT ... WITH AGGREGATE STATEMENT ") {
withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") {
assert(!sql("SELECT count(*) FROM t1")
.queryExecution.analyzed.isInstanceOf[GlobalLimit])
val sorts = List("", "ORDER BY cnt", "ORDER BY c1", "ORDER BY cnt, c1", "ORDER BY c1, cnt")
val havingConditions = List("", "HAVING cnt > 1")
havingConditions.foreach { having =>
sorts.foreach { sort =>
assert(sql(
s"""
|SELECT c1, COUNT(*) as cnt
|FROM t1
|GROUP BY c1
|$having
|$sort
|""".stripMargin).queryExecution.analyzed.isInstanceOf[GlobalLimit])
}
}
limitAndExpecteds.foreach { case LimitAndExpected(limit, expected) =>
havingConditions.foreach { having =>
sorts.foreach { sort =>
assert(sql(
s"""
|SELECT c1, COUNT(*) as cnt
|FROM t1
|GROUP BY c1
|$having
|$sort
|LIMIT $limit
|""".stripMargin).queryExecution.analyzed.maxRows.contains(expected))
}
}
}
}
}
test("test watchdog: SELECT with CTE forceMaxOutputRows") {
// simple CTE
val q1 =
"""
|WITH t2 AS (
| SELECT * FROM t1
|)
|""".stripMargin
// nested CTE
val q2 =
"""
|WITH
| t AS (SELECT * FROM t1),
| t2 AS (
| WITH t3 AS (SELECT * FROM t1)
| SELECT * FROM t3
| )
|""".stripMargin
withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") {
val sorts = List("", "ORDER BY c1", "ORDER BY c2")
sorts.foreach { sort =>
Seq(q1, q2).foreach { withQuery =>
assert(sql(
s"""
|$withQuery
|SELECT * FROM t2
|$sort
|""".stripMargin).queryExecution.analyzed.isInstanceOf[GlobalLimit])
}
}
limitAndExpecteds.foreach { case LimitAndExpected(limit, expected) =>
sorts.foreach { sort =>
Seq(q1, q2).foreach { withQuery =>
assert(sql(
s"""
|$withQuery
|SELECT * FROM t2
|$sort
|LIMIT $limit
|""".stripMargin).queryExecution.optimizedPlan.maxRows.contains(expected))
}
}
}
}
}
test("test watchdog: SELECT AGGREGATE WITH CTE forceMaxOutputRows") {
withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") {
assert(!sql(
"""
|WITH custom_cte AS (
|SELECT * FROM t1
|)
|
|SELECT COUNT(*)
|FROM custom_cte
|""".stripMargin).queryExecution
.analyzed.isInstanceOf[GlobalLimit])
val sorts = List("", "ORDER BY cnt", "ORDER BY c1", "ORDER BY cnt, c1", "ORDER BY c1, cnt")
val havingConditions = List("", "HAVING cnt > 1")
havingConditions.foreach { having =>
sorts.foreach { sort =>
assert(sql(
s"""
|WITH custom_cte AS (
|SELECT * FROM t1
|)
|
|SELECT c1, COUNT(*) as cnt
|FROM custom_cte
|GROUP BY c1
|$having
|$sort
|""".stripMargin).queryExecution.analyzed.isInstanceOf[GlobalLimit])
}
}
limitAndExpecteds.foreach { case LimitAndExpected(limit, expected) =>
havingConditions.foreach { having =>
sorts.foreach { sort =>
assert(sql(
s"""
|WITH custom_cte AS (
|SELECT * FROM t1
|)
|
|SELECT c1, COUNT(*) as cnt
|FROM custom_cte
|GROUP BY c1
|$having
|$sort
|LIMIT $limit
|""".stripMargin).queryExecution.optimizedPlan.maxRows.contains(expected))
}
}
}
}
}
test("test watchdog: UNION Statement for forceMaxOutputRows") {
withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") {
List("", "ALL").foreach { x =>
assert(sql(
s"""
|SELECT c1, c2 FROM t1
|UNION $x
|SELECT c1, c2 FROM t2
|UNION $x
|SELECT c1, c2 FROM t3
|""".stripMargin)
.queryExecution.analyzed.isInstanceOf[GlobalLimit])
}
val sorts = List("", "ORDER BY cnt", "ORDER BY c1", "ORDER BY cnt, c1", "ORDER BY c1, cnt")
val havingConditions = List("", "HAVING cnt > 1")
List("", "ALL").foreach { x =>
havingConditions.foreach { having =>
sorts.foreach { sort =>
assert(sql(
s"""
|SELECT c1, count(c2) as cnt
|FROM t1
|GROUP BY c1
|$having
|UNION $x
|SELECT c1, COUNT(c2) as cnt
|FROM t2
|GROUP BY c1
|$having
|UNION $x
|SELECT c1, COUNT(c2) as cnt
|FROM t3
|GROUP BY c1
|$having
|$sort
|""".stripMargin)
.queryExecution.analyzed.isInstanceOf[GlobalLimit])
}
}
}
limitAndExpecteds.foreach { case LimitAndExpected(limit, expected) =>
assert(sql(
s"""
|SELECT c1, c2 FROM t1
|UNION
|SELECT c1, c2 FROM t2
|UNION
|SELECT c1, c2 FROM t3
|LIMIT $limit
|""".stripMargin)
.queryExecution.analyzed.maxRows.contains(expected))
}
}
}
test("test watchdog: Select View Statement for forceMaxOutputRows") {
withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "3") {
withTable("tmp_table", "tmp_union") {
withView("tmp_view", "tmp_view2") {
sql(s"create table tmp_table (a int, b int)")
sql(s"insert into tmp_table values (1,10),(2,20),(3,30),(4,40),(5,50)")
sql(s"create table tmp_union (a int, b int)")
sql(s"insert into tmp_union values (6,60),(7,70),(8,80),(9,90),(10,100)")
sql(s"create view tmp_view2 as select * from tmp_union")
assert(!sql(
s"""
|CREATE VIEW tmp_view
|as
|SELECT * FROM
|tmp_table
|""".stripMargin)
.queryExecution.analyzed.isInstanceOf[GlobalLimit])
assert(sql(
s"""
|SELECT * FROM
|tmp_view
|""".stripMargin)
.queryExecution.analyzed.maxRows.contains(3))
assert(sql(
s"""
|SELECT * FROM
|tmp_view
|limit 11
|""".stripMargin)
.queryExecution.analyzed.maxRows.contains(3))
assert(sql(
s"""
|SELECT * FROM
|(select * from tmp_view
|UNION
|select * from tmp_view2)
|ORDER BY a
|DESC
|""".stripMargin)
.collect().head.get(0).equals(10))
}
}
}
}
test("test watchdog: Insert Statement for forceMaxOutputRows") {
withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") {
withTable("tmp_table", "tmp_insert") {
spark.sql(s"create table tmp_table (a int, b int)")
spark.sql(s"insert into tmp_table values (1,10),(2,20),(3,30),(4,40),(5,50)")
val multiInsertTableName1: String = "tmp_tbl1"
val multiInsertTableName2: String = "tmp_tbl2"
sql(s"drop table if exists $multiInsertTableName1")
sql(s"drop table if exists $multiInsertTableName2")
sql(s"create table $multiInsertTableName1 like tmp_table")
sql(s"create table $multiInsertTableName2 like tmp_table")
assert(!sql(
s"""
|FROM tmp_table
|insert into $multiInsertTableName1 select * limit 2
|insert into $multiInsertTableName2 select *
|""".stripMargin)
.queryExecution.analyzed.isInstanceOf[GlobalLimit])
}
}
}
test("test watchdog: Distribute by for forceMaxOutputRows") {
withSQLConf(KyuubiSQLConf.WATCHDOG_FORCED_MAXOUTPUTROWS.key -> "10") {
withTable("tmp_table") {
spark.sql(s"create table tmp_table (a int, b int)")
spark.sql(s"insert into tmp_table values (1,10),(2,20),(3,30),(4,40),(5,50)")
assert(sql(
s"""
|SELECT *
|FROM tmp_table
|DISTRIBUTE BY a
|""".stripMargin)
.queryExecution.analyzed.isInstanceOf[GlobalLimit])
}
}
}
}