[CELEBORN-267] reuse stream when client channel reconnected (#1200)
This commit is contained in:
parent
4b6f7e4593
commit
2634476758
@ -17,14 +17,13 @@
|
||||
|
||||
package org.apache.celeborn.common.network.server;
|
||||
|
||||
import java.util.Map;
|
||||
import java.util.Random;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.atomic.AtomicLong;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import com.google.common.base.Preconditions;
|
||||
import io.netty.channel.Channel;
|
||||
import org.apache.commons.lang3.tuple.ImmutablePair;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.slf4j.Logger;
|
||||
@ -41,14 +40,15 @@ public class ChunkStreamManager {
|
||||
private static final Logger logger = LoggerFactory.getLogger(ChunkStreamManager.class);
|
||||
|
||||
private final AtomicLong nextStreamId;
|
||||
// StreamId -> StreamState
|
||||
protected final ConcurrentHashMap<Long, StreamState> streams;
|
||||
// ShuffleKey -> StreamId
|
||||
protected final ConcurrentHashMap<String, Set<Long>> shuffleStreamIds;
|
||||
|
||||
/** State of a single stream. */
|
||||
protected static class StreamState {
|
||||
final FileManagedBuffers buffers;
|
||||
|
||||
// The channel associated to the stream
|
||||
final Channel associatedChannel;
|
||||
final String shuffleKey;
|
||||
|
||||
// Used to keep track of the index of the buffer that the user has retrieved, just to ensure
|
||||
// that the caller only requests each chunk one at a time, in order.
|
||||
@ -57,9 +57,9 @@ public class ChunkStreamManager {
|
||||
// Used to keep track of the number of chunks being transferred and not finished yet.
|
||||
volatile long chunksBeingTransferred = 0L;
|
||||
|
||||
StreamState(FileManagedBuffers buffers, Channel channel) {
|
||||
StreamState(String shuffleKey, FileManagedBuffers buffers) {
|
||||
this.buffers = Preconditions.checkNotNull(buffers);
|
||||
this.associatedChannel = channel;
|
||||
this.shuffleKey = shuffleKey;
|
||||
}
|
||||
}
|
||||
|
||||
@ -68,6 +68,7 @@ public class ChunkStreamManager {
|
||||
// This does not need to be globally unique, only unique to this class.
|
||||
nextStreamId = new AtomicLong((long) new Random().nextInt(Integer.MAX_VALUE) * 1000);
|
||||
streams = new ConcurrentHashMap<>();
|
||||
shuffleStreamIds = new ConcurrentHashMap<>();
|
||||
}
|
||||
|
||||
public ManagedBuffer getChunk(long streamId, int chunkIndex, int offset, int len) {
|
||||
@ -91,9 +92,13 @@ public class ChunkStreamManager {
|
||||
if (state.buffers.isFullyRead()) {
|
||||
// Normally, when all chunks are returned to the client, the stream should be removed here.
|
||||
// But if there is a switch on the client side, it will not go here at this time, so we need
|
||||
// to remove the stream when the connection is terminated, and release the unused buffer.
|
||||
// to remove the stream when the shuffle is expired, and release the unused buffer.
|
||||
logger.trace("Removing stream id {}", streamId);
|
||||
streams.remove(streamId);
|
||||
Set<Long> streamIds = shuffleStreamIds.get(state.shuffleKey);
|
||||
if (streamIds != null) {
|
||||
streamIds.remove(streamId);
|
||||
}
|
||||
}
|
||||
|
||||
return nextChunk;
|
||||
@ -113,16 +118,6 @@ public class ChunkStreamManager {
|
||||
return ImmutablePair.of(streamId, chunkIndex);
|
||||
}
|
||||
|
||||
public void connectionTerminated(Channel channel) {
|
||||
// Close all streams which have been associated with the channel.
|
||||
for (Map.Entry<Long, StreamState> entry : streams.entrySet()) {
|
||||
StreamState state = entry.getValue();
|
||||
if (state.associatedChannel == channel) {
|
||||
streams.remove(entry.getKey());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void chunkBeingSent(long streamId) {
|
||||
StreamState streamState = streams.get(streamId);
|
||||
if (streamState != null) {
|
||||
@ -154,18 +149,43 @@ public class ChunkStreamManager {
|
||||
* <p>If an app ID is provided, only callers who've authenticated with the given app ID will be
|
||||
* allowed to fetch from this stream.
|
||||
*
|
||||
* <p>This method also associates the stream with a single client connection, which is guaranteed
|
||||
* to be the only reader of the stream. Once the connection is closed, the stream will never be
|
||||
* used again, enabling cleanup by `connectionTerminated`.
|
||||
* <p>This stream could be reused again when other channel of the client is reconnected. If a
|
||||
* stream is not properly closed, it will eventually be cleaned up by `cleanupExpiredShuffleKey`.
|
||||
*/
|
||||
public long registerStream(FileManagedBuffers buffers, Channel channel) {
|
||||
public long registerStream(String shuffleKey, FileManagedBuffers buffers) {
|
||||
long myStreamId = nextStreamId.getAndIncrement();
|
||||
streams.put(myStreamId, new StreamState(buffers, channel));
|
||||
streams.put(myStreamId, new StreamState(shuffleKey, buffers));
|
||||
shuffleStreamIds.compute(
|
||||
shuffleKey,
|
||||
(key, value) -> {
|
||||
if (value == null) {
|
||||
value = ConcurrentHashMap.newKeySet();
|
||||
}
|
||||
value.add(myStreamId);
|
||||
return value;
|
||||
});
|
||||
|
||||
return myStreamId;
|
||||
}
|
||||
|
||||
public void cleanupExpiredShuffleKey(Set<String> expiredShuffleKeys) {
|
||||
for (String expiredShuffleKey : expiredShuffleKeys) {
|
||||
Set<Long> expiredStreamIds = shuffleStreamIds.remove(expiredShuffleKey);
|
||||
|
||||
// normally expiredStreamIds set will be empty as streamId will be removed when be fully read
|
||||
if (expiredStreamIds != null && !expiredStreamIds.isEmpty()) {
|
||||
streams.keySet().removeAll(expiredStreamIds);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
public int numStreamStates() {
|
||||
return streams.size();
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
public long numShuffleSteams() {
|
||||
return shuffleStreamIds.values().stream().flatMap(Set::stream).count();
|
||||
}
|
||||
}
|
||||
|
||||
@ -17,7 +17,9 @@
|
||||
|
||||
package org.apache.celeborn.common.network.server;
|
||||
|
||||
import io.netty.channel.Channel;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashSet;
|
||||
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
import org.mockito.Mockito;
|
||||
@ -26,7 +28,7 @@ import org.apache.celeborn.common.meta.FileManagedBuffers;
|
||||
|
||||
public class ChunkStreamManagerSuiteJ {
|
||||
@Test
|
||||
public void streamStatesAreFreedWhenConnectionIsClosedEvenIfBufferIteratorThrowsException() {
|
||||
public void testStreamRegisterAndCleanup() {
|
||||
ChunkStreamManager manager = new ChunkStreamManager();
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
@ -34,14 +36,30 @@ public class ChunkStreamManagerSuiteJ {
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
FileManagedBuffers buffers2 = Mockito.mock(FileManagedBuffers.class);
|
||||
FileManagedBuffers buffers3 = Mockito.mock(FileManagedBuffers.class);
|
||||
FileManagedBuffers buffers4 = Mockito.mock(FileManagedBuffers.class);
|
||||
|
||||
Channel dummyChannel = Mockito.mock(Channel.class, Mockito.RETURNS_SMART_NULLS);
|
||||
manager.registerStream(buffers, dummyChannel);
|
||||
manager.registerStream(buffers2, dummyChannel);
|
||||
manager.registerStream("shuffleKey1", buffers);
|
||||
manager.registerStream("shuffleKey1", buffers2);
|
||||
manager.registerStream("shuffleKey2", buffers3);
|
||||
long stream3 = manager.registerStream("shuffleKey3", buffers4);
|
||||
Assert.assertEquals(4, manager.numStreamStates());
|
||||
Assert.assertEquals(manager.numStreamStates(), manager.numShuffleSteams());
|
||||
|
||||
Assert.assertEquals(2, manager.numStreamStates());
|
||||
manager.cleanupExpiredShuffleKey(new HashSet<>(Arrays.asList("shuffleKey1", "shuffleKey2")));
|
||||
manager.cleanupExpiredShuffleKey(new HashSet<>(Arrays.asList("none_exit_shuffleKey")));
|
||||
|
||||
manager.connectionTerminated(dummyChannel);
|
||||
assert manager.streams.isEmpty();
|
||||
Assert.assertEquals(1, manager.numStreamStates());
|
||||
Assert.assertEquals(manager.numStreamStates(), manager.numShuffleSteams());
|
||||
|
||||
// stream removed when buffer fully read
|
||||
manager.streams.remove(stream3);
|
||||
manager.shuffleStreamIds.get("shuffleKey3").remove(stream3);
|
||||
Assert.assertEquals(0, manager.numStreamStates());
|
||||
Assert.assertEquals(manager.numStreamStates(), manager.numShuffleSteams());
|
||||
|
||||
// cleanup shuffleKey3
|
||||
manager.cleanupExpiredShuffleKey(new HashSet<>(Arrays.asList("shuffleKey3")));
|
||||
Assert.assertEquals(manager.numStreamStates(), manager.numShuffleSteams());
|
||||
}
|
||||
}
|
||||
|
||||
@ -20,10 +20,10 @@ package org.apache.celeborn.service.deploy.worker
|
||||
import java.io.{FileNotFoundException, IOException}
|
||||
import java.nio.ByteBuffer
|
||||
import java.nio.charset.StandardCharsets
|
||||
import java.util
|
||||
import java.util.concurrent.atomic.AtomicBoolean
|
||||
|
||||
import com.google.common.base.Throwables
|
||||
import io.netty.buffer.{ByteBuf, Unpooled}
|
||||
import io.netty.util.concurrent.{Future, GenericFutureListener}
|
||||
|
||||
import org.apache.celeborn.common.exception.CelebornException
|
||||
@ -113,7 +113,7 @@ class FetchHandler(val conf: TransportConf) extends BaseMessageHandler with Logg
|
||||
new NioManagedBuffer(streamHandle.toByteBuffer)))
|
||||
} else {
|
||||
val buffers = new FileManagedBuffers(fileInfo, conf)
|
||||
val streamId = chunkStreamManager.registerStream(buffers, client.getChannel)
|
||||
val streamId = chunkStreamManager.registerStream(shuffleKey, buffers)
|
||||
val streamHandle = new StreamHandle(streamId, fileInfo.numChunks())
|
||||
if (fileInfo.numChunks() == 0)
|
||||
logDebug(s"StreamId $streamId fileName $fileName startMapIndex" +
|
||||
@ -205,7 +205,6 @@ class FetchHandler(val conf: TransportConf) extends BaseMessageHandler with Logg
|
||||
override def checkRegistered: Boolean = registered.get
|
||||
|
||||
override def channelInactive(client: TransportClient): Unit = {
|
||||
chunkStreamManager.connectionTerminated(client.getChannel)
|
||||
bufferStreamManager.connectionTerminated(client.getChannel)
|
||||
logDebug(s"channel inactive ${client.getSocketAddress}")
|
||||
}
|
||||
@ -213,4 +212,8 @@ class FetchHandler(val conf: TransportConf) extends BaseMessageHandler with Logg
|
||||
override def exceptionCaught(cause: Throwable, client: TransportClient): Unit = {
|
||||
logWarning(s"exception caught ${client.getSocketAddress}", cause)
|
||||
}
|
||||
|
||||
def cleanupExpiredShuffleKey(expiredShuffleKeys: util.HashSet[String]): Unit = {
|
||||
chunkStreamManager.cleanupExpiredShuffleKey(expiredShuffleKeys)
|
||||
}
|
||||
}
|
||||
|
||||
@ -436,6 +436,7 @@ private[celeborn] class Worker(
|
||||
}
|
||||
partitionsSorter.cleanup(expiredShuffleKeys)
|
||||
storageManager.cleanupExpiredShuffleKey(expiredShuffleKeys)
|
||||
fetchHandler.cleanupExpiredShuffleKey(expiredShuffleKeys)
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
|
||||
Loading…
Reference in New Issue
Block a user