diff --git a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_2.patch b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_2.patch index f02902d64..257a8fe83 100644 --- a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_2.patch +++ b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_2.patch @@ -95,7 +95,7 @@ index 00000000000..5e190c512df + ConfigBuilder("spark.celeborn.client.spark.stageRerun.enabled") + .withAlternative("spark.celeborn.client.spark.fetch.throwsFetchFailure") + .booleanConf -+ .createWithDefault(false) ++ .createWithDefault(true) + + private val celebornOptimizeSkewedPartitionReadEnabled = new AtomicBoolean() + private val stageRerunEnabled = new AtomicBoolean() diff --git a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_3.patch b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_3.patch index 2b0aecf0b..ef0cdc318 100644 --- a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_3.patch +++ b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_3.patch @@ -95,7 +95,7 @@ index 00000000000..5e190c512df + ConfigBuilder("spark.celeborn.client.spark.stageRerun.enabled") + .withAlternative("spark.celeborn.client.spark.fetch.throwsFetchFailure") + .booleanConf -+ .createWithDefault(false) ++ .createWithDefault(true) + + private val celebornOptimizeSkewedPartitionReadEnabled = new AtomicBoolean() + private val stageRerunEnabled = new AtomicBoolean() diff --git a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_4.patch b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_4.patch index 03b26e07a..ddb33ad4f 100644 --- a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_4.patch +++ b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_4.patch @@ -95,7 +95,7 @@ index 00000000000..5e190c512df + ConfigBuilder("spark.celeborn.client.spark.stageRerun.enabled") + .withAlternative("spark.celeborn.client.spark.fetch.throwsFetchFailure") + .booleanConf -+ .createWithDefault(false) ++ .createWithDefault(true) + + private val celebornOptimizeSkewedPartitionReadEnabled = new AtomicBoolean() + private val stageRerunEnabled = new AtomicBoolean() diff --git a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_5.patch b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_5.patch index b9fc88b79..67f53889f 100644 --- a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_5.patch +++ b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_5.patch @@ -95,7 +95,7 @@ index 00000000000..5e190c512df + ConfigBuilder("spark.celeborn.client.spark.stageRerun.enabled") + .withAlternative("spark.celeborn.client.spark.fetch.throwsFetchFailure") + .booleanConf -+ .createWithDefault(false) ++ .createWithDefault(true) + + private val celebornOptimizeSkewedPartitionReadEnabled = new AtomicBoolean() + private val stageRerunEnabled = new AtomicBoolean() diff --git a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_5_6.patch b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_5_6.patch new file mode 100644 index 000000000..362af6b08 --- /dev/null +++ b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_5_6.patch @@ -0,0 +1,408 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +index 9a7a3b0c0e7..543423dadd9 100644 +--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala ++++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +@@ -34,6 +34,7 @@ import org.apache.commons.io.output.{ByteArrayOutputStream => ApacheByteArrayOut + import org.roaringbitmap.RoaringBitmap + + import org.apache.spark.broadcast.{Broadcast, BroadcastManager} ++import org.apache.spark.celeborn.CelebornShuffleState + import org.apache.spark.internal.Logging + import org.apache.spark.internal.config._ + import org.apache.spark.io.CompressionCodec +@@ -916,6 +917,7 @@ private[spark] class MapOutputTrackerMaster( + shuffleStatus.invalidateSerializedMergeOutputStatusCache() + } + } ++ CelebornShuffleState.unregisterCelebornSkewedShuffle(shuffleId) + } + + /** +diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala +index edad91a0c6f..76b377729a0 100644 +--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala ++++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala +@@ -32,6 +32,7 @@ import org.apache.hadoop.conf.Configuration + import org.apache.spark.annotation.DeveloperApi + import org.apache.spark.api.python.PythonWorkerFactory + import org.apache.spark.broadcast.BroadcastManager ++import org.apache.spark.celeborn.CelebornShuffleState + import org.apache.spark.executor.ExecutorBackend + import org.apache.spark.internal.{config, Logging} + import org.apache.spark.internal.config._ +@@ -419,6 +420,7 @@ object SparkEnv extends Logging { + if (isDriver) { + val sparkFilesDir = Utils.createTempDir(Utils.getLocalDir(conf), "userFiles").getAbsolutePath + envInstance.driverTmpDir = Some(sparkFilesDir) ++ CelebornShuffleState.init(envInstance) + } + + envInstance +diff --git a/core/src/main/scala/org/apache/spark/celeborn/CelebornShuffleState.scala b/core/src/main/scala/org/apache/spark/celeborn/CelebornShuffleState.scala +new file mode 100644 +index 00000000000..5e190c512df +--- /dev/null ++++ b/core/src/main/scala/org/apache/spark/celeborn/CelebornShuffleState.scala +@@ -0,0 +1,75 @@ ++/* ++ * Licensed to the Apache Software Foundation (ASF) under one or more ++ * contributor license agreements. See the NOTICE file distributed with ++ * this work for additional information regarding copyright ownership. ++ * The ASF licenses this file to You under the Apache License, Version 2.0 ++ * (the "License"); you may not use this file except in compliance with ++ * the License. You may obtain a copy of the License at ++ * ++ * http://www.apache.org/licenses/LICENSE-2.0 ++ * ++ * Unless required by applicable law or agreed to in writing, software ++ * distributed under the License is distributed on an "AS IS" BASIS, ++ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++ * See the License for the specific language governing permissions and ++ * limitations under the License. ++ */ ++ ++package org.apache.spark.celeborn ++ ++import java.util.concurrent.ConcurrentHashMap ++import java.util.concurrent.atomic.AtomicBoolean ++ ++import org.apache.spark.SparkEnv ++import org.apache.spark.internal.config.ConfigBuilder ++ ++object CelebornShuffleState { ++ ++ private val CELEBORN_CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ = ++ ConfigBuilder("spark.celeborn.client.adaptive.optimizeSkewedPartitionRead.enabled") ++ .booleanConf ++ .createWithDefault(false) ++ ++ private val CELEBORN_STAGE_RERUN_ENABLED = ++ ConfigBuilder("spark.celeborn.client.spark.stageRerun.enabled") ++ .withAlternative("spark.celeborn.client.spark.fetch.throwsFetchFailure") ++ .booleanConf ++ .createWithDefault(true) ++ ++ private val celebornOptimizeSkewedPartitionReadEnabled = new AtomicBoolean() ++ private val stageRerunEnabled = new AtomicBoolean() ++ private val skewShuffleIds = ConcurrentHashMap.newKeySet[Int]() ++ ++ // call this from SparkEnv.create ++ def init(env: SparkEnv): Unit = { ++ // cleanup existing state (if required) - and initialize ++ skewShuffleIds.clear() ++ ++ // use env.conf for all initialization, and not SQLConf ++ celebornOptimizeSkewedPartitionReadEnabled.set( ++ env.conf.get("spark.shuffle.manager", "sort").contains("celeborn") && ++ env.conf.get(CELEBORN_CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ)) ++ stageRerunEnabled.set(env.conf.get(CELEBORN_STAGE_RERUN_ENABLED)) ++ } ++ ++ def unregisterCelebornSkewedShuffle(shuffleId: Int): Unit = { ++ skewShuffleIds.remove(shuffleId) ++ } ++ ++ def registerCelebornSkewedShuffle(shuffleId: Int): Unit = { ++ skewShuffleIds.add(shuffleId) ++ } ++ ++ def isCelebornSkewedShuffle(shuffleId: Int): Boolean = { ++ skewShuffleIds.contains(shuffleId) ++ } ++ ++ def celebornAdaptiveOptimizeSkewedPartitionReadEnabled: Boolean = { ++ celebornOptimizeSkewedPartitionReadEnabled.get() ++ } ++ ++ def celebornStageRerunEnabled: Boolean = { ++ stageRerunEnabled.get() ++ } ++ ++} +diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +index 5ae29a5cd02..b54265fa6e7 100644 +--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala ++++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +@@ -34,6 +34,7 @@ import com.google.common.util.concurrent.{Futures, SettableFuture} + + import org.apache.spark._ + import org.apache.spark.broadcast.Broadcast ++import org.apache.spark.celeborn.CelebornShuffleState + import org.apache.spark.errors.SparkCoreErrors + import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics} + import org.apache.spark.internal.Logging +@@ -1486,7 +1487,10 @@ private[spark] class DAGScheduler( + // The operation here can make sure for the partially completed intermediate stage, + // `findMissingPartitions()` returns all partitions every time. + stage match { +- case sms: ShuffleMapStage if stage.isIndeterminate && !sms.isAvailable => ++ case sms: ShuffleMapStage if (stage.isIndeterminate || ++ CelebornShuffleState.isCelebornSkewedShuffle(sms.shuffleDep.shuffleId)) && !sms.isAvailable => ++ logInfo(s"Unregistering shuffle output for stage ${stage.id}" + ++ s" shuffle ${sms.shuffleDep.shuffleId}") + // already executed at least once + if (sms.getNextAttemptId > 0) { + // While we previously validated possible rollbacks during the handling of a FetchFailure, +@@ -1495,7 +1499,7 @@ private[spark] class DAGScheduler( + // loss. Moreover, because this check occurs later in the process, if a result stage task + // has successfully completed, we can detect this and abort the job, as rolling back a + // result stage is not possible. +- val stagesToRollback = collectSucceedingStages(sms) ++ val stagesToRollback = collectSucceedingStages(sms, CelebornShuffleState.isCelebornSkewedShuffle(sms.shuffleDep.shuffleId)) + abortStageWithInvalidRollBack(stagesToRollback) + // stages which cannot be rolled back were aborted which leads to removing the + // the dependant job(s) from the active jobs set +@@ -1880,7 +1884,18 @@ private[spark] class DAGScheduler( + // tasks complete, they still count and we can mark the corresponding partitions as + // finished if the stage is determinate. Here we notify the task scheduler to skip running + // tasks for the same partition to save resource. +- if (!stage.isIndeterminate && task.stageAttemptId < stage.latestInfo.attemptNumber()) { ++ // finished. Here we notify the task scheduler to skip running tasks for the same partition, ++ // to save resource. ++ // CELEBORN-1856, if stage is indeterminate or shuffleMapStage is skewed and read by ++ // Celeborn chunkOffsets, should not call notifyPartitionCompletion, otherwise will ++ // skip running tasks for the same partition because TaskSetManager.dequeueTaskFromList ++ // will skip running task which TaskSetManager.successful(taskIndex) is true. ++ // TODO: ResultStage has result commit and other issues ++ val isCelebornShuffleIndeterminate = stage.isInstanceOf[ShuffleMapStage] && ++ CelebornShuffleState.isCelebornSkewedShuffle( ++ stage.asInstanceOf[ShuffleMapStage].shuffleDep.shuffleId) ++ if (!stage.isIndeterminate && task.stageAttemptId < stage.latestInfo.attemptNumber() ++ && !isCelebornShuffleIndeterminate) { + taskScheduler.notifyPartitionCompletion(stageId, task.partitionId) + } + +@@ -1935,7 +1950,7 @@ private[spark] class DAGScheduler( + case smt: ShuffleMapTask => + val shuffleStage = stage.asInstanceOf[ShuffleMapStage] + // Ignore task completion for old attempt of indeterminate stage +- val ignoreIndeterminate = stage.isIndeterminate && ++ val ignoreIndeterminate = (stage.isIndeterminate || isCelebornShuffleIndeterminate) && + task.stageAttemptId < stage.latestInfo.attemptNumber() + if (!ignoreIndeterminate) { + shuffleStage.pendingPartitions -= task.partitionId +@@ -1970,6 +1985,14 @@ private[spark] class DAGScheduler( + val failedStage = stageIdToStage(task.stageId) + val mapStage = shuffleIdToMapStage(shuffleId) + ++ // In Celeborn-1139 we support read skew partition by Celeborn chunkOffsets, ++ // it will make shuffle be indeterminate, so abort the ResultStage directly here. ++ if (failedStage.isInstanceOf[ResultStage] && CelebornShuffleState.isCelebornSkewedShuffle(shuffleId)) { ++ val shuffleFailedReason = s"ResultStage:${failedStage.id} fetch failed and the shuffle:$shuffleId " + ++ s"is skewed partition read by Celeborn, so abort it." ++ abortStage(failedStage, shuffleFailedReason, None) ++ } ++ + if (failedStage.latestInfo.attemptNumber != task.stageAttemptId) { + logInfo(s"Ignoring fetch failure from $task as it's from $failedStage attempt" + + s" ${task.stageAttemptId} and there is a more recent attempt for that stage " + +@@ -2068,8 +2091,9 @@ private[spark] class DAGScheduler( + // Note that, if map stage is UNORDERED, we are fine. The shuffle partitioner is + // guaranteed to be determinate, so the input data of the reducers will not change + // even if the map tasks are re-tried. +- if (mapStage.isIndeterminate) { +- val stagesToRollback = collectSucceedingStages(mapStage) ++ val isCelebornShuffleIndeterminate = CelebornShuffleState.isCelebornSkewedShuffle(shuffleId) ++ if (mapStage.isIndeterminate || isCelebornShuffleIndeterminate) { ++ val stagesToRollback = collectSucceedingStages(mapStage, isCelebornShuffleIndeterminate) + val rollingBackStages = abortStageWithInvalidRollBack(stagesToRollback) + logInfo(s"The shuffle map stage $mapStage with indeterminate output was failed, " + + s"we will roll back and rerun below stages which include itself and all its " + +@@ -2233,7 +2257,7 @@ private[spark] class DAGScheduler( + } + } + +- private def collectSucceedingStages(mapStage: ShuffleMapStage): HashSet[Stage] = { ++ private def collectSucceedingStages(mapStage: ShuffleMapStage, isCelebornShuffleIndeterminate: Boolean): HashSet[Stage] = { + // TODO: perhaps materialize this if we are going to compute it often enough ? + // It's a little tricky to find all the succeeding stages of `mapStage`, because + // each stage only know its parents not children. Here we traverse the stages from +@@ -2245,7 +2269,17 @@ private[spark] class DAGScheduler( + + def collectSucceedingStagesInternal(stageChain: List[Stage]): Unit = { + if (succeedingStages.contains(stageChain.head)) { +- stageChain.drop(1).foreach(s => succeedingStages += s) ++ stageChain.drop(1).foreach(s => { ++ succeedingStages += s ++ if (isCelebornShuffleIndeterminate) { ++ s match { ++ case currentMapStage: ShuffleMapStage => ++ CelebornShuffleState.registerCelebornSkewedShuffle(currentMapStage.shuffleDep.shuffleId) ++ case _: ResultStage => ++ // do nothing, should abort celeborn skewed read stage ++ } ++ } ++ }) + } else { + stageChain.head.parents.foreach { s => + collectSucceedingStagesInternal(s :: stageChain) +diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CelebornShuffleUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CelebornShuffleUtil.scala +new file mode 100644 +index 00000000000..3dc60678461 +--- /dev/null ++++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CelebornShuffleUtil.scala +@@ -0,0 +1,35 @@ ++/* ++ * Licensed to the Apache Software Foundation (ASF) under one or more ++ * contributor license agreements. See the NOTICE file distributed with ++ * this work for additional information regarding copyright ownership. ++ * The ASF licenses this file to You under the Apache License, Version 2.0 ++ * (the "License"); you may not use this file except in compliance with ++ * the License. You may obtain a copy of the License at ++ * ++ * http://www.apache.org/licenses/LICENSE-2.0 ++ * ++ * Unless required by applicable law or agreed to in writing, software ++ * distributed under the License is distributed on an "AS IS" BASIS, ++ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++ * See the License for the specific language governing permissions and ++ * limitations under the License. ++ */ ++ ++package org.apache.spark.sql.execution.adaptive ++ ++import java.util.Locale ++ ++import org.apache.spark.sql.execution.exchange.{ShuffleExchangeExec, ShuffleExchangeLike} ++ ++object CelebornShuffleUtil { ++ ++ def isCelebornShuffle(shuffleExchangeLike: ShuffleExchangeLike): Boolean = { ++ shuffleExchangeLike match { ++ case exec: ShuffleExchangeExec => ++ exec.shuffleDependency.shuffleHandle ++ .getClass.getName.toLowerCase(Locale.ROOT).contains("celeborn") ++ case _ => false ++ } ++ } ++ ++} +diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala +index abd096b9c7c..ff0363f87d8 100644 +--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala ++++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala +@@ -47,14 +47,15 @@ object OptimizeSkewInRebalancePartitions extends AQEShuffleReadRule { + private def optimizeSkewedPartitions( + shuffleId: Int, + bytesByPartitionId: Array[Long], +- targetSize: Long): Seq[ShufflePartitionSpec] = { ++ targetSize: Long, ++ isCelebornShuffle: Boolean = false): Seq[ShufflePartitionSpec] = { + val smallPartitionFactor = + conf.getConf(SQLConf.ADAPTIVE_REBALANCE_PARTITIONS_SMALL_PARTITION_FACTOR) + bytesByPartitionId.indices.flatMap { reduceIndex => + val bytes = bytesByPartitionId(reduceIndex) + if (bytes > targetSize) { + val newPartitionSpec = ShufflePartitionsUtil.createSkewPartitionSpecs( +- shuffleId, reduceIndex, targetSize, smallPartitionFactor) ++ shuffleId, reduceIndex, targetSize, smallPartitionFactor, isCelebornShuffle) + if (newPartitionSpec.isEmpty) { + CoalescedPartitionSpec(reduceIndex, reduceIndex + 1, bytes) :: Nil + } else { +@@ -77,8 +78,9 @@ object OptimizeSkewInRebalancePartitions extends AQEShuffleReadRule { + return shuffle + } + ++ val isCelebornShuffle = CelebornShuffleUtil.isCelebornShuffle(shuffle.shuffle) + val newPartitionsSpec = optimizeSkewedPartitions( +- mapStats.get.shuffleId, mapStats.get.bytesByPartitionId, advisorySize) ++ mapStats.get.shuffleId, mapStats.get.bytesByPartitionId, advisorySize, isCelebornShuffle) + // return origin plan if we can not optimize partitions + if (newPartitionsSpec.length == mapStats.get.bytesByPartitionId.length) { + shuffle +diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala +index 37cdea084d8..4694a06919e 100644 +--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala ++++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala +@@ -152,8 +152,10 @@ case class OptimizeSkewedJoin(ensureRequirements: EnsureRequirements) + Seq(CoalescedPartitionSpec(partitionIndex, partitionIndex + 1, rightSize)) + + val leftParts = if (isLeftSkew) { ++ val isCelebornShuffle = CelebornShuffleUtil.isCelebornShuffle(left.shuffle) + val skewSpecs = ShufflePartitionsUtil.createSkewPartitionSpecs( +- left.mapStats.get.shuffleId, partitionIndex, leftTargetSize) ++ left.mapStats.get.shuffleId, partitionIndex, leftTargetSize, ++ isCelebornShuffle = isCelebornShuffle) + if (skewSpecs.isDefined) { + logDebug(s"Left side partition $partitionIndex " + + s"(${FileUtils.byteCountToDisplaySize(leftSize)}) is skewed, " + +@@ -166,8 +168,10 @@ case class OptimizeSkewedJoin(ensureRequirements: EnsureRequirements) + } + + val rightParts = if (isRightSkew) { ++ val isCelebornShuffle = CelebornShuffleUtil.isCelebornShuffle(right.shuffle) + val skewSpecs = ShufflePartitionsUtil.createSkewPartitionSpecs( +- right.mapStats.get.shuffleId, partitionIndex, rightTargetSize) ++ right.mapStats.get.shuffleId, partitionIndex, rightTargetSize, ++ isCelebornShuffle = isCelebornShuffle) + if (skewSpecs.isDefined) { + logDebug(s"Right side partition $partitionIndex " + + s"(${FileUtils.byteCountToDisplaySize(rightSize)}) is skewed, " + +diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala +index 9370b3d8d1d..599eb976591 100644 +--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala ++++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala +@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.adaptive + import scala.collection.mutable.ArrayBuffer + + import org.apache.spark.{MapOutputStatistics, MapOutputTrackerMaster, SparkEnv} ++import org.apache.spark.celeborn.CelebornShuffleState + import org.apache.spark.internal.Logging + import org.apache.spark.sql.execution.{CoalescedPartitionSpec, PartialReducerPartitionSpec, ShufflePartitionSpec} + +@@ -382,13 +383,21 @@ object ShufflePartitionsUtil extends Logging { + shuffleId: Int, + reducerId: Int, + targetSize: Long, +- smallPartitionFactor: Double = SMALL_PARTITION_FACTOR) +- : Option[Seq[PartialReducerPartitionSpec]] = { ++ smallPartitionFactor: Double = SMALL_PARTITION_FACTOR, ++ isCelebornShuffle: Boolean = false): Option[Seq[PartialReducerPartitionSpec]] = { + val mapPartitionSizes = getMapSizesForReduceId(shuffleId, reducerId) + if (mapPartitionSizes.exists(_ < 0)) return None + val mapStartIndices = splitSizeListByTargetSize( + mapPartitionSizes, targetSize, smallPartitionFactor) + if (mapStartIndices.length > 1) { ++ val celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled = ++ CelebornShuffleState.celebornAdaptiveOptimizeSkewedPartitionReadEnabled && isCelebornShuffle ++ ++ val stageRerunEnabled = CelebornShuffleState.celebornStageRerunEnabled ++ if (stageRerunEnabled && celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled) { ++ logInfo(s"Celeborn shuffle retry enabled and shuffle $shuffleId is skewed") ++ CelebornShuffleState.registerCelebornSkewedShuffle(shuffleId) ++ } + Some(mapStartIndices.indices.map { i => + val startMapIndex = mapStartIndices(i) + val endMapIndex = if (i == mapStartIndices.length - 1) { +@@ -402,7 +411,15 @@ object ShufflePartitionsUtil extends Logging { + dataSize += mapPartitionSizes(mapIndex) + mapIndex += 1 + } +- PartialReducerPartitionSpec(reducerId, startMapIndex, endMapIndex, dataSize) ++ ++ if (celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled) { ++ // These `dataSize` variables may not be accurate as they only represent the sum of ++ // `dataSize` when the Celeborn optimize skewed partition read feature is enabled. ++ // Please not to use these dataSize variables in any other part of the codebase. ++ PartialReducerPartitionSpec(reducerId, mapStartIndices.length, i, dataSize) ++ } else { ++ PartialReducerPartitionSpec(reducerId, startMapIndex, endMapIndex, dataSize) ++ } + }) + } else { + None diff --git a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark4_0.patch b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark4_0.patch new file mode 100644 index 000000000..1bd62f687 --- /dev/null +++ b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark4_0.patch @@ -0,0 +1,408 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +index a660bccd2e6..a231ecfe72f 100644 +--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala ++++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +@@ -34,6 +34,7 @@ import org.apache.commons.io.output.{ByteArrayOutputStream => ApacheByteArrayOut + import org.roaringbitmap.RoaringBitmap + + import org.apache.spark.broadcast.{Broadcast, BroadcastManager} ++import org.apache.spark.celeborn.CelebornShuffleState + import org.apache.spark.internal.{Logging, MDC, MessageWithContext} + import org.apache.spark.internal.LogKeys._ + import org.apache.spark.internal.config._ +@@ -929,6 +930,7 @@ private[spark] class MapOutputTrackerMaster( + shuffleStatus.invalidateSerializedMapOutputStatusCache() + shuffleStatus.invalidateSerializedMergeOutputStatusCache() + } ++ CelebornShuffleState.unregisterCelebornSkewedShuffle(shuffleId) + } + + /** +diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala +index bf6e30f5afa..15dd28475a1 100644 +--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala ++++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala +@@ -31,6 +31,7 @@ import org.apache.hadoop.conf.Configuration + import org.apache.spark.annotation.DeveloperApi + import org.apache.spark.api.python.{PythonWorker, PythonWorkerFactory} + import org.apache.spark.broadcast.BroadcastManager ++import org.apache.spark.celeborn.CelebornShuffleState + import org.apache.spark.executor.ExecutorBackend + import org.apache.spark.internal.{config, Logging, MDC} + import org.apache.spark.internal.LogKeys +@@ -494,6 +495,7 @@ object SparkEnv extends Logging { + if (isDriver) { + val sparkFilesDir = Utils.createTempDir(Utils.getLocalDir(conf), "userFiles").getAbsolutePath + envInstance.driverTmpDir = Some(sparkFilesDir) ++ CelebornShuffleState.init(envInstance) + } + + envInstance +diff --git a/core/src/main/scala/org/apache/spark/celeborn/CelebornShuffleState.scala b/core/src/main/scala/org/apache/spark/celeborn/CelebornShuffleState.scala +new file mode 100644 +index 00000000000..5e190c512df +--- /dev/null ++++ b/core/src/main/scala/org/apache/spark/celeborn/CelebornShuffleState.scala +@@ -0,0 +1,75 @@ ++/* ++ * Licensed to the Apache Software Foundation (ASF) under one or more ++ * contributor license agreements. See the NOTICE file distributed with ++ * this work for additional information regarding copyright ownership. ++ * The ASF licenses this file to You under the Apache License, Version 2.0 ++ * (the "License"); you may not use this file except in compliance with ++ * the License. You may obtain a copy of the License at ++ * ++ * http://www.apache.org/licenses/LICENSE-2.0 ++ * ++ * Unless required by applicable law or agreed to in writing, software ++ * distributed under the License is distributed on an "AS IS" BASIS, ++ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++ * See the License for the specific language governing permissions and ++ * limitations under the License. ++ */ ++ ++package org.apache.spark.celeborn ++ ++import java.util.concurrent.ConcurrentHashMap ++import java.util.concurrent.atomic.AtomicBoolean ++ ++import org.apache.spark.SparkEnv ++import org.apache.spark.internal.config.ConfigBuilder ++ ++object CelebornShuffleState { ++ ++ private val CELEBORN_CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ = ++ ConfigBuilder("spark.celeborn.client.adaptive.optimizeSkewedPartitionRead.enabled") ++ .booleanConf ++ .createWithDefault(false) ++ ++ private val CELEBORN_STAGE_RERUN_ENABLED = ++ ConfigBuilder("spark.celeborn.client.spark.stageRerun.enabled") ++ .withAlternative("spark.celeborn.client.spark.fetch.throwsFetchFailure") ++ .booleanConf ++ .createWithDefault(true) ++ ++ private val celebornOptimizeSkewedPartitionReadEnabled = new AtomicBoolean() ++ private val stageRerunEnabled = new AtomicBoolean() ++ private val skewShuffleIds = ConcurrentHashMap.newKeySet[Int]() ++ ++ // call this from SparkEnv.create ++ def init(env: SparkEnv): Unit = { ++ // cleanup existing state (if required) - and initialize ++ skewShuffleIds.clear() ++ ++ // use env.conf for all initialization, and not SQLConf ++ celebornOptimizeSkewedPartitionReadEnabled.set( ++ env.conf.get("spark.shuffle.manager", "sort").contains("celeborn") && ++ env.conf.get(CELEBORN_CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ)) ++ stageRerunEnabled.set(env.conf.get(CELEBORN_STAGE_RERUN_ENABLED)) ++ } ++ ++ def unregisterCelebornSkewedShuffle(shuffleId: Int): Unit = { ++ skewShuffleIds.remove(shuffleId) ++ } ++ ++ def registerCelebornSkewedShuffle(shuffleId: Int): Unit = { ++ skewShuffleIds.add(shuffleId) ++ } ++ ++ def isCelebornSkewedShuffle(shuffleId: Int): Boolean = { ++ skewShuffleIds.contains(shuffleId) ++ } ++ ++ def celebornAdaptiveOptimizeSkewedPartitionReadEnabled: Boolean = { ++ celebornOptimizeSkewedPartitionReadEnabled.get() ++ } ++ ++ def celebornStageRerunEnabled: Boolean = { ++ stageRerunEnabled.get() ++ } ++ ++} +diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +index baf0ed4df53..2d7e36344a0 100644 +--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala ++++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +@@ -35,6 +35,7 @@ import com.google.common.util.concurrent.{Futures, SettableFuture} + + import org.apache.spark._ + import org.apache.spark.broadcast.Broadcast ++import org.apache.spark.celeborn.CelebornShuffleState + import org.apache.spark.errors.SparkCoreErrors + import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics} + import org.apache.spark.internal.{config, Logging, LogKeys, MDC} +@@ -1551,7 +1552,10 @@ private[spark] class DAGScheduler( + // The operation here can make sure for the partially completed intermediate stage, + // `findMissingPartitions()` returns all partitions every time. + stage match { +- case sms: ShuffleMapStage if stage.isIndeterminate && !sms.isAvailable => ++ case sms: ShuffleMapStage if (stage.isIndeterminate || ++ CelebornShuffleState.isCelebornSkewedShuffle(sms.shuffleDep.shuffleId)) && !sms.isAvailable => ++ logInfo(log"Unregistering shuffle output for stage ${MDC(STAGE_ID, stage.id)}" + ++ log" shuffle ${MDC(SHUFFLE_ID, sms.shuffleDep.shuffleId)}") + // already executed at least once + if (sms.getNextAttemptId > 0) { + // While we previously validated possible rollbacks during the handling of a FetchFailure, +@@ -1560,7 +1564,7 @@ private[spark] class DAGScheduler( + // loss. Moreover, because this check occurs later in the process, if a result stage task + // has successfully completed, we can detect this and abort the job, as rolling back a + // result stage is not possible. +- val stagesToRollback = collectSucceedingStages(sms) ++ val stagesToRollback = collectSucceedingStages(sms, CelebornShuffleState.isCelebornSkewedShuffle(sms.shuffleDep.shuffleId)) + abortStageWithInvalidRollBack(stagesToRollback) + // stages which cannot be rolled back were aborted which leads to removing the + // the dependant job(s) from the active jobs set +@@ -1951,7 +1955,18 @@ private[spark] class DAGScheduler( + // tasks complete, they still count and we can mark the corresponding partitions as + // finished if the stage is determinate. Here we notify the task scheduler to skip running + // tasks for the same partition to save resource. +- if (!stage.isIndeterminate && task.stageAttemptId < stage.latestInfo.attemptNumber()) { ++ // finished. Here we notify the task scheduler to skip running tasks for the same partition, ++ // to save resource. ++ // CELEBORN-1856, if stage is indeterminate or shuffleMapStage is skewed and read by ++ // Celeborn chunkOffsets, should not call notifyPartitionCompletion, otherwise will ++ // skip running tasks for the same partition because TaskSetManager.dequeueTaskFromList ++ // will skip running task which TaskSetManager.successful(taskIndex) is true. ++ // TODO: ResultStage has result commit and other issues ++ val isCelebornShuffleIndeterminate = stage.isInstanceOf[ShuffleMapStage] && ++ CelebornShuffleState.isCelebornSkewedShuffle( ++ stage.asInstanceOf[ShuffleMapStage].shuffleDep.shuffleId) ++ if (!stage.isIndeterminate && task.stageAttemptId < stage.latestInfo.attemptNumber() ++ && !isCelebornShuffleIndeterminate) { + taskScheduler.notifyPartitionCompletion(stageId, task.partitionId) + } + +@@ -2007,7 +2022,7 @@ private[spark] class DAGScheduler( + case smt: ShuffleMapTask => + val shuffleStage = stage.asInstanceOf[ShuffleMapStage] + // Ignore task completion for old attempt of indeterminate stage +- val ignoreIndeterminate = stage.isIndeterminate && ++ val ignoreIndeterminate = (stage.isIndeterminate || isCelebornShuffleIndeterminate) && + task.stageAttemptId < stage.latestInfo.attemptNumber() + if (!ignoreIndeterminate) { + shuffleStage.pendingPartitions -= task.partitionId +@@ -2043,6 +2058,14 @@ private[spark] class DAGScheduler( + val failedStage = stageIdToStage(task.stageId) + val mapStage = shuffleIdToMapStage(shuffleId) + ++ // In Celeborn-1139 we support read skew partition by Celeborn chunkOffsets, ++ // it will make shuffle be indeterminate, so abort the ResultStage directly here. ++ if (failedStage.isInstanceOf[ResultStage] && CelebornShuffleState.isCelebornSkewedShuffle(shuffleId)) { ++ val shuffleFailedReason = s"ResultStage:${failedStage.id} fetch failed and the shuffle:$shuffleId " + ++ s"is skewed partition read by Celeborn, so abort it." ++ abortStage(failedStage, shuffleFailedReason, None) ++ } ++ + if (failedStage.latestInfo.attemptNumber() != task.stageAttemptId) { + logInfo(log"Ignoring fetch failure from " + + log"${MDC(TASK_ID, task)} as it's from " + +@@ -2148,8 +2171,9 @@ private[spark] class DAGScheduler( + // Note that, if map stage is UNORDERED, we are fine. The shuffle partitioner is + // guaranteed to be determinate, so the input data of the reducers will not change + // even if the map tasks are re-tried. +- if (mapStage.isIndeterminate) { +- val stagesToRollback = collectSucceedingStages(mapStage) ++ val isCelebornShuffleIndeterminate = CelebornShuffleState.isCelebornSkewedShuffle(shuffleId) ++ if (mapStage.isIndeterminate || isCelebornShuffleIndeterminate) { ++ val stagesToRollback = collectSucceedingStages(mapStage, isCelebornShuffleIndeterminate) + val rollingBackStages = abortStageWithInvalidRollBack(stagesToRollback) + logInfo(log"The shuffle map stage ${MDC(SHUFFLE_ID, mapStage)} with indeterminate output was failed, " + + log"we will roll back and rerun below stages which include itself and all its " + +@@ -2314,7 +2338,7 @@ private[spark] class DAGScheduler( + } + } + +- private def collectSucceedingStages(mapStage: ShuffleMapStage): HashSet[Stage] = { ++ private def collectSucceedingStages(mapStage: ShuffleMapStage, isCelebornShuffleIndeterminate: Boolean): HashSet[Stage] = { + // TODO: perhaps materialize this if we are going to compute it often enough ? + // It's a little tricky to find all the succeeding stages of `mapStage`, because + // each stage only know its parents not children. Here we traverse the stages from +@@ -2326,7 +2350,17 @@ private[spark] class DAGScheduler( + + def collectSucceedingStagesInternal(stageChain: List[Stage]): Unit = { + if (succeedingStages.contains(stageChain.head)) { +- stageChain.drop(1).foreach(s => succeedingStages += s) ++ stageChain.drop(1).foreach(s => { ++ succeedingStages += s ++ if (isCelebornShuffleIndeterminate) { ++ s match { ++ case currentMapStage: ShuffleMapStage => ++ CelebornShuffleState.registerCelebornSkewedShuffle(currentMapStage.shuffleDep.shuffleId) ++ case _: ResultStage => ++ // do nothing, should abort celeborn skewed read stage ++ } ++ } ++ }) + } else { + stageChain.head.parents.foreach { s => + collectSucceedingStagesInternal(s :: stageChain) +diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CelebornShuffleUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CelebornShuffleUtil.scala +new file mode 100644 +index 00000000000..3dc60678461 +--- /dev/null ++++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CelebornShuffleUtil.scala +@@ -0,0 +1,35 @@ ++/* ++ * Licensed to the Apache Software Foundation (ASF) under one or more ++ * contributor license agreements. See the NOTICE file distributed with ++ * this work for additional information regarding copyright ownership. ++ * The ASF licenses this file to You under the Apache License, Version 2.0 ++ * (the "License"); you may not use this file except in compliance with ++ * the License. You may obtain a copy of the License at ++ * ++ * http://www.apache.org/licenses/LICENSE-2.0 ++ * ++ * Unless required by applicable law or agreed to in writing, software ++ * distributed under the License is distributed on an "AS IS" BASIS, ++ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++ * See the License for the specific language governing permissions and ++ * limitations under the License. ++ */ ++ ++package org.apache.spark.sql.execution.adaptive ++ ++import java.util.Locale ++ ++import org.apache.spark.sql.execution.exchange.{ShuffleExchangeExec, ShuffleExchangeLike} ++ ++object CelebornShuffleUtil { ++ ++ def isCelebornShuffle(shuffleExchangeLike: ShuffleExchangeLike): Boolean = { ++ shuffleExchangeLike match { ++ case exec: ShuffleExchangeExec => ++ exec.shuffleDependency.shuffleHandle ++ .getClass.getName.toLowerCase(Locale.ROOT).contains("celeborn") ++ case _ => false ++ } ++ } ++ ++} +diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala +index abd096b9c7c..ff0363f87d8 100644 +--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala ++++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala +@@ -47,14 +47,15 @@ object OptimizeSkewInRebalancePartitions extends AQEShuffleReadRule { + private def optimizeSkewedPartitions( + shuffleId: Int, + bytesByPartitionId: Array[Long], +- targetSize: Long): Seq[ShufflePartitionSpec] = { ++ targetSize: Long, ++ isCelebornShuffle: Boolean = false): Seq[ShufflePartitionSpec] = { + val smallPartitionFactor = + conf.getConf(SQLConf.ADAPTIVE_REBALANCE_PARTITIONS_SMALL_PARTITION_FACTOR) + bytesByPartitionId.indices.flatMap { reduceIndex => + val bytes = bytesByPartitionId(reduceIndex) + if (bytes > targetSize) { + val newPartitionSpec = ShufflePartitionsUtil.createSkewPartitionSpecs( +- shuffleId, reduceIndex, targetSize, smallPartitionFactor) ++ shuffleId, reduceIndex, targetSize, smallPartitionFactor, isCelebornShuffle) + if (newPartitionSpec.isEmpty) { + CoalescedPartitionSpec(reduceIndex, reduceIndex + 1, bytes) :: Nil + } else { +@@ -77,8 +78,9 @@ object OptimizeSkewInRebalancePartitions extends AQEShuffleReadRule { + return shuffle + } + ++ val isCelebornShuffle = CelebornShuffleUtil.isCelebornShuffle(shuffle.shuffle) + val newPartitionsSpec = optimizeSkewedPartitions( +- mapStats.get.shuffleId, mapStats.get.bytesByPartitionId, advisorySize) ++ mapStats.get.shuffleId, mapStats.get.bytesByPartitionId, advisorySize, isCelebornShuffle) + // return origin plan if we can not optimize partitions + if (newPartitionsSpec.length == mapStats.get.bytesByPartitionId.length) { + shuffle +diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala +index c256b3fcb6b..c366ff346e2 100644 +--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala ++++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala +@@ -151,8 +151,10 @@ case class OptimizeSkewedJoin(ensureRequirements: EnsureRequirements) + Seq(CoalescedPartitionSpec(partitionIndex, partitionIndex + 1, rightSize)) + + val leftParts = if (isLeftSkew) { ++ val isCelebornShuffle = CelebornShuffleUtil.isCelebornShuffle(left.shuffle) + val skewSpecs = ShufflePartitionsUtil.createSkewPartitionSpecs( +- left.mapStats.get.shuffleId, partitionIndex, leftTargetSize) ++ left.mapStats.get.shuffleId, partitionIndex, leftTargetSize, ++ isCelebornShuffle = isCelebornShuffle) + if (skewSpecs.isDefined) { + logDebug(s"Left side partition $partitionIndex " + + s"(${Utils.bytesToString(leftSize)}) is skewed, " + +@@ -165,8 +167,10 @@ case class OptimizeSkewedJoin(ensureRequirements: EnsureRequirements) + } + + val rightParts = if (isRightSkew) { ++ val isCelebornShuffle = CelebornShuffleUtil.isCelebornShuffle(right.shuffle) + val skewSpecs = ShufflePartitionsUtil.createSkewPartitionSpecs( +- right.mapStats.get.shuffleId, partitionIndex, rightTargetSize) ++ right.mapStats.get.shuffleId, partitionIndex, rightTargetSize, ++ isCelebornShuffle = isCelebornShuffle) + if (skewSpecs.isDefined) { + logDebug(s"Right side partition $partitionIndex " + + s"(${Utils.bytesToString(rightSize)}) is skewed, " + +diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala +index 1ea4df02546..22bc98f4f86 100644 +--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala ++++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala +@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.adaptive + import scala.collection.mutable.ArrayBuffer + + import org.apache.spark.{MapOutputStatistics, MapOutputTrackerMaster, SparkEnv} ++import org.apache.spark.celeborn.CelebornShuffleState + import org.apache.spark.internal.{Logging, LogKeys, MDC} + import org.apache.spark.sql.execution.{CoalescedPartitionSpec, PartialReducerPartitionSpec, ShufflePartitionSpec} + +@@ -384,13 +385,21 @@ object ShufflePartitionsUtil extends Logging { + shuffleId: Int, + reducerId: Int, + targetSize: Long, +- smallPartitionFactor: Double = SMALL_PARTITION_FACTOR) +- : Option[Seq[PartialReducerPartitionSpec]] = { ++ smallPartitionFactor: Double = SMALL_PARTITION_FACTOR, ++ isCelebornShuffle: Boolean = false): Option[Seq[PartialReducerPartitionSpec]] = { + val mapPartitionSizes = getMapSizesForReduceId(shuffleId, reducerId) + if (mapPartitionSizes.exists(_ < 0)) return None + val mapStartIndices = splitSizeListByTargetSize( + mapPartitionSizes, targetSize, smallPartitionFactor) + if (mapStartIndices.length > 1) { ++ val celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled = ++ CelebornShuffleState.celebornAdaptiveOptimizeSkewedPartitionReadEnabled && isCelebornShuffle ++ ++ val stageRerunEnabled = CelebornShuffleState.celebornStageRerunEnabled ++ if (stageRerunEnabled && celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled) { ++ logInfo(log"Celeborn shuffle retry enabled and shuffle ${MDC(LogKeys.SHUFFLE_ID, shuffleId)} is skewed") ++ CelebornShuffleState.registerCelebornSkewedShuffle(shuffleId) ++ } + Some(mapStartIndices.indices.map { i => + val startMapIndex = mapStartIndices(i) + val endMapIndex = if (i == mapStartIndices.length - 1) { +@@ -404,7 +413,15 @@ object ShufflePartitionsUtil extends Logging { + dataSize += mapPartitionSizes(mapIndex) + mapIndex += 1 + } +- PartialReducerPartitionSpec(reducerId, startMapIndex, endMapIndex, dataSize) ++ ++ if (celebornClientAdaptiveOptimizeSkewedPartitionReadEnabled) { ++ // These `dataSize` variables may not be accurate as they only represent the sum of ++ // `dataSize` when the Celeborn optimize skewed partition read feature is enabled. ++ // Please not to use these dataSize variables in any other part of the codebase. ++ PartialReducerPartitionSpec(reducerId, mapStartIndices.length, i, dataSize) ++ } else { ++ PartialReducerPartitionSpec(reducerId, startMapIndex, endMapIndex, dataSize) ++ } + }) + } else { + None