[CELEBORN-124]Add buffer stream. (#1069)

This commit is contained in:
Ethan Feng 2023-01-06 15:54:52 +08:00 committed by GitHub
parent 3b2be25a50
commit 5595f2f4b3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 300 additions and 2 deletions

View File

@ -182,4 +182,8 @@ public class FileInfo {
public void setBufferSize(int bufferSize) {
this.bufferSize = bufferSize;
}
public int getBufferSize() {
return bufferSize;
}
}

View File

@ -84,7 +84,9 @@ public abstract class Message implements Encodable {
PUSH_MERGED_DATA(12),
REGION_START(13),
REGION_FINISH(14),
PUSH_DATA_HAND_SHAKE(15);
PUSH_DATA_HAND_SHAKE(15),
READ_ADD_CREDIT(16),
READ_DATA(17);
private final byte id;
@ -138,6 +140,10 @@ public abstract class Message implements Encodable {
return REGION_FINISH;
case 15:
return PUSH_DATA_HAND_SHAKE;
case 16:
return READ_ADD_CREDIT;
case 17:
return READ_DATA;
case -1:
throw new IllegalArgumentException("User type messages cannot be decoded.");
default:
@ -193,6 +199,13 @@ public abstract class Message implements Encodable {
case PUSH_DATA_HAND_SHAKE:
return PushDataHandShake.decode(in);
case READ_ADD_CREDIT:
return ReadAddCredit.decode(in);
case READ_DATA:
return ReadData.decode(in);
default:
throw new IllegalArgumentException("Unexpected message type: " + msgType);
}

View File

@ -0,0 +1,79 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.celeborn.common.network.protocol;
import java.util.Objects;
import io.netty.buffer.ByteBuf;
public class ReadAddCredit extends RequestMessage {
private long streamId;
private int credit;
public ReadAddCredit(long streamId, int credit) {
this.streamId = streamId;
this.credit = credit;
}
@Override
public int encodedLength() {
return 8 + 4;
}
@Override
public void encode(ByteBuf buf) {
buf.writeLong(streamId);
buf.writeInt(credit);
}
public static ReadAddCredit decode(ByteBuf buf) {
long streamId = buf.readLong();
int credit = buf.readInt();
return new ReadAddCredit(streamId, credit);
}
public long getStreamId() {
return streamId;
}
public int getCredit() {
return credit;
}
@Override
public Type type() {
return Type.READ_ADD_CREDIT;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ReadAddCredit that = (ReadAddCredit) o;
return streamId == that.streamId && credit == that.credit;
}
@Override
public int hashCode() {
return Objects.hash(streamId, credit);
}
@Override
public String toString() {
return "ReadAddCredit{" + "streamId=" + streamId + ", credit=" + credit + '}';
}
}

View File

@ -0,0 +1,111 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.celeborn.common.network.protocol;
import java.util.Objects;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
public class ReadData extends RequestMessage {
private long streamId;
private int backlog;
private long offset;
private ByteBuf buf;
public ReadData(long streamId, int backlog, long offset, ByteBuf buf) {
this.streamId = streamId;
this.backlog = backlog;
this.offset = offset;
this.buf = buf;
}
@Override
public int encodedLength() {
return 8 + 4 + 4 + 8 + 4 + buf.readableBytes();
}
@Override
public void encode(ByteBuf buf) {
buf.writeLong(streamId);
buf.writeInt(backlog);
buf.writeLong(offset);
buf.writeInt(this.buf.readableBytes());
buf.writeBytes(this.buf);
}
public static ReadData decode(ByteBuf buf) {
long streamId = buf.readLong();
int backlog = buf.readInt();
long offset = buf.readLong();
int tmpBufSize = buf.readInt();
ByteBuf tmpBuf = Unpooled.buffer(tmpBufSize, tmpBufSize);
buf.readBytes(tmpBuf);
return new ReadData(streamId, backlog, offset, tmpBuf);
}
public long getStreamId() {
return streamId;
}
public int getBacklog() {
return backlog;
}
public long getOffset() {
return offset;
}
public ByteBuf getBuf() {
return buf;
}
@Override
public Type type() {
return Type.READ_DATA;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ReadData readData = (ReadData) o;
return streamId == readData.streamId
&& backlog == readData.backlog
&& offset == readData.offset
&& Objects.equals(buf, readData.buf);
}
@Override
public int hashCode() {
return Objects.hash(streamId, backlog, offset, buf);
}
@Override
public String toString() {
return "ReadData{"
+ "streamId="
+ streamId
+ ", backlog="
+ backlog
+ ", offset="
+ offset
+ ", buf="
+ buf
+ '}';
}
}

View File

@ -0,0 +1,72 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.celeborn.common.network.server;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
import io.netty.channel.Channel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class BufferStreamManager {
private static final Logger logger = LoggerFactory.getLogger(BufferStreamManager.class);
private final AtomicLong nextStreamId;
protected final ConcurrentHashMap<Long, StreamState> streams;
protected class StreamState {
private Channel associatedChannel;
private int bufferSize;
public StreamState(Channel associatedChannel, int bufferSize) {
this.associatedChannel = associatedChannel;
this.bufferSize = bufferSize;
}
public Channel getAssociatedChannel() {
return associatedChannel;
}
public int getBufferSize() {
return bufferSize;
}
}
public BufferStreamManager() {
nextStreamId = new AtomicLong((long) new Random().nextInt(Integer.MAX_VALUE) * 1000);
streams = new ConcurrentHashMap<>();
}
public long registerStream(Channel channel, int bufferSize) {
long streamId = nextStreamId.getAndIncrement();
streams.put(streamId, new StreamState(channel, bufferSize));
return streamId;
}
public void addCredit(int numCredit, long streamId) {}
public void connectionTerminated(Channel channel) {
for (Map.Entry<Long, StreamState> entry : streams.entrySet()) {
if (entry.getValue().getAssociatedChannel() == channel) {
streams.remove(entry.getKey());
}
}
}
}

View File

@ -18,10 +18,12 @@
package org.apache.celeborn.service.deploy.worker
import java.io.{FileNotFoundException, IOException}
import java.nio.ByteBuffer
import java.nio.charset.StandardCharsets
import java.util.concurrent.atomic.AtomicBoolean
import com.google.common.base.Throwables
import io.netty.buffer.{ByteBuf, Unpooled}
import io.netty.util.concurrent.{Future, GenericFutureListener}
import org.apache.celeborn.common.exception.CelebornException
@ -31,13 +33,14 @@ import org.apache.celeborn.common.metrics.source.RPCSource
import org.apache.celeborn.common.network.buffer.NioManagedBuffer
import org.apache.celeborn.common.network.client.TransportClient
import org.apache.celeborn.common.network.protocol._
import org.apache.celeborn.common.network.server.{BaseMessageHandler, ChunkStreamManager}
import org.apache.celeborn.common.network.server.{BaseMessageHandler, BufferStreamManager, ChunkStreamManager}
import org.apache.celeborn.common.network.util.{NettyUtils, TransportConf}
import org.apache.celeborn.common.protocol.PartitionType
import org.apache.celeborn.service.deploy.worker.storage.{PartitionFilesSorter, StorageManager}
class FetchHandler(val conf: TransportConf) extends BaseMessageHandler with Logging {
var chunkStreamManager = new ChunkStreamManager()
val bufferStreamManager = new BufferStreamManager()
var workerSource: WorkerSource = _
var rpcSource: RPCSource = _
var storageManager: StorageManager = _
@ -67,6 +70,9 @@ class FetchHandler(val conf: TransportConf) extends BaseMessageHandler with Logg
override def receive(client: TransportClient, msg: RequestMessage): Unit = {
msg match {
case r: ReadAddCredit =>
rpcSource.updateMessageMetrics(r, 0)
handleReadAddCredit(client, r)
case r: ChunkFetchRequest =>
rpcSource.updateMessageMetrics(r, 0)
handleChunkFetchRequest(client, r)
@ -120,6 +126,14 @@ class FetchHandler(val conf: TransportConf) extends BaseMessageHandler with Logg
new NioManagedBuffer(streamHandle.toByteBuffer)))
}
case PartitionType.MAP =>
// return stream id
val streamId =
bufferStreamManager.registerStream(client.getChannel, fileInfo.getBufferSize)
val res = ByteBuffer.allocate(8)
res.putLong(streamId)
client.getChannel.writeAndFlush(new RpcResponse(
request.requestId,
new NioManagedBuffer(res)))
case PartitionType.MAPGROUP =>
} catch {
case e: IOException =>
@ -141,6 +155,10 @@ class FetchHandler(val conf: TransportConf) extends BaseMessageHandler with Logg
}
}
def handleReadAddCredit(client: TransportClient, req: ReadAddCredit): Unit = {
bufferStreamManager.addCredit(req.getCredit, req.getStreamId)
}
def handleChunkFetchRequest(client: TransportClient, req: ChunkFetchRequest): Unit = {
workerSource.startTimer(WorkerSource.FetchChunkTime, req.toString)
logTrace(s"Received req from ${NettyUtils.getRemoteAddress(client.getChannel)}" +
@ -188,6 +206,7 @@ class FetchHandler(val conf: TransportConf) extends BaseMessageHandler with Logg
override def channelInactive(client: TransportClient): Unit = {
chunkStreamManager.connectionTerminated(client.getChannel)
bufferStreamManager.connectionTerminated(client.getChannel)
logDebug(s"channel inactive ${client.getSocketAddress}")
}