From 045411ac34c749d8ca2d7bed8d0b9c3554f71287 Mon Sep 17 00:00:00 2001 From: lijianfu03 Date: Tue, 13 May 2025 16:14:03 +0800 Subject: [PATCH] [CELEBORN-1855] LifecycleManager return appshuffleId for non barrier stage when fetch fail has been reported ### What changes were proposed in this pull request? for non barrier shuffle read stage, LifecycleManager#handleGetShuffleIdForApp always return appshuffleId whether fetch status is true or not. ### Why are the changes needed? As described in [jira](https://issues.apache.org/jira/browse/CELEBORN-1855), If LifecycleManager only returns appshuffleId whose fetch status is success, the task will fail directly to "there is no finished map stage associated with", but previous fetch fail event reported may not be fatal.So just give it a chance ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? Closes #3090 from buska88/celeborn-1855. Authored-by: lijianfu03 Signed-off-by: Shuang --- .../spark/shuffle/celeborn/SparkUtils.java | 17 +++++++++----- .../celeborn/CelebornShuffleReader.scala | 22 ++++++++++++++++--- .../spark/shuffle/celeborn/SparkUtils.java | 17 +++++++++----- .../celeborn/CelebornShuffleReader.scala | 22 +++++++++++++++++-- .../celeborn/client/DummyShuffleClient.java | 6 +++-- .../apache/celeborn/client/ShuffleClient.java | 4 +++- .../celeborn/client/ShuffleClientImpl.java | 7 +++--- .../celeborn/client/LifecycleManager.scala | 18 ++++++++++----- common/src/main/proto/TransportMessages.proto | 1 + .../celeborn/tests/spark/SparkTestBase.scala | 3 ++- 10 files changed, 89 insertions(+), 28 deletions(-) diff --git a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java index 44dd1ea28..b52758581 100644 --- a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java +++ b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java @@ -63,6 +63,7 @@ import org.slf4j.LoggerFactory; import org.apache.celeborn.client.ShuffleClient; import org.apache.celeborn.common.CelebornConf; +import org.apache.celeborn.common.exception.CelebornRuntimeException; import org.apache.celeborn.common.network.protocol.TransportMessage; import org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse; import org.apache.celeborn.common.util.JavaUtils; @@ -161,11 +162,17 @@ public class SparkUtils { Boolean isWriter) { if (handle.throwsFetchFailure()) { String appShuffleIdentifier = getAppShuffleIdentifier(handle.shuffleId(), context); - return client.getShuffleId( - handle.shuffleId(), - appShuffleIdentifier, - isWriter, - context instanceof BarrierTaskContext); + Tuple2 res = + client.getShuffleId( + handle.shuffleId(), + appShuffleIdentifier, + isWriter, + context instanceof BarrierTaskContext); + if (!res._2) { + throw new CelebornRuntimeException(String.format("Get invalid shuffle id %s", res._1)); + } else { + return res._1; + } } else { return handle.shuffleId(); } diff --git a/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala b/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala index 2269aaf68..df63a94b1 100644 --- a/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala +++ b/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala @@ -33,7 +33,7 @@ import org.apache.celeborn.client.ShuffleClient import org.apache.celeborn.client.read.CelebornInputStream import org.apache.celeborn.client.read.MetricsCallback import org.apache.celeborn.common.CelebornConf -import org.apache.celeborn.common.exception.{CelebornIOException, PartitionUnRetryAbleException} +import org.apache.celeborn.common.exception.{CelebornIOException, CelebornRuntimeException, PartitionUnRetryAbleException} import org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse import org.apache.celeborn.common.util.{JavaUtils, ThreadUtils} @@ -63,8 +63,24 @@ class CelebornShuffleReader[K, C]( override def read(): Iterator[Product2[K, C]] = { val serializerInstance = dep.serializer.newInstance() - - val shuffleId = SparkUtils.celebornShuffleId(shuffleClient, handle, context, false) + val shuffleId = + try { + SparkUtils.celebornShuffleId(shuffleClient, handle, context, false) + } catch { + case e: CelebornRuntimeException => + logError(s"Failed to get shuffleId for appShuffleId ${handle.shuffleId}", e) + if (handle.throwsFetchFailure) { + throw new FetchFailedException( + null, + handle.shuffleId, + -1, + startPartition, + SparkUtils.FETCH_FAILURE_ERROR_MSG + handle.shuffleId, + e) + } else { + throw e + } + } shuffleIdTracker.track(handle.shuffleId, shuffleId) logDebug( s"get shuffleId $shuffleId for appShuffleId ${handle.shuffleId} attemptNum ${context.stageAttemptNumber()}") diff --git a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java index edaeb28bd..b2e64565e 100644 --- a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java +++ b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java @@ -66,6 +66,7 @@ import org.slf4j.LoggerFactory; import org.apache.celeborn.client.ShuffleClient; import org.apache.celeborn.common.CelebornConf; +import org.apache.celeborn.common.exception.CelebornRuntimeException; import org.apache.celeborn.common.network.protocol.TransportMessage; import org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse; import org.apache.celeborn.common.util.JavaUtils; @@ -138,11 +139,17 @@ public class SparkUtils { Boolean isWriter) { if (handle.throwsFetchFailure()) { String appShuffleIdentifier = getAppShuffleIdentifier(handle.shuffleId(), context); - return client.getShuffleId( - handle.shuffleId(), - appShuffleIdentifier, - isWriter, - context instanceof BarrierTaskContext); + Tuple2 res = + client.getShuffleId( + handle.shuffleId(), + appShuffleIdentifier, + isWriter, + context instanceof BarrierTaskContext); + if (!res._2) { + throw new CelebornRuntimeException(String.format("Get invalid shuffle id %s", res._1)); + } else { + return res._1; + } } else { return handle.shuffleId(); } diff --git a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala index 3e296b310..3d9f14f23 100644 --- a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala +++ b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala @@ -40,7 +40,7 @@ import org.apache.celeborn.client.{ClientUtils, ShuffleClient} import org.apache.celeborn.client.ShuffleClientImpl.ReduceFileGroups import org.apache.celeborn.client.read.{CelebornInputStream, MetricsCallback} import org.apache.celeborn.common.CelebornConf -import org.apache.celeborn.common.exception.{CelebornIOException, PartitionUnRetryAbleException} +import org.apache.celeborn.common.exception.{CelebornIOException, CelebornRuntimeException, PartitionUnRetryAbleException} import org.apache.celeborn.common.network.client.TransportClient import org.apache.celeborn.common.network.protocol.TransportMessage import org.apache.celeborn.common.protocol._ @@ -79,7 +79,25 @@ class CelebornShuffleReader[K, C]( val startTime = System.currentTimeMillis() val serializerInstance = newSerializerInstance(dep) - val shuffleId = SparkUtils.celebornShuffleId(shuffleClient, handle, context, false) + val shuffleId = + try { + SparkUtils.celebornShuffleId(shuffleClient, handle, context, false) + } catch { + case e: CelebornRuntimeException => + logError(s"Failed to get shuffleId for appShuffleId ${handle.shuffleId}", e) + if (throwsFetchFailure) { + throw new FetchFailedException( + null, + handle.shuffleId, + -1, + -1, + startPartition, + SparkUtils.FETCH_FAILURE_ERROR_MSG + handle.shuffleId + "/" + handle.shuffleId, + e) + } else { + throw e + } + } shuffleIdTracker.track(handle.shuffleId, shuffleId) logDebug( s"get shuffleId $shuffleId for appShuffleId ${handle.shuffleId} attemptNum ${context.stageAttemptNumber()}") diff --git a/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java b/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java index 6b3673b18..dd1a032c8 100644 --- a/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java +++ b/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java @@ -32,6 +32,8 @@ import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicInteger; +import scala.Tuple2; + import org.apache.commons.lang3.tuple.Pair; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -182,9 +184,9 @@ public class DummyShuffleClient extends ShuffleClient { } @Override - public int getShuffleId( + public Tuple2 getShuffleId( int appShuffleId, String appShuffleIdentifier, boolean isWriter, boolean isBarrierStage) { - return appShuffleId; + return Tuple2.apply(appShuffleId, true); } @Override diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java index dde2b36c4..bf0192e4a 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java @@ -26,6 +26,8 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.LongAdder; import java.util.function.BiFunction; +import scala.Tuple2; + import org.apache.commons.lang3.tuple.Pair; import org.apache.hadoop.fs.FileSystem; import org.slf4j.Logger; @@ -285,7 +287,7 @@ public abstract class ShuffleClient { public abstract PushState getPushState(String mapKey); - public abstract int getShuffleId( + public abstract Tuple2 getShuffleId( int appShuffleId, String appShuffleIdentifier, boolean isWriter, boolean isBarrierStage); /** diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java index c6fc7f6bd..81329e8d1 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java @@ -103,7 +103,7 @@ public class ShuffleClientImpl extends ShuffleClient { protected byte[] extension; // key: appShuffleIdentifier, value: shuffleId - protected Map shuffleIdCache = JavaUtils.newConcurrentHashMap(); + protected Map> shuffleIdCache = JavaUtils.newConcurrentHashMap(); // key: shuffleId, value: (partitionId, PartitionLocation) final Map> reducePartitionMap = @@ -626,7 +626,7 @@ public class ShuffleClientImpl extends ShuffleClient { } @Override - public int getShuffleId( + public Tuple2 getShuffleId( int appShuffleId, String appShuffleIdentifier, boolean isWriter, boolean isBarrierStage) { return shuffleIdCache.computeIfAbsent( appShuffleIdentifier, @@ -643,7 +643,8 @@ public class ShuffleClientImpl extends ShuffleClient { pbGetShuffleId, conf.clientRpcRegisterShuffleAskTimeout(), ClassTag$.MODULE$.apply(PbGetShuffleIdResponse.class)); - return pbGetShuffleIdResponse.getShuffleId(); + return Tuple2.apply( + pbGetShuffleIdResponse.getShuffleId(), pbGetShuffleIdResponse.getSuccess()); }); } diff --git a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala index 923d2f838..20e1099d2 100644 --- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala +++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala @@ -892,7 +892,8 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends if (shuffleIds == null) { logWarning(s"unknown appShuffleId $appShuffleId, maybe no shuffle data for this shuffle") val pbGetShuffleIdResponse = - PbGetShuffleIdResponse.newBuilder().setShuffleId(UNKNOWN_APP_SHUFFLE_ID).build() + PbGetShuffleIdResponse.newBuilder().setShuffleId(UNKNOWN_APP_SHUFFLE_ID).setSuccess( + true).build() context.reply(pbGetShuffleIdResponse) return } @@ -906,7 +907,7 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends shuffleIds.get(appShuffleIdentifier) match { case Some((shuffleId, _)) => val pbGetShuffleIdResponse = - PbGetShuffleIdResponse.newBuilder().setShuffleId(shuffleId).build() + PbGetShuffleIdResponse.newBuilder().setShuffleId(shuffleId).setSuccess(true).build() context.reply(pbGetShuffleIdResponse) case None => Option(appShuffleDeterminateMap.get(appShuffleId)).map { determinate => @@ -940,7 +941,7 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends newShuffleId } val pbGetShuffleIdResponse = - PbGetShuffleIdResponse.newBuilder().setShuffleId(shuffleId).build() + PbGetShuffleIdResponse.newBuilder().setShuffleId(shuffleId).setSuccess(true).build() context.reply(pbGetShuffleIdResponse) }.orElse( throw new UnsupportedOperationException( @@ -953,12 +954,17 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends val pbGetShuffleIdResponse = { logDebug( s"get shuffleId $shuffleId for appShuffleId $appShuffleId appShuffleIdentifier $appShuffleIdentifier isWriter $isWriter") - PbGetShuffleIdResponse.newBuilder().setShuffleId(shuffleId).build() + PbGetShuffleIdResponse.newBuilder().setShuffleId(shuffleId).setSuccess(true).build() } context.reply(pbGetShuffleIdResponse) case None => - throw new UnsupportedOperationException( - s"unexpected! there is no finished map stage associated with appShuffleId $appShuffleId") + val pbGetShuffleIdResponse = { + logInfo( + s"there is no finished map stage associated with appShuffleId $appShuffleId") + PbGetShuffleIdResponse.newBuilder().setShuffleId(UNKNOWN_APP_SHUFFLE_ID).setSuccess( + false).build() + } + context.reply(pbGetShuffleIdResponse) } } } diff --git a/common/src/main/proto/TransportMessages.proto b/common/src/main/proto/TransportMessages.proto index acf355756..7b0d0bec2 100644 --- a/common/src/main/proto/TransportMessages.proto +++ b/common/src/main/proto/TransportMessages.proto @@ -403,6 +403,7 @@ message PbGetShuffleId { message PbGetShuffleIdResponse { int32 shuffleId = 1; + bool success = 2; } message PbReportShuffleFetchFailure { diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala index 055b763e5..e29b21a0c 100644 --- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala @@ -139,7 +139,8 @@ trait SparkTestBase extends AnyFunSuite conf, h.userIdentifier, h.extension) - val celebornShuffleId = SparkUtils.celebornShuffleId(shuffleClient, h, context, false) + val celebornShuffleId = + SparkUtils.celebornShuffleId(shuffleClient, h, context, false) val allFiles = workerDirs.map(dir => { new File(s"$dir/celeborn-worker/shuffle_data/$appUniqueId/$celebornShuffleId") })