[CELEBORN-267] reuse stream when client channel reconnected (#1200)

This commit is contained in:
Shuang 2023-02-03 15:12:45 +08:00 committed by GitHub
parent 4b6f7e4593
commit 2634476758
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 76 additions and 34 deletions

View File

@ -17,14 +17,13 @@
package org.apache.celeborn.common.network.server;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import io.netty.channel.Channel;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.slf4j.Logger;
@ -41,14 +40,15 @@ public class ChunkStreamManager {
private static final Logger logger = LoggerFactory.getLogger(ChunkStreamManager.class);
private final AtomicLong nextStreamId;
// StreamId -> StreamState
protected final ConcurrentHashMap<Long, StreamState> streams;
// ShuffleKey -> StreamId
protected final ConcurrentHashMap<String, Set<Long>> shuffleStreamIds;
/** State of a single stream. */
protected static class StreamState {
final FileManagedBuffers buffers;
// The channel associated to the stream
final Channel associatedChannel;
final String shuffleKey;
// Used to keep track of the index of the buffer that the user has retrieved, just to ensure
// that the caller only requests each chunk one at a time, in order.
@ -57,9 +57,9 @@ public class ChunkStreamManager {
// Used to keep track of the number of chunks being transferred and not finished yet.
volatile long chunksBeingTransferred = 0L;
StreamState(FileManagedBuffers buffers, Channel channel) {
StreamState(String shuffleKey, FileManagedBuffers buffers) {
this.buffers = Preconditions.checkNotNull(buffers);
this.associatedChannel = channel;
this.shuffleKey = shuffleKey;
}
}
@ -68,6 +68,7 @@ public class ChunkStreamManager {
// This does not need to be globally unique, only unique to this class.
nextStreamId = new AtomicLong((long) new Random().nextInt(Integer.MAX_VALUE) * 1000);
streams = new ConcurrentHashMap<>();
shuffleStreamIds = new ConcurrentHashMap<>();
}
public ManagedBuffer getChunk(long streamId, int chunkIndex, int offset, int len) {
@ -91,9 +92,13 @@ public class ChunkStreamManager {
if (state.buffers.isFullyRead()) {
// Normally, when all chunks are returned to the client, the stream should be removed here.
// But if there is a switch on the client side, it will not go here at this time, so we need
// to remove the stream when the connection is terminated, and release the unused buffer.
// to remove the stream when the shuffle is expired, and release the unused buffer.
logger.trace("Removing stream id {}", streamId);
streams.remove(streamId);
Set<Long> streamIds = shuffleStreamIds.get(state.shuffleKey);
if (streamIds != null) {
streamIds.remove(streamId);
}
}
return nextChunk;
@ -113,16 +118,6 @@ public class ChunkStreamManager {
return ImmutablePair.of(streamId, chunkIndex);
}
public void connectionTerminated(Channel channel) {
// Close all streams which have been associated with the channel.
for (Map.Entry<Long, StreamState> entry : streams.entrySet()) {
StreamState state = entry.getValue();
if (state.associatedChannel == channel) {
streams.remove(entry.getKey());
}
}
}
public void chunkBeingSent(long streamId) {
StreamState streamState = streams.get(streamId);
if (streamState != null) {
@ -154,18 +149,43 @@ public class ChunkStreamManager {
* <p>If an app ID is provided, only callers who've authenticated with the given app ID will be
* allowed to fetch from this stream.
*
* <p>This method also associates the stream with a single client connection, which is guaranteed
* to be the only reader of the stream. Once the connection is closed, the stream will never be
* used again, enabling cleanup by `connectionTerminated`.
* <p>This stream could be reused again when other channel of the client is reconnected. If a
* stream is not properly closed, it will eventually be cleaned up by `cleanupExpiredShuffleKey`.
*/
public long registerStream(FileManagedBuffers buffers, Channel channel) {
public long registerStream(String shuffleKey, FileManagedBuffers buffers) {
long myStreamId = nextStreamId.getAndIncrement();
streams.put(myStreamId, new StreamState(buffers, channel));
streams.put(myStreamId, new StreamState(shuffleKey, buffers));
shuffleStreamIds.compute(
shuffleKey,
(key, value) -> {
if (value == null) {
value = ConcurrentHashMap.newKeySet();
}
value.add(myStreamId);
return value;
});
return myStreamId;
}
public void cleanupExpiredShuffleKey(Set<String> expiredShuffleKeys) {
for (String expiredShuffleKey : expiredShuffleKeys) {
Set<Long> expiredStreamIds = shuffleStreamIds.remove(expiredShuffleKey);
// normally expiredStreamIds set will be empty as streamId will be removed when be fully read
if (expiredStreamIds != null && !expiredStreamIds.isEmpty()) {
streams.keySet().removeAll(expiredStreamIds);
}
}
}
@VisibleForTesting
public int numStreamStates() {
return streams.size();
}
@VisibleForTesting
public long numShuffleSteams() {
return shuffleStreamIds.values().stream().flatMap(Set::stream).count();
}
}

View File

@ -17,7 +17,9 @@
package org.apache.celeborn.common.network.server;
import io.netty.channel.Channel;
import java.util.Arrays;
import java.util.HashSet;
import org.junit.Assert;
import org.junit.Test;
import org.mockito.Mockito;
@ -26,7 +28,7 @@ import org.apache.celeborn.common.meta.FileManagedBuffers;
public class ChunkStreamManagerSuiteJ {
@Test
public void streamStatesAreFreedWhenConnectionIsClosedEvenIfBufferIteratorThrowsException() {
public void testStreamRegisterAndCleanup() {
ChunkStreamManager manager = new ChunkStreamManager();
@SuppressWarnings("unchecked")
@ -34,14 +36,30 @@ public class ChunkStreamManagerSuiteJ {
@SuppressWarnings("unchecked")
FileManagedBuffers buffers2 = Mockito.mock(FileManagedBuffers.class);
FileManagedBuffers buffers3 = Mockito.mock(FileManagedBuffers.class);
FileManagedBuffers buffers4 = Mockito.mock(FileManagedBuffers.class);
Channel dummyChannel = Mockito.mock(Channel.class, Mockito.RETURNS_SMART_NULLS);
manager.registerStream(buffers, dummyChannel);
manager.registerStream(buffers2, dummyChannel);
manager.registerStream("shuffleKey1", buffers);
manager.registerStream("shuffleKey1", buffers2);
manager.registerStream("shuffleKey2", buffers3);
long stream3 = manager.registerStream("shuffleKey3", buffers4);
Assert.assertEquals(4, manager.numStreamStates());
Assert.assertEquals(manager.numStreamStates(), manager.numShuffleSteams());
Assert.assertEquals(2, manager.numStreamStates());
manager.cleanupExpiredShuffleKey(new HashSet<>(Arrays.asList("shuffleKey1", "shuffleKey2")));
manager.cleanupExpiredShuffleKey(new HashSet<>(Arrays.asList("none_exit_shuffleKey")));
manager.connectionTerminated(dummyChannel);
assert manager.streams.isEmpty();
Assert.assertEquals(1, manager.numStreamStates());
Assert.assertEquals(manager.numStreamStates(), manager.numShuffleSteams());
// stream removed when buffer fully read
manager.streams.remove(stream3);
manager.shuffleStreamIds.get("shuffleKey3").remove(stream3);
Assert.assertEquals(0, manager.numStreamStates());
Assert.assertEquals(manager.numStreamStates(), manager.numShuffleSteams());
// cleanup shuffleKey3
manager.cleanupExpiredShuffleKey(new HashSet<>(Arrays.asList("shuffleKey3")));
Assert.assertEquals(manager.numStreamStates(), manager.numShuffleSteams());
}
}

View File

@ -20,10 +20,10 @@ package org.apache.celeborn.service.deploy.worker
import java.io.{FileNotFoundException, IOException}
import java.nio.ByteBuffer
import java.nio.charset.StandardCharsets
import java.util
import java.util.concurrent.atomic.AtomicBoolean
import com.google.common.base.Throwables
import io.netty.buffer.{ByteBuf, Unpooled}
import io.netty.util.concurrent.{Future, GenericFutureListener}
import org.apache.celeborn.common.exception.CelebornException
@ -113,7 +113,7 @@ class FetchHandler(val conf: TransportConf) extends BaseMessageHandler with Logg
new NioManagedBuffer(streamHandle.toByteBuffer)))
} else {
val buffers = new FileManagedBuffers(fileInfo, conf)
val streamId = chunkStreamManager.registerStream(buffers, client.getChannel)
val streamId = chunkStreamManager.registerStream(shuffleKey, buffers)
val streamHandle = new StreamHandle(streamId, fileInfo.numChunks())
if (fileInfo.numChunks() == 0)
logDebug(s"StreamId $streamId fileName $fileName startMapIndex" +
@ -205,7 +205,6 @@ class FetchHandler(val conf: TransportConf) extends BaseMessageHandler with Logg
override def checkRegistered: Boolean = registered.get
override def channelInactive(client: TransportClient): Unit = {
chunkStreamManager.connectionTerminated(client.getChannel)
bufferStreamManager.connectionTerminated(client.getChannel)
logDebug(s"channel inactive ${client.getSocketAddress}")
}
@ -213,4 +212,8 @@ class FetchHandler(val conf: TransportConf) extends BaseMessageHandler with Logg
override def exceptionCaught(cause: Throwable, client: TransportClient): Unit = {
logWarning(s"exception caught ${client.getSocketAddress}", cause)
}
def cleanupExpiredShuffleKey(expiredShuffleKeys: util.HashSet[String]): Unit = {
chunkStreamManager.cleanupExpiredShuffleKey(expiredShuffleKeys)
}
}

View File

@ -436,6 +436,7 @@ private[celeborn] class Worker(
}
partitionsSorter.cleanup(expiredShuffleKeys)
storageManager.cleanupExpiredShuffleKey(expiredShuffleKeys)
fetchHandler.cleanupExpiredShuffleKey(expiredShuffleKeys)
}
@VisibleForTesting