[KYUUBI #6989] Calculate expected join partitions based on scanned table size

### Why are the changes needed?

Avoid unstable test case caused by table size changes, this is likely to happen when upgrading Parquet/ORC/Spark.

### How was this patch tested?

unit test

### Was this patch authored or co-authored using generative AI tooling?

No

Closes #6989 from wForget/minor_fix.

Closes #6989

9cdd36973 [wforget] address comments
f79fcca0d [wforget] Calculate expected join partitions based on scanned table size

Authored-by: wforget <643348094@qq.com>
Signed-off-by: Cheng Pan <chengpan@apache.org>
This commit is contained in:
wforget 2025-03-18 20:23:35 +08:00 committed by Cheng Pan
parent 86d7b3b348
commit cb36e748ed
No known key found for this signature in database
GPG Key ID: 8001952629BCC75D

View File

@ -16,6 +16,7 @@
*/
package org.apache.spark.sql
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.execution.{CommandResultExec, SparkPlan}
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, ShuffleExchangeExec}
@ -51,6 +52,15 @@ class DynamicShufflePartitionsSuite extends KyuubiSparkSQLExtensionTest {
sql("ANALYZE TABLE table2 COMPUTE STATISTICS")
val initialPartitionNum: Int = 2
val advisoryPartitionSizeInBytes: Long = 500
val t1Size = spark.sessionState.catalog.getTableMetadata(TableIdentifier("table1"))
.stats.get.sizeInBytes.toLong
val t2Size = spark.sessionState.catalog.getTableMetadata(TableIdentifier("table2"))
.stats.get.sizeInBytes.toLong
val scanSize = t1Size + t2Size
val expectedJoinPartitionNum = Math.ceil(scanSize.toDouble / advisoryPartitionSizeInBytes)
Seq(false, true).foreach { dynamicShufflePartitions =>
val maxDynamicShufflePartitions = if (dynamicShufflePartitions) {
Seq(8, 2000)
@ -63,34 +73,37 @@ class DynamicShufflePartitionsSuite extends KyuubiSparkSQLExtensionTest {
DYNAMIC_SHUFFLE_PARTITIONS_MAX_NUM.key -> maxDynamicShufflePartitionNum.toString,
AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
COALESCE_PARTITIONS_INITIAL_PARTITION_NUM.key -> initialPartitionNum.toString,
ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "500") {
ADVISORY_PARTITION_SIZE_IN_BYTES.key -> advisoryPartitionSizeInBytes.toString) {
val df = sql("insert overwrite table3 " +
" select a.c1 as c1, b.c2 as c2 from table1 a join table2 b on a.c1 = b.c1")
val exchanges = collectExchanges(df.queryExecution.executedPlan)
val (joinExchanges, rebalanceExchanges) = exchanges
.partition(_.shuffleOrigin == ENSURE_REQUIREMENTS)
// table scan size: 7369 3287
assert(joinExchanges.size == 2)
if (dynamicShufflePartitions) {
joinExchanges.foreach(e =>
assert(e.outputPartitioning.numPartitions
== Math.min(22, maxDynamicShufflePartitionNum)))
joinExchanges.foreach { e =>
val expected = Math.min(expectedJoinPartitionNum, maxDynamicShufflePartitionNum)
assert(e.outputPartitioning.numPartitions == expected)
}
} else {
joinExchanges.foreach(e =>
assert(e.outputPartitioning.numPartitions == initialPartitionNum))
joinExchanges.foreach { e =>
assert(e.outputPartitioning.numPartitions == initialPartitionNum)
}
}
assert(rebalanceExchanges.size == 1)
if (dynamicShufflePartitions) {
if (maxDynamicShufflePartitionNum == 8) {
// shuffle query size: 1424 451
assert(rebalanceExchanges.head.outputPartitioning.numPartitions ==
Math.min(4, maxDynamicShufflePartitionNum))
// shuffle query size: 1424 451 (the size may change with spark version updates
// or shuffle configuration updates)
val expected = Math.min(4, maxDynamicShufflePartitionNum)
assert(rebalanceExchanges.head.outputPartitioning.numPartitions == expected)
} else {
// shuffle query size: 2057 664
assert(rebalanceExchanges.head.outputPartitioning.numPartitions ==
Math.min(6, maxDynamicShufflePartitionNum))
// shuffle query size: 2057 664 (the size may change with spark version updates
// or shuffle configuration updates)
val expected = Math.min(6, maxDynamicShufflePartitionNum)
assert(rebalanceExchanges.head.outputPartitioning.numPartitions == expected)
}
} else {
assert(
@ -104,7 +117,7 @@ class DynamicShufflePartitionsSuite extends KyuubiSparkSQLExtensionTest {
DYNAMIC_SHUFFLE_PARTITIONS_MAX_NUM.key -> maxDynamicShufflePartitionNum.toString,
AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
COALESCE_PARTITIONS_INITIAL_PARTITION_NUM.key -> initialPartitionNum.toString,
ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "500",
ADVISORY_PARTITION_SIZE_IN_BYTES.key -> advisoryPartitionSizeInBytes.toString,
CONVERT_METASTORE_PARQUET.key -> "false") {
val df = sql("insert overwrite table3 " +
" select a.c1 as c1, b.c2 as c2 from table1 a join table2 b on a.c1 = b.c1")
@ -112,21 +125,23 @@ class DynamicShufflePartitionsSuite extends KyuubiSparkSQLExtensionTest {
val exchanges = collectExchanges(df.queryExecution.executedPlan)
val (joinExchanges, rebalanceExchanges) = exchanges
.partition(_.shuffleOrigin == ENSURE_REQUIREMENTS)
// table scan size: 7369 3287
assert(joinExchanges.size == 2)
if (dynamicShufflePartitions) {
joinExchanges.foreach(e =>
assert(e.outputPartitioning.numPartitions ==
Math.min(22, maxDynamicShufflePartitionNum)))
joinExchanges.foreach { e =>
val expected = Math.min(expectedJoinPartitionNum, maxDynamicShufflePartitionNum)
assert(e.outputPartitioning.numPartitions == expected)
}
} else {
joinExchanges.foreach(e =>
assert(e.outputPartitioning.numPartitions == initialPartitionNum))
joinExchanges.foreach { e =>
assert(e.outputPartitioning.numPartitions == initialPartitionNum)
}
}
// shuffle query size: 5154 720
// shuffle query size: 5154 720 (the size may change with spark version updates
// or shuffle configuration updates)
assert(rebalanceExchanges.size == 1)
if (dynamicShufflePartitions) {
assert(rebalanceExchanges.head.outputPartitioning.numPartitions
== Math.min(12, maxDynamicShufflePartitionNum))
val expected = Math.min(12, maxDynamicShufflePartitionNum)
assert(rebalanceExchanges.head.outputPartitioning.numPartitions == expected)
} else {
assert(rebalanceExchanges.head.outputPartitioning.numPartitions ==
initialPartitionNum)