From 6a5e3ed794d034751bd3cd209a7447a7e94de041 Mon Sep 17 00:00:00 2001 From: "zky.zhoukeyong" Date: Thu, 20 Jul 2023 00:34:55 +0800 Subject: [PATCH] [CELEBORN-812] Cleanup SendBufferPool if idle for long ### What changes were proposed in this pull request? Cleans up the pooled send buffers and push tasks if the SendBufferPool has been idle for more than `celeborn.client.push.sendbufferpool.expireTimeout`. ### Why are the changes needed? Before this PR the SendBufferPool will cache the send buffers and push tasks forever. If they are large and will not be reused in the future, it wastes memory and causes GC. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Passes GA and manual tests. Closes #1735 from waitinfuture/812-1. Authored-by: zky.zhoukeyong Signed-off-by: zky.zhoukeyong --- .../shuffle/celeborn/SendBufferPool.java | 29 +++++++++++++++++-- .../celeborn/SortBasedPusherSuiteJ.java | 2 +- .../shuffle/celeborn/SparkShuffleManager.java | 14 +++++++-- .../HashBasedShuffleWriterSuiteJ.java | 2 +- .../SortBasedShuffleWriterSuiteJ.java | 8 ++++- .../shuffle/celeborn/SparkShuffleManager.java | 14 +++++++-- .../HashBasedShuffleWriterSuiteJ.java | 2 +- .../SortBasedShuffleWriterSuiteJ.java | 2 +- .../apache/celeborn/common/CelebornConf.scala | 21 ++++++++++++++ docs/configuration/client.md | 2 ++ 10 files changed, 84 insertions(+), 12 deletions(-) diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SendBufferPool.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SendBufferPool.java index 30989d75c..849694518 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SendBufferPool.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SendBufferPool.java @@ -20,17 +20,20 @@ package org.apache.spark.shuffle.celeborn; import java.util.Iterator; import java.util.LinkedList; import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; import org.apache.celeborn.client.write.PushTask; +import org.apache.celeborn.common.util.ThreadUtils; public class SendBufferPool { private static volatile SendBufferPool _instance; - public static SendBufferPool get(int capacity) { + public static SendBufferPool get(int capacity, long checkInterval, long timeout) { if (_instance == null) { synchronized (SendBufferPool.class) { if (_instance == null) { - _instance = new SendBufferPool(capacity); + _instance = new SendBufferPool(capacity, checkInterval, timeout); } } } @@ -41,16 +44,35 @@ public class SendBufferPool { // numPartitions -> buffers private final LinkedList buffers; + private long lastAquireTime; private final LinkedList> pushTaskQueues; - private SendBufferPool(int capacity) { + private ScheduledExecutorService cleaner = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("celeborn-sendBufferPool-cleaner"); + + private SendBufferPool(int capacity, long checkInterval, long timeout) { assert capacity > 0; this.capacity = capacity; buffers = new LinkedList<>(); pushTaskQueues = new LinkedList<>(); + + lastAquireTime = System.currentTimeMillis(); + cleaner.scheduleAtFixedRate( + () -> { + if (System.currentTimeMillis() - lastAquireTime > timeout) { + synchronized (this) { + buffers.clear(); + pushTaskQueues.clear(); + } + } + }, + checkInterval, + checkInterval, + TimeUnit.MILLISECONDS); } public synchronized byte[][] acquireBuffer(int numPartitions) { + lastAquireTime = System.currentTimeMillis(); Iterator iterator = buffers.iterator(); while (iterator.hasNext()) { byte[][] candidate = iterator.next(); @@ -66,6 +88,7 @@ public class SendBufferPool { } public synchronized LinkedBlockingQueue acquirePushTaskQueue() { + lastAquireTime = System.currentTimeMillis(); if (!pushTaskQueues.isEmpty()) { return pushTaskQueues.removeFirst(); } diff --git a/client-spark/common/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedPusherSuiteJ.java b/client-spark/common/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedPusherSuiteJ.java index 58e8ad053..5bdd704d0 100644 --- a/client-spark/common/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedPusherSuiteJ.java +++ b/client-spark/common/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedPusherSuiteJ.java @@ -94,7 +94,7 @@ public class SortBasedPusherSuiteJ { /*pushSortMemoryThreshold=*/ Utils.byteStringAsBytes("1m"), /*sharedPushLock=*/ null, /*executorService=*/ null, - SendBufferPool.get(4)); + SendBufferPool.get(4, 30, 60)); // default page size == 2 MiB assertEquals(unifiedMemoryManager.pageSizeBytes(), Utils.byteStringAsBytes("2m")); diff --git a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java index 3827e5891..477b3f505 100644 --- a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java +++ b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java @@ -63,6 +63,9 @@ public class SparkShuffleManager implements ShuffleManager { private final ExecutorService[] asyncPushers; private AtomicInteger pusherIdx = new AtomicInteger(0); + private long sendBufferPoolCheckInterval; + private long sendBufferPoolExpireTimeout; + public SparkShuffleManager(SparkConf conf, boolean isDriver) { this.conf = conf; this.isDriver = isDriver; @@ -78,6 +81,8 @@ public class SparkShuffleManager implements ShuffleManager { } else { asyncPushers = null; } + this.sendBufferPoolCheckInterval = celebornConf.clientPushSendBufferPoolExpireCheckInterval(); + this.sendBufferPoolExpireTimeout = celebornConf.clientPushSendBufferPoolExpireTimeout(); } private SortShuffleManager sortShuffleManager() { @@ -187,10 +192,15 @@ public class SparkShuffleManager implements ShuffleManager { celebornConf, client, pushThread, - SendBufferPool.get(cores)); + SendBufferPool.get(cores, sendBufferPoolCheckInterval, sendBufferPoolExpireTimeout)); } else if (ShuffleMode.HASH.equals(celebornConf.shuffleWriterMode())) { return new HashBasedShuffleWriter<>( - h, mapId, context, celebornConf, client, SendBufferPool.get(cores)); + h, + mapId, + context, + celebornConf, + client, + SendBufferPool.get(cores, sendBufferPoolCheckInterval, sendBufferPoolExpireTimeout)); } else { throw new UnsupportedOperationException( "Unrecognized shuffle write mode!" + celebornConf.shuffleWriterMode()); diff --git a/client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriterSuiteJ.java b/client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriterSuiteJ.java index bd7c46eb8..475efee24 100644 --- a/client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriterSuiteJ.java +++ b/client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriterSuiteJ.java @@ -35,6 +35,6 @@ public class HashBasedShuffleWriterSuiteJ extends CelebornShuffleWriterSuiteBase throws IOException { // this test case is independent of the `mapId` value return new HashBasedShuffleWriter( - handle, /*mapId=*/ 0, context, conf, client, SendBufferPool.get(1)); + handle, /*mapId=*/ 0, context, conf, client, SendBufferPool.get(1, 30, 60)); } } diff --git a/client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriterSuiteJ.java b/client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriterSuiteJ.java index a3638f4eb..efff3b97a 100644 --- a/client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriterSuiteJ.java +++ b/client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriterSuiteJ.java @@ -34,6 +34,12 @@ public class SortBasedShuffleWriterSuiteJ extends CelebornShuffleWriterSuiteBase CelebornShuffleHandle handle, TaskContext context, CelebornConf conf, ShuffleClient client) throws IOException { return new SortBasedShuffleWriter( - handle.dependency(), numPartitions, context, conf, client, null, SendBufferPool.get(4)); + handle.dependency(), + numPartitions, + context, + conf, + client, + null, + SendBufferPool.get(4, 30, 60)); } } diff --git a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java index b876d48d6..63cab05e9 100644 --- a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java +++ b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkShuffleManager.java @@ -64,6 +64,9 @@ public class SparkShuffleManager implements ShuffleManager { private final ExecutorService[] asyncPushers; private AtomicInteger pusherIdx = new AtomicInteger(0); + private long sendBufferPoolCheckInterval; + private long sendBufferPoolExpireTimeout; + public SparkShuffleManager(SparkConf conf, boolean isDriver) { if (conf.getBoolean(LOCAL_SHUFFLE_READER_KEY, true)) { logger.warn( @@ -85,6 +88,8 @@ public class SparkShuffleManager implements ShuffleManager { } else { asyncPushers = null; } + this.sendBufferPoolCheckInterval = celebornConf.clientPushSendBufferPoolExpireCheckInterval(); + this.sendBufferPoolExpireTimeout = celebornConf.clientPushSendBufferPoolExpireTimeout(); } private SortShuffleManager sortShuffleManager() { @@ -198,10 +203,15 @@ public class SparkShuffleManager implements ShuffleManager { shuffleClient, metrics, pushThread, - SendBufferPool.get(cores)); + SendBufferPool.get(cores, sendBufferPoolCheckInterval, sendBufferPoolExpireTimeout)); } else if (ShuffleMode.HASH.equals(celebornConf.shuffleWriterMode())) { return new HashBasedShuffleWriter<>( - h, context, celebornConf, shuffleClient, metrics, SendBufferPool.get(cores)); + h, + context, + celebornConf, + shuffleClient, + metrics, + SendBufferPool.get(cores, sendBufferPoolCheckInterval, sendBufferPoolExpireTimeout)); } else { throw new UnsupportedOperationException( "Unrecognized shuffle write mode!" + celebornConf.shuffleWriterMode()); diff --git a/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriterSuiteJ.java b/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriterSuiteJ.java index ed022b4a4..d8c514dc7 100644 --- a/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriterSuiteJ.java +++ b/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriterSuiteJ.java @@ -37,6 +37,6 @@ public class HashBasedShuffleWriterSuiteJ extends CelebornShuffleWriterSuiteBase ShuffleWriteMetricsReporter metrics) throws IOException { return new HashBasedShuffleWriter( - handle, context, conf, client, metrics, SendBufferPool.get(1)); + handle, context, conf, client, metrics, SendBufferPool.get(1, 30, 60)); } } diff --git a/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriterSuiteJ.java b/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriterSuiteJ.java index db94d8441..2b1ea681b 100644 --- a/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriterSuiteJ.java +++ b/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriterSuiteJ.java @@ -36,6 +36,6 @@ public class SortBasedShuffleWriterSuiteJ extends CelebornShuffleWriterSuiteBase ShuffleWriteMetricsReporter metrics) throws IOException { return new SortBasedShuffleWriter( - handle, context, conf, client, metrics, null, SendBufferPool.get(4)); + handle, context, conf, client, metrics, null, SendBufferPool.get(4, 30, 60)); } } diff --git a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala index 693250515..a3aaf2281 100644 --- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala +++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala @@ -783,6 +783,9 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se def clientPushSplitPartitionThreads: Int = get(CLIENT_PUSH_SPLIT_PARTITION_THREADS) def clientPushTakeTaskWaitIntervalMs: Long = get(CLIENT_PUSH_TAKE_TASK_WAIT_INTERVAL) def clientPushTakeTaskMaxWaitAttempts: Int = get(CLIENT_PUSH_TAKE_TASK_MAX_WAIT_ATTEMPTS) + def clientPushSendBufferPoolExpireTimeout: Long = get(CLIENT_PUSH_SENDBUFFERPOOL_EXPIRETIMEOUT) + def clientPushSendBufferPoolExpireCheckInterval: Long = + get(CLIENT_PUSH_SENDBUFFERPOOL_CHECKEXPIREINTERVAL) // ////////////////////////////////////////////////////// // Client Shuffle // @@ -2883,6 +2886,24 @@ object CelebornConf extends Logging { .intConf .createWithDefault(1) + val CLIENT_PUSH_SENDBUFFERPOOL_EXPIRETIMEOUT: ConfigEntry[Long] = + buildConf("celeborn.client.push.sendBufferPool.expireTimeout") + .categories("client") + .doc("Timeout before clean up SendBufferPool. If SendBufferPool is idle for more than this time, " + + "the send buffers and push tasks will be cleaned up.") + .version("0.3.1") + .timeConf(TimeUnit.MILLISECONDS) + .createWithDefaultString("60s") + + val CLIENT_PUSH_SENDBUFFERPOOL_CHECKEXPIREINTERVAL: ConfigEntry[Long] = + buildConf("celeborn.client.push.sendBufferPool.checkExpireInterval") + .categories("client") + .doc("Interval to check expire for send buffer pool. If the pool has been idle " + + s"for more than `${CLIENT_PUSH_SENDBUFFERPOOL_EXPIRETIMEOUT.key}`, the pooled send buffers and push tasks will be cleaned up.") + .version("0.3.1") + .timeConf(TimeUnit.MILLISECONDS) + .createWithDefaultString("30s") + val TEST_CLIENT_RETRY_REVIVE: ConfigEntry[Boolean] = buildConf("celeborn.test.client.retryRevive") .withAlternative("celeborn.test.retryRevive") diff --git a/docs/configuration/client.md b/docs/configuration/client.md index 91a146f01..9f7a96656 100644 --- a/docs/configuration/client.md +++ b/docs/configuration/client.md @@ -51,6 +51,8 @@ license: | | celeborn.client.push.revive.batchSize | 2048 | Max number of partitions in one Revive request. | 0.3.0 | | celeborn.client.push.revive.interval | 100ms | Interval for client to trigger Revive to LifecycleManager. The number of partitions in one Revive request is `celeborn.client.push.revive.batchSize`. | 0.3.0 | | celeborn.client.push.revive.maxRetries | 5 | Max retry times for reviving when celeborn push data failed. | 0.3.0 | +| celeborn.client.push.sendBufferPool.checkExpireInterval | 30s | Interval to check expire for send buffer pool. If the pool has been idle for more than `celeborn.client.push.sendBufferPool.expireTimeout`, the pooled send buffers and push tasks will be cleaned up. | 0.3.1 | +| celeborn.client.push.sendBufferPool.expireTimeout | 60s | Timeout before clean up SendBufferPool. If SendBufferPool is idle for more than this time, the send buffers and push tasks will be cleaned up. | 0.3.1 | | celeborn.client.push.slowStart.initialSleepTime | 500ms | The initial sleep time if the current max in flight requests is 0 | 0.3.0 | | celeborn.client.push.slowStart.maxSleepTime | 2s | If celeborn.client.push.limit.strategy is set to SLOWSTART, push side will take a sleep strategy for each batch of requests, this controls the max sleep time if the max in flight requests limit is 1 for a long time | 0.3.0 | | celeborn.client.push.sort.randomizePartitionId.enabled | false | Whether to randomize partitionId in push sorter. If true, partitionId will be randomized when sort data to avoid skew when push to worker | 0.3.0 |