diff --git a/dev/kyuubi-extension-spark-3-1/src/test/scala/org/apache/spark/sql/RepartitionBeforeWritingSuite.scala b/dev/kyuubi-extension-spark-3-1/src/test/scala/org/apache/spark/sql/RepartitionBeforeWritingSuite.scala index 8b07f543e..f978623af 100644 --- a/dev/kyuubi-extension-spark-3-1/src/test/scala/org/apache/spark/sql/RepartitionBeforeWritingSuite.scala +++ b/dev/kyuubi-extension-spark-3-1/src/test/scala/org/apache/spark/sql/RepartitionBeforeWritingSuite.scala @@ -26,14 +26,14 @@ import org.apache.kyuubi.sql.KyuubiSQLConf class RepartitionBeforeWritingSuite extends KyuubiSparkSQLExtensionTest { test("check repartition exists") { - def check(df: DataFrame): Unit = { + def check(df: DataFrame, expectedRepartitionNum: Int = 1): Unit = { assert( df.queryExecution.analyzed.collect { case r: RepartitionByExpression => assert(r.optNumPartitions === spark.sessionState.conf.getConf(KyuubiSQLConf.INSERT_REPARTITION_NUM)) r - }.size == 1) + }.size == expectedRepartitionNum) } // It's better to set config explicitly in case of we change the default value. @@ -45,6 +45,18 @@ class RepartitionBeforeWritingSuite extends KyuubiSparkSQLExtensionTest { "SELECT * FROM VALUES(1),(2) AS t(c1)")) } + withTable("tmp1", "tmp2") { + sql(s"CREATE TABLE tmp1 (c1 int) $storage PARTITIONED BY (c2 string)") + sql(s"CREATE TABLE tmp2 (c1 int) $storage PARTITIONED BY (c2 string)") + check( + sql( + """FROM VALUES(1),(2) AS t(c1) + |INSERT INTO TABLE tmp1 PARTITION(c2='a') SELECT * + |INSERT INTO TABLE tmp2 PARTITION(c2='a') SELECT * + |""".stripMargin), + 2) + } + withTable("tmp1") { sql(s"CREATE TABLE tmp1 (c1 int) $storage") check(sql("INSERT INTO TABLE tmp1 SELECT * FROM VALUES(1),(2),(3) AS t(c1)")) @@ -52,6 +64,25 @@ class RepartitionBeforeWritingSuite extends KyuubiSparkSQLExtensionTest { "SELECT * FROM VALUES(1),(2),(3) AS t(c1) DISTRIBUTE BY c1")) } + withTable("tmp1", "tmp2") { + sql(s"CREATE TABLE tmp1 (c1 int) $storage") + sql(s"CREATE TABLE tmp2 (c1 int) $storage") + check( + sql( + """FROM VALUES(1),(2),(3) + |INSERT INTO TABLE tmp1 SELECT * + |INSERT INTO TABLE tmp2 SELECT * + |""".stripMargin), + 2) + check( + sql( + """FROM (SELECT * FROM VALUES(1),(2),(3) AS t(c1) DISTRIBUTE BY c1) + |INSERT INTO TABLE tmp1 SELECT * + |INSERT INTO TABLE tmp2 SELECT * + |""".stripMargin), + 2) + } + withTable("tmp1") { sql(s"CREATE TABLE tmp1 $storage AS SELECT * FROM VALUES(1),(2),(3) AS t(c1)") } diff --git a/dev/kyuubi-extension-spark-3-2/src/test/scala/org/apache/spark/sql/RebalanceBeforeWritingSuite.scala b/dev/kyuubi-extension-spark-3-2/src/test/scala/org/apache/spark/sql/RebalanceBeforeWritingSuite.scala index f40b43213..f1a27cdb8 100644 --- a/dev/kyuubi-extension-spark-3-2/src/test/scala/org/apache/spark/sql/RebalanceBeforeWritingSuite.scala +++ b/dev/kyuubi-extension-spark-3-2/src/test/scala/org/apache/spark/sql/RebalanceBeforeWritingSuite.scala @@ -26,11 +26,11 @@ import org.apache.kyuubi.sql.KyuubiSQLConf class RebalanceBeforeWritingSuite extends KyuubiSparkSQLExtensionTest { test("check rebalance exists") { - def check(df: DataFrame): Unit = { + def check(df: DataFrame, expectedRebalanceNum: Int = 1): Unit = { assert( df.queryExecution.analyzed.collect { case r: RebalancePartitions => r - }.size == 1) + }.size == expectedRebalanceNum) } // It's better to set config explicitly in case of we change the default value. @@ -42,11 +42,35 @@ class RebalanceBeforeWritingSuite extends KyuubiSparkSQLExtensionTest { "SELECT * FROM VALUES(1),(2) AS t(c1)")) } + withTable("tmp1", "tmp2") { + sql(s"CREATE TABLE tmp1 (c1 int) $storage PARTITIONED BY (c2 string)") + sql(s"CREATE TABLE tmp2 (c1 int) $storage PARTITIONED BY (c2 string)") + check( + sql( + """FROM VALUES(1),(2) + |INSERT INTO TABLE tmp1 PARTITION(c2='a') SELECT * + |INSERT INTO TABLE tmp2 PARTITION(c2='a') SELECT * + |""".stripMargin), + 2) + } + withTable("tmp1") { sql(s"CREATE TABLE tmp1 (c1 int) $storage") check(sql("INSERT INTO TABLE tmp1 SELECT * FROM VALUES(1),(2),(3) AS t(c1)")) } + withTable("tmp1", "tmp2") { + sql(s"CREATE TABLE tmp1 (c1 int) $storage") + sql(s"CREATE TABLE tmp2 (c1 int) $storage") + check( + sql( + """FROM VALUES(1),(2),(3) + |INSERT INTO TABLE tmp1 SELECT * + |INSERT INTO TABLE tmp2 SELECT * + |""".stripMargin), + 2) + } + withTable("tmp1") { sql(s"CREATE TABLE tmp1 $storage AS SELECT * FROM VALUES(1),(2),(3) AS t(c1)") } diff --git a/dev/kyuubi-extension-spark-common/src/main/scala/org/apache/kyuubi/sql/RepartitionBeforeWritingBase.scala b/dev/kyuubi-extension-spark-common/src/main/scala/org/apache/kyuubi/sql/RepartitionBeforeWritingBase.scala index b987a72aa..33aff09a4 100644 --- a/dev/kyuubi-extension-spark-common/src/main/scala/org/apache/kyuubi/sql/RepartitionBeforeWritingBase.scala +++ b/dev/kyuubi-extension-spark-common/src/main/scala/org/apache/kyuubi/sql/RepartitionBeforeWritingBase.scala @@ -59,6 +59,9 @@ abstract class RepartitionBeforeWritingDatasourceBase extends RepartitionBuilder query.output.filter(attr => table.partitionColumnNames.contains(attr.name)) c.copy(query = buildRepartition(dynamicPartitionColumns, query)) + case u @ Union(children, _, _) => + u.copy(children = children.map(addRepartition)) + case _ => plan } } @@ -98,6 +101,9 @@ abstract class RepartitionBeforeWritingHiveBase extends RepartitionBuilder { query.output.filter(attr => table.partitionColumnNames.contains(attr.name)) c.copy(query = buildRepartition(dynamicPartitionColumns, query)) + case u @ Union(children, _, _) => + u.copy(children = children.map(addRepartition)) + case _ => plan } } @@ -105,6 +111,7 @@ abstract class RepartitionBeforeWritingHiveBase extends RepartitionBuilder { trait RepartitionBeforeWriteHelper { def canInsertRepartitionByExpression(plan: LogicalPlan): Boolean = plan match { case Project(_, child) => canInsertRepartitionByExpression(child) + case SubqueryAlias(_, child) => canInsertRepartitionByExpression(child) case Limit(_, _) => false case _: Sort => false case _: RepartitionByExpression => false