[CELEBORN-772] Convert StreamChunkSlice, ChunkFetchRequest, TransportableError to PB

### What changes were proposed in this pull request?

`StreamChunkSlice`, `ChunkFetchRequest` and `TransportableError` should merge to transport messages to enhance celeborn's compatibility.

### Why are the changes needed?

1. Improves celeborn's transport flexibility to change RPC.
2. Makes Compatible with 0.2 client.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

- `FetchHandlerSuiteJ`
- `RequestTimeoutIntegrationSuiteJ`
- `ChunkFetchIntegrationSuiteJ`

Closes #1982 from SteNicholas/CELEBORN-772.

Authored-by: SteNicholas <programgeek@163.com>
Signed-off-by: Shuang <lvshuang.tb@gmail.com>
This commit is contained in:
SteNicholas 2023-10-17 11:12:01 +08:00 committed by Shuang
parent bfa341c32f
commit 9244cf2cf2
13 changed files with 159 additions and 38 deletions

View File

@ -19,6 +19,7 @@ package org.apache.celeborn.plugin.flink.network;
import static org.apache.celeborn.common.protocol.MessageType.BACKLOG_ANNOUNCEMENT_VALUE;
import static org.apache.celeborn.common.protocol.MessageType.BUFFER_STREAM_END_VALUE;
import static org.apache.celeborn.common.protocol.MessageType.TRANSPORTABLE_ERROR_VALUE;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
@ -103,6 +104,9 @@ public class ReadClientHandler extends BaseMessageHandler {
case BUFFER_STREAM_END_VALUE:
receive(client, BufferStreamEnd.fromProto(transportMessage.getParsedPayload()));
break;
case TRANSPORTABLE_ERROR_VALUE:
receive(client, TransportableError.fromProto(transportMessage.getParsedPayload()));
break;
}
} catch (IOException e) {
logger.warn("Failed to process RpcRequest message {}. ", msg, e);

View File

@ -36,8 +36,15 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.celeborn.common.network.buffer.NioManagedBuffer;
import org.apache.celeborn.common.network.protocol.*;
import org.apache.celeborn.common.network.protocol.OneWayMessage;
import org.apache.celeborn.common.network.protocol.PushData;
import org.apache.celeborn.common.network.protocol.PushMergedData;
import org.apache.celeborn.common.network.protocol.RpcRequest;
import org.apache.celeborn.common.network.protocol.StreamChunkSlice;
import org.apache.celeborn.common.network.protocol.TransportMessage;
import org.apache.celeborn.common.network.util.NettyUtils;
import org.apache.celeborn.common.protocol.MessageType;
import org.apache.celeborn.common.protocol.PbChunkFetchRequest;
import org.apache.celeborn.common.read.FetchRequestInfo;
import org.apache.celeborn.common.write.PushRequestInfo;
@ -140,7 +147,19 @@ public class TransportClient implements Closeable {
handler.addFetchRequest(streamChunkSlice, info);
ChannelFuture channelFuture =
channel.writeAndFlush(new ChunkFetchRequest(streamChunkSlice)).addListener(listener);
channel
.writeAndFlush(
new RpcRequest(
TransportClient.requestId(),
new NioManagedBuffer(
new TransportMessage(
MessageType.CHUNK_FETCH_REQUEST,
PbChunkFetchRequest.newBuilder()
.setStreamChunkSlice(streamChunkSlice.toProto())
.build()
.toByteArray())
.toByteBuffer())))
.addListener(listener);
info.setChannelFuture(channelFuture);
}

View File

@ -20,7 +20,9 @@ package org.apache.celeborn.common.network.protocol;
import com.google.common.base.Objects;
import io.netty.buffer.ByteBuf;
/** Response to {@link ChunkFetchRequest} when there is an error fetching the chunk. */
import org.apache.celeborn.common.protocol.PbChunkFetchRequest;
/** Response to {@link PbChunkFetchRequest} when there is an error fetching the chunk. */
public final class ChunkFetchFailure extends ResponseMessage {
public final StreamChunkSlice streamChunkSlice;
public final String errorString;

View File

@ -24,6 +24,7 @@ import io.netty.buffer.ByteBuf;
* Request to fetch a sequence of a single chunk of a stream. This will correspond to a single
* {@link ResponseMessage} (either success or failure).
*/
@Deprecated
public final class ChunkFetchRequest extends RequestMessage {
public final StreamChunkSlice streamChunkSlice;

View File

@ -22,9 +22,10 @@ import io.netty.buffer.ByteBuf;
import org.apache.celeborn.common.network.buffer.ManagedBuffer;
import org.apache.celeborn.common.network.buffer.NettyManagedBuffer;
import org.apache.celeborn.common.protocol.PbChunkFetchRequest;
/**
* Response to {@link ChunkFetchRequest} when a chunk exists and has been successfully fetched.
* Response to {@link PbChunkFetchRequest} when a chunk exists and has been successfully fetched.
*
* <p>Note that the server-side encoding of this message does NOT include the buffer itself, as this
* may be written by Netty in a more efficient manner (i.e., zero-copy write). Similarly, the

View File

@ -20,6 +20,8 @@ package org.apache.celeborn.common.network.protocol;
import com.google.common.base.Objects;
import io.netty.buffer.ByteBuf;
import org.apache.celeborn.common.protocol.PbStreamChunkSlice;
/** Encapsulates a request for a particular chunk of a stream. */
public final class StreamChunkSlice implements Encodable {
public final long streamId;
@ -90,4 +92,17 @@ public final class StreamChunkSlice implements Encodable {
.add("len", len)
.toString();
}
public PbStreamChunkSlice toProto() {
return PbStreamChunkSlice.newBuilder()
.setStreamId(streamId)
.setChunkIndex(chunkIndex)
.setOffset(offset)
.setLen(len)
.build();
}
public static StreamChunkSlice fromProto(PbStreamChunkSlice pb) {
return new StreamChunkSlice(pb.getStreamId(), pb.getChunkIndex(), pb.getOffset(), pb.getLen());
}
}

View File

@ -31,12 +31,15 @@ import org.apache.celeborn.common.exception.CelebornIOException;
import org.apache.celeborn.common.protocol.MessageType;
import org.apache.celeborn.common.protocol.PbBacklogAnnouncement;
import org.apache.celeborn.common.protocol.PbBufferStreamEnd;
import org.apache.celeborn.common.protocol.PbChunkFetchRequest;
import org.apache.celeborn.common.protocol.PbOpenStream;
import org.apache.celeborn.common.protocol.PbPushDataHandShake;
import org.apache.celeborn.common.protocol.PbReadAddCredit;
import org.apache.celeborn.common.protocol.PbRegionFinish;
import org.apache.celeborn.common.protocol.PbRegionStart;
import org.apache.celeborn.common.protocol.PbStreamChunkSlice;
import org.apache.celeborn.common.protocol.PbStreamHandler;
import org.apache.celeborn.common.protocol.PbTransportableError;
public class TransportMessage implements Serializable {
private static final long serialVersionUID = -3259000920699629773L;
@ -81,6 +84,12 @@ public class TransportMessage implements Serializable {
return (T) PbBufferStreamEnd.parseFrom(payload);
case READ_ADD_CREDIT_VALUE:
return (T) PbReadAddCredit.parseFrom(payload);
case STREAM_CHUNK_SLICE_VALUE:
return (T) PbStreamChunkSlice.parseFrom(payload);
case CHUNK_FETCH_REQUEST_VALUE:
return (T) PbChunkFetchRequest.parseFrom(payload);
case TRANSPORTABLE_ERROR_VALUE:
return (T) PbTransportableError.parseFrom(payload);
default:
logger.error("Unexpected type {}", type);
}

View File

@ -21,6 +21,7 @@ import java.nio.charset.StandardCharsets;
import io.netty.buffer.ByteBuf;
import org.apache.celeborn.common.protocol.PbTransportableError;
import org.apache.celeborn.common.util.ExceptionUtils;
public class TransportableError extends RequestMessage {
@ -70,4 +71,9 @@ public class TransportableError extends RequestMessage {
public String getErrorMessage() {
return new String(errorMessage, StandardCharsets.UTF_8);
}
public static TransportableError fromProto(PbTransportableError pb) {
return new TransportableError(
pb.getStreamId(), pb.getMessage().getBytes(StandardCharsets.UTF_8));
}
}

View File

@ -82,6 +82,9 @@ enum MessageType {
BACKLOG_ANNOUNCEMENT = 59;
BUFFER_STREAM_END = 60;
READ_ADD_CREDIT = 61;
STREAM_CHUNK_SLICE = 62;
CHUNK_FETCH_REQUEST = 63;
TRANSPORTABLE_ERROR = 64;
}
enum StreamType {
@ -551,3 +554,19 @@ message PbReadAddCredit {
int64 streamId = 1;
int32 credit = 2;
}
message PbStreamChunkSlice {
int64 streamId = 1;
int32 chunkIndex = 2;
int32 offset = 3;
int32 len = 4;
}
message PbChunkFetchRequest {
PbStreamChunkSlice streamChunkSlice = 1;
}
message PbTransportableError {
int64 streamId = 1;
string message = 2;
}

View File

@ -37,7 +37,7 @@ import org.apache.celeborn.common.network.client.TransportClient
import org.apache.celeborn.common.network.protocol._
import org.apache.celeborn.common.network.server.BaseMessageHandler
import org.apache.celeborn.common.network.util.{NettyUtils, TransportConf}
import org.apache.celeborn.common.protocol.{MessageType, PartitionType, PbBufferStreamEnd, PbOpenStream, PbReadAddCredit, PbStreamHandler, StreamType}
import org.apache.celeborn.common.protocol.{MessageType, PartitionType, PbBufferStreamEnd, PbChunkFetchRequest, PbOpenStream, PbReadAddCredit, PbStreamHandler, StreamType}
import org.apache.celeborn.common.util.{ExceptionUtils, Utils}
import org.apache.celeborn.service.deploy.worker.storage.{ChunkStreamManager, CreditStreamManager, PartitionFilesSorter, StorageManager}
@ -93,7 +93,7 @@ class FetchHandler(val conf: CelebornConf, val transportConf: TransportConf)
case r: ReadAddCredit =>
handleReadAddCredit(r.getCredit, r.getStreamId)
case r: ChunkFetchRequest =>
handleChunkFetchRequest(client, r)
handleChunkFetchRequest(client, r.streamChunkSlice, r)
case r: RpcRequest =>
handleRpcRequest(client, r)
case unknown: RequestMessage =>
@ -125,9 +125,14 @@ class FetchHandler(val conf: CelebornConf, val transportConf: TransportConf)
isLegacy = false,
openStream.getReadLocalShuffle)
case bufferStreamEnd: PbBufferStreamEnd =>
handleEndStreamFromClient(bufferStreamEnd)
handleEndStreamFromClient(bufferStreamEnd.getStreamId, bufferStreamEnd.getStreamType)
case readAddCredit: PbReadAddCredit =>
handleReadAddCredit(readAddCredit.getCredit, readAddCredit.getStreamId)
case chunkFetchRequest: PbChunkFetchRequest =>
handleChunkFetchRequest(
client,
StreamChunkSlice.fromProto(chunkFetchRequest.getStreamChunkSlice),
rpcRequest)
case message: GeneratedMessageV3 =>
logError(s"Unknown message $message")
}
@ -318,18 +323,18 @@ class FetchHandler(val conf: CelebornConf, val transportConf: TransportConf)
}
def handleEndStreamFromClient(streamId: Long): Unit = {
creditStreamManager.notifyStreamEndByClient(streamId)
handleEndStreamFromClient(streamId, StreamType.CreditStream)
}
def handleEndStreamFromClient(req: PbBufferStreamEnd): Unit = {
req.getStreamType match {
def handleEndStreamFromClient(streamId: Long, streamType: StreamType): Unit = {
streamType match {
case StreamType.ChunkStream =>
val (shuffleKey, fileName) = chunkStreamManager.getShuffleKeyAndFileName(req.getStreamId)
getRawFileInfo(shuffleKey, fileName).closeStream(req.getStreamId)
val (shuffleKey, fileName) = chunkStreamManager.getShuffleKeyAndFileName(streamId)
getRawFileInfo(shuffleKey, fileName).closeStream(streamId)
case StreamType.CreditStream =>
creditStreamManager.notifyStreamEndByClient(req.getStreamId)
creditStreamManager.notifyStreamEndByClient(streamId)
case _ =>
logError(s"Received a PbBufferStreamEnd message with unknown type ${req.getStreamType}")
logError(s"Received a PbBufferStreamEnd message with unknown type $streamType")
}
}
@ -337,9 +342,12 @@ class FetchHandler(val conf: CelebornConf, val transportConf: TransportConf)
creditStreamManager.addCredit(credit, streamId)
}
def handleChunkFetchRequest(client: TransportClient, req: ChunkFetchRequest): Unit = {
def handleChunkFetchRequest(
client: TransportClient,
streamChunkSlice: StreamChunkSlice,
req: RequestMessage): Unit = {
logDebug(s"Received req from ${NettyUtils.getRemoteAddress(client.getChannel)}" +
s" to fetch block ${req.streamChunkSlice}")
s" to fetch block $streamChunkSlice")
maxChunkBeingTransferred.foreach { threshold =>
val chunksBeingTransferred = chunkStreamManager.chunksBeingTransferred // take high cpu usage
@ -348,35 +356,35 @@ class FetchHandler(val conf: CelebornConf, val transportConf: TransportConf)
s"$chunksBeingTransferred exceeds ${MAX_CHUNKS_BEING_TRANSFERRED.key} " +
s"${Utils.bytesToString(threshold)}."
logError(message)
client.getChannel.writeAndFlush(new ChunkFetchFailure(req.streamChunkSlice, message))
client.getChannel.writeAndFlush(new ChunkFetchFailure(streamChunkSlice, message))
return
}
}
workerSource.startTimer(WorkerSource.FETCH_CHUNK_TIME, req.toString)
val fetchTimeMetric = chunkStreamManager.getFetchTimeMetric(req.streamChunkSlice.streamId)
val fetchTimeMetric = chunkStreamManager.getFetchTimeMetric(streamChunkSlice.streamId)
val fetchBeginTime = System.nanoTime()
try {
val buf = chunkStreamManager.getChunk(
req.streamChunkSlice.streamId,
req.streamChunkSlice.chunkIndex,
req.streamChunkSlice.offset,
req.streamChunkSlice.len)
chunkStreamManager.chunkBeingSent(req.streamChunkSlice.streamId)
client.getChannel.writeAndFlush(new ChunkFetchSuccess(req.streamChunkSlice, buf))
streamChunkSlice.streamId,
streamChunkSlice.chunkIndex,
streamChunkSlice.offset,
streamChunkSlice.len)
chunkStreamManager.chunkBeingSent(streamChunkSlice.streamId)
client.getChannel.writeAndFlush(new ChunkFetchSuccess(streamChunkSlice, buf))
.addListener(new GenericFutureListener[Future[_ >: Void]] {
override def operationComplete(future: Future[_ >: Void]): Unit = {
if (future.isSuccess()) {
if (future.isSuccess) {
if (log.isDebugEnabled) {
logDebug(
s"Sending ChunkFetchSuccess operation succeeded, chunk ${req.streamChunkSlice}")
s"Sending ChunkFetchSuccess operation succeeded, chunk $streamChunkSlice")
}
} else {
logError(
s"Sending ChunkFetchSuccess operation failed, chunk ${req.streamChunkSlice}",
s"Sending ChunkFetchSuccess operation failed, chunk $streamChunkSlice",
future.cause())
}
chunkStreamManager.chunkSent(req.streamChunkSlice.streamId)
chunkStreamManager.chunkSent(streamChunkSlice.streamId)
if (fetchTimeMetric != null) {
fetchTimeMetric.update(System.nanoTime() - fetchBeginTime)
}
@ -386,11 +394,11 @@ class FetchHandler(val conf: CelebornConf, val transportConf: TransportConf)
} catch {
case e: Exception =>
logError(
s"Error opening block ${req.streamChunkSlice} for request from " +
s"Error opening block $streamChunkSlice for request from " +
NettyUtils.getRemoteAddress(client.getChannel),
e)
client.getChannel.writeAndFlush(new ChunkFetchFailure(
req.streamChunkSlice,
streamChunkSlice,
Throwables.getStackTraceAsString(e)))
workerSource.stopTimer(WorkerSource.FETCH_CHUNK_TIME, req.toString)
}

View File

@ -47,19 +47,19 @@ import org.apache.celeborn.common.meta.FileInfo;
import org.apache.celeborn.common.network.buffer.NioManagedBuffer;
import org.apache.celeborn.common.network.client.TransportClient;
import org.apache.celeborn.common.network.client.TransportResponseHandler;
import org.apache.celeborn.common.network.protocol.ChunkFetchRequest;
import org.apache.celeborn.common.network.protocol.ChunkFetchSuccess;
import org.apache.celeborn.common.network.protocol.Message;
import org.apache.celeborn.common.network.protocol.OpenStream;
import org.apache.celeborn.common.network.protocol.RpcRequest;
import org.apache.celeborn.common.network.protocol.RpcResponse;
import org.apache.celeborn.common.network.protocol.StreamChunkSlice;
import org.apache.celeborn.common.network.protocol.StreamHandle;
import org.apache.celeborn.common.network.protocol.TransportMessage;
import org.apache.celeborn.common.network.util.TransportConf;
import org.apache.celeborn.common.protocol.MessageType;
import org.apache.celeborn.common.protocol.PbBufferStreamEnd;
import org.apache.celeborn.common.protocol.PbChunkFetchRequest;
import org.apache.celeborn.common.protocol.PbOpenStream;
import org.apache.celeborn.common.protocol.PbStreamChunkSlice;
import org.apache.celeborn.common.protocol.PbStreamHandler;
import org.apache.celeborn.common.protocol.StreamType;
import org.apache.celeborn.common.protocol.TransportModuleConstants;
@ -339,8 +339,21 @@ public class FetchHandlerSuiteJ {
for (int chunkIndex = 0; chunkIndex < streamHandler.getNumChunks(); chunkIndex++) {
fetchHandler.receive(
client,
new ChunkFetchRequest(
new StreamChunkSlice(streamHandler.getStreamId(), chunkIndex, 0, Integer.MAX_VALUE)));
new RpcRequest(
TransportClient.requestId(),
new NioManagedBuffer(
new TransportMessage(
MessageType.CHUNK_FETCH_REQUEST,
PbChunkFetchRequest.newBuilder()
.setStreamChunkSlice(
PbStreamChunkSlice.newBuilder()
.setStreamId(streamHandler.getStreamId())
.setChunkIndex(chunkIndex)
.setOffset(0)
.setLen(Integer.MAX_VALUE))
.build()
.toByteArray())
.toByteBuffer())));
ChunkFetchSuccess chunkFetchSuccess = channel.readOutbound();
chunkFetchSuccess.body().retain();
// chunk size 8m

View File

@ -39,10 +39,16 @@ import org.apache.celeborn.common.network.client.ChunkReceivedCallback;
import org.apache.celeborn.common.network.client.RpcResponseCallback;
import org.apache.celeborn.common.network.client.TransportClient;
import org.apache.celeborn.common.network.client.TransportClientFactory;
import org.apache.celeborn.common.network.protocol.*;
import org.apache.celeborn.common.network.protocol.ChunkFetchSuccess;
import org.apache.celeborn.common.network.protocol.RequestMessage;
import org.apache.celeborn.common.network.protocol.RpcRequest;
import org.apache.celeborn.common.network.protocol.RpcResponse;
import org.apache.celeborn.common.network.protocol.StreamChunkSlice;
import org.apache.celeborn.common.network.protocol.TransportMessage;
import org.apache.celeborn.common.network.server.BaseMessageHandler;
import org.apache.celeborn.common.network.server.TransportServer;
import org.apache.celeborn.common.network.util.TransportConf;
import org.apache.celeborn.common.protocol.PbChunkFetchRequest;
import org.apache.celeborn.service.deploy.worker.storage.ChunkStreamManager;
/**
@ -196,7 +202,15 @@ public class RequestTimeoutIntegrationSuiteJ {
new BaseMessageHandler() {
@Override
public void receive(TransportClient client, RequestMessage msg) {
StreamChunkSlice slice = ((ChunkFetchRequest) msg).streamChunkSlice;
PbChunkFetchRequest chunkFetchRequest;
try {
chunkFetchRequest =
TransportMessage.fromByteBuffer(msg.body().nioByteBuffer()).getParsedPayload();
} catch (IOException e) {
throw new RuntimeException(e);
}
StreamChunkSlice slice =
StreamChunkSlice.fromProto(chunkFetchRequest.getStreamChunkSlice());
ManagedBuffer buf =
manager.getChunk(slice.streamId, slice.chunkIndex, slice.offset, slice.len);
client.getChannel().writeAndFlush(new ChunkFetchSuccess(slice, buf));

View File

@ -21,6 +21,7 @@ import static org.apache.celeborn.common.util.JavaUtils.getLocalHost;
import static org.junit.Assert.*;
import java.io.File;
import java.io.IOException;
import java.io.RandomAccessFile;
import java.nio.ByteBuffer;
import java.util.*;
@ -41,13 +42,14 @@ import org.apache.celeborn.common.network.buffer.NioManagedBuffer;
import org.apache.celeborn.common.network.client.ChunkReceivedCallback;
import org.apache.celeborn.common.network.client.TransportClient;
import org.apache.celeborn.common.network.client.TransportClientFactory;
import org.apache.celeborn.common.network.protocol.ChunkFetchRequest;
import org.apache.celeborn.common.network.protocol.ChunkFetchSuccess;
import org.apache.celeborn.common.network.protocol.RequestMessage;
import org.apache.celeborn.common.network.protocol.StreamChunkSlice;
import org.apache.celeborn.common.network.protocol.TransportMessage;
import org.apache.celeborn.common.network.server.BaseMessageHandler;
import org.apache.celeborn.common.network.server.TransportServer;
import org.apache.celeborn.common.network.util.TransportConf;
import org.apache.celeborn.common.protocol.PbChunkFetchRequest;
public class ChunkFetchIntegrationSuiteJ {
static final long STREAM_ID = 1;
@ -106,7 +108,15 @@ public class ChunkFetchIntegrationSuiteJ {
new BaseMessageHandler() {
@Override
public void receive(TransportClient client, RequestMessage msg) {
StreamChunkSlice slice = ((ChunkFetchRequest) msg).streamChunkSlice;
PbChunkFetchRequest chunkFetchRequest;
try {
chunkFetchRequest =
TransportMessage.fromByteBuffer(msg.body().nioByteBuffer()).getParsedPayload();
} catch (IOException e) {
throw new RuntimeException(e);
}
StreamChunkSlice slice =
StreamChunkSlice.fromProto(chunkFetchRequest.getStreamChunkSlice());
ManagedBuffer buf =
chunkStreamManager.getChunk(
slice.streamId, slice.chunkIndex, slice.offset, slice.len);