diff --git a/common/src/main/java/org/apache/celeborn/common/network/server/ChunkStreamManager.java b/common/src/main/java/org/apache/celeborn/common/network/server/ChunkStreamManager.java index 83541c855..b57405955 100644 --- a/common/src/main/java/org/apache/celeborn/common/network/server/ChunkStreamManager.java +++ b/common/src/main/java/org/apache/celeborn/common/network/server/ChunkStreamManager.java @@ -17,14 +17,13 @@ package org.apache.celeborn.common.network.server; -import java.util.Map; import java.util.Random; +import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicLong; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; -import io.netty.channel.Channel; import org.apache.commons.lang3.tuple.ImmutablePair; import org.apache.commons.lang3.tuple.Pair; import org.slf4j.Logger; @@ -41,14 +40,15 @@ public class ChunkStreamManager { private static final Logger logger = LoggerFactory.getLogger(ChunkStreamManager.class); private final AtomicLong nextStreamId; + // StreamId -> StreamState protected final ConcurrentHashMap streams; + // ShuffleKey -> StreamId + protected final ConcurrentHashMap> shuffleStreamIds; /** State of a single stream. */ protected static class StreamState { final FileManagedBuffers buffers; - - // The channel associated to the stream - final Channel associatedChannel; + final String shuffleKey; // Used to keep track of the index of the buffer that the user has retrieved, just to ensure // that the caller only requests each chunk one at a time, in order. @@ -57,9 +57,9 @@ public class ChunkStreamManager { // Used to keep track of the number of chunks being transferred and not finished yet. volatile long chunksBeingTransferred = 0L; - StreamState(FileManagedBuffers buffers, Channel channel) { + StreamState(String shuffleKey, FileManagedBuffers buffers) { this.buffers = Preconditions.checkNotNull(buffers); - this.associatedChannel = channel; + this.shuffleKey = shuffleKey; } } @@ -68,6 +68,7 @@ public class ChunkStreamManager { // This does not need to be globally unique, only unique to this class. nextStreamId = new AtomicLong((long) new Random().nextInt(Integer.MAX_VALUE) * 1000); streams = new ConcurrentHashMap<>(); + shuffleStreamIds = new ConcurrentHashMap<>(); } public ManagedBuffer getChunk(long streamId, int chunkIndex, int offset, int len) { @@ -91,9 +92,13 @@ public class ChunkStreamManager { if (state.buffers.isFullyRead()) { // Normally, when all chunks are returned to the client, the stream should be removed here. // But if there is a switch on the client side, it will not go here at this time, so we need - // to remove the stream when the connection is terminated, and release the unused buffer. + // to remove the stream when the shuffle is expired, and release the unused buffer. logger.trace("Removing stream id {}", streamId); streams.remove(streamId); + Set streamIds = shuffleStreamIds.get(state.shuffleKey); + if (streamIds != null) { + streamIds.remove(streamId); + } } return nextChunk; @@ -113,16 +118,6 @@ public class ChunkStreamManager { return ImmutablePair.of(streamId, chunkIndex); } - public void connectionTerminated(Channel channel) { - // Close all streams which have been associated with the channel. - for (Map.Entry entry : streams.entrySet()) { - StreamState state = entry.getValue(); - if (state.associatedChannel == channel) { - streams.remove(entry.getKey()); - } - } - } - public void chunkBeingSent(long streamId) { StreamState streamState = streams.get(streamId); if (streamState != null) { @@ -154,18 +149,43 @@ public class ChunkStreamManager { *

If an app ID is provided, only callers who've authenticated with the given app ID will be * allowed to fetch from this stream. * - *

This method also associates the stream with a single client connection, which is guaranteed - * to be the only reader of the stream. Once the connection is closed, the stream will never be - * used again, enabling cleanup by `connectionTerminated`. + *

This stream could be reused again when other channel of the client is reconnected. If a + * stream is not properly closed, it will eventually be cleaned up by `cleanupExpiredShuffleKey`. */ - public long registerStream(FileManagedBuffers buffers, Channel channel) { + public long registerStream(String shuffleKey, FileManagedBuffers buffers) { long myStreamId = nextStreamId.getAndIncrement(); - streams.put(myStreamId, new StreamState(buffers, channel)); + streams.put(myStreamId, new StreamState(shuffleKey, buffers)); + shuffleStreamIds.compute( + shuffleKey, + (key, value) -> { + if (value == null) { + value = ConcurrentHashMap.newKeySet(); + } + value.add(myStreamId); + return value; + }); + return myStreamId; } + public void cleanupExpiredShuffleKey(Set expiredShuffleKeys) { + for (String expiredShuffleKey : expiredShuffleKeys) { + Set expiredStreamIds = shuffleStreamIds.remove(expiredShuffleKey); + + // normally expiredStreamIds set will be empty as streamId will be removed when be fully read + if (expiredStreamIds != null && !expiredStreamIds.isEmpty()) { + streams.keySet().removeAll(expiredStreamIds); + } + } + } + @VisibleForTesting public int numStreamStates() { return streams.size(); } + + @VisibleForTesting + public long numShuffleSteams() { + return shuffleStreamIds.values().stream().flatMap(Set::stream).count(); + } } diff --git a/common/src/test/java/org/apache/celeborn/common/network/server/ChunkStreamManagerSuiteJ.java b/common/src/test/java/org/apache/celeborn/common/network/server/ChunkStreamManagerSuiteJ.java index 2fce016ee..ba32a06bf 100644 --- a/common/src/test/java/org/apache/celeborn/common/network/server/ChunkStreamManagerSuiteJ.java +++ b/common/src/test/java/org/apache/celeborn/common/network/server/ChunkStreamManagerSuiteJ.java @@ -17,7 +17,9 @@ package org.apache.celeborn.common.network.server; -import io.netty.channel.Channel; +import java.util.Arrays; +import java.util.HashSet; + import org.junit.Assert; import org.junit.Test; import org.mockito.Mockito; @@ -26,7 +28,7 @@ import org.apache.celeborn.common.meta.FileManagedBuffers; public class ChunkStreamManagerSuiteJ { @Test - public void streamStatesAreFreedWhenConnectionIsClosedEvenIfBufferIteratorThrowsException() { + public void testStreamRegisterAndCleanup() { ChunkStreamManager manager = new ChunkStreamManager(); @SuppressWarnings("unchecked") @@ -34,14 +36,30 @@ public class ChunkStreamManagerSuiteJ { @SuppressWarnings("unchecked") FileManagedBuffers buffers2 = Mockito.mock(FileManagedBuffers.class); + FileManagedBuffers buffers3 = Mockito.mock(FileManagedBuffers.class); + FileManagedBuffers buffers4 = Mockito.mock(FileManagedBuffers.class); - Channel dummyChannel = Mockito.mock(Channel.class, Mockito.RETURNS_SMART_NULLS); - manager.registerStream(buffers, dummyChannel); - manager.registerStream(buffers2, dummyChannel); + manager.registerStream("shuffleKey1", buffers); + manager.registerStream("shuffleKey1", buffers2); + manager.registerStream("shuffleKey2", buffers3); + long stream3 = manager.registerStream("shuffleKey3", buffers4); + Assert.assertEquals(4, manager.numStreamStates()); + Assert.assertEquals(manager.numStreamStates(), manager.numShuffleSteams()); - Assert.assertEquals(2, manager.numStreamStates()); + manager.cleanupExpiredShuffleKey(new HashSet<>(Arrays.asList("shuffleKey1", "shuffleKey2"))); + manager.cleanupExpiredShuffleKey(new HashSet<>(Arrays.asList("none_exit_shuffleKey"))); - manager.connectionTerminated(dummyChannel); - assert manager.streams.isEmpty(); + Assert.assertEquals(1, manager.numStreamStates()); + Assert.assertEquals(manager.numStreamStates(), manager.numShuffleSteams()); + + // stream removed when buffer fully read + manager.streams.remove(stream3); + manager.shuffleStreamIds.get("shuffleKey3").remove(stream3); + Assert.assertEquals(0, manager.numStreamStates()); + Assert.assertEquals(manager.numStreamStates(), manager.numShuffleSteams()); + + // cleanup shuffleKey3 + manager.cleanupExpiredShuffleKey(new HashSet<>(Arrays.asList("shuffleKey3"))); + Assert.assertEquals(manager.numStreamStates(), manager.numShuffleSteams()); } } diff --git a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala index 26eeafbc0..6db905159 100644 --- a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala +++ b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala @@ -20,10 +20,10 @@ package org.apache.celeborn.service.deploy.worker import java.io.{FileNotFoundException, IOException} import java.nio.ByteBuffer import java.nio.charset.StandardCharsets +import java.util import java.util.concurrent.atomic.AtomicBoolean import com.google.common.base.Throwables -import io.netty.buffer.{ByteBuf, Unpooled} import io.netty.util.concurrent.{Future, GenericFutureListener} import org.apache.celeborn.common.exception.CelebornException @@ -113,7 +113,7 @@ class FetchHandler(val conf: TransportConf) extends BaseMessageHandler with Logg new NioManagedBuffer(streamHandle.toByteBuffer))) } else { val buffers = new FileManagedBuffers(fileInfo, conf) - val streamId = chunkStreamManager.registerStream(buffers, client.getChannel) + val streamId = chunkStreamManager.registerStream(shuffleKey, buffers) val streamHandle = new StreamHandle(streamId, fileInfo.numChunks()) if (fileInfo.numChunks() == 0) logDebug(s"StreamId $streamId fileName $fileName startMapIndex" + @@ -205,7 +205,6 @@ class FetchHandler(val conf: TransportConf) extends BaseMessageHandler with Logg override def checkRegistered: Boolean = registered.get override def channelInactive(client: TransportClient): Unit = { - chunkStreamManager.connectionTerminated(client.getChannel) bufferStreamManager.connectionTerminated(client.getChannel) logDebug(s"channel inactive ${client.getSocketAddress}") } @@ -213,4 +212,8 @@ class FetchHandler(val conf: TransportConf) extends BaseMessageHandler with Logg override def exceptionCaught(cause: Throwable, client: TransportClient): Unit = { logWarning(s"exception caught ${client.getSocketAddress}", cause) } + + def cleanupExpiredShuffleKey(expiredShuffleKeys: util.HashSet[String]): Unit = { + chunkStreamManager.cleanupExpiredShuffleKey(expiredShuffleKeys) + } } diff --git a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala index 60afb40b8..e4258ddf6 100644 --- a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala +++ b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala @@ -436,6 +436,7 @@ private[celeborn] class Worker( } partitionsSorter.cleanup(expiredShuffleKeys) storageManager.cleanupExpiredShuffleKey(expiredShuffleKeys) + fetchHandler.cleanupExpiredShuffleKey(expiredShuffleKeys) } @VisibleForTesting