diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java index bf61500bd..32c0da758 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java @@ -25,7 +25,6 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutorService; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; -import java.util.concurrent.atomic.AtomicInteger; import java.util.function.BooleanSupplier; import scala.reflect.ClassTag$; @@ -84,11 +83,8 @@ public class ShuffleClientImpl extends ShuffleClient { private final int registerShuffleMaxRetries; private final long registerShuffleRetryWaitMs; - private int maxInFlight; private int maxReviveTimes; private boolean testRetryRevive; - private final AtomicInteger currentMaxReqsInFlight; - private int congestionAvoidanceFlag = 0; private final int pushBufferMaxSize; private final long pushDataTimeout; @@ -143,16 +139,8 @@ public class ShuffleClientImpl extends ShuffleClient { this.userIdentifier = userIdentifier; registerShuffleMaxRetries = conf.registerShuffleMaxRetry(); registerShuffleRetryWaitMs = conf.registerShuffleRetryWaitMs(); - maxInFlight = conf.pushMaxReqsInFlight(); maxReviveTimes = conf.pushMaxReviveTimes(); testRetryRevive = conf.testRetryRevive(); - - if (conf.pushDataSlowStart()) { - currentMaxReqsInFlight = new AtomicInteger(1); - } else { - currentMaxReqsInFlight = new AtomicInteger(maxInFlight); - } - pushBufferMaxSize = conf.pushBufferMaxSize(); if (conf.pushReplicateEnabled()) { pushDataTimeout = conf.pushDataTimeoutMs() * 2; @@ -423,9 +411,9 @@ public class ShuffleClientImpl extends ShuffleClient { return null; } - private void limitMaxInFlight( - String mapKey, PushState pushState, int limit, String hostAndPushPort) throws IOException { - boolean reachLimit = pushState.limitMaxInFlight(hostAndPushPort, limit); + private void limitMaxInFlight(String mapKey, PushState pushState, String hostAndPushPort) + throws IOException { + boolean reachLimit = pushState.limitMaxInFlight(hostAndPushPort); if (reachLimit) { throw new IOException("wait timeout for task " + mapKey, pushState.exception.get()); @@ -633,7 +621,7 @@ public class ShuffleClientImpl extends ShuffleClient { partitionId, nextBatchId); // check limit - limitMaxInFlight(mapKey, pushState, currentMaxReqsInFlight.get(), loc.hostAndPushPort()); + limitMaxInFlight(mapKey, pushState, loc.hostAndPushPort()); // add inFlight requests pushState.addBatch(nextBatchId, loc.hostAndPushPort()); @@ -694,7 +682,7 @@ public class ShuffleClientImpl extends ShuffleClient { attemptId, nextBatchId); splitPartition(shuffleId, partitionId, applicationId, loc); - slowStart(); + pushState.onSuccess(nextBatchId, loc.hostAndPushPort()); callback.onSuccess(response); } else if (reason == StatusCode.HARD_SPLIT.getValue()) { logger.debug( @@ -722,7 +710,7 @@ public class ShuffleClientImpl extends ShuffleClient { mapId, attemptId, nextBatchId); - congestionControl(); + pushState.onCongestControl(nextBatchId, loc.hostAndPushPort()); callback.onSuccess(response); } else if (reason == StatusCode.PUSH_DATA_SUCCESS_SLAVE_CONGESTED.getValue()) { logger.debug( @@ -730,15 +718,15 @@ public class ShuffleClientImpl extends ShuffleClient { mapId, attemptId, nextBatchId); - congestionControl(); + pushState.onCongestControl(nextBatchId, loc.hostAndPushPort()); callback.onSuccess(response); } else { response.rewind(); - slowStart(); + pushState.onSuccess(nextBatchId, loc.hostAndPushPort()); callback.onSuccess(response); } } else { - slowStart(); + pushState.onSuccess(nextBatchId, loc.hostAndPushPort()); callback.onSuccess(response); } } @@ -831,7 +819,7 @@ public class ShuffleClientImpl extends ShuffleClient { String addressPair = genAddressPair(loc); boolean shouldPush = pushState.addBatchData(addressPair, loc, nextBatchId, body); if (shouldPush) { - limitMaxInFlight(mapKey, pushState, maxInFlight, loc.hostAndPushPort()); + limitMaxInFlight(mapKey, pushState, loc.hostAndPushPort()); DataBatches dataBatches = pushState.takeDataBatches(addressPair); doPushMergedData( addressPair.split("-")[0], @@ -951,7 +939,7 @@ public class ShuffleClientImpl extends ShuffleClient { while (!batchesArr.isEmpty()) { Map.Entry entry = batchesArr.get(RND.nextInt(batchesArr.size())); String[] tokens = entry.getKey().split("-"); - limitMaxInFlight(mapKey, pushState, currentMaxReqsInFlight.get(), tokens[0]); + limitMaxInFlight(mapKey, pushState, tokens[0]); ArrayList batches = entry.getValue().requireBatches(pushBufferMaxSize); if (entry.getValue().getTotalSize() == 0) { batchesArr.remove(entry); @@ -1079,7 +1067,7 @@ public class ShuffleClientImpl extends ShuffleClient { mapId, attemptId, Arrays.toString(batchIds)); - congestionControl(); + pushState.onCongestControl(groupedBatchId, hostPort); callback.onSuccess(response); } else if (reason == StatusCode.PUSH_DATA_SUCCESS_SLAVE_CONGESTED.getValue()) { logger.debug( @@ -1087,17 +1075,17 @@ public class ShuffleClientImpl extends ShuffleClient { mapId, attemptId, Arrays.toString(batchIds)); - congestionControl(); + pushState.onCongestControl(groupedBatchId, hostPort); callback.onSuccess(response); } else { // Should not happen in current architecture. response.rewind(); logger.error("Push merged data should not receive this response"); - slowStart(); + pushState.onSuccess(groupedBatchId, hostPort); callback.onSuccess(response); } } else { - slowStart(); + pushState.onSuccess(groupedBatchId, hostPort); callback.onSuccess(response); } } @@ -1373,47 +1361,6 @@ public class ShuffleClientImpl extends ShuffleClient { driverRssMetaService = endpointRef; } - /** - * If `pushDataSlowStart` is enabled, will increase `currentMaxReqsInFlight` gradually to meet the - * max push speed. - * - *

1. slow start period: every RTT period, `currentMaxReqsInFlight` is multiplied. - * - *

2. congestion avoidance: every RTT period, `currentMaxReqsInFlight` plus 1. - * - *

Note that here we define one RTT period: one batch(currentMaxReqsInFlight) of push data - * requests. - */ - private void slowStart() { - if (conf.pushDataSlowStart()) { - synchronized (currentMaxReqsInFlight) { - if (currentMaxReqsInFlight.get() > maxInFlight) { - // Congestion avoidance - congestionAvoidanceFlag++; - if (congestionAvoidanceFlag >= currentMaxReqsInFlight.get()) { - currentMaxReqsInFlight.incrementAndGet(); - congestionAvoidanceFlag = 0; - } - } else { - // Slow start - currentMaxReqsInFlight.incrementAndGet(); - } - } - } - } - - private void congestionControl() { - synchronized (currentMaxReqsInFlight) { - if (currentMaxReqsInFlight.get() <= 1) { - currentMaxReqsInFlight.set(1); - } else { - currentMaxReqsInFlight.updateAndGet(pre -> pre / 2); - } - maxInFlight = currentMaxReqsInFlight.get(); - congestionAvoidanceFlag = 0; - } - } - private boolean mapperEnded(int shuffleId, int mapId, int attemptId) { return mapperEndMap.containsKey(shuffleId) && mapperEndMap.get(shuffleId).contains(Utils.makeMapKey(shuffleId, mapId, attemptId)); @@ -1503,7 +1450,7 @@ public class ShuffleClientImpl extends ShuffleClient { partitionId, nextBatchId); // check limit - limitMaxInFlight(mapKey, pushState, maxInFlight, location.hostAndPushPort()); + limitMaxInFlight(mapKey, pushState, location.hostAndPushPort()); // add inFlight requests pushState.addBatch(nextBatchId, location.hostAndPushPort()); diff --git a/common/src/main/java/org/apache/celeborn/common/write/InFlightRequestTracker.java b/common/src/main/java/org/apache/celeborn/common/write/InFlightRequestTracker.java index 26ce3c41a..336bba9c5 100644 --- a/common/src/main/java/org/apache/celeborn/common/write/InFlightRequestTracker.java +++ b/common/src/main/java/org/apache/celeborn/common/write/InFlightRequestTracker.java @@ -40,6 +40,7 @@ public class InFlightRequestTracker { private final long waitInflightTimeoutMs; private final long delta; private final PushState pushState; + private final PushStrategy pushStrategy; private final AtomicInteger batchId = new AtomicInteger(); private final ConcurrentHashMap> @@ -49,6 +50,7 @@ public class InFlightRequestTracker { this.waitInflightTimeoutMs = conf.pushLimitInFlightTimeoutMs(); this.delta = conf.pushLimitInFlightSleepDeltaMs(); this.pushState = pushState; + this.pushStrategy = PushStrategy.getStrategy(conf); } public void addBatch(int batchId, String hostAndPushPort) { @@ -57,6 +59,14 @@ public class InFlightRequestTracker { batchIdSetPerPair.computeIfAbsent(batchId, id -> new BatchInfo()); } + public void onSuccess(int batchId, String hostAndPushPort) { + pushStrategy.onSuccess(hostAndPushPort); + } + + public void onCongestControl(int batchId, String hostAndPushPort) { + pushStrategy.onCongestControl(hostAndPushPort); + } + public void removeBatch(int batchId, String hostAndPushPort) { ConcurrentHashMap batchIdMap = inflightBatchesPerAddress.get(hostAndPushPort); @@ -74,16 +84,19 @@ public class InFlightRequestTracker { hostAndPort, pair -> new ConcurrentHashMap<>()); } - public boolean limitMaxInFlight(String hostAndPushPort, int maxInFlight) throws IOException { + public boolean limitMaxInFlight(String hostAndPushPort) throws IOException { if (pushState.exception.get() != null) { throw pushState.exception.get(); } + pushStrategy.limitPushSpeed(pushState, hostAndPushPort); + int currentMaxReqsInFlight = pushStrategy.getCurrentMaxReqsInFlight(hostAndPushPort); + ConcurrentHashMap batchIdMap = getBatchIdSetByAddressPair(hostAndPushPort); long times = waitInflightTimeoutMs / delta; try { while (times > 0) { - if (batchIdMap.size() <= maxInFlight) { + if (batchIdMap.size() <= currentMaxReqsInFlight) { break; } if (pushState.exception.get() != null) { @@ -102,11 +115,11 @@ public class InFlightRequestTracker { "After waiting for {} ms, " + "there are still {} batches in flight " + "for hostAndPushPort {}, " - + "which exceeds the limit {}.", + + "which exceeds the current limit {}.", waitInflightTimeoutMs, batchIdMap.size(), hostAndPushPort, - maxInFlight); + currentMaxReqsInFlight); } if (pushState.exception.get() != null) { @@ -224,6 +237,7 @@ public class InFlightRequestTracker { })); inflightBatchesPerAddress.clear(); } + pushStrategy.clear(); } static class BatchInfo { diff --git a/common/src/main/java/org/apache/celeborn/common/write/PushState.java b/common/src/main/java/org/apache/celeborn/common/write/PushState.java index 5397e0655..dd342f2cf 100644 --- a/common/src/main/java/org/apache/celeborn/common/write/PushState.java +++ b/common/src/main/java/org/apache/celeborn/common/write/PushState.java @@ -101,12 +101,20 @@ public class PushState { inFlightRequestTracker.addBatch(batchId, hostAndPushPort); } + public void onSuccess(int batchId, String hostAndPushPort) { + inFlightRequestTracker.onSuccess(batchId, hostAndPushPort); + } + + public void onCongestControl(int batchId, String hostAndPushPort) { + inFlightRequestTracker.onCongestControl(batchId, hostAndPushPort); + } + public void removeBatch(int batchId, String hostAndPushPort) { inFlightRequestTracker.removeBatch(batchId, hostAndPushPort); } - public boolean limitMaxInFlight(String hostAndPushPort, int maxInFlight) throws IOException { - return inFlightRequestTracker.limitMaxInFlight(hostAndPushPort, maxInFlight); + public boolean limitMaxInFlight(String hostAndPushPort) throws IOException { + return inFlightRequestTracker.limitMaxInFlight(hostAndPushPort); } public boolean limitZeroInFlight() throws IOException { diff --git a/common/src/main/java/org/apache/celeborn/common/write/PushStrategy.java b/common/src/main/java/org/apache/celeborn/common/write/PushStrategy.java new file mode 100644 index 000000000..e8e13f1bb --- /dev/null +++ b/common/src/main/java/org/apache/celeborn/common/write/PushStrategy.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.common.write; + +import java.io.IOException; + +import org.apache.celeborn.common.CelebornConf; + +public abstract class PushStrategy { + + protected final CelebornConf conf; + + public static PushStrategy getStrategy(CelebornConf conf) { + String strategyName = conf.pushLimitStrategy(); + switch (strategyName) { + case "SIMPLE": + return new SimplePushStrategy(conf); + case "SLOWSTART": + return new SlowStartPushStrategy(conf); + default: + throw new IllegalArgumentException("The strategy " + strategyName + " is not supported!"); + } + } + + public PushStrategy(CelebornConf conf) { + this.conf = conf; + } + + /** Handle the response is successful. */ + public abstract void onSuccess(String hostAndPushPort); + + /** Handle the response is congested controlled. */ + public abstract void onCongestControl(String hostAndPushPort); + + public abstract void clear(); + + /** Control the push speed to meet the requirement. */ + public abstract void limitPushSpeed(PushState pushState, String hostAndPushPort) + throws IOException; + + public abstract int getCurrentMaxReqsInFlight(String hostAndPushPort); +} diff --git a/common/src/main/java/org/apache/celeborn/common/write/SimplePushStrategy.java b/common/src/main/java/org/apache/celeborn/common/write/SimplePushStrategy.java new file mode 100644 index 000000000..c659fff04 --- /dev/null +++ b/common/src/main/java/org/apache/celeborn/common/write/SimplePushStrategy.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.common.write; + +import java.io.IOException; + +import org.apache.celeborn.common.CelebornConf; + +/** A Simple strategy that control the push speed by a solid configure, pushMaxReqsInFlight. */ +public class SimplePushStrategy extends PushStrategy { + + private final int maxInFlight; + + public SimplePushStrategy(CelebornConf conf) { + super(conf); + this.maxInFlight = conf.pushMaxReqsInFlight(); + } + + @Override + public void onSuccess(String hostAndPushPort) { + // No op + } + + @Override + public void onCongestControl(String hostAndPushPort) { + // No op + } + + @Override + public void clear() { + // No op + } + + @Override + public void limitPushSpeed(PushState pushState, String hostAndPushPort) throws IOException { + // No op + } + + @Override + public int getCurrentMaxReqsInFlight(String hostAndPushPort) { + return maxInFlight; + } +} diff --git a/common/src/main/java/org/apache/celeborn/common/write/SlowStartPushStrategy.java b/common/src/main/java/org/apache/celeborn/common/write/SlowStartPushStrategy.java new file mode 100644 index 000000000..5d29b7643 --- /dev/null +++ b/common/src/main/java/org/apache/celeborn/common/write/SlowStartPushStrategy.java @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.common.write; + +import java.io.IOException; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; + +import com.google.common.annotations.VisibleForTesting; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.celeborn.common.CelebornConf; + +public class SlowStartPushStrategy extends PushStrategy { + + protected static class CongestControlContext { + private final AtomicInteger currentMaxReqsInFlight; + + // Indicate the number of congested times even after the in flight requests reduced to 1 + private final AtomicInteger continueCongestedNumber; + private int congestionAvoidanceFlag; + private int reqsInFlightBlockThreshold; + + public CongestControlContext(int reqsInFlightBlockThreshold) { + this.currentMaxReqsInFlight = new AtomicInteger(1); + this.continueCongestedNumber = new AtomicInteger(0); + this.congestionAvoidanceFlag = 0; + this.reqsInFlightBlockThreshold = reqsInFlightBlockThreshold; + } + + public synchronized void increaseCurrentMaxReqs() { + continueCongestedNumber.set(0); + if (currentMaxReqsInFlight.get() >= reqsInFlightBlockThreshold) { + // Congestion avoidance + congestionAvoidanceFlag++; + if (congestionAvoidanceFlag >= currentMaxReqsInFlight.get()) { + currentMaxReqsInFlight.incrementAndGet(); + congestionAvoidanceFlag = 0; + } + } else { + // Slow start + currentMaxReqsInFlight.incrementAndGet(); + } + } + + public synchronized void decreaseCurrentMaxReqs() { + if (currentMaxReqsInFlight.get() <= 1) { + currentMaxReqsInFlight.set(1); + continueCongestedNumber.incrementAndGet(); + } else { + currentMaxReqsInFlight.updateAndGet(pre -> pre / 2); + } + reqsInFlightBlockThreshold = currentMaxReqsInFlight.get(); + congestionAvoidanceFlag = 0; + } + + public int getCurrentMaxReqsInFlight() { + return currentMaxReqsInFlight.get(); + } + + public int getContinueCongestedNumber() { + return continueCongestedNumber.get(); + } + } + + private static final Logger logger = LoggerFactory.getLogger(SlowStartPushStrategy.class); + + private final int maxInFlight; + private final long initialSleepMills; + private final long maxSleepMills; + private final ConcurrentHashMap congestControlInfoPerAddress; + + public SlowStartPushStrategy(CelebornConf conf) { + super(conf); + this.maxInFlight = conf.pushMaxReqsInFlight(); + this.initialSleepMills = conf.pushSlowStartInitialSleepTime(); + this.maxSleepMills = conf.pushSlowStartMaxSleepMills(); + this.congestControlInfoPerAddress = new ConcurrentHashMap<>(); + } + + @VisibleForTesting + protected CongestControlContext getCongestControlContextByAddress(String hostAndPushPort) { + return congestControlInfoPerAddress.computeIfAbsent( + hostAndPushPort, host -> new CongestControlContext(maxInFlight)); + } + + /** + * If `pushDataSlowStart` is enabled, will increase `currentMaxReqsInFlight` gradually to meet the + * max push speed. + * + *

1. slow start period: every RTT period, `currentMaxReqsInFlight` is doubled. + * + *

2. congestion avoidance: every RTT period, `currentMaxReqsInFlight` plus 1. + * + *

Note that here we define one RTT period: one batch(currentMaxReqsInFlight) of push data + * requests. + */ + @Override + public void onSuccess(String hostAndPushPort) { + CongestControlContext congestControlContext = + getCongestControlContextByAddress(hostAndPushPort); + congestControlContext.increaseCurrentMaxReqs(); + } + + @Override + public void onCongestControl(String hostAndPushPort) { + CongestControlContext congestControlContext = + getCongestControlContextByAddress(hostAndPushPort); + congestControlContext.decreaseCurrentMaxReqs(); + } + + protected long getSleepTime(CongestControlContext context) { + int currentMaxReqs = context.getCurrentMaxReqsInFlight(); + if (currentMaxReqs >= conf.pushMaxReqsInFlight()) { + return 0; + } + + long sleepInterval = initialSleepMills - 60L * currentMaxReqs; + + if (currentMaxReqs == 1) { + return Math.min(sleepInterval + context.getContinueCongestedNumber() * 1000L, maxSleepMills); + } + + return Math.max(sleepInterval, 0); + } + + @Override + public void limitPushSpeed(PushState pushState, String hostAndPushPort) throws IOException { + if (pushState.exception.get() != null) { + throw pushState.exception.get(); + } + CongestControlContext congestControlContext = + getCongestControlContextByAddress(hostAndPushPort); + long sleepInterval = getSleepTime(congestControlContext); + if (sleepInterval > 0L) { + try { + logger.debug("Will sleep {} ms to control the push speed.", sleepInterval); + Thread.sleep(sleepInterval); + } catch (InterruptedException e) { + pushState.exception.set(new IOException(e)); + } + } + } + + @Override + public int getCurrentMaxReqsInFlight(String hostAndPushPort) { + return getCongestControlContextByAddress(hostAndPushPort).getCurrentMaxReqsInFlight(); + } + + @Override + public void clear() { + congestControlInfoPerAddress.clear(); + } +} diff --git a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala index 71bb0cda4..933331691 100644 --- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala +++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala @@ -667,6 +667,9 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se def pushRetryThreads: Int = get(PUSH_RETRY_THREADS) def pushStageEndTimeout: Long = get(PUSH_STAGE_END_TIMEOUT).getOrElse(get(RPC_ASK_TIMEOUT) * (requestCommitFilesMaxRetries + 1)) + def pushLimitStrategy: String = get(PUSH_LIMIT_STRATEGY) + def pushSlowStartInitialSleepTime: Long = get(PUSH_SLOW_START_INITIAL_SLEEP_TIME) + def pushSlowStartMaxSleepMills: Long = get(PUSH_SLOW_START_MAX_SLEEP_TIME) def pushLimitInFlightTimeoutMs: Long = if (pushReplicateEnabled) { get(PUSH_LIMIT_IN_FLIGHT_TIMEOUT).getOrElse(pushDataTimeoutMs * 4) @@ -704,8 +707,6 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se .getOrElse(rpcAskTimeout.duration * (requestCommitFilesMaxRetries + 2)), GET_REDUCER_FILE_GROUP_RPC_ASK_TIMEOUT.key) - def pushDataSlowStart: Boolean = get(PUSH_DATA_SLOW_START) - // ////////////////////////////////////////////////////// // Graceful Shutdown & Recover // // ////////////////////////////////////////////////////// @@ -2158,6 +2159,35 @@ object CelebornConf extends Logging { .version("0.2.0") .timeConf(TimeUnit.MILLISECONDS) .createOptional + val PUSH_LIMIT_STRATEGY: ConfigEntry[String] = + buildConf("celeborn.push.limit.strategy") + .categories("client") + .doc("The strategy used to control the push speed. " + + "Valid strategies are SIMPLE and SLOWSTART. the SLOWSTART strategy is usually cooperate with " + + "congest control mechanism in the worker side.") + .version("0.3.0") + .stringConf + .transform(_.toUpperCase(Locale.ROOT)) + .checkValues(Set("SIMPLE", "SLOWSTART")) + .createWithDefaultString("SIMPLE") + + val PUSH_SLOW_START_INITIAL_SLEEP_TIME: ConfigEntry[Long] = + buildConf("celeborn.push.slowStart.initialSleepTime") + .categories("client") + .version("0.3.0") + .doc(s"The initial sleep time if the current max in flight requests is 0") + .timeConf(TimeUnit.MILLISECONDS) + .createWithDefaultString("500ms") + + val PUSH_SLOW_START_MAX_SLEEP_TIME: ConfigEntry[Long] = + buildConf("celeborn.push.slowStart.maxSleepTime") + .categories("client") + .version("0.3.0") + .doc(s"If ${PUSH_LIMIT_STRATEGY.key} is set to SLOWSTART, push side will " + + "take a sleep strategy for each batch of requests, this controls " + + "the max sleep time if the max in flight requests limit is 1 for a long time") + .timeConf(TimeUnit.MILLISECONDS) + .createWithDefaultString("2s") val PUSH_DATA_TIMEOUT: ConfigEntry[Long] = buildConf("celeborn.push.data.timeout") @@ -2340,15 +2370,6 @@ object CelebornConf extends Logging { .timeConf(TimeUnit.MILLISECONDS) .createOptional - val PUSH_DATA_SLOW_START: ConfigEntry[Boolean] = - buildConf("celeborn.push.data.slowStart") - .categories("client") - .version("0.3.0") - .doc("Whether to allow to slow increasing maxReqs to meet the max push capacity, " + - "worked when worker side enables rate limit mechanism") - .booleanConf - .createWithDefault(false) - val PORT_MAX_RETRY: ConfigEntry[Int] = buildConf("celeborn.port.maxRetries") .withAlternative("rss.master.port.maxretry") diff --git a/common/src/test/java/org/apache/celeborn/common/write/SlowStartPushStrategyTest.java b/common/src/test/java/org/apache/celeborn/common/write/SlowStartPushStrategyTest.java new file mode 100644 index 000000000..fdabd71b2 --- /dev/null +++ b/common/src/test/java/org/apache/celeborn/common/write/SlowStartPushStrategyTest.java @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.common.write; + +import org.junit.Assert; +import org.junit.Test; + +import org.apache.celeborn.common.CelebornConf; + +public class SlowStartPushStrategyTest { + + private final CelebornConf conf = new CelebornConf(); + + @Test + public void testSleepTime() { + conf.set("celeborn.push.maxReqsInFlight", "32"); + conf.set("celeborn.push.limit.strategy", "slowstart"); + conf.set("celeborn.push.slowStart.maxSleepTime", "3s"); + SlowStartPushStrategy strategy = (SlowStartPushStrategy) PushStrategy.getStrategy(conf); + String dummyHostPort = "test:9087"; + SlowStartPushStrategy.CongestControlContext context = + strategy.getCongestControlContextByAddress(dummyHostPort); + + // If the currentReq is 0, not throw error + strategy.getSleepTime(context); + + // If the currentReq is 1, should sleep 440 ms + Assert.assertEquals(440, strategy.getSleepTime(context)); + + // If the currentReq is 8, should sleep 20 ms + for (int i = 0; i < 7; i++) { + strategy.onSuccess(dummyHostPort); + } + Assert.assertEquals(20, strategy.getSleepTime(context)); + + // If the currentReq is 16, should sleep 0 ms + for (int i = 0; i < 8; i++) { + strategy.onSuccess(dummyHostPort); + } + Assert.assertEquals(0, strategy.getSleepTime(context)); + + // Congest the requests, the currentReq reduced to 8, should sleep 20 ms + strategy.onCongestControl(dummyHostPort); + Assert.assertEquals(20, strategy.getSleepTime(context)); + + // Congest the requests, the currentReq reduced to 4, should sleep 20 ms + strategy.onCongestControl(dummyHostPort); + Assert.assertEquals(260, strategy.getSleepTime(context)); + // Congest the requests, the currentReq reduced to 2, should sleep 20 ms + strategy.onCongestControl(dummyHostPort); + Assert.assertEquals(380, strategy.getSleepTime(context)); + // Congest the requests, the currentReq reduced to 1, should sleep 20 ms + strategy.onCongestControl(dummyHostPort); + Assert.assertEquals(440, strategy.getSleepTime(context)); + // Keep congest the requests even the currentReq reduced to 1, will increase the sleep time 1s + strategy.onCongestControl(dummyHostPort); + Assert.assertEquals(1440, strategy.getSleepTime(context)); + strategy.onCongestControl(dummyHostPort); + Assert.assertEquals(2440, strategy.getSleepTime(context)); + + // Cannot exceed the max sleep time + strategy.onCongestControl(dummyHostPort); + Assert.assertEquals(3000, strategy.getSleepTime(context)); + + // If start to return success, the currentReq is increased to 2, should sleep 380 ms + strategy.onSuccess(dummyHostPort); + Assert.assertEquals(380, strategy.getSleepTime(context)); + } + + @Test + public void testCongestStrategy() { + conf.set("celeborn.push.maxReqsInFlight", "5"); + conf.set("celeborn.push.limit.strategy", "slowstart"); + conf.set("celeborn.push.slowStart.maxSleepTime", "4s"); + SlowStartPushStrategy strategy = (SlowStartPushStrategy) PushStrategy.getStrategy(conf); + String dummyHostPort = "test:9087"; + // Slow start, should exponentially increase the currentReq + strategy.onSuccess(dummyHostPort); + Assert.assertEquals(2, strategy.getCurrentMaxReqsInFlight(dummyHostPort)); + strategy.onSuccess(dummyHostPort); + strategy.onSuccess(dummyHostPort); + Assert.assertEquals(4, strategy.getCurrentMaxReqsInFlight(dummyHostPort)); + strategy.onSuccess(dummyHostPort); + strategy.onSuccess(dummyHostPort); + strategy.onSuccess(dummyHostPort); + strategy.onSuccess(dummyHostPort); + + // Will linearly increase the currentReq if meet the maxReqsInFlight + Assert.assertEquals(5, strategy.getCurrentMaxReqsInFlight(dummyHostPort)); + strategy.onSuccess(dummyHostPort); + strategy.onSuccess(dummyHostPort); + strategy.onSuccess(dummyHostPort); + strategy.onSuccess(dummyHostPort); + strategy.onSuccess(dummyHostPort); + Assert.assertEquals(6, strategy.getCurrentMaxReqsInFlight(dummyHostPort)); + + // Congest controlled, should half the currentReq + strategy.onCongestControl(dummyHostPort); + Assert.assertEquals(3, strategy.getCurrentMaxReqsInFlight(dummyHostPort)); + strategy.onCongestControl(dummyHostPort); + Assert.assertEquals(1, strategy.getCurrentMaxReqsInFlight(dummyHostPort)); + + // Cannot lower than 1 + strategy.onCongestControl(dummyHostPort); + Assert.assertEquals(1, strategy.getCurrentMaxReqsInFlight(dummyHostPort)); + } + + @Test + public void testMultiHosts() { + conf.set("celeborn.push.maxReqsInFlight", "3"); + conf.set("celeborn.push.limit.strategy", "slowstart"); + conf.set("celeborn.push.slowStart.maxSleepTime", "3s"); + SlowStartPushStrategy strategy = (SlowStartPushStrategy) PushStrategy.getStrategy(conf); + String dummyHostPort1 = "test1:9087"; + String dummyHostPort2 = "test2:9087"; + SlowStartPushStrategy.CongestControlContext context1 = + strategy.getCongestControlContextByAddress(dummyHostPort1); + SlowStartPushStrategy.CongestControlContext context2 = + strategy.getCongestControlContextByAddress(dummyHostPort2); + + Assert.assertEquals(440, strategy.getSleepTime(context1)); + Assert.assertEquals(440, strategy.getSleepTime(context2)); + + // Control the dummyHostPort1, should not affect dummyHostPort2 + for (int i = 0; i < 3; i++) { + strategy.onSuccess(dummyHostPort1); + } + Assert.assertEquals(0, strategy.getSleepTime(context1)); + Assert.assertEquals(440, strategy.getSleepTime(context2)); + } +} diff --git a/docs/configuration/client.md b/docs/configuration/client.md index d26896b98..600c19020 100644 --- a/docs/configuration/client.md +++ b/docs/configuration/client.md @@ -30,15 +30,17 @@ license: | | celeborn.master.endpoints | <localhost>:9097 | Endpoints of master nodes for celeborn client to connect, allowed pattern is: `:[,:]*`, e.g. `clb1:9097,clb2:9098,clb3:9099`. If the port is omitted, 9097 will be used. | 0.2.0 | | celeborn.push.buffer.initial.size | 8k | | 0.2.0 | | celeborn.push.buffer.max.size | 64k | Max size of reducer partition buffer memory for shuffle hash writer. The pushed data will be buffered in memory before sending to Celeborn worker. For performance consideration keep this buffer size higher than 32K. Example: If reducer amount is 2000, buffer size is 64K, then each task will consume up to `64KiB * 2000 = 125MiB` heap memory. | 0.2.0 | -| celeborn.push.data.slowStart | false | Whether to allow to slow increasing maxReqs to meet the max push capacity, worked when worker side enables rate limit mechanism | 0.3.0 | | celeborn.push.data.timeout | 120s | Timeout for a task to push data rpc message. This value should better be more than twice of `celeborn.push.timeoutCheck.interval` | 0.2.0 | | celeborn.push.limit.inFlight.sleepInterval | 50ms | Sleep interval when check netty in-flight requests to be done. | 0.2.0 | | celeborn.push.limit.inFlight.timeout | <undefined> | Timeout for netty in-flight requests to be done.Default value should be `celeborn.push.data.timeout * 2`. | 0.2.0 | +| celeborn.push.limit.strategy | SIMPLE | The strategy used to control the push speed. Valid strategies are SIMPLE and SLOWSTART. the SLOWSTART strategy is usually cooperate with congest control mechanism in the worker side. | 0.3.0 | | celeborn.push.maxReqsInFlight | 4 | Amount of Netty in-flight requests per worker. The maximum memory is `celeborn.push.maxReqsInFlight` * `celeborn.push.buffer.max.size` * compression ratio(1 in worst case), default: 64Kib * 32 = 2Mib | 0.2.0 | | celeborn.push.queue.capacity | 512 | Push buffer queue size for a task. The maximum memory is `celeborn.push.buffer.max.size` * `celeborn.push.queue.capacity`, default: 64KiB * 512 = 32MiB | 0.2.0 | | celeborn.push.replicate.enabled | true | When true, Celeborn worker will replicate shuffle data to another Celeborn worker asynchronously to ensure the pushed shuffle data won't be lost after the node failure. | 0.2.0 | | celeborn.push.retry.threads | 8 | Thread number to process shuffle re-send push data requests. | 0.2.0 | | celeborn.push.revive.maxRetries | 5 | Max retry times for reviving when celeborn push data failed. | 0.3.0 | +| celeborn.push.slowStart.initialSleepTime | 500ms | The initial sleep time if the current max in flight requests is 0 | 0.3.0 | +| celeborn.push.slowStart.maxSleepTime | 2s | If celeborn.push.limit.strategy is set to SLOWSTART, push side will take a sleep strategy for each batch of requests, this controls the max sleep time if the max in flight requests limit is 1 for a long time | 0.3.0 | | celeborn.push.sortMemory.threshold | 64m | When SortBasedPusher use memory over the threshold, will trigger push data. | 0.2.0 | | celeborn.push.splitPartition.threads | 8 | Thread number to process shuffle split request in shuffle client. | 0.2.0 | | celeborn.push.stageEnd.timeout | <undefined> | Timeout for waiting StageEnd. Default value should be `celeborn.rpc.askTimeout * (celeborn.rpc.requestCommitFiles.maxRetries + 1)`. | 0.2.0 |