[CELEBORN-278] Add openStreamWithCredit RPC. (#1214)

This commit is contained in:
Ethan Feng 2023-02-16 14:07:13 +08:00 committed by GitHub
parent 2c508dae0f
commit 534853bf8a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 111 additions and 6 deletions

View File

@ -86,7 +86,8 @@ public abstract class Message implements Encodable {
REGION_FINISH(14),
PUSH_DATA_HAND_SHAKE(15),
READ_ADD_CREDIT(16),
READ_DATA(17);
READ_DATA(17),
OPEN_STREAM_WITH_CREDIT(18);
private final byte id;
@ -144,6 +145,8 @@ public abstract class Message implements Encodable {
return READ_ADD_CREDIT;
case 17:
return READ_DATA;
case 18:
return OPEN_STREAM_WITH_CREDIT;
case -1:
throw new IllegalArgumentException("User type messages cannot be decoded.");
default:
@ -206,6 +209,9 @@ public abstract class Message implements Encodable {
case READ_DATA:
return ReadData.decode(in);
case OPEN_STREAM_WITH_CREDIT:
return OpenStreamWithCredit.decode(in);
default:
throw new IllegalArgumentException("Unexpected message type: " + msgType);
}

View File

@ -0,0 +1,87 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.celeborn.common.network.protocol;
import static org.apache.celeborn.common.network.protocol.Message.Type.OPEN_STREAM_WITH_CREDIT;
import java.nio.charset.StandardCharsets;
import io.netty.buffer.ByteBuf;
/** Buffer stream used in Map partition scenario. */
public final class OpenStreamWithCredit extends RequestMessage {
public final byte[] shuffleKey;
public final byte[] fileName;
public final int startIndex;
public final int endIndex;
public final int initialCredit;
public OpenStreamWithCredit(
byte[] shuffleKey, byte[] fileName, int startIndex, int endIndex, int initialCredit) {
this.shuffleKey = shuffleKey;
this.fileName = fileName;
this.startIndex = startIndex;
this.endIndex = endIndex;
this.initialCredit = initialCredit;
}
public OpenStreamWithCredit(
String shuffleKey, String fileName, int startIndex, int endIndex, int initialCredit) {
this(
shuffleKey.getBytes(StandardCharsets.UTF_8),
fileName.getBytes(StandardCharsets.UTF_8),
startIndex,
endIndex,
initialCredit);
}
@Override
public int encodedLength() {
return 4 + shuffleKey.length + 4 + fileName.length + 4 + 4 + 4;
}
@Override
public void encode(ByteBuf buf) {
buf.writeInt(shuffleKey.length);
buf.writeBytes(shuffleKey);
buf.writeInt(fileName.length);
buf.writeBytes(fileName);
buf.writeInt(startIndex);
buf.writeInt(endIndex);
buf.writeInt(initialCredit);
}
@Override
public Message.Type type() {
return OPEN_STREAM_WITH_CREDIT;
}
public static OpenStreamWithCredit decode(ByteBuf in) {
int shuffleKeyLength = in.readInt();
byte[] tmpShuffleKey = new byte[shuffleKeyLength];
in.readBytes(tmpShuffleKey);
int fileNameLength = in.readInt();
byte[] tmpFileName = new byte[fileNameLength];
in.readBytes(tmpFileName);
int startSubIndex = in.readInt();
int endSubIndex = in.readInt();
int initialCredit = in.readInt();
return new OpenStreamWithCredit(
tmpShuffleKey, tmpFileName, startSubIndex, endSubIndex, initialCredit);
}
}

View File

@ -33,6 +33,7 @@ import org.apache.celeborn.common.metrics.source.RPCSource
import org.apache.celeborn.common.network.buffer.NioManagedBuffer
import org.apache.celeborn.common.network.client.TransportClient
import org.apache.celeborn.common.network.protocol._
import org.apache.celeborn.common.network.protocol.Message.Type
import org.apache.celeborn.common.network.server.{BaseMessageHandler, BufferStreamManager, ChunkStreamManager}
import org.apache.celeborn.common.network.util.{NettyUtils, TransportConf}
import org.apache.celeborn.common.protocol.PartitionType
@ -85,17 +86,26 @@ class FetchHandler(val conf: TransportConf) extends BaseMessageHandler with Logg
def handleOpenStream(client: TransportClient, request: RpcRequest): Unit = {
val msg = Message.decode(request.body().nioByteBuffer())
val openBlocks = msg.asInstanceOf[OpenStream]
val shuffleKey = new String(openBlocks.shuffleKey, StandardCharsets.UTF_8)
val fileName = new String(openBlocks.fileName, StandardCharsets.UTF_8)
val startMapIndex = openBlocks.startMapIndex
val endMapIndex = openBlocks.endMapIndex
val (shuffleKey, fileName) =
if (msg.`type`() == Type.OPEN_STREAM) {
val openStream = msg.asInstanceOf[OpenStream]
(
new String(openStream.shuffleKey, StandardCharsets.UTF_8),
new String(openStream.fileName, StandardCharsets.UTF_8))
} else {
val openStreamWithCredit = msg.asInstanceOf[OpenStreamWithCredit]
(
new String(openStreamWithCredit.shuffleKey, StandardCharsets.UTF_8),
new String(openStreamWithCredit.fileName, StandardCharsets.UTF_8))
}
// metrics start
workerSource.startTimer(WorkerSource.OpenStreamTime, shuffleKey)
try {
var fileInfo = getRawFileInfo(shuffleKey, fileName)
try fileInfo.getPartitionType() match {
case PartitionType.REDUCE =>
val startMapIndex = msg.asInstanceOf[OpenStream].startMapIndex
val endMapIndex = msg.asInstanceOf[OpenStream].endMapIndex
if (endMapIndex != Integer.MAX_VALUE) {
fileInfo = partitionsSorter.getSortedFileInfo(
shuffleKey,
@ -127,6 +137,8 @@ class FetchHandler(val conf: TransportConf) extends BaseMessageHandler with Logg
}
case PartitionType.MAP =>
// return stream id
val startIndex = msg.asInstanceOf[OpenStreamWithCredit].startIndex
val endIndex = msg.asInstanceOf[OpenStreamWithCredit].endIndex
val streamId =
bufferStreamManager.registerStream(client.getChannel, fileInfo.getBufferSize)
val res = ByteBuffer.allocate(8)