diff --git a/common/src/main/java/org/apache/celeborn/common/network/protocol/Message.java b/common/src/main/java/org/apache/celeborn/common/network/protocol/Message.java index 7ba54cbd4..c82feb6d1 100644 --- a/common/src/main/java/org/apache/celeborn/common/network/protocol/Message.java +++ b/common/src/main/java/org/apache/celeborn/common/network/protocol/Message.java @@ -86,7 +86,8 @@ public abstract class Message implements Encodable { REGION_FINISH(14), PUSH_DATA_HAND_SHAKE(15), READ_ADD_CREDIT(16), - READ_DATA(17); + READ_DATA(17), + OPEN_STREAM_WITH_CREDIT(18); private final byte id; @@ -144,6 +145,8 @@ public abstract class Message implements Encodable { return READ_ADD_CREDIT; case 17: return READ_DATA; + case 18: + return OPEN_STREAM_WITH_CREDIT; case -1: throw new IllegalArgumentException("User type messages cannot be decoded."); default: @@ -206,6 +209,9 @@ public abstract class Message implements Encodable { case READ_DATA: return ReadData.decode(in); + case OPEN_STREAM_WITH_CREDIT: + return OpenStreamWithCredit.decode(in); + default: throw new IllegalArgumentException("Unexpected message type: " + msgType); } diff --git a/common/src/main/java/org/apache/celeborn/common/network/protocol/OpenStreamWithCredit.java b/common/src/main/java/org/apache/celeborn/common/network/protocol/OpenStreamWithCredit.java new file mode 100644 index 000000000..d59d86dfc --- /dev/null +++ b/common/src/main/java/org/apache/celeborn/common/network/protocol/OpenStreamWithCredit.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.common.network.protocol; + +import static org.apache.celeborn.common.network.protocol.Message.Type.OPEN_STREAM_WITH_CREDIT; + +import java.nio.charset.StandardCharsets; + +import io.netty.buffer.ByteBuf; + +/** Buffer stream used in Map partition scenario. */ +public final class OpenStreamWithCredit extends RequestMessage { + public final byte[] shuffleKey; + public final byte[] fileName; + public final int startIndex; + public final int endIndex; + public final int initialCredit; + + public OpenStreamWithCredit( + byte[] shuffleKey, byte[] fileName, int startIndex, int endIndex, int initialCredit) { + this.shuffleKey = shuffleKey; + this.fileName = fileName; + this.startIndex = startIndex; + this.endIndex = endIndex; + this.initialCredit = initialCredit; + } + + public OpenStreamWithCredit( + String shuffleKey, String fileName, int startIndex, int endIndex, int initialCredit) { + this( + shuffleKey.getBytes(StandardCharsets.UTF_8), + fileName.getBytes(StandardCharsets.UTF_8), + startIndex, + endIndex, + initialCredit); + } + + @Override + public int encodedLength() { + return 4 + shuffleKey.length + 4 + fileName.length + 4 + 4 + 4; + } + + @Override + public void encode(ByteBuf buf) { + buf.writeInt(shuffleKey.length); + buf.writeBytes(shuffleKey); + buf.writeInt(fileName.length); + buf.writeBytes(fileName); + buf.writeInt(startIndex); + buf.writeInt(endIndex); + buf.writeInt(initialCredit); + } + + @Override + public Message.Type type() { + return OPEN_STREAM_WITH_CREDIT; + } + + public static OpenStreamWithCredit decode(ByteBuf in) { + int shuffleKeyLength = in.readInt(); + byte[] tmpShuffleKey = new byte[shuffleKeyLength]; + in.readBytes(tmpShuffleKey); + int fileNameLength = in.readInt(); + byte[] tmpFileName = new byte[fileNameLength]; + in.readBytes(tmpFileName); + int startSubIndex = in.readInt(); + int endSubIndex = in.readInt(); + int initialCredit = in.readInt(); + return new OpenStreamWithCredit( + tmpShuffleKey, tmpFileName, startSubIndex, endSubIndex, initialCredit); + } +} diff --git a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala index 6db905159..36f1629a4 100644 --- a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala +++ b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala @@ -33,6 +33,7 @@ import org.apache.celeborn.common.metrics.source.RPCSource import org.apache.celeborn.common.network.buffer.NioManagedBuffer import org.apache.celeborn.common.network.client.TransportClient import org.apache.celeborn.common.network.protocol._ +import org.apache.celeborn.common.network.protocol.Message.Type import org.apache.celeborn.common.network.server.{BaseMessageHandler, BufferStreamManager, ChunkStreamManager} import org.apache.celeborn.common.network.util.{NettyUtils, TransportConf} import org.apache.celeborn.common.protocol.PartitionType @@ -85,17 +86,26 @@ class FetchHandler(val conf: TransportConf) extends BaseMessageHandler with Logg def handleOpenStream(client: TransportClient, request: RpcRequest): Unit = { val msg = Message.decode(request.body().nioByteBuffer()) - val openBlocks = msg.asInstanceOf[OpenStream] - val shuffleKey = new String(openBlocks.shuffleKey, StandardCharsets.UTF_8) - val fileName = new String(openBlocks.fileName, StandardCharsets.UTF_8) - val startMapIndex = openBlocks.startMapIndex - val endMapIndex = openBlocks.endMapIndex + val (shuffleKey, fileName) = + if (msg.`type`() == Type.OPEN_STREAM) { + val openStream = msg.asInstanceOf[OpenStream] + ( + new String(openStream.shuffleKey, StandardCharsets.UTF_8), + new String(openStream.fileName, StandardCharsets.UTF_8)) + } else { + val openStreamWithCredit = msg.asInstanceOf[OpenStreamWithCredit] + ( + new String(openStreamWithCredit.shuffleKey, StandardCharsets.UTF_8), + new String(openStreamWithCredit.fileName, StandardCharsets.UTF_8)) + } // metrics start workerSource.startTimer(WorkerSource.OpenStreamTime, shuffleKey) try { var fileInfo = getRawFileInfo(shuffleKey, fileName) try fileInfo.getPartitionType() match { case PartitionType.REDUCE => + val startMapIndex = msg.asInstanceOf[OpenStream].startMapIndex + val endMapIndex = msg.asInstanceOf[OpenStream].endMapIndex if (endMapIndex != Integer.MAX_VALUE) { fileInfo = partitionsSorter.getSortedFileInfo( shuffleKey, @@ -127,6 +137,8 @@ class FetchHandler(val conf: TransportConf) extends BaseMessageHandler with Logg } case PartitionType.MAP => // return stream id + val startIndex = msg.asInstanceOf[OpenStreamWithCredit].startIndex + val endIndex = msg.asInstanceOf[OpenStreamWithCredit].endIndex val streamId = bufferStreamManager.registerStream(client.getChannel, fileInfo.getBufferSize) val res = ByteBuffer.allocate(8)