diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SendBufferPool.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SendBufferPool.java index 30989d75c..849694518 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SendBufferPool.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SendBufferPool.java @@ -20,17 +20,20 @@ package org.apache.spark.shuffle.celeborn; import java.util.Iterator; import java.util.LinkedList; import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; import org.apache.celeborn.client.write.PushTask; +import org.apache.celeborn.common.util.ThreadUtils; public class SendBufferPool { private static volatile SendBufferPool _instance; - public static SendBufferPool get(int capacity) { + public static SendBufferPool get(int capacity, long checkInterval, long timeout) { if (_instance == null) { synchronized (SendBufferPool.class) { if (_instance == null) { - _instance = new SendBufferPool(capacity); + _instance = new SendBufferPool(capacity, checkInterval, timeout); } } } @@ -41,16 +44,35 @@ public class SendBufferPool { // numPartitions -> buffers private final LinkedList buffers; + private long lastAquireTime; private final LinkedList> pushTaskQueues; - private SendBufferPool(int capacity) { + private ScheduledExecutorService cleaner = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("celeborn-sendBufferPool-cleaner"); + + private SendBufferPool(int capacity, long checkInterval, long timeout) { assert capacity > 0; this.capacity = capacity; buffers = new LinkedList<>(); pushTaskQueues = new LinkedList<>(); + + lastAquireTime = System.currentTimeMillis(); + cleaner.scheduleAtFixedRate( + () -> { + if (System.currentTimeMillis() - lastAquireTime > timeout) { + synchronized (this) { + buffers.clear(); + pushTaskQueues.clear(); + } + } + }, + checkInterval, + checkInterval, + TimeUnit.MILLISECONDS); } public synchronized byte[][] acquireBuffer(int numPartitions) { + lastAquireTime = System.currentTimeMillis(); Iterator iterator = buffers.iterator(); while (iterator.hasNext()) { byte[][] candidate = iterator.next(); @@ -66,6 +88,7 @@ public class SendBufferPool { } public synchronized LinkedBlockingQueue acquirePushTaskQueue() { + lastAquireTime = System.currentTimeMillis(); if (!pushTaskQueues.isEmpty()) { return pushTaskQueues.removeFirst(); } diff --git a/client-spark/common/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedPusherSuiteJ.java b/client-spark/common/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedPusherSuiteJ.java index 58e8ad053..5bdd704d0 100644 --- a/client-spark/common/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedPusherSuiteJ.java +++ b/client-spark/common/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedPusherSuiteJ.java @@ -94,7 +94,7 @@ public class SortBasedPusherSuiteJ { /*pushSortMemoryThreshold=*/ Utils.byteStringAsBytes("1m"), /*sharedPushLock=*/ null, /*executorService=*/ null, - SendBufferPool.get(4)); + SendBufferPool.get(4, 30, 60)); // default page size == 2 MiB assertEquals(unifiedMemoryManager.pageSizeBytes(), Utils.byteStringAsBytes("2m")); diff --git a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java index 3827e5891..477b3f505 100644 --- a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java +++ b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java @@ -63,6 +63,9 @@ public class SparkShuffleManager implements ShuffleManager { private final ExecutorService[] asyncPushers; private AtomicInteger pusherIdx = new AtomicInteger(0); + private long sendBufferPoolCheckInterval; + private long sendBufferPoolExpireTimeout; + public SparkShuffleManager(SparkConf conf, boolean isDriver) { this.conf = conf; this.isDriver = isDriver; @@ -78,6 +81,8 @@ public class SparkShuffleManager implements ShuffleManager { } else { asyncPushers = null; } + this.sendBufferPoolCheckInterval = celebornConf.clientPushSendBufferPoolExpireCheckInterval(); + this.sendBufferPoolExpireTimeout = celebornConf.clientPushSendBufferPoolExpireTimeout(); } private SortShuffleManager sortShuffleManager() { @@ -187,10 +192,15 @@ public class SparkShuffleManager implements ShuffleManager { celebornConf, client, pushThread, - SendBufferPool.get(cores)); + SendBufferPool.get(cores, sendBufferPoolCheckInterval, sendBufferPoolExpireTimeout)); } else if (ShuffleMode.HASH.equals(celebornConf.shuffleWriterMode())) { return new HashBasedShuffleWriter<>( - h, mapId, context, celebornConf, client, SendBufferPool.get(cores)); + h, + mapId, + context, + celebornConf, + client, + SendBufferPool.get(cores, sendBufferPoolCheckInterval, sendBufferPoolExpireTimeout)); } else { throw new UnsupportedOperationException( "Unrecognized shuffle write mode!" + celebornConf.shuffleWriterMode()); diff --git a/client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriterSuiteJ.java b/client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriterSuiteJ.java index bd7c46eb8..475efee24 100644 --- a/client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriterSuiteJ.java +++ b/client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriterSuiteJ.java @@ -35,6 +35,6 @@ public class HashBasedShuffleWriterSuiteJ extends CelebornShuffleWriterSuiteBase throws IOException { // this test case is independent of the `mapId` value return new HashBasedShuffleWriter( - handle, /*mapId=*/ 0, context, conf, client, SendBufferPool.get(1)); + handle, /*mapId=*/ 0, context, conf, client, SendBufferPool.get(1, 30, 60)); } } diff --git a/client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriterSuiteJ.java b/client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriterSuiteJ.java index a3638f4eb..efff3b97a 100644 --- a/client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriterSuiteJ.java +++ b/client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriterSuiteJ.java @@ -34,6 +34,12 @@ public class SortBasedShuffleWriterSuiteJ extends CelebornShuffleWriterSuiteBase CelebornShuffleHandle handle, TaskContext context, CelebornConf conf, ShuffleClient client) throws IOException { return new SortBasedShuffleWriter( - handle.dependency(), numPartitions, context, conf, client, null, SendBufferPool.get(4)); + handle.dependency(), + numPartitions, + context, + conf, + client, + null, + SendBufferPool.get(4, 30, 60)); } } diff --git a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java index b876d48d6..63cab05e9 100644 --- a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java +++ b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java @@ -64,6 +64,9 @@ public class SparkShuffleManager implements ShuffleManager { private final ExecutorService[] asyncPushers; private AtomicInteger pusherIdx = new AtomicInteger(0); + private long sendBufferPoolCheckInterval; + private long sendBufferPoolExpireTimeout; + public SparkShuffleManager(SparkConf conf, boolean isDriver) { if (conf.getBoolean(LOCAL_SHUFFLE_READER_KEY, true)) { logger.warn( @@ -85,6 +88,8 @@ public class SparkShuffleManager implements ShuffleManager { } else { asyncPushers = null; } + this.sendBufferPoolCheckInterval = celebornConf.clientPushSendBufferPoolExpireCheckInterval(); + this.sendBufferPoolExpireTimeout = celebornConf.clientPushSendBufferPoolExpireTimeout(); } private SortShuffleManager sortShuffleManager() { @@ -198,10 +203,15 @@ public class SparkShuffleManager implements ShuffleManager { shuffleClient, metrics, pushThread, - SendBufferPool.get(cores)); + SendBufferPool.get(cores, sendBufferPoolCheckInterval, sendBufferPoolExpireTimeout)); } else if (ShuffleMode.HASH.equals(celebornConf.shuffleWriterMode())) { return new HashBasedShuffleWriter<>( - h, context, celebornConf, shuffleClient, metrics, SendBufferPool.get(cores)); + h, + context, + celebornConf, + shuffleClient, + metrics, + SendBufferPool.get(cores, sendBufferPoolCheckInterval, sendBufferPoolExpireTimeout)); } else { throw new UnsupportedOperationException( "Unrecognized shuffle write mode!" + celebornConf.shuffleWriterMode()); diff --git a/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriterSuiteJ.java b/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriterSuiteJ.java index ed022b4a4..d8c514dc7 100644 --- a/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriterSuiteJ.java +++ b/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriterSuiteJ.java @@ -37,6 +37,6 @@ public class HashBasedShuffleWriterSuiteJ extends CelebornShuffleWriterSuiteBase ShuffleWriteMetricsReporter metrics) throws IOException { return new HashBasedShuffleWriter( - handle, context, conf, client, metrics, SendBufferPool.get(1)); + handle, context, conf, client, metrics, SendBufferPool.get(1, 30, 60)); } } diff --git a/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriterSuiteJ.java b/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriterSuiteJ.java index db94d8441..2b1ea681b 100644 --- a/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriterSuiteJ.java +++ b/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriterSuiteJ.java @@ -36,6 +36,6 @@ public class SortBasedShuffleWriterSuiteJ extends CelebornShuffleWriterSuiteBase ShuffleWriteMetricsReporter metrics) throws IOException { return new SortBasedShuffleWriter( - handle, context, conf, client, metrics, null, SendBufferPool.get(4)); + handle, context, conf, client, metrics, null, SendBufferPool.get(4, 30, 60)); } } diff --git a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala index 693250515..a3aaf2281 100644 --- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala +++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala @@ -783,6 +783,9 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se def clientPushSplitPartitionThreads: Int = get(CLIENT_PUSH_SPLIT_PARTITION_THREADS) def clientPushTakeTaskWaitIntervalMs: Long = get(CLIENT_PUSH_TAKE_TASK_WAIT_INTERVAL) def clientPushTakeTaskMaxWaitAttempts: Int = get(CLIENT_PUSH_TAKE_TASK_MAX_WAIT_ATTEMPTS) + def clientPushSendBufferPoolExpireTimeout: Long = get(CLIENT_PUSH_SENDBUFFERPOOL_EXPIRETIMEOUT) + def clientPushSendBufferPoolExpireCheckInterval: Long = + get(CLIENT_PUSH_SENDBUFFERPOOL_CHECKEXPIREINTERVAL) // ////////////////////////////////////////////////////// // Client Shuffle // @@ -2883,6 +2886,24 @@ object CelebornConf extends Logging { .intConf .createWithDefault(1) + val CLIENT_PUSH_SENDBUFFERPOOL_EXPIRETIMEOUT: ConfigEntry[Long] = + buildConf("celeborn.client.push.sendBufferPool.expireTimeout") + .categories("client") + .doc("Timeout before clean up SendBufferPool. If SendBufferPool is idle for more than this time, " + + "the send buffers and push tasks will be cleaned up.") + .version("0.3.1") + .timeConf(TimeUnit.MILLISECONDS) + .createWithDefaultString("60s") + + val CLIENT_PUSH_SENDBUFFERPOOL_CHECKEXPIREINTERVAL: ConfigEntry[Long] = + buildConf("celeborn.client.push.sendBufferPool.checkExpireInterval") + .categories("client") + .doc("Interval to check expire for send buffer pool. If the pool has been idle " + + s"for more than `${CLIENT_PUSH_SENDBUFFERPOOL_EXPIRETIMEOUT.key}`, the pooled send buffers and push tasks will be cleaned up.") + .version("0.3.1") + .timeConf(TimeUnit.MILLISECONDS) + .createWithDefaultString("30s") + val TEST_CLIENT_RETRY_REVIVE: ConfigEntry[Boolean] = buildConf("celeborn.test.client.retryRevive") .withAlternative("celeborn.test.retryRevive") diff --git a/docs/configuration/client.md b/docs/configuration/client.md index 91a146f01..9f7a96656 100644 --- a/docs/configuration/client.md +++ b/docs/configuration/client.md @@ -51,6 +51,8 @@ license: | | celeborn.client.push.revive.batchSize | 2048 | Max number of partitions in one Revive request. | 0.3.0 | | celeborn.client.push.revive.interval | 100ms | Interval for client to trigger Revive to LifecycleManager. The number of partitions in one Revive request is `celeborn.client.push.revive.batchSize`. | 0.3.0 | | celeborn.client.push.revive.maxRetries | 5 | Max retry times for reviving when celeborn push data failed. | 0.3.0 | +| celeborn.client.push.sendBufferPool.checkExpireInterval | 30s | Interval to check expire for send buffer pool. If the pool has been idle for more than `celeborn.client.push.sendBufferPool.expireTimeout`, the pooled send buffers and push tasks will be cleaned up. | 0.3.1 | +| celeborn.client.push.sendBufferPool.expireTimeout | 60s | Timeout before clean up SendBufferPool. If SendBufferPool is idle for more than this time, the send buffers and push tasks will be cleaned up. | 0.3.1 | | celeborn.client.push.slowStart.initialSleepTime | 500ms | The initial sleep time if the current max in flight requests is 0 | 0.3.0 | | celeborn.client.push.slowStart.maxSleepTime | 2s | If celeborn.client.push.limit.strategy is set to SLOWSTART, push side will take a sleep strategy for each batch of requests, this controls the max sleep time if the max in flight requests limit is 1 for a long time | 0.3.0 | | celeborn.client.push.sort.randomizePartitionId.enabled | false | Whether to randomize partitionId in push sorter. If true, partitionId will be randomized when sort data to avoid skew when push to worker | 0.3.0 |