diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java index 5f84d6921..06035575b 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java @@ -95,6 +95,9 @@ public class ShuffleClientImpl extends ShuffleClient { // key: shuffleId-mapId-attemptId protected final Map pushStates = JavaUtils.newConcurrentHashMap(); + private final boolean shuffleClientPushBlacklistEnabled; + private final Set blacklist = ConcurrentHashMap.newKeySet(); + private final ExecutorService pushDataRetryPool; private final ExecutorService partitionSplitPool; @@ -148,6 +151,7 @@ public class ShuffleClientImpl extends ShuffleClient { maxReviveTimes = conf.pushMaxReviveTimes(); testRetryRevive = conf.testRetryRevive(); pushBufferMaxSize = conf.pushBufferMaxSize(); + shuffleClientPushBlacklistEnabled = conf.shuffleClientPushBlacklistEnabled(); if (conf.pushReplicateEnabled()) { pushDataTimeout = conf.pushDataTimeoutMs() * 2; } else { @@ -175,6 +179,21 @@ public class ShuffleClientImpl extends ShuffleClient { "celeborn-shuffle-split", pushSplitPartitionThreads, 60); } + private boolean checkPushBlacklisted( + PartitionLocation location, RpcResponseCallback wrappedCallback) { + // If shuffleClientBlacklistEnabled = false, blacklist should be empty. + if (blacklist.contains(location.hostAndPushPort())) { + wrappedCallback.onFailure(new CelebornIOException(StatusCode.PUSH_DATA_MASTER_BLACKLISTED)); + return true; + } else if (location.getPeer() != null + && blacklist.contains(location.getPeer().hostAndPushPort())) { + wrappedCallback.onFailure(new CelebornIOException(StatusCode.PUSH_DATA_SLAVE_BLACKLISTED)); + return true; + } else { + return false; + } + } + private void submitRetryPushData( String applicationId, int shuffleId, @@ -213,18 +232,22 @@ public class ShuffleClientImpl extends ShuffleClient { batchId, newLoc); try { - if (!testRetryRevive || remainReviveTimes < 1) { - TransportClient client = - dataClientFactory.createClient(newLoc.getHost(), newLoc.getPushPort(), partitionId); - NettyManagedBuffer newBuffer = new NettyManagedBuffer(Unpooled.wrappedBuffer(body)); - String shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId); + if (!checkPushBlacklisted(newLoc, wrappedCallback)) { + if (!testRetryRevive || remainReviveTimes < 1) { + TransportClient client = + dataClientFactory.createClient(newLoc.getHost(), newLoc.getPushPort(), partitionId); + NettyManagedBuffer newBuffer = new NettyManagedBuffer(Unpooled.wrappedBuffer(body)); + String shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId); - PushData newPushData = - new PushData(MASTER_MODE, shuffleKey, newLoc.getUniqueId(), newBuffer); - client.pushData(newPushData, pushDataTimeout, wrappedCallback); - } else { - throw new RuntimeException( - "Mock push data submit retry failed. remainReviveTimes = " + remainReviveTimes + "."); + PushData newPushData = + new PushData(MASTER_MODE, shuffleKey, newLoc.getUniqueId(), newBuffer); + client.pushData(newPushData, pushDataTimeout, wrappedCallback); + } else { + throw new RuntimeException( + "Mock push data submit retry failed. remainReviveTimes = " + + remainReviveTimes + + "."); + } } } catch (Exception e) { logger.error( @@ -505,6 +528,23 @@ public class ShuffleClientImpl extends ShuffleClient { int epoch, PartitionLocation oldLocation, StatusCode cause) { + // Add ShuffleClient side blacklist + if (shuffleClientPushBlacklistEnabled) { + if (cause == StatusCode.PUSH_DATA_CREATE_CONNECTION_FAIL_MASTER) { + blacklist.add(oldLocation.hostAndPushPort()); + } else if (cause == StatusCode.PUSH_DATA_CONNECTION_EXCEPTION_MASTER) { + blacklist.add(oldLocation.hostAndPushPort()); + } else if (cause == StatusCode.PUSH_DATA_TIMEOUT_MASTER) { + blacklist.add(oldLocation.hostAndPushPort()); + } else if (cause == StatusCode.PUSH_DATA_CREATE_CONNECTION_FAIL_SLAVE) { + blacklist.add(oldLocation.getPeer().hostAndPushPort()); + } else if (cause == StatusCode.PUSH_DATA_CONNECTION_EXCEPTION_SLAVE) { + blacklist.add(oldLocation.getPeer().hostAndPushPort()); + } else if (cause == StatusCode.PUSH_DATA_TIMEOUT_SLAVE) { + blacklist.add(oldLocation.getPeer().hostAndPushPort()); + } + } + ConcurrentHashMap map = reducePartitionMap.get(shuffleId); if (waitRevivedLocation(map, partitionId, epoch)) { logger.debug( @@ -815,7 +855,10 @@ public class ShuffleClientImpl extends ShuffleClient { e); // async retry push data if (!mapperEnded(shuffleId, mapId, attemptId)) { - remainReviveTimes = remainReviveTimes - 1; + // For blacklisted partition location, Celeborn should not use retry quota. + if (!pushStatusIsBlacklisted(cause)) { + remainReviveTimes = remainReviveTimes - 1; + } pushDataRetryPool.submit( () -> submitRetryPushData( @@ -847,15 +890,17 @@ public class ShuffleClientImpl extends ShuffleClient { // do push data try { - if (!testRetryRevive) { - TransportClient client = - dataClientFactory.createClient(loc.getHost(), loc.getPushPort(), partitionId); - client.pushData(pushData, pushDataTimeout, wrappedCallback); - } else { - wrappedCallback.onFailure( - new CelebornIOException( - StatusCode.PUSH_DATA_FAIL_NON_CRITICAL_CAUSE, - new RuntimeException("Mock push data first time failed."))); + if (!checkPushBlacklisted(loc, wrappedCallback)) { + if (!testRetryRevive) { + TransportClient client = + dataClientFactory.createClient(loc.getHost(), loc.getPushPort(), partitionId); + client.pushData(pushData, pushDataTimeout, wrappedCallback); + } else { + wrappedCallback.onFailure( + new CelebornIOException( + StatusCode.PUSH_DATA_FAIL_NON_CRITICAL_CAUSE, + new RuntimeException("Mock push data first time failed."))); + } } } catch (Exception e) { logger.error( @@ -1196,6 +1241,12 @@ public class ShuffleClientImpl extends ShuffleClient { remainReviveTimes, e); if (!mapperEnded(shuffleId, mapId, attemptId)) { + int tmpRemainReviveTimes = remainReviveTimes; + // For blacklisted partition location, Celeborn should not use retry quota. + if (!pushStatusIsBlacklisted(cause)) { + tmpRemainReviveTimes = tmpRemainReviveTimes - 1; + } + int finalRemainReviveTimes = tmpRemainReviveTimes; pushDataRetryPool.submit( () -> submitRetryPushMergedData( @@ -1207,21 +1258,23 @@ public class ShuffleClientImpl extends ShuffleClient { batches, cause, groupedBatchId, - remainReviveTimes - 1)); + finalRemainReviveTimes)); } } }; // do push merged data try { - if (!testRetryRevive || remainReviveTimes < 1) { - TransportClient client = dataClientFactory.createClient(host, port); - client.pushMergedData(mergedData, pushDataTimeout, wrappedCallback); - } else { - wrappedCallback.onFailure( - new CelebornIOException( - StatusCode.PUSH_DATA_FAIL_NON_CRITICAL_CAUSE, - new RuntimeException("Mock push merge data failed."))); + if (!checkPushBlacklisted(batches.get(0).loc, wrappedCallback)) { + if (!testRetryRevive || remainReviveTimes < 1) { + TransportClient client = dataClientFactory.createClient(host, port); + client.pushMergedData(mergedData, pushDataTimeout, wrappedCallback); + } else { + wrappedCallback.onFailure( + new CelebornIOException( + StatusCode.PUSH_DATA_FAIL_NON_CRITICAL_CAUSE, + new RuntimeException("Mock push merge data failed."))); + } } } catch (Exception e) { logger.error( @@ -1460,6 +1513,11 @@ public class ShuffleClientImpl extends ShuffleClient { && mapperEndMap.get(shuffleId).contains(Utils.makeMapKey(shuffleId, mapId, attemptId)); } + private boolean pushStatusIsBlacklisted(StatusCode cause) { + return cause == StatusCode.PUSH_DATA_MASTER_BLACKLISTED + || cause == StatusCode.PUSH_DATA_SLAVE_BLACKLISTED; + } + private StatusCode getPushDataFailCause(String message) { logger.debug("Push data failed cause message: " + message); StatusCode cause; @@ -1481,6 +1539,10 @@ public class ShuffleClientImpl extends ShuffleClient { cause = StatusCode.PUSH_DATA_TIMEOUT_SLAVE; } else if (message.startsWith(StatusCode.REPLICATE_DATA_FAILED.name())) { cause = StatusCode.REPLICATE_DATA_FAILED; + } else if (message.startsWith(StatusCode.PUSH_DATA_MASTER_BLACKLISTED.name())) { + cause = StatusCode.PUSH_DATA_MASTER_BLACKLISTED; + } else if (message.startsWith(StatusCode.PUSH_DATA_SLAVE_BLACKLISTED.name())) { + cause = StatusCode.PUSH_DATA_SLAVE_BLACKLISTED; } else if (connectFail(message)) { // Throw when push to master worker connection causeException. cause = StatusCode.PUSH_DATA_CONNECTION_EXCEPTION_MASTER; diff --git a/common/src/main/java/org/apache/celeborn/common/protocol/message/StatusCode.java b/common/src/main/java/org/apache/celeborn/common/protocol/message/StatusCode.java index 407db99f3..670d1efea 100644 --- a/common/src/main/java/org/apache/celeborn/common/protocol/message/StatusCode.java +++ b/common/src/main/java/org/apache/celeborn/common/protocol/message/StatusCode.java @@ -73,7 +73,9 @@ public enum StatusCode { PUSH_DATA_CONNECTION_EXCEPTION_MASTER(40), PUSH_DATA_CONNECTION_EXCEPTION_SLAVE(41), PUSH_DATA_TIMEOUT_MASTER(42), - PUSH_DATA_TIMEOUT_SLAVE(43); + PUSH_DATA_TIMEOUT_SLAVE(43), + PUSH_DATA_MASTER_BLACKLISTED(44), + PUSH_DATA_SLAVE_BLACKLISTED(45); private final byte value; diff --git a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala index 9aae0fcbf..3d35a10ae 100644 --- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala +++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala @@ -541,6 +541,7 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se def shuffleRangeReadFilterEnabled: Boolean = get(SHUFFLE_RANGE_READ_FILTER_ENABLED) def shufflePartitionType: PartitionType = PartitionType.valueOf(get(SHUFFLE_PARTITION_TYPE)) def requestCommitFilesMaxRetries: Int = get(COMMIT_FILE_REQUEST_MAX_RETRY) + def shuffleClientPushBlacklistEnabled: Boolean = get(SHUFFLE_CLIENT_PUSH_BLACKLIST_ENABLED) // ////////////////////////////////////////////////////// // Shuffle Compression // @@ -1504,6 +1505,14 @@ object CelebornConf extends Logging { .booleanConf .createWithDefault(false) + val SHUFFLE_CLIENT_PUSH_BLACKLIST_ENABLED: ConfigEntry[Boolean] = + buildConf(" celeborn.client.push.blacklist.enabled") + .categories("client") + .doc("Whether to enable shuffle client-side push blacklist of workers.") + .version("0.3.0") + .booleanConf + .createWithDefault(false) + val MASTER_HOST: ConfigEntry[String] = buildConf("celeborn.master.host") .categories("master") diff --git a/common/src/main/scala/org/apache/celeborn/common/util/Utils.scala b/common/src/main/scala/org/apache/celeborn/common/util/Utils.scala index 95401fab5..77034008d 100644 --- a/common/src/main/scala/org/apache/celeborn/common/util/Utils.scala +++ b/common/src/main/scala/org/apache/celeborn/common/util/Utils.scala @@ -909,6 +909,10 @@ object Utils extends Logging { StatusCode.PUSH_DATA_TIMEOUT_MASTER case 43 => StatusCode.PUSH_DATA_TIMEOUT_SLAVE + case 44 => + StatusCode.PUSH_DATA_MASTER_BLACKLISTED + case 45 => + StatusCode.PUSH_DATA_SLAVE_BLACKLISTED case _ => null } diff --git a/docs/configuration/client.md b/docs/configuration/client.md index b3bc3a32c..9307ee4e0 100644 --- a/docs/configuration/client.md +++ b/docs/configuration/client.md @@ -19,6 +19,7 @@ license: | | Key | Default | Description | Since | | --- | ------- | ----------- | ----- | +| celeborn.client.push.blacklist.enabled | false | Whether to enable shuffle client-side push blacklist of workers. | 0.3.0 | | celeborn.application.heartbeatInterval | 10s | Interval for client to send heartbeat message to master. | 0.2.0 | | celeborn.client.blacklistSlave.enabled | true | When true, Celeborn will add partition's peer worker into blacklist when push data to slave failed. | 0.3.0 | | celeborn.client.closeIdleConnections | true | Whether client will close idle connections. | 0.3.0 |