diff --git a/.github/workflows/master.yml b/.github/workflows/master.yml
index a0198a40b..36ceafe50 100644
--- a/.github/workflows/master.yml
+++ b/.github/workflows/master.yml
@@ -35,8 +35,8 @@ jobs:
profiles:
- ''
- '-Pspark-3.0 -Dspark.archive.mirror=https://archive.apache.org/dist/spark/spark-3.1.1 -Dspark.archive.name=spark-3.1.1-bin-hadoop2.7.tgz -Dmaven.plugin.scalatest.exclude.tags=org.apache.kyuubi.tags.ExtendedSQLTest,org.apache.kyuubi.tags.DataLakeTest'
- - '-Pspark-3.1'
- - '-Pspark-3.1 -Pspark-hadoop-3.2'
+ - '-Pspark-3.1 -Pkyuubi-sql-spark_3.1'
+ - '-Pspark-3.1 -Pkyuubi-sql-spark_3.1 -Pspark-hadoop-3.2'
- '-Pspark-3.2-snapshot -pl :kyuubi-spark-sql-engine,:kyuubi-common,:kyuubi-ha,:kyuubi-zookeeper -Dmaven.plugin.scalatest.exclude.tags=org.apache.kyuubi.tags.ExtendedSQLTest,org.apache.kyuubi.tags.DataLakeTest'
- '-DwildcardSuites=org.apache.kyuubi.operation.tpcds.TPCDSOutputSchemaSuite,org.apache.kyuubi.operation.tpcds.TPCDSDDLSuite -Dmaven.plugin.scalatest.exclude.tags=""'
env:
diff --git a/dev/kyuubi-extension-spark_3.1/pom.xml b/dev/kyuubi-extension-spark_3.1/pom.xml
new file mode 100644
index 000000000..232e729d3
--- /dev/null
+++ b/dev/kyuubi-extension-spark_3.1/pom.xml
@@ -0,0 +1,106 @@
+
+
+
+
+
+ kyuubi
+ org.apache.kyuubi
+ 1.2.0-SNAPSHOT
+ ../../pom.xml
+
+ 4.0.0
+
+ kyuubi-extension-spark_3.1
+ jar
+ Kyuubi Project Dev Spark Extensions
+
+
+
+ org.apache.spark
+ spark-hive_${scala.binary.version}
+ ${spark.version}
+ provided
+
+
+
+ org.apache.spark
+ spark-sql_${scala.binary.version}
+ ${spark.version}
+ provided
+
+
+
+ org.apache.spark
+ spark-catalyst_${scala.binary.version}
+ ${spark.version}
+ provided
+
+
+
+ org.apache.hadoop
+ hadoop-client-runtime
+ provided
+
+
+
+ org.apache.spark
+ spark-hive_${scala.binary.version}
+ ${spark.version}
+ test-jar
+ test
+
+
+
+ org.apache.spark
+ spark-sql_${scala.binary.version}
+ ${spark.version}
+ test-jar
+ test
+
+
+
+ org.apache.spark
+ spark-catalyst_${scala.binary.version}
+ ${spark.version}
+ test-jar
+ test
+
+
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ ${spark.version}
+ test-jar
+ test
+
+
+
+ org.scalatest
+ scalatest_${scala.binary.version}
+ ${scalatest.version}
+ test
+
+
+
+
+ target/scala-${scala.binary.version}/classes
+ target/scala-${scala.binary.version}/test-classes
+
+
diff --git a/dev/kyuubi-extension-spark_3.1/src/main/scala/org/apache/kyuubi/sql/KyuubiAnalysis.scala b/dev/kyuubi-extension-spark_3.1/src/main/scala/org/apache/kyuubi/sql/KyuubiAnalysis.scala
new file mode 100644
index 000000000..c9a1d0bb2
--- /dev/null
+++ b/dev/kyuubi-extension-spark_3.1/src/main/scala/org/apache/kyuubi/sql/KyuubiAnalysis.scala
@@ -0,0 +1,172 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi.sql
+
+import java.util.Random
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, Literal, Multiply, Rand}
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.execution.command.CreateDataSourceTableAsSelectCommand
+import org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand
+import org.apache.spark.sql.hive.execution.{CreateHiveTableAsSelectCommand, InsertIntoHiveTable}
+import org.apache.spark.sql.internal.StaticSQLConf
+import org.apache.spark.sql.types.IntegerType
+
+import org.apache.kyuubi.sql.RepartitionBeforeWriteHelper._
+
+/**
+ * For datasource table, there two commands can write data to table
+ * 1. InsertIntoHadoopFsRelationCommand
+ * 2. CreateDataSourceTableAsSelectCommand
+ * This rule add a repartition node between write and query
+ */
+case class RepartitionBeforeWrite(session: SparkSession) extends Rule[LogicalPlan] {
+ override def apply(plan: LogicalPlan): LogicalPlan = {
+ if (conf.getConf(KyuubiSQLConf.INSERT_REPARTITION_BEFORE_WRITE)) {
+ addRepartition(plan)
+ } else {
+ plan
+ }
+ }
+
+ private def addRepartition(plan: LogicalPlan): LogicalPlan = plan match {
+ case i @ InsertIntoHadoopFsRelationCommand(_, sp, _, pc, bucket, _, _, query, _, _, _, _)
+ if query.resolved && bucket.isEmpty && canInsertRepartitionByExpression(query) =>
+ val dynamicPartitionColumns = pc.filterNot(attr => sp.contains(attr.name))
+ if (dynamicPartitionColumns.isEmpty) {
+ i.copy(query =
+ RepartitionByExpression(
+ Seq.empty,
+ query,
+ conf.getConf(KyuubiSQLConf.INSERT_REPARTITION_NUM)))
+ } else {
+ val extended = dynamicPartitionColumns ++ dynamicPartitionExtraExpression(
+ conf.getConf(KyuubiSQLConf.DYNAMIC_PARTITION_INSERTION_REPARTITION_NUM))
+ i.copy(query =
+ RepartitionByExpression(
+ extended,
+ query,
+ conf.getConf(KyuubiSQLConf.INSERT_REPARTITION_NUM)))
+ }
+
+ case c @ CreateDataSourceTableAsSelectCommand(table, _, query, _)
+ if query.resolved && table.bucketSpec.isEmpty && canInsertRepartitionByExpression(query) =>
+ val dynamicPartitionColumns =
+ query.output.filter(attr => table.partitionColumnNames.contains(attr.name))
+ if (dynamicPartitionColumns.isEmpty) {
+ c.copy(query =
+ RepartitionByExpression(
+ Seq.empty,
+ query,
+ conf.getConf(KyuubiSQLConf.INSERT_REPARTITION_NUM)))
+ } else {
+ val extended = dynamicPartitionColumns ++ dynamicPartitionExtraExpression(
+ conf.getConf(KyuubiSQLConf.DYNAMIC_PARTITION_INSERTION_REPARTITION_NUM))
+ c.copy(query =
+ RepartitionByExpression(
+ extended,
+ query,
+ conf.getConf(KyuubiSQLConf.INSERT_REPARTITION_NUM)))
+ }
+
+ case _ => plan
+ }
+}
+
+/**
+ * For Hive table, there two commands can write data to table
+ * 1. InsertIntoHiveTable
+ * 2. CreateHiveTableAsSelectCommand
+ * This rule add a repartition node between write and query
+ */
+case class RepartitionBeforeWriteHive(session: SparkSession) extends Rule[LogicalPlan] {
+ override def apply(plan: LogicalPlan): LogicalPlan = {
+ if (conf.getConf(StaticSQLConf.CATALOG_IMPLEMENTATION) == "hive" &&
+ conf.getConf(KyuubiSQLConf.INSERT_REPARTITION_BEFORE_WRITE)) {
+ addRepartition(plan)
+ } else {
+ plan
+ }
+ }
+
+ def addRepartition(plan: LogicalPlan): LogicalPlan = plan match {
+ case i @ InsertIntoHiveTable(table, partition, query, _, _, _)
+ if query.resolved && table.bucketSpec.isEmpty && canInsertRepartitionByExpression(query) =>
+ val dynamicPartitionColumns = partition.filter(_._2.isEmpty).keys
+ .flatMap(name => query.output.find(_.name == name)).toSeq
+ if (dynamicPartitionColumns.isEmpty) {
+ i.copy(query =
+ RepartitionByExpression(
+ Seq.empty,
+ query,
+ conf.getConf(KyuubiSQLConf.INSERT_REPARTITION_NUM)))
+ } else {
+ val extended = dynamicPartitionColumns ++ dynamicPartitionExtraExpression(
+ conf.getConf(KyuubiSQLConf.DYNAMIC_PARTITION_INSERTION_REPARTITION_NUM))
+ i.copy(query =
+ RepartitionByExpression(
+ extended,
+ query,
+ conf.getConf(KyuubiSQLConf.INSERT_REPARTITION_NUM)))
+ }
+
+ case c @ CreateHiveTableAsSelectCommand(table, query, _, _)
+ if query.resolved && table.bucketSpec.isEmpty && canInsertRepartitionByExpression(query) =>
+ val dynamicPartitionColumns =
+ query.output.filter(attr => table.partitionColumnNames.contains(attr.name))
+ if (dynamicPartitionColumns.isEmpty) {
+ c.copy(query =
+ RepartitionByExpression(
+ Seq.empty,
+ query,
+ conf.getConf(KyuubiSQLConf.INSERT_REPARTITION_NUM)))
+ } else {
+ val extended = dynamicPartitionColumns ++ dynamicPartitionExtraExpression(
+ conf.getConf(KyuubiSQLConf.DYNAMIC_PARTITION_INSERTION_REPARTITION_NUM))
+ c.copy(query =
+ RepartitionByExpression(
+ extended,
+ query,
+ conf.getConf(KyuubiSQLConf.INSERT_REPARTITION_NUM)))
+ }
+
+ case _ => plan
+ }
+}
+
+object RepartitionBeforeWriteHelper {
+ def canInsertRepartitionByExpression(plan: LogicalPlan): Boolean = plan match {
+ case Limit(_, _) => false
+ case _: Sort => false
+ case _: RepartitionByExpression => false
+ case _: Repartition => false
+ case _ => true
+ }
+
+ def dynamicPartitionExtraExpression(partitionNumber: Int): Seq[Expression] = {
+ // Dynamic partition insertion will add repartition by partition column, but it could cause
+ // data skew (one partition value has large data). So we add extra partition column for the
+ // same dynamic partition to avoid skew.
+ Cast(Multiply(
+ new Rand(Literal(new Random().nextLong())),
+ Literal(partitionNumber.toDouble)
+ ), IntegerType) :: Nil
+ }
+}
diff --git a/dev/kyuubi-extension-spark_3.1/src/main/scala/org/apache/kyuubi/sql/KyuubiEnsureRequirements.scala b/dev/kyuubi-extension-spark_3.1/src/main/scala/org/apache/kyuubi/sql/KyuubiEnsureRequirements.scala
new file mode 100644
index 000000000..8f8942cb1
--- /dev/null
+++ b/dev/kyuubi-extension-spark_3.1/src/main/scala/org/apache/kyuubi/sql/KyuubiEnsureRequirements.scala
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi.sql
+
+import org.apache.spark.sql.catalyst.expressions.SortOrder
+import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, UnspecifiedDistribution}
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.execution.{SortExec, SparkPlan}
+import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec}
+
+/**
+ * Copy from Apache Spark `EnsureRequirements`
+ * 1. remove reorder join predicates
+ * 2. remove shuffle pruning
+ */
+object KyuubiEnsureRequirements extends Rule[SparkPlan] {
+ private def ensureDistributionAndOrdering(operator: SparkPlan): SparkPlan = {
+ val requiredChildDistributions: Seq[Distribution] = operator.requiredChildDistribution
+ val requiredChildOrderings: Seq[Seq[SortOrder]] = operator.requiredChildOrdering
+ var children: Seq[SparkPlan] = operator.children
+ assert(requiredChildDistributions.length == children.length)
+ assert(requiredChildOrderings.length == children.length)
+
+ // Ensure that the operator's children satisfy their output distribution requirements.
+ children = children.zip(requiredChildDistributions).map {
+ case (child, distribution) if child.outputPartitioning.satisfies(distribution) =>
+ child
+ case (child, BroadcastDistribution(mode)) =>
+ BroadcastExchangeExec(mode, child)
+ case (child, distribution) =>
+ val numPartitions = distribution.requiredNumPartitions
+ .getOrElse(conf.numShufflePartitions)
+ ShuffleExchangeExec(distribution.createPartitioning(numPartitions), child)
+ }
+
+ // Get the indexes of children which have specified distribution requirements and need to have
+ // same number of partitions.
+ val childrenIndexes = requiredChildDistributions.zipWithIndex.filter {
+ case (UnspecifiedDistribution, _) => false
+ case (_: BroadcastDistribution, _) => false
+ case _ => true
+ }.map(_._2)
+
+ val childrenNumPartitions =
+ childrenIndexes.map(children(_).outputPartitioning.numPartitions).toSet
+
+ if (childrenNumPartitions.size > 1) {
+ // Get the number of partitions which is explicitly required by the distributions.
+ val requiredNumPartitions = {
+ val numPartitionsSet = childrenIndexes.flatMap {
+ index => requiredChildDistributions(index).requiredNumPartitions
+ }.toSet
+ assert(numPartitionsSet.size <= 1,
+ s"$operator have incompatible requirements of the number of partitions for its children")
+ numPartitionsSet.headOption
+ }
+
+ // If there are non-shuffle children that satisfy the required distribution, we have
+ // some tradeoffs when picking the expected number of shuffle partitions:
+ // 1. We should avoid shuffling these children.
+ // 2. We should have a reasonable parallelism.
+ val nonShuffleChildrenNumPartitions =
+ childrenIndexes.map(children).filterNot(_.isInstanceOf[ShuffleExchangeExec])
+ .map(_.outputPartitioning.numPartitions)
+ val expectedChildrenNumPartitions = if (nonShuffleChildrenNumPartitions.nonEmpty) {
+ if (nonShuffleChildrenNumPartitions.length == childrenIndexes.length) {
+ // Here we pick the max number of partitions among these non-shuffle children.
+ nonShuffleChildrenNumPartitions.max
+ } else {
+ // Here we pick the max number of partitions among these non-shuffle children as the
+ // expected number of shuffle partitions. However, if it's smaller than
+ // `conf.numShufflePartitions`, we pick `conf.numShufflePartitions` as the
+ // expected number of shuffle partitions.
+ math.max(nonShuffleChildrenNumPartitions.max, conf.defaultNumShufflePartitions)
+ }
+ } else {
+ childrenNumPartitions.max
+ }
+
+ val targetNumPartitions = requiredNumPartitions.getOrElse(expectedChildrenNumPartitions)
+
+ children = children.zip(requiredChildDistributions).zipWithIndex.map {
+ case ((child, distribution), index) if childrenIndexes.contains(index) =>
+ if (child.outputPartitioning.numPartitions == targetNumPartitions) {
+ child
+ } else {
+ val defaultPartitioning = distribution.createPartitioning(targetNumPartitions)
+ child match {
+ // If child is an exchange, we replace it with a new one having defaultPartitioning.
+ case ShuffleExchangeExec(_, c, _) => ShuffleExchangeExec(defaultPartitioning, c)
+ case _ => ShuffleExchangeExec(defaultPartitioning, child)
+ }
+ }
+
+ case ((child, _), _) => child
+ }
+ }
+
+ // Now that we've performed any necessary shuffles, add sorts to guarantee output orderings:
+ children = children.zip(requiredChildOrderings).map { case (child, requiredOrdering) =>
+ // If child.outputOrdering already satisfies the requiredOrdering, we do not need to sort.
+ if (SortOrder.orderingSatisfies(child.outputOrdering, requiredOrdering)) {
+ child
+ } else {
+ SortExec(requiredOrdering, global = false, child = child)
+ }
+ }
+
+ operator.withNewChildren(children)
+ }
+
+ def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
+ case operator: SparkPlan =>
+ ensureDistributionAndOrdering(operator)
+ }
+}
diff --git a/dev/kyuubi-extension-spark_3.1/src/main/scala/org/apache/kyuubi/sql/KyuubiQueryStagePreparation.scala b/dev/kyuubi-extension-spark_3.1/src/main/scala/org/apache/kyuubi/sql/KyuubiQueryStagePreparation.scala
new file mode 100644
index 000000000..6c48a3891
--- /dev/null
+++ b/dev/kyuubi-extension-spark_3.1/src/main/scala/org/apache/kyuubi/sql/KyuubiQueryStagePreparation.scala
@@ -0,0 +1,241 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi.sql
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.plans.physical.Distribution
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.execution.{SortExec, SparkPlan}
+import org.apache.spark.sql.execution.adaptive.QueryStageExec
+import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
+import org.apache.spark.sql.execution.command.{ResetCommand, SetCommand}
+import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike, Exchange, ShuffleExchangeExec, ShuffleExchangeLike}
+import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoinExec}
+import org.apache.spark.sql.internal.SQLConf
+
+import org.apache.kyuubi.sql.KyuubiSQLConf._
+
+/**
+ * Insert shuffle node before join if it doesn't exist to make `OptimizeSkewedJoin` works.
+ */
+object InsertShuffleNodeBeforeJoin extends Rule[SparkPlan] {
+
+ override def apply(plan: SparkPlan): SparkPlan = {
+ // this rule has no meaning without AQE
+ if (!conf.getConf(FORCE_SHUFFLE_BEFORE_JOIN) ||
+ !conf.getConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED)) {
+ return plan
+ }
+
+ val newPlan = insertShuffleBeforeJoin(plan)
+ if (plan.fastEquals(newPlan)) {
+ plan
+ } else {
+ // make sure the output partitioning and ordering will not be broken.
+ KyuubiEnsureRequirements.apply(newPlan)
+ }
+ }
+
+ private def insertShuffleBeforeJoin(plan: SparkPlan): SparkPlan = plan transformUp {
+ case smj @ SortMergeJoinExec(_, _, _, _, l, r, _) =>
+ smj.withNewChildren(checkAndInsertShuffle(smj.requiredChildDistribution.head, l) ::
+ checkAndInsertShuffle(smj.requiredChildDistribution(1), r) :: Nil)
+
+ case shj: ShuffledHashJoinExec =>
+ if (!shj.left.isInstanceOf[Exchange] && !shj.right.isInstanceOf[Exchange]) {
+ shj.withNewChildren(withShuffleExec(shj.requiredChildDistribution.head, shj.left) ::
+ withShuffleExec(shj.requiredChildDistribution(1), shj.right) :: Nil)
+ } else if (!shj.left.isInstanceOf[Exchange]) {
+ shj.withNewChildren(
+ withShuffleExec(shj.requiredChildDistribution.head, shj.left) :: shj.right :: Nil)
+ } else if (!shj.right.isInstanceOf[Exchange]) {
+ shj.withNewChildren(
+ shj.left :: withShuffleExec(shj.requiredChildDistribution(1), shj.right) :: Nil)
+ } else {
+ shj
+ }
+ }
+
+ private def checkAndInsertShuffle(
+ distribution: Distribution, child: SparkPlan): SparkPlan = child match {
+ case SortExec(_, _, _: Exchange, _) =>
+ child
+ case SortExec(_, _, _: QueryStageExec, _) =>
+ child
+ case sort @ SortExec(_, _, agg: BaseAggregateExec, _) =>
+ sort.withNewChildren(withShuffleExec(distribution, agg) :: Nil)
+ case _ =>
+ withShuffleExec(distribution, child)
+ }
+
+ private def withShuffleExec(distribution: Distribution, child: SparkPlan): SparkPlan = {
+ val numPartitions = distribution.requiredNumPartitions
+ .getOrElse(conf.numShufflePartitions)
+ ShuffleExchangeExec(distribution.createPartitioning(numPartitions), child)
+ }
+}
+
+
+/**
+ * This rule split stage into two parts:
+ * 1. previous stage
+ * 2. final stage
+ * For final stage, we can inject extra config. It's useful if we use repartition to optimize
+ * small files that needs bigger shuffle partition size than previous.
+ *
+ * Let's say we have a query with 3 stages, then the logical machine like:
+ *
+ * Set/Reset Command -> cleanup previousStage config if user set the spark config.
+ * Query -> AQE -> stage1 -> preparation (use previousStage to overwrite spark config)
+ * -> AQE -> stage2 -> preparation (use spark config)
+ * -> AQE -> stage3 -> preparation (use finalStage config to overwrite spark config,
+ * store spark config to previousStage.)
+ *
+ * An example of the new finalStage config:
+ * `spark.sql.adaptive.advisoryPartitionSizeInBytes` ->
+ * `spark.sql.finalStage.adaptive.advisoryPartitionSizeInBytes`
+ */
+case class FinalStageConfigIsolation(session: SparkSession) extends Rule[SparkPlan] {
+ import FinalStageConfigIsolation._
+
+ override def apply(plan: SparkPlan): SparkPlan = {
+ // this rule has no meaning without AQE
+ if (!conf.getConf(FINAL_STAGE_CONFIG_ISOLATION) ||
+ !conf.getConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED)) {
+ return plan
+ }
+
+ if (isFinalStage(plan)) {
+ // set config for final stage
+ session.conf.getAll.filter(_._1.startsWith(FINAL_STAGE_CONFIG_PREFIX)).foreach {
+ case (k, v) =>
+ val sparkConfigKey = s"spark.sql.${k.substring(FINAL_STAGE_CONFIG_PREFIX.length)}"
+ val previousStageConfigKey =
+ s"$PREVIOUS_STAGE_CONFIG_PREFIX${k.substring(FINAL_STAGE_CONFIG_PREFIX.length)}"
+ // store the previous config only if we have not stored, to avoid some query only
+ // have one stage that will overwrite real config.
+ if (!session.sessionState.conf.contains(previousStageConfigKey)) {
+ val originalValue = if (session.conf.getOption(sparkConfigKey).isDefined) {
+ session.sessionState.conf.getConfString(sparkConfigKey)
+ } else {
+ // the default value of config is None, so we need to use a internal tag
+ INTERNAL_UNSET_CONFIG_TAG
+ }
+ logInfo(s"Store config: $sparkConfigKey to previousStage, " +
+ s"original value: $originalValue ")
+ session.sessionState.conf.setConfString(previousStageConfigKey, originalValue)
+ }
+ logInfo(s"For final stage: set $sparkConfigKey = $v.")
+ session.conf.set(sparkConfigKey, v)
+ }
+ } else {
+ // reset config for previous stage
+ session.conf.getAll.filter(_._1.startsWith(PREVIOUS_STAGE_CONFIG_PREFIX)).foreach {
+ case (k, v) =>
+ val sparkConfigKey = s"spark.sql.${k.substring(PREVIOUS_STAGE_CONFIG_PREFIX.length)}"
+ logInfo(s"For previous stage: set $sparkConfigKey = $v.")
+ if (v == INTERNAL_UNSET_CONFIG_TAG) {
+ session.conf.unset(sparkConfigKey)
+ } else {
+ session.conf.set(sparkConfigKey, v)
+ }
+ // unset config so that we do not need to reset configs for every previous stage
+ session.conf.unset(k)
+ }
+ }
+
+ plan
+ }
+
+ /**
+ * Currently formula depend on AQE in Spark 3.1.1, not sure it can work in future.
+ */
+ private def isFinalStage(plan: SparkPlan): Boolean = {
+ var shuffleNum = 0
+ var broadcastNum = 0
+ var queryStageNum = 0
+
+ def collectNumber(p: SparkPlan): SparkPlan = {
+ p transform {
+ case shuffle: ShuffleExchangeLike =>
+ shuffleNum += 1
+ shuffle
+
+ case broadcast: BroadcastExchangeLike =>
+ broadcastNum += 1
+ broadcast
+
+ // query stage is leaf node so we need to transform it manually
+ case queryStage: QueryStageExec =>
+ queryStageNum += 1
+ collectNumber(queryStage.plan)
+ queryStage
+ }
+ }
+ collectNumber(plan)
+
+ if (shuffleNum == 0) {
+ // we don not care about broadcast stage here since it won't change partition number.
+ true
+ } else if (shuffleNum + broadcastNum == queryStageNum) {
+ true
+ } else {
+ false
+ }
+ }
+}
+object FinalStageConfigIsolation {
+ final val SQL_PREFIX = "spark.sql."
+ final val FINAL_STAGE_CONFIG_PREFIX = "spark.sql.finalStage."
+ final val PREVIOUS_STAGE_CONFIG_PREFIX = "spark.sql.previousStage."
+ final val INTERNAL_UNSET_CONFIG_TAG = "__INTERNAL_UNSET_CONFIG_TAG__"
+
+ def getPreviousStageConfigKey(configKey: String): Option[String] = {
+ if (configKey.startsWith(SQL_PREFIX)) {
+ Some(s"$PREVIOUS_STAGE_CONFIG_PREFIX${configKey.substring(SQL_PREFIX.length)}")
+ } else {
+ None
+ }
+ }
+}
+
+case class FinalStageConfigIsolationCleanRule(session: SparkSession) extends Rule[LogicalPlan] {
+ import FinalStageConfigIsolation._
+
+ override def apply(plan: LogicalPlan): LogicalPlan = plan match {
+ case set @ SetCommand(Some((k, Some(_)))) if k.startsWith(SQL_PREFIX) =>
+ checkAndUnsetPreviousStageConfig(k)
+ set
+
+ case reset @ ResetCommand(Some(k)) if k.startsWith(SQL_PREFIX) =>
+ checkAndUnsetPreviousStageConfig(k)
+ reset
+
+ case other => other
+ }
+
+ private def checkAndUnsetPreviousStageConfig(configKey: String): Unit = {
+ getPreviousStageConfigKey(configKey).foreach { previousStageConfigKey =>
+ if (session.sessionState.conf.contains(previousStageConfigKey)) {
+ logInfo(s"For previous stage: unset $previousStageConfigKey")
+ session.conf.unset(previousStageConfigKey)
+ }
+ }
+ }
+}
diff --git a/dev/kyuubi-extension-spark_3.1/src/main/scala/org/apache/kyuubi/sql/KyuubiSQLConf.scala b/dev/kyuubi-extension-spark_3.1/src/main/scala/org/apache/kyuubi/sql/KyuubiSQLConf.scala
new file mode 100644
index 000000000..099be65e8
--- /dev/null
+++ b/dev/kyuubi-extension-spark_3.1/src/main/scala/org/apache/kyuubi/sql/KyuubiSQLConf.scala
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi.sql
+
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.internal.SQLConf._
+
+object KyuubiSQLConf {
+
+ val INSERT_REPARTITION_BEFORE_WRITE =
+ buildConf("spark.sql.optimizer.insertRepartitionBeforeWrite.enabled")
+ .doc("Add repartition node at the top of plan. A approach of merging small files.")
+ .version("1.2.0")
+ .booleanConf
+ .createWithDefault(true)
+
+ val INSERT_REPARTITION_NUM =
+ buildConf("spark.sql.optimizer.insertRepartitionNum")
+ .doc(s"The partition number if ${INSERT_REPARTITION_BEFORE_WRITE.key} is enabled. " +
+ s"If AQE is disabled, the default value is ${SQLConf.SHUFFLE_PARTITIONS}. " +
+ s"If AQE is enabled, the default value is none that means depend on AQE.")
+ .version("1.2.0")
+ .intConf
+ .createOptional
+
+ val DYNAMIC_PARTITION_INSERTION_REPARTITION_NUM =
+ buildConf("spark.sql.optimizer.dynamicPartitionInsertionRepartitionNum")
+ .doc(s"The partition number of each dynamic partition if " +
+ s"${INSERT_REPARTITION_BEFORE_WRITE.key} is enabled. " +
+ s"We will repartition by dynamic partition columns to reduce the small file but that " +
+ s"can cause data skew. This config is to extend the partition of dynamic " +
+ s"partition column to avoid skew but may generate some small files.")
+ .version("1.2.0")
+ .intConf
+ .createWithDefault(100)
+
+ val FORCE_SHUFFLE_BEFORE_JOIN =
+ buildConf("spark.sql.optimizer.forceShuffleBeforeJoin.enabled")
+ .doc("Ensure shuffle node exists before shuffled join (shj and smj) to make AQE " +
+ "`OptimizeSkewedJoin` works (extra shuffle, multi table join).")
+ .version("1.2.0")
+ .booleanConf
+ .createWithDefault(false)
+
+ val FINAL_STAGE_CONFIG_ISOLATION =
+ buildConf("spark.sql.optimizer.finalStageConfigIsolation.enabled")
+ .doc("If true, the final stage support use different config with previous stage. The final " +
+ "stage config key prefix should be `spark.sql.finalStage.`." +
+ "For example, the raw spark config: `spark.sql.adaptive.advisoryPartitionSizeInBytes`, " +
+ "then the final stage config should be: " +
+ "`spark.sql.finalStage.adaptive.advisoryPartitionSizeInBytes`.")
+ .version("1.2.0")
+ .booleanConf
+ .createWithDefault(false)
+}
diff --git a/dev/kyuubi-extension-spark_3.1/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala b/dev/kyuubi-extension-spark_3.1/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala
new file mode 100644
index 000000000..bcd2afd7b
--- /dev/null
+++ b/dev/kyuubi-extension-spark_3.1/src/main/scala/org/apache/kyuubi/sql/KyuubiSparkSQLExtension.scala
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kyuubi.sql
+
+import org.apache.spark.sql.SparkSessionExtensions
+
+// scalastyle:off line.size.limit
+/**
+ * Depend on Spark SQL Extension framework, we can use this extension follow steps
+ * 1. move this jar into $SPARK_HOME/jars
+ * 2. add config into `spark-defaults.conf`: `spark.sql.extensions=org.apache.kyuubi.sql.KyuubiSparkSQLExtension`
+ */
+// scalastyle:on line.size.limit
+class KyuubiSparkSQLExtension extends (SparkSessionExtensions => Unit) {
+ override def apply(extensions: SparkSessionExtensions): Unit = {
+ extensions.injectPostHocResolutionRule(RepartitionBeforeWrite)
+ extensions.injectPostHocResolutionRule(RepartitionBeforeWriteHive)
+ extensions.injectPostHocResolutionRule(FinalStageConfigIsolationCleanRule)
+ extensions.injectQueryStagePrepRule(_ => InsertShuffleNodeBeforeJoin)
+ extensions.injectQueryStagePrepRule(FinalStageConfigIsolation(_))
+ }
+}
diff --git a/dev/kyuubi-extension-spark_3.1/src/test/scala/org/apache/spark/sql/KyuubiExtensionSuite.scala b/dev/kyuubi-extension-spark_3.1/src/test/scala/org/apache/spark/sql/KyuubiExtensionSuite.scala
new file mode 100644
index 000000000..3a128e8f1
--- /dev/null
+++ b/dev/kyuubi-extension-spark_3.1/src/test/scala/org/apache/spark/sql/KyuubiExtensionSuite.scala
@@ -0,0 +1,364 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Multiply}
+import org.apache.spark.sql.catalyst.plans.logical.RepartitionByExpression
+import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, CustomShuffleReaderExec}
+import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, ShuffleExchangeLike}
+import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf}
+import org.apache.spark.sql.test.SQLTestData.TestData
+import org.apache.spark.sql.test.SQLTestUtils
+
+import org.apache.kyuubi.sql.{FinalStageConfigIsolation, KyuubiSQLConf}
+
+class KyuubiExtensionSuite extends QueryTest with SQLTestUtils with AdaptiveSparkPlanHelper {
+
+ var _spark: SparkSession = _
+ override def spark: SparkSession = _spark
+
+ protected override def beforeAll(): Unit = {
+ _spark = SparkSession.builder()
+ .master("local[1]")
+ .config(StaticSQLConf.SPARK_SESSION_EXTENSIONS.key,
+ "org.apache.kyuubi.sql.KyuubiSparkSQLExtension")
+ .config(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true")
+ .config("spark.hadoop.hive.exec.dynamic.partition.mode", "nonstrict")
+ .config("spark.ui.enabled", "false")
+ .enableHiveSupport()
+ .getOrCreate()
+ setupData()
+ super.beforeAll()
+ }
+
+ protected override def afterAll(): Unit = {
+ super.afterAll()
+ cleanupData()
+ if (_spark != null) {
+ _spark.stop()
+ }
+ }
+
+ private def setupData(): Unit = {
+ val self = _spark
+ import self.implicits._
+ spark.sparkContext.parallelize(
+ (1 to 100).map(i => TestData(i, i.toString)), 10)
+ .toDF("c1", "c2").createOrReplaceTempView("t1")
+ spark.sparkContext.parallelize(
+ (1 to 10).map(i => TestData(i, i.toString)), 5)
+ .toDF("c1", "c2").createOrReplaceTempView("t2")
+ spark.sparkContext.parallelize(
+ (1 to 50).map(i => TestData(i, i.toString)), 2)
+ .toDF("c1", "c2").createOrReplaceTempView("t3")
+ }
+
+ private def cleanupData(): Unit = {
+ spark.sql("DROP VIEW IF EXISTS t1")
+ spark.sql("DROP VIEW IF EXISTS t2")
+ spark.sql("DROP VIEW IF EXISTS t3")
+ }
+
+ test("check repartition exists") {
+ def check(df: DataFrame): Unit = {
+ assert(
+ df.queryExecution.analyzed.collect {
+ case r: RepartitionByExpression =>
+ assert(r.optNumPartitions ===
+ spark.sessionState.conf.getConf(KyuubiSQLConf.INSERT_REPARTITION_NUM))
+ r
+ }.size == 1
+ )
+ }
+
+ // It's better to set config explicitly in case of we change the default value.
+ withSQLConf(KyuubiSQLConf.INSERT_REPARTITION_BEFORE_WRITE.key -> "true") {
+ Seq("USING PARQUET", "").foreach { storage =>
+ withTable("tmp1") {
+ sql(s"CREATE TABLE tmp1 (c1 int) $storage PARTITIONED BY (c2 string)")
+ check(sql("INSERT INTO TABLE tmp1 PARTITION(c2='a') " +
+ "SELECT * FROM VALUES(1),(2) AS t(c1)"))
+ }
+
+ withTable("tmp1") {
+ sql(s"CREATE TABLE tmp1 (c1 int) $storage")
+ check(sql("INSERT INTO TABLE tmp1 SELECT * FROM VALUES(1),(2),(3) AS t(c1)"))
+ check(sql("INSERT INTO TABLE tmp1 " +
+ "SELECT * FROM VALUES(1),(2),(3) AS t(c1) DISTRIBUTE BY c1"))
+ }
+
+ withTable("tmp1") {
+ sql(s"CREATE TABLE tmp1 $storage AS SELECT * FROM VALUES(1),(2),(3) AS t(c1)")
+ }
+
+ withTable("tmp1") {
+ sql(s"CREATE TABLE tmp1 $storage PARTITIONED BY(c2) AS " +
+ s"SELECT * FROM VALUES(1, 'a'),(2, 'b') AS t(c1, c2)")
+ }
+ }
+ }
+ }
+
+ test("check repartition does not exists") {
+ def check(df: DataFrame): Unit = {
+ assert(
+ df.queryExecution.analyzed.collect {
+ case r: RepartitionByExpression => r
+ }.isEmpty
+ )
+ }
+
+ withSQLConf(KyuubiSQLConf.INSERT_REPARTITION_BEFORE_WRITE.key -> "true") {
+ // test no write command
+ check(sql("SELECT * FROM VALUES(1, 'a'),(2, 'b') AS t(c1, c2)"))
+ check(sql("SELECT count(*) FROM VALUES(1, 'a'),(2, 'b') AS t(c1, c2)"))
+
+ // test not supported plan
+ withTable("tmp1") {
+ sql(s"CREATE TABLE tmp1 (c1 int) PARTITIONED BY (c2 string)")
+ check(sql("INSERT INTO TABLE tmp1 PARTITION(c2) " +
+ "SELECT /*+ repartition(10) */ * FROM VALUES(1, 'a'),(2, 'b') AS t(c1, c2)"))
+ check(sql("INSERT INTO TABLE tmp1 PARTITION(c2) " +
+ "SELECT * FROM VALUES(1, 'a'),(2, 'b') AS t(c1, c2) ORDER BY c1"))
+ check(sql("INSERT INTO TABLE tmp1 PARTITION(c2) " +
+ "SELECT * FROM VALUES(1, 'a'),(2, 'b') AS t(c1, c2) LIMIT 10"))
+ }
+ }
+
+ withSQLConf(KyuubiSQLConf.INSERT_REPARTITION_BEFORE_WRITE.key -> "false") {
+ Seq("USING PARQUET", "").foreach { storage =>
+ withTable("tmp1") {
+ sql(s"CREATE TABLE tmp1 (c1 int) $storage PARTITIONED BY (c2 string)")
+ check(sql("INSERT INTO TABLE tmp1 PARTITION(c2) " +
+ "SELECT * FROM VALUES(1, 'a'),(2, 'b') AS t(c1, c2)"))
+ }
+
+ withTable("tmp1") {
+ sql(s"CREATE TABLE tmp1 (c1 int) $storage")
+ check(sql("INSERT INTO TABLE tmp1 SELECT * FROM VALUES(1),(2),(3) AS t(c1)"))
+ }
+
+ withTable("tmp1") {
+ sql(s"CREATE TABLE tmp1 $storage AS SELECT * FROM VALUES(1),(2),(3) AS t(c1)")
+ }
+
+ withTable("tmp1") {
+ sql(s"CREATE TABLE tmp1 $storage PARTITIONED BY(c2) AS " +
+ s"SELECT * FROM VALUES(1, 'a'),(2, 'b') AS t(c1, c2)")
+ }
+ }
+ }
+ }
+
+ test("test dynamic partition write") {
+ def checkRepartitionExpression(df: DataFrame): Unit = {
+ assert(df.queryExecution.analyzed.collect {
+ case r: RepartitionByExpression if r.partitionExpressions.size == 2 =>
+ assert(r.partitionExpressions.head.asInstanceOf[Attribute].name === "c2")
+ assert(r.partitionExpressions(1).asInstanceOf[Cast].child.asInstanceOf[Multiply]
+ .left.asInstanceOf[Attribute].name.startsWith("_nondeterministic"))
+ r
+ }.size == 1)
+ }
+
+ withSQLConf(KyuubiSQLConf.INSERT_REPARTITION_BEFORE_WRITE.key -> "true",
+ KyuubiSQLConf.DYNAMIC_PARTITION_INSERTION_REPARTITION_NUM.key -> "2") {
+ Seq("USING PARQUET", "").foreach { storage =>
+ withTable("tmp1") {
+ sql(s"CREATE TABLE tmp1 (c1 int) $storage PARTITIONED BY (c2 string)")
+ checkRepartitionExpression(sql("INSERT INTO TABLE tmp1 SELECT 1 as c1, 'a' as c2 "))
+ }
+
+ withTable("tmp1") {
+ checkRepartitionExpression(
+ sql("CREATE TABLE tmp1 PARTITIONED BY(C2) SELECT 1 as c1, 'a' as c2 "))
+ }
+ }
+ }
+ }
+
+ test("force shuffle before join") {
+ def checkShuffleNodeNum(sqlString: String, num: Int): Unit = {
+ var expectedResult: Seq[Row] = Seq.empty
+ withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") {
+ expectedResult = sql(sqlString).collect()
+ }
+ val df = sql(sqlString)
+ checkAnswer(df, expectedResult)
+ assert(
+ collect(df.queryExecution.executedPlan) {
+ case shuffle: ShuffleExchangeLike
+ if shuffle.shuffleOrigin == ENSURE_REQUIREMENTS => shuffle
+ }.size == num)
+ }
+
+ withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
+ KyuubiSQLConf.FORCE_SHUFFLE_BEFORE_JOIN.key -> "true") {
+ Seq("SHUFFLE_HASH", "MERGE").foreach { joinHint =>
+ // positive case
+ checkShuffleNodeNum(
+ s"""
+ |SELECT /*+ $joinHint(t2, t3) */ t1.c1, t1.c2, t2.c1, t3.c1 from t1
+ | JOIN t2 ON t1.c1 = t2.c1
+ | JOIN t3 ON t1.c1 = t3.c1
+ | """.stripMargin, 4)
+
+ // negative case
+ checkShuffleNodeNum(
+ s"""
+ |SELECT /*+ $joinHint(t2, t3) */ t1.c1, t1.c2, t2.c1, t3.c1 from t1
+ | JOIN t2 ON t1.c1 = t2.c1
+ | JOIN t3 ON t1.c2 = t3.c2
+ | """.stripMargin, 4)
+ }
+
+ checkShuffleNodeNum(
+ """
+ |SELECT t1.c1, t2.c1, t3.c2 from t1
+ | JOIN t2 ON t1.c1 = t2.c1
+ | JOIN (
+ | SELECT c2, count(*) FROM t1 GROUP BY c2
+ | ) t3 ON t1.c1 = t3.c2
+ | """.stripMargin, 5)
+
+ checkShuffleNodeNum(
+ """
+ |SELECT t1.c1, t2.c1, t3.c1 from t1
+ | JOIN t2 ON t1.c1 = t2.c1
+ | JOIN (
+ | SELECT c1, count(*) FROM t1 GROUP BY c1
+ | ) t3 ON t1.c1 = t3.c1
+ | """.stripMargin, 5)
+ }
+ }
+
+ test("final stage config set reset check") {
+ withSQLConf(KyuubiSQLConf.FINAL_STAGE_CONFIG_ISOLATION.key -> "true",
+ "spark.sql.finalStage.adaptive.coalescePartitions.minPartitionNum" -> "1",
+ "spark.sql.finalStage.adaptive.advisoryPartitionSizeInBytes" -> "100") {
+ // use loop to double check final stage config doesn't affect the sql query each other
+ (1 to 3).foreach { _ =>
+ sql("SELECT COUNT(*) FROM VALUES(1) as t(c)").collect()
+ assert(spark.sessionState.conf.getConfString(
+ "spark.sql.previousStage.adaptive.coalescePartitions.minPartitionNum") ===
+ FinalStageConfigIsolation.INTERNAL_UNSET_CONFIG_TAG)
+ assert(spark.sessionState.conf.getConfString(
+ "spark.sql.adaptive.coalescePartitions.minPartitionNum") ===
+ "1")
+ assert(spark.sessionState.conf.getConfString(
+ "spark.sql.finalStage.adaptive.coalescePartitions.minPartitionNum") ===
+ "1")
+
+ // 64MB
+ assert(spark.sessionState.conf.getConfString(
+ "spark.sql.previousStage.adaptive.advisoryPartitionSizeInBytes") ===
+ "67108864b")
+ assert(spark.sessionState.conf.getConfString(
+ "spark.sql.adaptive.advisoryPartitionSizeInBytes") ===
+ "100")
+ assert(spark.sessionState.conf.getConfString(
+ "spark.sql.finalStage.adaptive.advisoryPartitionSizeInBytes") ===
+ "100")
+ }
+
+ sql("SET spark.sql.adaptive.advisoryPartitionSizeInBytes=1")
+ assert(spark.sessionState.conf.getConfString(
+ "spark.sql.adaptive.advisoryPartitionSizeInBytes") ===
+ "1")
+ assert(!spark.sessionState.conf.contains(
+ "spark.sql.previousStage.adaptive.advisoryPartitionSizeInBytes"))
+
+ sql("SET a=1")
+ assert(spark.sessionState.conf.getConfString("a") === "1")
+
+ sql("RESET spark.sql.adaptive.coalescePartitions.minPartitionNum")
+ assert(!spark.sessionState.conf.contains(
+ "spark.sql.adaptive.coalescePartitions.minPartitionNum"))
+ assert(!spark.sessionState.conf.contains(
+ "spark.sql.previousStage.adaptive.coalescePartitions.minPartitionNum"))
+
+ sql("RESET a")
+ assert(!spark.sessionState.conf.contains("a"))
+ }
+ }
+
+ test("final stage config isolation") {
+ def checkPartitionNum(
+ sqlString: String, previousPartitionNum: Int, finalPartitionNum: Int): Unit = {
+ val df = sql(sqlString)
+ df.collect()
+ val shuffleReaders = collect(df.queryExecution.executedPlan) {
+ case customShuffleReader: CustomShuffleReaderExec => customShuffleReader
+ }
+ assert(shuffleReaders.nonEmpty)
+ shuffleReaders.tail.foreach { readers =>
+ assert(readers.partitionSpecs.length === previousPartitionNum)
+ }
+ assert(shuffleReaders.head.partitionSpecs.length === finalPartitionNum)
+ assert(df.rdd.partitions.length === finalPartitionNum)
+ }
+
+ withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
+ SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1",
+ SQLConf.SHUFFLE_PARTITIONS.key -> "3",
+ KyuubiSQLConf.FINAL_STAGE_CONFIG_ISOLATION.key -> "true",
+ "spark.sql.adaptive.advisoryPartitionSizeInBytes" -> "1",
+ "spark.sql.finalStage.adaptive.advisoryPartitionSizeInBytes" -> "10000000") {
+
+ // use loop to double check final stage config doesn't affect the sql query each other
+ (1 to 3).foreach { _ =>
+ checkPartitionNum(
+ "SELECT c1, count(*) FROM t1 GROUP BY c1",
+ 1,
+ 1)
+
+ checkPartitionNum(
+ "SELECT c2, count(*) FROM (SELECT c1, count(*) as c2 FROM t1 GROUP BY c1) GROUP BY c2",
+ 3,
+ 1)
+
+ checkPartitionNum(
+ "SELECT t1.c1, count(*) FROM t1 JOIN t2 ON t1.c2 = t2.c2 GROUP BY t1.c1",
+ 3,
+ 1)
+
+ checkPartitionNum(
+ """
+ | SELECT /*+ REPARTITION */
+ | t1.c1, count(*) FROM t1
+ | JOIN t2 ON t1.c2 = t2.c2
+ | JOIN t3 ON t1.c1 = t3.c1
+ | GROUP BY t1.c1
+ |""".stripMargin,
+ 3,
+ 1)
+
+ // one shuffle reader
+ checkPartitionNum(
+ """
+ | SELECT /*+ BROADCAST(t1) */
+ | t1.c1, t2.c2 FROM t1
+ | JOIN t2 ON t1.c2 = t2.c2
+ | DISTRIBUTE BY c1
+ |""".stripMargin,
+ 1,
+ 1)
+ }
+ }
+ }
+}
diff --git a/kyuubi-assembly/pom.xml b/kyuubi-assembly/pom.xml
index 1e7257edd..44b60fe5c 100644
--- a/kyuubi-assembly/pom.xml
+++ b/kyuubi-assembly/pom.xml
@@ -101,4 +101,16 @@
-
\ No newline at end of file
+
+
+ kyuubi-extension-spark_3.1
+
+
+ org.apache.kyuubi
+ kyuubi-extension-spark_3.1
+ ${project.version}
+
+
+
+
+
diff --git a/pom.xml b/pom.xml
index 78407a40a..2747fba9b 100644
--- a/pom.xml
+++ b/pom.xml
@@ -1584,5 +1584,15 @@
true
+
+
+ kyuubi-extension-spark_3.1
+
+ 3.1.1
+
+
+ dev/kyuubi-extension-spark_3.1
+
+