From 4bacd1f211fa500460fb389ff96ae2c9e0591360 Mon Sep 17 00:00:00 2001 From: wangshengjie3 Date: Mon, 24 Mar 2025 22:03:15 +0800 Subject: [PATCH] [CELEBORN-1856] Support stage-rerun when read partition by chunkOffsets when enable optimize skew partition read ### What changes were proposed in this pull request? Support stage-rerun when read partition by chunkOffsets when enable optimize skew partition read ### Why are the changes needed? In [CELEBORN-1319](https://issues.apache.org/jira/browse/CELEBORN-1319), we have already implemented the skew partition read optimization based on chunk offsets, but we don't support skew partition shuffle retry, so we need support the stage rerun. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Cluster test Closes #3118 from wangshengjie123/support-stage-rerun. Lead-authored-by: wangshengjie3 Co-authored-by: Wang, Fei Signed-off-by: Shuang --- ...rn-Optimize-Skew-Partitions-spark3_2.patch | 79 ++++++++++++++-- ...rn-Optimize-Skew-Partitions-spark3_3.patch | 79 ++++++++++++++-- ...rn-Optimize-Skew-Partitions-spark3_4.patch | 77 ++++++++++++++-- ...rn-Optimize-Skew-Partitions-spark3_5.patch | 89 +++++++++++++++++-- .../shuffle/celeborn/SparkShuffleManager.java | 5 ++ .../spark/shuffle/celeborn/SparkUtils.java | 14 +++ .../celeborn/client/LifecycleManager.scala | 18 +++- 7 files changed, 326 insertions(+), 35 deletions(-) diff --git a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_2.patch b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_2.patch index 0cb1fc812..0e3be7a8e 100644 --- a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_2.patch +++ b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_2.patch @@ -135,7 +135,7 @@ index 00000000000..5e190c512df + +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala -index b950c07f3d8..2cb430c3c3d 100644 +index b950c07f3d8..9e339db4fb4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -33,6 +33,7 @@ import com.google.common.util.concurrent.{Futures, SettableFuture} @@ -146,15 +146,76 @@ index b950c07f3d8..2cb430c3c3d 100644 import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics} import org.apache.spark.internal.Logging import org.apache.spark.internal.config -@@ -1780,7 +1781,7 @@ private[spark] class DAGScheduler( - failedStage.failedAttemptIds.add(task.stageAttemptId) - val shouldAbortStage = - failedStage.failedAttemptIds.size >= maxConsecutiveStageAttempts || -- disallowStageRetryForTest -+ disallowStageRetryForTest || CelebornShuffleState.isCelebornSkewedShuffle(shuffleId) +@@ -1369,7 +1370,10 @@ private[spark] class DAGScheduler( + // The operation here can make sure for the partially completed intermediate stage, + // `findMissingPartitions()` returns all partitions every time. + stage match { +- case sms: ShuffleMapStage if stage.isIndeterminate && !sms.isAvailable => ++ case sms: ShuffleMapStage if (stage.isIndeterminate || ++ CelebornShuffleState.isCelebornSkewedShuffle(sms.shuffleDep.shuffleId)) && !sms.isAvailable => ++ logInfo(s"Unregistering shuffle output for stage ${stage.id}" + ++ s" shuffle ${sms.shuffleDep.shuffleId}") + mapOutputTracker.unregisterAllMapAndMergeOutput(sms.shuffleDep.shuffleId) + sms.shuffleDep.newShuffleMergeState() + case _ => +@@ -1689,7 +1693,15 @@ private[spark] class DAGScheduler( + // tasks complete, they still count and we can mark the corresponding partitions as + // finished. Here we notify the task scheduler to skip running tasks for the same partition, + // to save resource. +- if (task.stageAttemptId < stage.latestInfo.attemptNumber()) { ++ // CELEBORN-1856, if stage is indeterminate or shuffleMapStage is skewed and read by ++ // Celeborn chunkOffsets, should not call notifyPartitionCompletion, otherwise will ++ // skip running tasks for the same partition because TaskSetManager.dequeueTaskFromList ++ // will skip running task which TaskSetManager.successful(taskIndex) is true. ++ // TODO: Suggest cherry-pick SPARK-45182 and SPARK-45498, ResultStage may has result commit and other issues ++ val isStageIndeterminate = stage.isInstanceOf[ShuffleMapStage] && ++ CelebornShuffleState.isCelebornSkewedShuffle( ++ stage.asInstanceOf[ShuffleMapStage].shuffleDep.shuffleId) ++ if (task.stageAttemptId < stage.latestInfo.attemptNumber() && !isStageIndeterminate) { + taskScheduler.notifyPartitionCompletion(stageId, task.partitionId) + } - // It is likely that we receive multiple FetchFailed for a single stage (because we have - // multiple tasks running concurrently on different executors). In that case, it is +@@ -1772,6 +1784,14 @@ private[spark] class DAGScheduler( + val failedStage = stageIdToStage(task.stageId) + val mapStage = shuffleIdToMapStage(shuffleId) + ++ // In Celeborn-1139 we support read skew partition by Celeborn chunkOffsets, ++ // it will make shuffle be indeterminate, so abort the ResultStage directly here. ++ if (failedStage.isInstanceOf[ResultStage] && CelebornShuffleState.isCelebornSkewedShuffle(shuffleId)) { ++ val shuffleFailedReason = s"ResultStage:${failedStage.id} fetch failed and the shuffle:$shuffleId " + ++ s"is skewed partition read by Celeborn, so abort it." ++ abortStage(failedStage, shuffleFailedReason, None) ++ } ++ + if (failedStage.latestInfo.attemptNumber != task.stageAttemptId) { + logInfo(s"Ignoring fetch failure from $task as it's from $failedStage attempt" + + s" ${task.stageAttemptId} and there is a more recent attempt for that stage " + +@@ -1850,7 +1870,7 @@ private[spark] class DAGScheduler( + // Note that, if map stage is UNORDERED, we are fine. The shuffle partitioner is + // guaranteed to be determinate, so the input data of the reducers will not change + // even if the map tasks are re-tried. +- if (mapStage.isIndeterminate) { ++ if (mapStage.isIndeterminate || CelebornShuffleState.isCelebornSkewedShuffle(shuffleId)) { + // It's a little tricky to find all the succeeding stages of `mapStage`, because + // each stage only know its parents not children. Here we traverse the stages from + // the leaf nodes (the result stages of active jobs), and rollback all the stages +@@ -1861,7 +1881,15 @@ private[spark] class DAGScheduler( + + def collectStagesToRollback(stageChain: List[Stage]): Unit = { + if (stagesToRollback.contains(stageChain.head)) { +- stageChain.drop(1).foreach(s => stagesToRollback += s) ++ stageChain.drop(1).foreach(s => { ++ stagesToRollback += s ++ s match { ++ case currentMapStage: ShuffleMapStage => ++ CelebornShuffleState.registerCelebornSkewedShuffle(currentMapStage.shuffleDep.shuffleId) ++ case _: ResultStage => ++ // do nothing, should abort celeborn skewed read stage ++ } ++ }) + } else { + stageChain.head.parents.foreach { s => + collectStagesToRollback(s :: stageChain) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CelebornShuffleUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CelebornShuffleUtil.scala new file mode 100644 index 00000000000..3dc60678461 diff --git a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_3.patch b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_3.patch index f8e38615c..6bb8be966 100644 --- a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_3.patch +++ b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_3.patch @@ -135,7 +135,7 @@ index 00000000000..5e190c512df + +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala -index bd2823bcac1..d0c88081527 100644 +index bd2823bcac1..e97218b046b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -33,6 +33,7 @@ import com.google.common.util.concurrent.{Futures, SettableFuture} @@ -146,15 +146,76 @@ index bd2823bcac1..d0c88081527 100644 import org.apache.spark.errors.SparkCoreErrors import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics} import org.apache.spark.internal.Logging -@@ -1851,7 +1852,7 @@ private[spark] class DAGScheduler( - failedStage.failedAttemptIds.add(task.stageAttemptId) - val shouldAbortStage = - failedStage.failedAttemptIds.size >= maxConsecutiveStageAttempts || -- disallowStageRetryForTest -+ disallowStageRetryForTest || CelebornShuffleState.isCelebornSkewedShuffle(shuffleId) +@@ -1404,7 +1405,10 @@ private[spark] class DAGScheduler( + // The operation here can make sure for the partially completed intermediate stage, + // `findMissingPartitions()` returns all partitions every time. + stage match { +- case sms: ShuffleMapStage if stage.isIndeterminate && !sms.isAvailable => ++ case sms: ShuffleMapStage if (stage.isIndeterminate || ++ CelebornShuffleState.isCelebornSkewedShuffle(sms.shuffleDep.shuffleId)) && !sms.isAvailable => ++ logInfo(s"Unregistering shuffle output for stage ${stage.id}" + ++ s" shuffle ${sms.shuffleDep.shuffleId}") + mapOutputTracker.unregisterAllMapAndMergeOutput(sms.shuffleDep.shuffleId) + sms.shuffleDep.newShuffleMergeState() + case _ => +@@ -1760,7 +1764,15 @@ private[spark] class DAGScheduler( + // tasks complete, they still count and we can mark the corresponding partitions as + // finished. Here we notify the task scheduler to skip running tasks for the same partition, + // to save resource. +- if (task.stageAttemptId < stage.latestInfo.attemptNumber()) { ++ // CELEBORN-1856, if stage is indeterminate or shuffleMapStage is skewed and read by ++ // Celeborn chunkOffsets, should not call notifyPartitionCompletion, otherwise will ++ // skip running tasks for the same partition because TaskSetManager.dequeueTaskFromList ++ // will skip running task which TaskSetManager.successful(taskIndex) is true. ++ // TODO: Suggest cherry-pick SPARK-45182 and SPARK-45498, ResultStage may has result commit and other issues ++ val isStageIndeterminate = stage.isInstanceOf[ShuffleMapStage] && ++ CelebornShuffleState.isCelebornSkewedShuffle( ++ stage.asInstanceOf[ShuffleMapStage].shuffleDep.shuffleId) ++ if (task.stageAttemptId < stage.latestInfo.attemptNumber() && !isStageIndeterminate) { + taskScheduler.notifyPartitionCompletion(stageId, task.partitionId) + } - // It is likely that we receive multiple FetchFailed for a single stage (because we have - // multiple tasks running concurrently on different executors). In that case, it is +@@ -1843,6 +1855,14 @@ private[spark] class DAGScheduler( + val failedStage = stageIdToStage(task.stageId) + val mapStage = shuffleIdToMapStage(shuffleId) + ++ // In Celeborn-1139 we support read skew partition by Celeborn chunkOffsets, ++ // it will make shuffle be indeterminate, so abort the ResultStage directly here. ++ if (failedStage.isInstanceOf[ResultStage] && CelebornShuffleState.isCelebornSkewedShuffle(shuffleId)) { ++ val shuffleFailedReason = s"ResultStage:${failedStage.id} fetch failed and the shuffle:$shuffleId " + ++ s"is skewed partition read by Celeborn, so abort it." ++ abortStage(failedStage, shuffleFailedReason, None) ++ } ++ + if (failedStage.latestInfo.attemptNumber != task.stageAttemptId) { + logInfo(s"Ignoring fetch failure from $task as it's from $failedStage attempt" + + s" ${task.stageAttemptId} and there is a more recent attempt for that stage " + +@@ -1921,7 +1941,7 @@ private[spark] class DAGScheduler( + // Note that, if map stage is UNORDERED, we are fine. The shuffle partitioner is + // guaranteed to be determinate, so the input data of the reducers will not change + // even if the map tasks are re-tried. +- if (mapStage.isIndeterminate) { ++ if (mapStage.isIndeterminate || CelebornShuffleState.isCelebornSkewedShuffle(shuffleId)) { + // It's a little tricky to find all the succeeding stages of `mapStage`, because + // each stage only know its parents not children. Here we traverse the stages from + // the leaf nodes (the result stages of active jobs), and rollback all the stages +@@ -1932,7 +1952,15 @@ private[spark] class DAGScheduler( + + def collectStagesToRollback(stageChain: List[Stage]): Unit = { + if (stagesToRollback.contains(stageChain.head)) { +- stageChain.drop(1).foreach(s => stagesToRollback += s) ++ stageChain.drop(1).foreach(s => { ++ stagesToRollback += s ++ s match { ++ case currentMapStage: ShuffleMapStage => ++ CelebornShuffleState.registerCelebornSkewedShuffle(currentMapStage.shuffleDep.shuffleId) ++ case _: ResultStage => ++ // do nothing, should abort celeborn skewed read stage ++ } ++ }) + } else { + stageChain.head.parents.foreach { s => + collectStagesToRollback(s :: stageChain) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CelebornShuffleUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CelebornShuffleUtil.scala new file mode 100644 index 00000000000..3dc60678461 diff --git a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_4.patch b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_4.patch index 9aed835fe..9f38d8026 100644 --- a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_4.patch +++ b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_4.patch @@ -135,7 +135,7 @@ index 00000000000..5e190c512df + +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala -index 26be8c72bbc..81feaba962c 100644 +index 26be8c72bbc..4323b6d1a75 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -34,6 +34,7 @@ import com.google.common.util.concurrent.{Futures, SettableFuture} @@ -146,15 +146,76 @@ index 26be8c72bbc..81feaba962c 100644 import org.apache.spark.errors.SparkCoreErrors import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics} import org.apache.spark.internal.Logging -@@ -1897,7 +1898,7 @@ private[spark] class DAGScheduler( +@@ -1435,7 +1436,10 @@ private[spark] class DAGScheduler( + // The operation here can make sure for the partially completed intermediate stage, + // `findMissingPartitions()` returns all partitions every time. + stage match { +- case sms: ShuffleMapStage if stage.isIndeterminate && !sms.isAvailable => ++ case sms: ShuffleMapStage if (stage.isIndeterminate || ++ CelebornShuffleState.isCelebornSkewedShuffle(sms.shuffleDep.shuffleId)) && !sms.isAvailable => ++ logInfo(s"Unregistering shuffle output for stage ${stage.id}" + ++ s" shuffle ${sms.shuffleDep.shuffleId}") + mapOutputTracker.unregisterAllMapAndMergeOutput(sms.shuffleDep.shuffleId) + sms.shuffleDep.newShuffleMergeState() + case _ => +@@ -1796,7 +1800,15 @@ private[spark] class DAGScheduler( + // tasks complete, they still count and we can mark the corresponding partitions as + // finished. Here we notify the task scheduler to skip running tasks for the same partition, + // to save resource. +- if (task.stageAttemptId < stage.latestInfo.attemptNumber()) { ++ // CELEBORN-1856, if stage is indeterminate or shuffleMapStage is skewed and read by ++ // Celeborn chunkOffsets, should not call notifyPartitionCompletion, otherwise will ++ // skip running tasks for the same partition because TaskSetManager.dequeueTaskFromList ++ // will skip running task which TaskSetManager.successful(taskIndex) is true. ++ // TODO: Suggest cherry-pick SPARK-45182 and SPARK-45498, ResultStage may has result commit and other issues ++ val isStageIndeterminate = stage.isInstanceOf[ShuffleMapStage] && ++ CelebornShuffleState.isCelebornSkewedShuffle( ++ stage.asInstanceOf[ShuffleMapStage].shuffleDep.shuffleId) ++ if (task.stageAttemptId < stage.latestInfo.attemptNumber() && !isStageIndeterminate) { + taskScheduler.notifyPartitionCompletion(stageId, task.partitionId) + } - val shouldAbortStage = - failedStage.failedAttemptIds.size >= maxConsecutiveStageAttempts || -- disallowStageRetryForTest -+ disallowStageRetryForTest || CelebornShuffleState.isCelebornSkewedShuffle(shuffleId) +@@ -1879,6 +1891,14 @@ private[spark] class DAGScheduler( + val failedStage = stageIdToStage(task.stageId) + val mapStage = shuffleIdToMapStage(shuffleId) - // It is likely that we receive multiple FetchFailed for a single stage (because we have - // multiple tasks running concurrently on different executors). In that case, it is ++ // In Celeborn-1139 we support read skew partition by Celeborn chunkOffsets, ++ // it will make shuffle be indeterminate, so abort the ResultStage directly here. ++ if (failedStage.isInstanceOf[ResultStage] && CelebornShuffleState.isCelebornSkewedShuffle(shuffleId)) { ++ val shuffleFailedReason = s"ResultStage:${failedStage.id} fetch failed and the shuffle:$shuffleId " + ++ s"is skewed partition read by Celeborn, so abort it." ++ abortStage(failedStage, shuffleFailedReason, None) ++ } ++ + if (failedStage.latestInfo.attemptNumber != task.stageAttemptId) { + logInfo(s"Ignoring fetch failure from $task as it's from $failedStage attempt" + + s" ${task.stageAttemptId} and there is a more recent attempt for that stage " + +@@ -1977,7 +1997,7 @@ private[spark] class DAGScheduler( + // Note that, if map stage is UNORDERED, we are fine. The shuffle partitioner is + // guaranteed to be determinate, so the input data of the reducers will not change + // even if the map tasks are re-tried. +- if (mapStage.isIndeterminate) { ++ if (mapStage.isIndeterminate || CelebornShuffleState.isCelebornSkewedShuffle(shuffleId)) { + // It's a little tricky to find all the succeeding stages of `mapStage`, because + // each stage only know its parents not children. Here we traverse the stages from + // the leaf nodes (the result stages of active jobs), and rollback all the stages +@@ -1988,7 +2008,15 @@ private[spark] class DAGScheduler( + + def collectStagesToRollback(stageChain: List[Stage]): Unit = { + if (stagesToRollback.contains(stageChain.head)) { +- stageChain.drop(1).foreach(s => stagesToRollback += s) ++ stageChain.drop(1).foreach(s => { ++ stagesToRollback += s ++ s match { ++ case currentMapStage: ShuffleMapStage => ++ CelebornShuffleState.registerCelebornSkewedShuffle(currentMapStage.shuffleDep.shuffleId) ++ case _: ResultStage => ++ // do nothing, should abort celeborn skewed read stage ++ } ++ }) + } else { + stageChain.head.parents.foreach { s => + collectStagesToRollback(s :: stageChain) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CelebornShuffleUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CelebornShuffleUtil.scala new file mode 100644 index 00000000000..3dc60678461 diff --git a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_5.patch b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_5.patch index 553bdeae6..71d0f9859 100644 --- a/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_5.patch +++ b/assets/spark-patch/Celeborn-Optimize-Skew-Partitions-spark3_5.patch @@ -135,7 +135,7 @@ index 00000000000..5e190c512df + +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala -index 89d16e57934..3b9094f3254 100644 +index 89d16e57934..36ce50093c0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -34,6 +34,7 @@ import com.google.common.util.concurrent.{Futures, SettableFuture} @@ -146,15 +146,88 @@ index 89d16e57934..3b9094f3254 100644 import org.apache.spark.errors.SparkCoreErrors import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics} import org.apache.spark.internal.Logging -@@ -1962,7 +1963,7 @@ private[spark] class DAGScheduler( +@@ -1480,7 +1481,10 @@ private[spark] class DAGScheduler( + // The operation here can make sure for the partially completed intermediate stage, + // `findMissingPartitions()` returns all partitions every time. + stage match { +- case sms: ShuffleMapStage if stage.isIndeterminate && !sms.isAvailable => ++ case sms: ShuffleMapStage if (stage.isIndeterminate || ++ CelebornShuffleState.isCelebornSkewedShuffle(sms.shuffleDep.shuffleId)) && !sms.isAvailable => ++ logInfo(s"Unregistering shuffle output for stage ${stage.id}" + ++ s" shuffle ${sms.shuffleDep.shuffleId}") + mapOutputTracker.unregisterAllMapAndMergeOutput(sms.shuffleDep.shuffleId) + sms.shuffleDep.newShuffleMergeState() + case _ => +@@ -1854,7 +1858,18 @@ private[spark] class DAGScheduler( + // tasks complete, they still count and we can mark the corresponding partitions as + // finished if the stage is determinate. Here we notify the task scheduler to skip running + // tasks for the same partition to save resource. +- if (!stage.isIndeterminate && task.stageAttemptId < stage.latestInfo.attemptNumber()) { ++ // finished. Here we notify the task scheduler to skip running tasks for the same partition, ++ // to save resource. ++ // CELEBORN-1856, if stage is indeterminate or shuffleMapStage is skewed and read by ++ // Celeborn chunkOffsets, should not call notifyPartitionCompletion, otherwise will ++ // skip running tasks for the same partition because TaskSetManager.dequeueTaskFromList ++ // will skip running task which TaskSetManager.successful(taskIndex) is true. ++ // TODO: ResultStage has result commit and other issues ++ val isCelebornShuffleIndeterminate = stage.isInstanceOf[ShuffleMapStage] && ++ CelebornShuffleState.isCelebornSkewedShuffle( ++ stage.asInstanceOf[ShuffleMapStage].shuffleDep.shuffleId) ++ if (!stage.isIndeterminate && task.stageAttemptId < stage.latestInfo.attemptNumber() ++ && !isCelebornShuffleIndeterminate) { + taskScheduler.notifyPartitionCompletion(stageId, task.partitionId) + } - val shouldAbortStage = - failedStage.failedAttemptIds.size >= maxConsecutiveStageAttempts || -- disallowStageRetryForTest -+ disallowStageRetryForTest || CelebornShuffleState.isCelebornSkewedShuffle(shuffleId) +@@ -1909,7 +1924,7 @@ private[spark] class DAGScheduler( + case smt: ShuffleMapTask => + val shuffleStage = stage.asInstanceOf[ShuffleMapStage] + // Ignore task completion for old attempt of indeterminate stage +- val ignoreIndeterminate = stage.isIndeterminate && ++ val ignoreIndeterminate = (stage.isIndeterminate || isCelebornShuffleIndeterminate) && + task.stageAttemptId < stage.latestInfo.attemptNumber() + if (!ignoreIndeterminate) { + shuffleStage.pendingPartitions -= task.partitionId +@@ -1944,6 +1959,14 @@ private[spark] class DAGScheduler( + val failedStage = stageIdToStage(task.stageId) + val mapStage = shuffleIdToMapStage(shuffleId) - // It is likely that we receive multiple FetchFailed for a single stage (because we have - // multiple tasks running concurrently on different executors). In that case, it is ++ // In Celeborn-1139 we support read skew partition by Celeborn chunkOffsets, ++ // it will make shuffle be indeterminate, so abort the ResultStage directly here. ++ if (failedStage.isInstanceOf[ResultStage] && CelebornShuffleState.isCelebornSkewedShuffle(shuffleId)) { ++ val shuffleFailedReason = s"ResultStage:${failedStage.id} fetch failed and the shuffle:$shuffleId " + ++ s"is skewed partition read by Celeborn, so abort it." ++ abortStage(failedStage, shuffleFailedReason, None) ++ } ++ + if (failedStage.latestInfo.attemptNumber != task.stageAttemptId) { + logInfo(s"Ignoring fetch failure from $task as it's from $failedStage attempt" + + s" ${task.stageAttemptId} and there is a more recent attempt for that stage " + +@@ -2042,7 +2065,7 @@ private[spark] class DAGScheduler( + // Note that, if map stage is UNORDERED, we are fine. The shuffle partitioner is + // guaranteed to be determinate, so the input data of the reducers will not change + // even if the map tasks are re-tried. +- if (mapStage.isIndeterminate) { ++ if (mapStage.isIndeterminate || CelebornShuffleState.isCelebornSkewedShuffle(shuffleId)) { + // It's a little tricky to find all the succeeding stages of `mapStage`, because + // each stage only know its parents not children. Here we traverse the stages from + // the leaf nodes (the result stages of active jobs), and rollback all the stages +@@ -2053,7 +2076,15 @@ private[spark] class DAGScheduler( + + def collectStagesToRollback(stageChain: List[Stage]): Unit = { + if (stagesToRollback.contains(stageChain.head)) { +- stageChain.drop(1).foreach(s => stagesToRollback += s) ++ stageChain.drop(1).foreach(s => { ++ stagesToRollback += s ++ s match { ++ case currentMapStage: ShuffleMapStage => ++ CelebornShuffleState.registerCelebornSkewedShuffle(currentMapStage.shuffleDep.shuffleId) ++ case _: ResultStage => ++ // do nothing, should abort celeborn skewed read stage ++ } ++ }) + } else { + stageChain.head.parents.foreach { s => + collectStagesToRollback(s :: stageChain) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CelebornShuffleUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CelebornShuffleUtil.scala new file mode 100644 index 00000000000..3dc60678461 diff --git a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java index df28143c6..234fba1fe 100644 --- a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java +++ b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java @@ -150,6 +150,11 @@ public class SparkShuffleManager implements ShuffleManager { lifecycleManager.registerShuffleTrackerCallback( shuffleId -> SparkUtils.unregisterAllMapOutput(mapOutputTracker, shuffleId)); + + if (celebornConf.clientAdaptiveOptimizeSkewedPartitionReadEnabled()) { + lifecycleManager.registerCelebornSkewShuffleCheckCallback( + SparkUtils::isCelebornSkewShuffleOrChildShuffle); + } } } } diff --git a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java index 6c2e5120e..6443c2163 100644 --- a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java +++ b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java @@ -462,4 +462,18 @@ public class SparkUtils { sparkContext.addSparkListener(listener); } } + + private static final DynMethods.UnboundMethod isCelebornSkewShuffle_METHOD = + DynMethods.builder("isCelebornSkewedShuffle") + .hiddenImpl("org.apache.spark.celeborn.CelebornShuffleState", Integer.TYPE) + .orNoop() + .build(); + + public static boolean isCelebornSkewShuffleOrChildShuffle(int appShuffleId) { + if (!isCelebornSkewShuffle_METHOD.isNoop()) { + return isCelebornSkewShuffle_METHOD.asStatic().invoke(appShuffleId); + } else { + return false; + } + } } diff --git a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala index 285da2296..f706eeb90 100644 --- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala +++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala @@ -909,7 +909,8 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends // For barrier stages, all tasks are re-executed when it is re-run : similar to indeterminate stage. // So if a barrier stage is getting reexecuted, previous stage/attempt needs to // be cleaned up as it is entirely unusuable - if (determinate && !isBarrierStage) + if (determinate && !isBarrierStage && !isCelebornSkewShuffleOrChildShuffle( + appShuffleId)) shuffleIds.values.toSeq.reverse.find(e => e._2 == true) else None @@ -1057,6 +1058,14 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends } } + private def isCelebornSkewShuffleOrChildShuffle(appShuffleId: Int): Boolean = { + celebornSkewShuffleCheckCallback match { + case Some(skewShuffleCallback) => + skewShuffleCallback.apply(appShuffleId) + case None => false + } + } + private def handleStageEnd(shuffleId: Int): Unit = { // check whether shuffle has registered if (!registeredShuffle.contains(shuffleId)) { @@ -1843,6 +1852,13 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends registerShuffleResponseRpcCache.invalidate(shuffleId) } + @volatile private var celebornSkewShuffleCheckCallback + : Option[function.Function[Integer, Boolean]] = None + def registerCelebornSkewShuffleCheckCallback(callback: function.Function[Integer, Boolean]) + : Unit = { + celebornSkewShuffleCheckCallback = Some(callback) + } + // Initialize at the end of LifecycleManager construction. initialize()