diff --git a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/ReadClientHandler.java b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/ReadClientHandler.java index 5c100002c..b55e3a618 100644 --- a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/ReadClientHandler.java +++ b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/ReadClientHandler.java @@ -19,6 +19,7 @@ package org.apache.celeborn.plugin.flink.network; import static org.apache.celeborn.common.protocol.MessageType.BACKLOG_ANNOUNCEMENT_VALUE; import static org.apache.celeborn.common.protocol.MessageType.BUFFER_STREAM_END_VALUE; +import static org.apache.celeborn.common.protocol.MessageType.TRANSPORTABLE_ERROR_VALUE; import java.io.IOException; import java.nio.charset.StandardCharsets; @@ -103,6 +104,9 @@ public class ReadClientHandler extends BaseMessageHandler { case BUFFER_STREAM_END_VALUE: receive(client, BufferStreamEnd.fromProto(transportMessage.getParsedPayload())); break; + case TRANSPORTABLE_ERROR_VALUE: + receive(client, TransportableError.fromProto(transportMessage.getParsedPayload())); + break; } } catch (IOException e) { logger.warn("Failed to process RpcRequest message {}. ", msg, e); diff --git a/common/src/main/java/org/apache/celeborn/common/network/client/TransportClient.java b/common/src/main/java/org/apache/celeborn/common/network/client/TransportClient.java index 151895607..7865732c3 100644 --- a/common/src/main/java/org/apache/celeborn/common/network/client/TransportClient.java +++ b/common/src/main/java/org/apache/celeborn/common/network/client/TransportClient.java @@ -36,8 +36,15 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.celeborn.common.network.buffer.NioManagedBuffer; -import org.apache.celeborn.common.network.protocol.*; +import org.apache.celeborn.common.network.protocol.OneWayMessage; +import org.apache.celeborn.common.network.protocol.PushData; +import org.apache.celeborn.common.network.protocol.PushMergedData; +import org.apache.celeborn.common.network.protocol.RpcRequest; +import org.apache.celeborn.common.network.protocol.StreamChunkSlice; +import org.apache.celeborn.common.network.protocol.TransportMessage; import org.apache.celeborn.common.network.util.NettyUtils; +import org.apache.celeborn.common.protocol.MessageType; +import org.apache.celeborn.common.protocol.PbChunkFetchRequest; import org.apache.celeborn.common.read.FetchRequestInfo; import org.apache.celeborn.common.write.PushRequestInfo; @@ -140,7 +147,19 @@ public class TransportClient implements Closeable { handler.addFetchRequest(streamChunkSlice, info); ChannelFuture channelFuture = - channel.writeAndFlush(new ChunkFetchRequest(streamChunkSlice)).addListener(listener); + channel + .writeAndFlush( + new RpcRequest( + TransportClient.requestId(), + new NioManagedBuffer( + new TransportMessage( + MessageType.CHUNK_FETCH_REQUEST, + PbChunkFetchRequest.newBuilder() + .setStreamChunkSlice(streamChunkSlice.toProto()) + .build() + .toByteArray()) + .toByteBuffer()))) + .addListener(listener); info.setChannelFuture(channelFuture); } diff --git a/common/src/main/java/org/apache/celeborn/common/network/protocol/ChunkFetchFailure.java b/common/src/main/java/org/apache/celeborn/common/network/protocol/ChunkFetchFailure.java index 3100824d6..532d0bab0 100644 --- a/common/src/main/java/org/apache/celeborn/common/network/protocol/ChunkFetchFailure.java +++ b/common/src/main/java/org/apache/celeborn/common/network/protocol/ChunkFetchFailure.java @@ -20,7 +20,9 @@ package org.apache.celeborn.common.network.protocol; import com.google.common.base.Objects; import io.netty.buffer.ByteBuf; -/** Response to {@link ChunkFetchRequest} when there is an error fetching the chunk. */ +import org.apache.celeborn.common.protocol.PbChunkFetchRequest; + +/** Response to {@link PbChunkFetchRequest} when there is an error fetching the chunk. */ public final class ChunkFetchFailure extends ResponseMessage { public final StreamChunkSlice streamChunkSlice; public final String errorString; diff --git a/common/src/main/java/org/apache/celeborn/common/network/protocol/ChunkFetchRequest.java b/common/src/main/java/org/apache/celeborn/common/network/protocol/ChunkFetchRequest.java index 2081000ca..28672eac1 100644 --- a/common/src/main/java/org/apache/celeborn/common/network/protocol/ChunkFetchRequest.java +++ b/common/src/main/java/org/apache/celeborn/common/network/protocol/ChunkFetchRequest.java @@ -24,6 +24,7 @@ import io.netty.buffer.ByteBuf; * Request to fetch a sequence of a single chunk of a stream. This will correspond to a single * {@link ResponseMessage} (either success or failure). */ +@Deprecated public final class ChunkFetchRequest extends RequestMessage { public final StreamChunkSlice streamChunkSlice; diff --git a/common/src/main/java/org/apache/celeborn/common/network/protocol/ChunkFetchSuccess.java b/common/src/main/java/org/apache/celeborn/common/network/protocol/ChunkFetchSuccess.java index baa663ce6..7d5992003 100644 --- a/common/src/main/java/org/apache/celeborn/common/network/protocol/ChunkFetchSuccess.java +++ b/common/src/main/java/org/apache/celeborn/common/network/protocol/ChunkFetchSuccess.java @@ -22,9 +22,10 @@ import io.netty.buffer.ByteBuf; import org.apache.celeborn.common.network.buffer.ManagedBuffer; import org.apache.celeborn.common.network.buffer.NettyManagedBuffer; +import org.apache.celeborn.common.protocol.PbChunkFetchRequest; /** - * Response to {@link ChunkFetchRequest} when a chunk exists and has been successfully fetched. + * Response to {@link PbChunkFetchRequest} when a chunk exists and has been successfully fetched. * *

Note that the server-side encoding of this message does NOT include the buffer itself, as this * may be written by Netty in a more efficient manner (i.e., zero-copy write). Similarly, the diff --git a/common/src/main/java/org/apache/celeborn/common/network/protocol/StreamChunkSlice.java b/common/src/main/java/org/apache/celeborn/common/network/protocol/StreamChunkSlice.java index a3918e4fd..4771faf87 100644 --- a/common/src/main/java/org/apache/celeborn/common/network/protocol/StreamChunkSlice.java +++ b/common/src/main/java/org/apache/celeborn/common/network/protocol/StreamChunkSlice.java @@ -20,6 +20,8 @@ package org.apache.celeborn.common.network.protocol; import com.google.common.base.Objects; import io.netty.buffer.ByteBuf; +import org.apache.celeborn.common.protocol.PbStreamChunkSlice; + /** Encapsulates a request for a particular chunk of a stream. */ public final class StreamChunkSlice implements Encodable { public final long streamId; @@ -90,4 +92,17 @@ public final class StreamChunkSlice implements Encodable { .add("len", len) .toString(); } + + public PbStreamChunkSlice toProto() { + return PbStreamChunkSlice.newBuilder() + .setStreamId(streamId) + .setChunkIndex(chunkIndex) + .setOffset(offset) + .setLen(len) + .build(); + } + + public static StreamChunkSlice fromProto(PbStreamChunkSlice pb) { + return new StreamChunkSlice(pb.getStreamId(), pb.getChunkIndex(), pb.getOffset(), pb.getLen()); + } } diff --git a/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java b/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java index 87e59151e..8fa07a145 100644 --- a/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java +++ b/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java @@ -31,12 +31,15 @@ import org.apache.celeborn.common.exception.CelebornIOException; import org.apache.celeborn.common.protocol.MessageType; import org.apache.celeborn.common.protocol.PbBacklogAnnouncement; import org.apache.celeborn.common.protocol.PbBufferStreamEnd; +import org.apache.celeborn.common.protocol.PbChunkFetchRequest; import org.apache.celeborn.common.protocol.PbOpenStream; import org.apache.celeborn.common.protocol.PbPushDataHandShake; import org.apache.celeborn.common.protocol.PbReadAddCredit; import org.apache.celeborn.common.protocol.PbRegionFinish; import org.apache.celeborn.common.protocol.PbRegionStart; +import org.apache.celeborn.common.protocol.PbStreamChunkSlice; import org.apache.celeborn.common.protocol.PbStreamHandler; +import org.apache.celeborn.common.protocol.PbTransportableError; public class TransportMessage implements Serializable { private static final long serialVersionUID = -3259000920699629773L; @@ -81,6 +84,12 @@ public class TransportMessage implements Serializable { return (T) PbBufferStreamEnd.parseFrom(payload); case READ_ADD_CREDIT_VALUE: return (T) PbReadAddCredit.parseFrom(payload); + case STREAM_CHUNK_SLICE_VALUE: + return (T) PbStreamChunkSlice.parseFrom(payload); + case CHUNK_FETCH_REQUEST_VALUE: + return (T) PbChunkFetchRequest.parseFrom(payload); + case TRANSPORTABLE_ERROR_VALUE: + return (T) PbTransportableError.parseFrom(payload); default: logger.error("Unexpected type {}", type); } diff --git a/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportableError.java b/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportableError.java index 8762067c1..262c4ee66 100644 --- a/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportableError.java +++ b/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportableError.java @@ -21,6 +21,7 @@ import java.nio.charset.StandardCharsets; import io.netty.buffer.ByteBuf; +import org.apache.celeborn.common.protocol.PbTransportableError; import org.apache.celeborn.common.util.ExceptionUtils; public class TransportableError extends RequestMessage { @@ -70,4 +71,9 @@ public class TransportableError extends RequestMessage { public String getErrorMessage() { return new String(errorMessage, StandardCharsets.UTF_8); } + + public static TransportableError fromProto(PbTransportableError pb) { + return new TransportableError( + pb.getStreamId(), pb.getMessage().getBytes(StandardCharsets.UTF_8)); + } } diff --git a/common/src/main/proto/TransportMessages.proto b/common/src/main/proto/TransportMessages.proto index ffeae3058..3e2e7a54c 100644 --- a/common/src/main/proto/TransportMessages.proto +++ b/common/src/main/proto/TransportMessages.proto @@ -82,6 +82,9 @@ enum MessageType { BACKLOG_ANNOUNCEMENT = 59; BUFFER_STREAM_END = 60; READ_ADD_CREDIT = 61; + STREAM_CHUNK_SLICE = 62; + CHUNK_FETCH_REQUEST = 63; + TRANSPORTABLE_ERROR = 64; } enum StreamType { @@ -551,3 +554,19 @@ message PbReadAddCredit { int64 streamId = 1; int32 credit = 2; } + +message PbStreamChunkSlice { + int64 streamId = 1; + int32 chunkIndex = 2; + int32 offset = 3; + int32 len = 4; +} + +message PbChunkFetchRequest { + PbStreamChunkSlice streamChunkSlice = 1; +} + +message PbTransportableError { + int64 streamId = 1; + string message = 2; +} diff --git a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala index 1530e9440..4b299728b 100644 --- a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala +++ b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala @@ -37,7 +37,7 @@ import org.apache.celeborn.common.network.client.TransportClient import org.apache.celeborn.common.network.protocol._ import org.apache.celeborn.common.network.server.BaseMessageHandler import org.apache.celeborn.common.network.util.{NettyUtils, TransportConf} -import org.apache.celeborn.common.protocol.{MessageType, PartitionType, PbBufferStreamEnd, PbOpenStream, PbReadAddCredit, PbStreamHandler, StreamType} +import org.apache.celeborn.common.protocol.{MessageType, PartitionType, PbBufferStreamEnd, PbChunkFetchRequest, PbOpenStream, PbReadAddCredit, PbStreamHandler, StreamType} import org.apache.celeborn.common.util.{ExceptionUtils, Utils} import org.apache.celeborn.service.deploy.worker.storage.{ChunkStreamManager, CreditStreamManager, PartitionFilesSorter, StorageManager} @@ -93,7 +93,7 @@ class FetchHandler(val conf: CelebornConf, val transportConf: TransportConf) case r: ReadAddCredit => handleReadAddCredit(r.getCredit, r.getStreamId) case r: ChunkFetchRequest => - handleChunkFetchRequest(client, r) + handleChunkFetchRequest(client, r.streamChunkSlice, r) case r: RpcRequest => handleRpcRequest(client, r) case unknown: RequestMessage => @@ -125,9 +125,14 @@ class FetchHandler(val conf: CelebornConf, val transportConf: TransportConf) isLegacy = false, openStream.getReadLocalShuffle) case bufferStreamEnd: PbBufferStreamEnd => - handleEndStreamFromClient(bufferStreamEnd) + handleEndStreamFromClient(bufferStreamEnd.getStreamId, bufferStreamEnd.getStreamType) case readAddCredit: PbReadAddCredit => handleReadAddCredit(readAddCredit.getCredit, readAddCredit.getStreamId) + case chunkFetchRequest: PbChunkFetchRequest => + handleChunkFetchRequest( + client, + StreamChunkSlice.fromProto(chunkFetchRequest.getStreamChunkSlice), + rpcRequest) case message: GeneratedMessageV3 => logError(s"Unknown message $message") } @@ -318,18 +323,18 @@ class FetchHandler(val conf: CelebornConf, val transportConf: TransportConf) } def handleEndStreamFromClient(streamId: Long): Unit = { - creditStreamManager.notifyStreamEndByClient(streamId) + handleEndStreamFromClient(streamId, StreamType.CreditStream) } - def handleEndStreamFromClient(req: PbBufferStreamEnd): Unit = { - req.getStreamType match { + def handleEndStreamFromClient(streamId: Long, streamType: StreamType): Unit = { + streamType match { case StreamType.ChunkStream => - val (shuffleKey, fileName) = chunkStreamManager.getShuffleKeyAndFileName(req.getStreamId) - getRawFileInfo(shuffleKey, fileName).closeStream(req.getStreamId) + val (shuffleKey, fileName) = chunkStreamManager.getShuffleKeyAndFileName(streamId) + getRawFileInfo(shuffleKey, fileName).closeStream(streamId) case StreamType.CreditStream => - creditStreamManager.notifyStreamEndByClient(req.getStreamId) + creditStreamManager.notifyStreamEndByClient(streamId) case _ => - logError(s"Received a PbBufferStreamEnd message with unknown type ${req.getStreamType}") + logError(s"Received a PbBufferStreamEnd message with unknown type $streamType") } } @@ -337,9 +342,12 @@ class FetchHandler(val conf: CelebornConf, val transportConf: TransportConf) creditStreamManager.addCredit(credit, streamId) } - def handleChunkFetchRequest(client: TransportClient, req: ChunkFetchRequest): Unit = { + def handleChunkFetchRequest( + client: TransportClient, + streamChunkSlice: StreamChunkSlice, + req: RequestMessage): Unit = { logDebug(s"Received req from ${NettyUtils.getRemoteAddress(client.getChannel)}" + - s" to fetch block ${req.streamChunkSlice}") + s" to fetch block $streamChunkSlice") maxChunkBeingTransferred.foreach { threshold => val chunksBeingTransferred = chunkStreamManager.chunksBeingTransferred // take high cpu usage @@ -348,35 +356,35 @@ class FetchHandler(val conf: CelebornConf, val transportConf: TransportConf) s"$chunksBeingTransferred exceeds ${MAX_CHUNKS_BEING_TRANSFERRED.key} " + s"${Utils.bytesToString(threshold)}." logError(message) - client.getChannel.writeAndFlush(new ChunkFetchFailure(req.streamChunkSlice, message)) + client.getChannel.writeAndFlush(new ChunkFetchFailure(streamChunkSlice, message)) return } } workerSource.startTimer(WorkerSource.FETCH_CHUNK_TIME, req.toString) - val fetchTimeMetric = chunkStreamManager.getFetchTimeMetric(req.streamChunkSlice.streamId) + val fetchTimeMetric = chunkStreamManager.getFetchTimeMetric(streamChunkSlice.streamId) val fetchBeginTime = System.nanoTime() try { val buf = chunkStreamManager.getChunk( - req.streamChunkSlice.streamId, - req.streamChunkSlice.chunkIndex, - req.streamChunkSlice.offset, - req.streamChunkSlice.len) - chunkStreamManager.chunkBeingSent(req.streamChunkSlice.streamId) - client.getChannel.writeAndFlush(new ChunkFetchSuccess(req.streamChunkSlice, buf)) + streamChunkSlice.streamId, + streamChunkSlice.chunkIndex, + streamChunkSlice.offset, + streamChunkSlice.len) + chunkStreamManager.chunkBeingSent(streamChunkSlice.streamId) + client.getChannel.writeAndFlush(new ChunkFetchSuccess(streamChunkSlice, buf)) .addListener(new GenericFutureListener[Future[_ >: Void]] { override def operationComplete(future: Future[_ >: Void]): Unit = { - if (future.isSuccess()) { + if (future.isSuccess) { if (log.isDebugEnabled) { logDebug( - s"Sending ChunkFetchSuccess operation succeeded, chunk ${req.streamChunkSlice}") + s"Sending ChunkFetchSuccess operation succeeded, chunk $streamChunkSlice") } } else { logError( - s"Sending ChunkFetchSuccess operation failed, chunk ${req.streamChunkSlice}", + s"Sending ChunkFetchSuccess operation failed, chunk $streamChunkSlice", future.cause()) } - chunkStreamManager.chunkSent(req.streamChunkSlice.streamId) + chunkStreamManager.chunkSent(streamChunkSlice.streamId) if (fetchTimeMetric != null) { fetchTimeMetric.update(System.nanoTime() - fetchBeginTime) } @@ -386,11 +394,11 @@ class FetchHandler(val conf: CelebornConf, val transportConf: TransportConf) } catch { case e: Exception => logError( - s"Error opening block ${req.streamChunkSlice} for request from " + + s"Error opening block $streamChunkSlice for request from " + NettyUtils.getRemoteAddress(client.getChannel), e) client.getChannel.writeAndFlush(new ChunkFetchFailure( - req.streamChunkSlice, + streamChunkSlice, Throwables.getStackTraceAsString(e))) workerSource.stopTimer(WorkerSource.FETCH_CHUNK_TIME, req.toString) } diff --git a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/FetchHandlerSuiteJ.java b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/FetchHandlerSuiteJ.java index 99a188ad3..7234c21fa 100644 --- a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/FetchHandlerSuiteJ.java +++ b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/FetchHandlerSuiteJ.java @@ -47,19 +47,19 @@ import org.apache.celeborn.common.meta.FileInfo; import org.apache.celeborn.common.network.buffer.NioManagedBuffer; import org.apache.celeborn.common.network.client.TransportClient; import org.apache.celeborn.common.network.client.TransportResponseHandler; -import org.apache.celeborn.common.network.protocol.ChunkFetchRequest; import org.apache.celeborn.common.network.protocol.ChunkFetchSuccess; import org.apache.celeborn.common.network.protocol.Message; import org.apache.celeborn.common.network.protocol.OpenStream; import org.apache.celeborn.common.network.protocol.RpcRequest; import org.apache.celeborn.common.network.protocol.RpcResponse; -import org.apache.celeborn.common.network.protocol.StreamChunkSlice; import org.apache.celeborn.common.network.protocol.StreamHandle; import org.apache.celeborn.common.network.protocol.TransportMessage; import org.apache.celeborn.common.network.util.TransportConf; import org.apache.celeborn.common.protocol.MessageType; import org.apache.celeborn.common.protocol.PbBufferStreamEnd; +import org.apache.celeborn.common.protocol.PbChunkFetchRequest; import org.apache.celeborn.common.protocol.PbOpenStream; +import org.apache.celeborn.common.protocol.PbStreamChunkSlice; import org.apache.celeborn.common.protocol.PbStreamHandler; import org.apache.celeborn.common.protocol.StreamType; import org.apache.celeborn.common.protocol.TransportModuleConstants; @@ -339,8 +339,21 @@ public class FetchHandlerSuiteJ { for (int chunkIndex = 0; chunkIndex < streamHandler.getNumChunks(); chunkIndex++) { fetchHandler.receive( client, - new ChunkFetchRequest( - new StreamChunkSlice(streamHandler.getStreamId(), chunkIndex, 0, Integer.MAX_VALUE))); + new RpcRequest( + TransportClient.requestId(), + new NioManagedBuffer( + new TransportMessage( + MessageType.CHUNK_FETCH_REQUEST, + PbChunkFetchRequest.newBuilder() + .setStreamChunkSlice( + PbStreamChunkSlice.newBuilder() + .setStreamId(streamHandler.getStreamId()) + .setChunkIndex(chunkIndex) + .setOffset(0) + .setLen(Integer.MAX_VALUE)) + .build() + .toByteArray()) + .toByteBuffer()))); ChunkFetchSuccess chunkFetchSuccess = channel.readOutbound(); chunkFetchSuccess.body().retain(); // chunk size 8m diff --git a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/network/RequestTimeoutIntegrationSuiteJ.java b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/network/RequestTimeoutIntegrationSuiteJ.java index 11223a0d3..85ce86367 100644 --- a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/network/RequestTimeoutIntegrationSuiteJ.java +++ b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/network/RequestTimeoutIntegrationSuiteJ.java @@ -39,10 +39,16 @@ import org.apache.celeborn.common.network.client.ChunkReceivedCallback; import org.apache.celeborn.common.network.client.RpcResponseCallback; import org.apache.celeborn.common.network.client.TransportClient; import org.apache.celeborn.common.network.client.TransportClientFactory; -import org.apache.celeborn.common.network.protocol.*; +import org.apache.celeborn.common.network.protocol.ChunkFetchSuccess; +import org.apache.celeborn.common.network.protocol.RequestMessage; +import org.apache.celeborn.common.network.protocol.RpcRequest; +import org.apache.celeborn.common.network.protocol.RpcResponse; +import org.apache.celeborn.common.network.protocol.StreamChunkSlice; +import org.apache.celeborn.common.network.protocol.TransportMessage; import org.apache.celeborn.common.network.server.BaseMessageHandler; import org.apache.celeborn.common.network.server.TransportServer; import org.apache.celeborn.common.network.util.TransportConf; +import org.apache.celeborn.common.protocol.PbChunkFetchRequest; import org.apache.celeborn.service.deploy.worker.storage.ChunkStreamManager; /** @@ -196,7 +202,15 @@ public class RequestTimeoutIntegrationSuiteJ { new BaseMessageHandler() { @Override public void receive(TransportClient client, RequestMessage msg) { - StreamChunkSlice slice = ((ChunkFetchRequest) msg).streamChunkSlice; + PbChunkFetchRequest chunkFetchRequest; + try { + chunkFetchRequest = + TransportMessage.fromByteBuffer(msg.body().nioByteBuffer()).getParsedPayload(); + } catch (IOException e) { + throw new RuntimeException(e); + } + StreamChunkSlice slice = + StreamChunkSlice.fromProto(chunkFetchRequest.getStreamChunkSlice()); ManagedBuffer buf = manager.getChunk(slice.streamId, slice.chunkIndex, slice.offset, slice.len); client.getChannel().writeAndFlush(new ChunkFetchSuccess(slice, buf)); diff --git a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/ChunkFetchIntegrationSuiteJ.java b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/ChunkFetchIntegrationSuiteJ.java index 9c3fd7c32..6a1df6747 100644 --- a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/ChunkFetchIntegrationSuiteJ.java +++ b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/ChunkFetchIntegrationSuiteJ.java @@ -21,6 +21,7 @@ import static org.apache.celeborn.common.util.JavaUtils.getLocalHost; import static org.junit.Assert.*; import java.io.File; +import java.io.IOException; import java.io.RandomAccessFile; import java.nio.ByteBuffer; import java.util.*; @@ -41,13 +42,14 @@ import org.apache.celeborn.common.network.buffer.NioManagedBuffer; import org.apache.celeborn.common.network.client.ChunkReceivedCallback; import org.apache.celeborn.common.network.client.TransportClient; import org.apache.celeborn.common.network.client.TransportClientFactory; -import org.apache.celeborn.common.network.protocol.ChunkFetchRequest; import org.apache.celeborn.common.network.protocol.ChunkFetchSuccess; import org.apache.celeborn.common.network.protocol.RequestMessage; import org.apache.celeborn.common.network.protocol.StreamChunkSlice; +import org.apache.celeborn.common.network.protocol.TransportMessage; import org.apache.celeborn.common.network.server.BaseMessageHandler; import org.apache.celeborn.common.network.server.TransportServer; import org.apache.celeborn.common.network.util.TransportConf; +import org.apache.celeborn.common.protocol.PbChunkFetchRequest; public class ChunkFetchIntegrationSuiteJ { static final long STREAM_ID = 1; @@ -106,7 +108,15 @@ public class ChunkFetchIntegrationSuiteJ { new BaseMessageHandler() { @Override public void receive(TransportClient client, RequestMessage msg) { - StreamChunkSlice slice = ((ChunkFetchRequest) msg).streamChunkSlice; + PbChunkFetchRequest chunkFetchRequest; + try { + chunkFetchRequest = + TransportMessage.fromByteBuffer(msg.body().nioByteBuffer()).getParsedPayload(); + } catch (IOException e) { + throw new RuntimeException(e); + } + StreamChunkSlice slice = + StreamChunkSlice.fromProto(chunkFetchRequest.getStreamChunkSlice()); ManagedBuffer buf = chunkStreamManager.getChunk( slice.streamId, slice.chunkIndex, slice.offset, slice.len);