diff --git a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/congestcontrol/CongestionController.java b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/congestcontrol/CongestionController.java index 2fbe8d637..3551c5a29 100644 --- a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/congestcontrol/CongestionController.java +++ b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/congestcontrol/CongestionController.java @@ -24,6 +24,7 @@ import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; +import com.google.common.annotations.VisibleForTesting; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -291,13 +292,18 @@ public class CongestionController { public void close() { logger.info("Closing {}", this.getClass().getSimpleName()); - this.removeUserExecutorService.shutdownNow(); - this.checkService.shutdownNow(); + ThreadUtils.shutdown(this.removeUserExecutorService); + ThreadUtils.shutdown(this.checkService); this.userBufferStatuses.clear(); this.consumedBufferStatusHub.clear(); this.producedBufferStatusHub.clear(); } + @VisibleForTesting + public void shutDownCheckService() { + ThreadUtils.shutdown(this.checkService); + } + public static synchronized void destroy() { if (_INSTANCE != null) { _INSTANCE.close(); diff --git a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/congestcontrol/TestCongestionController.java b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/congestcontrol/TestCongestionController.java index 6cdb1ce9f..9fa3bfe2c 100644 --- a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/congestcontrol/TestCongestionController.java +++ b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/congestcontrol/TestCongestionController.java @@ -36,7 +36,6 @@ public class TestCongestionController { private long pendingBytes = 0L; private final long userInactiveTimeMills = 2000L; - private final long checkIntervalTimeMills = Integer.MAX_VALUE; @Before public void initialize() { @@ -56,8 +55,6 @@ public class TestCongestionController { CelebornConf.WORKER_CONGESTION_CONTROL_WORKER_PRODUCE_SPEED_LOW_WATERMARK().key(), "10000"); celebornConf.set( CelebornConf.WORKER_CONGESTION_CONTROL_USER_INACTIVE_INTERVAL(), userInactiveTimeMills); - celebornConf.set( - CelebornConf.WORKER_CONGESTION_CONTROL_CHECK_INTERVAL(), checkIntervalTimeMills); // Make sampleTimeWindow a bit larger in case the tests run time exceed this window. controller = new CongestionController(source, 10, celebornConf, null) { @@ -71,6 +68,7 @@ public class TestCongestionController { // No op } }; + controller.shutDownCheckService(); } @After @@ -176,8 +174,6 @@ public class TestCongestionController { celebornConf.set( CelebornConf.WORKER_CONGESTION_CONTROL_WORKER_PRODUCE_SPEED_LOW_WATERMARK().key(), "1000"); celebornConf.set(CelebornConf.WORKER_CONGESTION_CONTROL_USER_INACTIVE_INTERVAL(), 120L * 1000); - celebornConf.set( - CelebornConf.WORKER_CONGESTION_CONTROL_CHECK_INTERVAL(), checkIntervalTimeMills); CongestionController controller1 = new CongestionController(source, 10, celebornConf, null) { @Override @@ -190,6 +186,7 @@ public class TestCongestionController { // No op } }; + controller1.shutDownCheckService(); UserIdentifier user1 = new UserIdentifier("test1", "celeborn"); UserCongestionControlContext context1 = controller1.getUserCongestionContext(user1); @@ -252,8 +249,6 @@ public class TestCongestionController { celebornConf.set( CelebornConf.WORKER_CONGESTION_CONTROL_WORKER_PRODUCE_SPEED_LOW_WATERMARK().key(), "700"); celebornConf.set(CelebornConf.WORKER_CONGESTION_CONTROL_USER_INACTIVE_INTERVAL(), 120L * 1000); - celebornConf.set( - CelebornConf.WORKER_CONGESTION_CONTROL_CHECK_INTERVAL(), checkIntervalTimeMills); CongestionController controller1 = new CongestionController(source, 10, celebornConf, null) { @Override @@ -266,6 +261,7 @@ public class TestCongestionController { // No op } }; + controller1.shutDownCheckService(); UserIdentifier user1 = new UserIdentifier("test1", "celeborn"); UserCongestionControlContext context1 = controller1.getUserCongestionContext(user1);