[CELEBORN-278] Add openStreamWithCredit RPC. (#1214)
This commit is contained in:
parent
2c508dae0f
commit
534853bf8a
@ -86,7 +86,8 @@ public abstract class Message implements Encodable {
|
||||
REGION_FINISH(14),
|
||||
PUSH_DATA_HAND_SHAKE(15),
|
||||
READ_ADD_CREDIT(16),
|
||||
READ_DATA(17);
|
||||
READ_DATA(17),
|
||||
OPEN_STREAM_WITH_CREDIT(18);
|
||||
|
||||
private final byte id;
|
||||
|
||||
@ -144,6 +145,8 @@ public abstract class Message implements Encodable {
|
||||
return READ_ADD_CREDIT;
|
||||
case 17:
|
||||
return READ_DATA;
|
||||
case 18:
|
||||
return OPEN_STREAM_WITH_CREDIT;
|
||||
case -1:
|
||||
throw new IllegalArgumentException("User type messages cannot be decoded.");
|
||||
default:
|
||||
@ -206,6 +209,9 @@ public abstract class Message implements Encodable {
|
||||
case READ_DATA:
|
||||
return ReadData.decode(in);
|
||||
|
||||
case OPEN_STREAM_WITH_CREDIT:
|
||||
return OpenStreamWithCredit.decode(in);
|
||||
|
||||
default:
|
||||
throw new IllegalArgumentException("Unexpected message type: " + msgType);
|
||||
}
|
||||
|
||||
@ -0,0 +1,87 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.celeborn.common.network.protocol;
|
||||
|
||||
import static org.apache.celeborn.common.network.protocol.Message.Type.OPEN_STREAM_WITH_CREDIT;
|
||||
|
||||
import java.nio.charset.StandardCharsets;
|
||||
|
||||
import io.netty.buffer.ByteBuf;
|
||||
|
||||
/** Buffer stream used in Map partition scenario. */
|
||||
public final class OpenStreamWithCredit extends RequestMessage {
|
||||
public final byte[] shuffleKey;
|
||||
public final byte[] fileName;
|
||||
public final int startIndex;
|
||||
public final int endIndex;
|
||||
public final int initialCredit;
|
||||
|
||||
public OpenStreamWithCredit(
|
||||
byte[] shuffleKey, byte[] fileName, int startIndex, int endIndex, int initialCredit) {
|
||||
this.shuffleKey = shuffleKey;
|
||||
this.fileName = fileName;
|
||||
this.startIndex = startIndex;
|
||||
this.endIndex = endIndex;
|
||||
this.initialCredit = initialCredit;
|
||||
}
|
||||
|
||||
public OpenStreamWithCredit(
|
||||
String shuffleKey, String fileName, int startIndex, int endIndex, int initialCredit) {
|
||||
this(
|
||||
shuffleKey.getBytes(StandardCharsets.UTF_8),
|
||||
fileName.getBytes(StandardCharsets.UTF_8),
|
||||
startIndex,
|
||||
endIndex,
|
||||
initialCredit);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int encodedLength() {
|
||||
return 4 + shuffleKey.length + 4 + fileName.length + 4 + 4 + 4;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void encode(ByteBuf buf) {
|
||||
buf.writeInt(shuffleKey.length);
|
||||
buf.writeBytes(shuffleKey);
|
||||
buf.writeInt(fileName.length);
|
||||
buf.writeBytes(fileName);
|
||||
buf.writeInt(startIndex);
|
||||
buf.writeInt(endIndex);
|
||||
buf.writeInt(initialCredit);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Message.Type type() {
|
||||
return OPEN_STREAM_WITH_CREDIT;
|
||||
}
|
||||
|
||||
public static OpenStreamWithCredit decode(ByteBuf in) {
|
||||
int shuffleKeyLength = in.readInt();
|
||||
byte[] tmpShuffleKey = new byte[shuffleKeyLength];
|
||||
in.readBytes(tmpShuffleKey);
|
||||
int fileNameLength = in.readInt();
|
||||
byte[] tmpFileName = new byte[fileNameLength];
|
||||
in.readBytes(tmpFileName);
|
||||
int startSubIndex = in.readInt();
|
||||
int endSubIndex = in.readInt();
|
||||
int initialCredit = in.readInt();
|
||||
return new OpenStreamWithCredit(
|
||||
tmpShuffleKey, tmpFileName, startSubIndex, endSubIndex, initialCredit);
|
||||
}
|
||||
}
|
||||
@ -33,6 +33,7 @@ import org.apache.celeborn.common.metrics.source.RPCSource
|
||||
import org.apache.celeborn.common.network.buffer.NioManagedBuffer
|
||||
import org.apache.celeborn.common.network.client.TransportClient
|
||||
import org.apache.celeborn.common.network.protocol._
|
||||
import org.apache.celeborn.common.network.protocol.Message.Type
|
||||
import org.apache.celeborn.common.network.server.{BaseMessageHandler, BufferStreamManager, ChunkStreamManager}
|
||||
import org.apache.celeborn.common.network.util.{NettyUtils, TransportConf}
|
||||
import org.apache.celeborn.common.protocol.PartitionType
|
||||
@ -85,17 +86,26 @@ class FetchHandler(val conf: TransportConf) extends BaseMessageHandler with Logg
|
||||
|
||||
def handleOpenStream(client: TransportClient, request: RpcRequest): Unit = {
|
||||
val msg = Message.decode(request.body().nioByteBuffer())
|
||||
val openBlocks = msg.asInstanceOf[OpenStream]
|
||||
val shuffleKey = new String(openBlocks.shuffleKey, StandardCharsets.UTF_8)
|
||||
val fileName = new String(openBlocks.fileName, StandardCharsets.UTF_8)
|
||||
val startMapIndex = openBlocks.startMapIndex
|
||||
val endMapIndex = openBlocks.endMapIndex
|
||||
val (shuffleKey, fileName) =
|
||||
if (msg.`type`() == Type.OPEN_STREAM) {
|
||||
val openStream = msg.asInstanceOf[OpenStream]
|
||||
(
|
||||
new String(openStream.shuffleKey, StandardCharsets.UTF_8),
|
||||
new String(openStream.fileName, StandardCharsets.UTF_8))
|
||||
} else {
|
||||
val openStreamWithCredit = msg.asInstanceOf[OpenStreamWithCredit]
|
||||
(
|
||||
new String(openStreamWithCredit.shuffleKey, StandardCharsets.UTF_8),
|
||||
new String(openStreamWithCredit.fileName, StandardCharsets.UTF_8))
|
||||
}
|
||||
// metrics start
|
||||
workerSource.startTimer(WorkerSource.OpenStreamTime, shuffleKey)
|
||||
try {
|
||||
var fileInfo = getRawFileInfo(shuffleKey, fileName)
|
||||
try fileInfo.getPartitionType() match {
|
||||
case PartitionType.REDUCE =>
|
||||
val startMapIndex = msg.asInstanceOf[OpenStream].startMapIndex
|
||||
val endMapIndex = msg.asInstanceOf[OpenStream].endMapIndex
|
||||
if (endMapIndex != Integer.MAX_VALUE) {
|
||||
fileInfo = partitionsSorter.getSortedFileInfo(
|
||||
shuffleKey,
|
||||
@ -127,6 +137,8 @@ class FetchHandler(val conf: TransportConf) extends BaseMessageHandler with Logg
|
||||
}
|
||||
case PartitionType.MAP =>
|
||||
// return stream id
|
||||
val startIndex = msg.asInstanceOf[OpenStreamWithCredit].startIndex
|
||||
val endIndex = msg.asInstanceOf[OpenStreamWithCredit].endIndex
|
||||
val streamId =
|
||||
bufferStreamManager.registerStream(client.getChannel, fileInfo.getBufferSize)
|
||||
val res = ByteBuffer.allocate(8)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user