diff --git a/common/src/main/java/org/apache/celeborn/common/meta/FileInfo.java b/common/src/main/java/org/apache/celeborn/common/meta/FileInfo.java index 50d3d8bb9..6eefb83c3 100644 --- a/common/src/main/java/org/apache/celeborn/common/meta/FileInfo.java +++ b/common/src/main/java/org/apache/celeborn/common/meta/FileInfo.java @@ -182,4 +182,8 @@ public class FileInfo { public void setBufferSize(int bufferSize) { this.bufferSize = bufferSize; } + + public int getBufferSize() { + return bufferSize; + } } diff --git a/common/src/main/java/org/apache/celeborn/common/network/protocol/Message.java b/common/src/main/java/org/apache/celeborn/common/network/protocol/Message.java index 020614264..7ba54cbd4 100644 --- a/common/src/main/java/org/apache/celeborn/common/network/protocol/Message.java +++ b/common/src/main/java/org/apache/celeborn/common/network/protocol/Message.java @@ -84,7 +84,9 @@ public abstract class Message implements Encodable { PUSH_MERGED_DATA(12), REGION_START(13), REGION_FINISH(14), - PUSH_DATA_HAND_SHAKE(15); + PUSH_DATA_HAND_SHAKE(15), + READ_ADD_CREDIT(16), + READ_DATA(17); private final byte id; @@ -138,6 +140,10 @@ public abstract class Message implements Encodable { return REGION_FINISH; case 15: return PUSH_DATA_HAND_SHAKE; + case 16: + return READ_ADD_CREDIT; + case 17: + return READ_DATA; case -1: throw new IllegalArgumentException("User type messages cannot be decoded."); default: @@ -193,6 +199,13 @@ public abstract class Message implements Encodable { case PUSH_DATA_HAND_SHAKE: return PushDataHandShake.decode(in); + + case READ_ADD_CREDIT: + return ReadAddCredit.decode(in); + + case READ_DATA: + return ReadData.decode(in); + default: throw new IllegalArgumentException("Unexpected message type: " + msgType); } diff --git a/common/src/main/java/org/apache/celeborn/common/network/protocol/ReadAddCredit.java b/common/src/main/java/org/apache/celeborn/common/network/protocol/ReadAddCredit.java new file mode 100644 index 000000000..ca34a5c17 --- /dev/null +++ b/common/src/main/java/org/apache/celeborn/common/network/protocol/ReadAddCredit.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.celeborn.common.network.protocol; + +import java.util.Objects; + +import io.netty.buffer.ByteBuf; + +public class ReadAddCredit extends RequestMessage { + private long streamId; + private int credit; + + public ReadAddCredit(long streamId, int credit) { + this.streamId = streamId; + this.credit = credit; + } + + @Override + public int encodedLength() { + return 8 + 4; + } + + @Override + public void encode(ByteBuf buf) { + buf.writeLong(streamId); + buf.writeInt(credit); + } + + public static ReadAddCredit decode(ByteBuf buf) { + long streamId = buf.readLong(); + int credit = buf.readInt(); + return new ReadAddCredit(streamId, credit); + } + + public long getStreamId() { + return streamId; + } + + public int getCredit() { + return credit; + } + + @Override + public Type type() { + return Type.READ_ADD_CREDIT; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ReadAddCredit that = (ReadAddCredit) o; + return streamId == that.streamId && credit == that.credit; + } + + @Override + public int hashCode() { + return Objects.hash(streamId, credit); + } + + @Override + public String toString() { + return "ReadAddCredit{" + "streamId=" + streamId + ", credit=" + credit + '}'; + } +} diff --git a/common/src/main/java/org/apache/celeborn/common/network/protocol/ReadData.java b/common/src/main/java/org/apache/celeborn/common/network/protocol/ReadData.java new file mode 100644 index 000000000..92e107ccf --- /dev/null +++ b/common/src/main/java/org/apache/celeborn/common/network/protocol/ReadData.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.celeborn.common.network.protocol; + +import java.util.Objects; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; + +public class ReadData extends RequestMessage { + private long streamId; + private int backlog; + private long offset; + private ByteBuf buf; + + public ReadData(long streamId, int backlog, long offset, ByteBuf buf) { + this.streamId = streamId; + this.backlog = backlog; + this.offset = offset; + this.buf = buf; + } + + @Override + public int encodedLength() { + return 8 + 4 + 4 + 8 + 4 + buf.readableBytes(); + } + + @Override + public void encode(ByteBuf buf) { + buf.writeLong(streamId); + buf.writeInt(backlog); + buf.writeLong(offset); + buf.writeInt(this.buf.readableBytes()); + buf.writeBytes(this.buf); + } + + public static ReadData decode(ByteBuf buf) { + long streamId = buf.readLong(); + int backlog = buf.readInt(); + long offset = buf.readLong(); + int tmpBufSize = buf.readInt(); + ByteBuf tmpBuf = Unpooled.buffer(tmpBufSize, tmpBufSize); + buf.readBytes(tmpBuf); + return new ReadData(streamId, backlog, offset, tmpBuf); + } + + public long getStreamId() { + return streamId; + } + + public int getBacklog() { + return backlog; + } + + public long getOffset() { + return offset; + } + + public ByteBuf getBuf() { + return buf; + } + + @Override + public Type type() { + return Type.READ_DATA; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ReadData readData = (ReadData) o; + return streamId == readData.streamId + && backlog == readData.backlog + && offset == readData.offset + && Objects.equals(buf, readData.buf); + } + + @Override + public int hashCode() { + return Objects.hash(streamId, backlog, offset, buf); + } + + @Override + public String toString() { + return "ReadData{" + + "streamId=" + + streamId + + ", backlog=" + + backlog + + ", offset=" + + offset + + ", buf=" + + buf + + '}'; + } +} diff --git a/common/src/main/java/org/apache/celeborn/common/network/server/BufferStreamManager.java b/common/src/main/java/org/apache/celeborn/common/network/server/BufferStreamManager.java new file mode 100644 index 000000000..bba500b8d --- /dev/null +++ b/common/src/main/java/org/apache/celeborn/common/network/server/BufferStreamManager.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.common.network.server; + +import java.util.Map; +import java.util.Random; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicLong; + +import io.netty.channel.Channel; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class BufferStreamManager { + private static final Logger logger = LoggerFactory.getLogger(BufferStreamManager.class); + private final AtomicLong nextStreamId; + protected final ConcurrentHashMap streams; + + protected class StreamState { + private Channel associatedChannel; + private int bufferSize; + + public StreamState(Channel associatedChannel, int bufferSize) { + this.associatedChannel = associatedChannel; + this.bufferSize = bufferSize; + } + + public Channel getAssociatedChannel() { + return associatedChannel; + } + + public int getBufferSize() { + return bufferSize; + } + } + + public BufferStreamManager() { + nextStreamId = new AtomicLong((long) new Random().nextInt(Integer.MAX_VALUE) * 1000); + streams = new ConcurrentHashMap<>(); + } + + public long registerStream(Channel channel, int bufferSize) { + long streamId = nextStreamId.getAndIncrement(); + streams.put(streamId, new StreamState(channel, bufferSize)); + return streamId; + } + + public void addCredit(int numCredit, long streamId) {} + + public void connectionTerminated(Channel channel) { + for (Map.Entry entry : streams.entrySet()) { + if (entry.getValue().getAssociatedChannel() == channel) { + streams.remove(entry.getKey()); + } + } + } +} diff --git a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala index 62b846deb..26eeafbc0 100644 --- a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala +++ b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala @@ -18,10 +18,12 @@ package org.apache.celeborn.service.deploy.worker import java.io.{FileNotFoundException, IOException} +import java.nio.ByteBuffer import java.nio.charset.StandardCharsets import java.util.concurrent.atomic.AtomicBoolean import com.google.common.base.Throwables +import io.netty.buffer.{ByteBuf, Unpooled} import io.netty.util.concurrent.{Future, GenericFutureListener} import org.apache.celeborn.common.exception.CelebornException @@ -31,13 +33,14 @@ import org.apache.celeborn.common.metrics.source.RPCSource import org.apache.celeborn.common.network.buffer.NioManagedBuffer import org.apache.celeborn.common.network.client.TransportClient import org.apache.celeborn.common.network.protocol._ -import org.apache.celeborn.common.network.server.{BaseMessageHandler, ChunkStreamManager} +import org.apache.celeborn.common.network.server.{BaseMessageHandler, BufferStreamManager, ChunkStreamManager} import org.apache.celeborn.common.network.util.{NettyUtils, TransportConf} import org.apache.celeborn.common.protocol.PartitionType import org.apache.celeborn.service.deploy.worker.storage.{PartitionFilesSorter, StorageManager} class FetchHandler(val conf: TransportConf) extends BaseMessageHandler with Logging { var chunkStreamManager = new ChunkStreamManager() + val bufferStreamManager = new BufferStreamManager() var workerSource: WorkerSource = _ var rpcSource: RPCSource = _ var storageManager: StorageManager = _ @@ -67,6 +70,9 @@ class FetchHandler(val conf: TransportConf) extends BaseMessageHandler with Logg override def receive(client: TransportClient, msg: RequestMessage): Unit = { msg match { + case r: ReadAddCredit => + rpcSource.updateMessageMetrics(r, 0) + handleReadAddCredit(client, r) case r: ChunkFetchRequest => rpcSource.updateMessageMetrics(r, 0) handleChunkFetchRequest(client, r) @@ -120,6 +126,14 @@ class FetchHandler(val conf: TransportConf) extends BaseMessageHandler with Logg new NioManagedBuffer(streamHandle.toByteBuffer))) } case PartitionType.MAP => + // return stream id + val streamId = + bufferStreamManager.registerStream(client.getChannel, fileInfo.getBufferSize) + val res = ByteBuffer.allocate(8) + res.putLong(streamId) + client.getChannel.writeAndFlush(new RpcResponse( + request.requestId, + new NioManagedBuffer(res))) case PartitionType.MAPGROUP => } catch { case e: IOException => @@ -141,6 +155,10 @@ class FetchHandler(val conf: TransportConf) extends BaseMessageHandler with Logg } } + def handleReadAddCredit(client: TransportClient, req: ReadAddCredit): Unit = { + bufferStreamManager.addCredit(req.getCredit, req.getStreamId) + } + def handleChunkFetchRequest(client: TransportClient, req: ChunkFetchRequest): Unit = { workerSource.startTimer(WorkerSource.FetchChunkTime, req.toString) logTrace(s"Received req from ${NettyUtils.getRemoteAddress(client.getChannel)}" + @@ -188,6 +206,7 @@ class FetchHandler(val conf: TransportConf) extends BaseMessageHandler with Logg override def channelInactive(client: TransportClient): Unit = { chunkStreamManager.connectionTerminated(client.getChannel) + bufferStreamManager.connectionTerminated(client.getChannel) logDebug(s"channel inactive ${client.getSocketAddress}") }