From 3aacede5f8f2cab9bcd9ecc3fe7902ca7d204c5d Mon Sep 17 00:00:00 2001 From: Ethan Feng Date: Fri, 17 Feb 2023 14:12:54 +0800 Subject: [PATCH] [CELEBORN-283] Derive network layer for flink plugin. (#1222) --- .../flink/buffer/FlinkNettyManagedBuffer.java | 67 +++++++++++ .../network/FlinkTransportClientFactory.java | 40 +++++++ .../flink/network/MessageDecoderExt.java | 88 +++++++++++++++ .../flink/network/ReadClientHandler.java | 104 ++++++++++++++++++ ...ansportFrameDecoderWithBufferSupplier.java | 88 +++++++-------- .../plugin/flink/protocol/ReadData.java | 102 +++++++++++++++++ .../common/network/TransportContext.java | 14 ++- .../network/buffer/NettyManagedBuffer.java | 4 + .../client/TransportClientFactory.java | 17 ++- .../common/network/protocol/Message.java | 10 +- .../common/network/util/NettyUtils.java | 15 --- .../common/network/util/TransportConf.java | 4 - .../apache/celeborn/common/CelebornConf.scala | 14 --- docs/configuration/network.md | 1 - 14 files changed, 481 insertions(+), 87 deletions(-) create mode 100644 client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/buffer/FlinkNettyManagedBuffer.java create mode 100644 client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/FlinkTransportClientFactory.java create mode 100644 client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/MessageDecoderExt.java create mode 100644 client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/ReadClientHandler.java rename {common/src/main/java/org/apache/celeborn/common/network/util => client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network}/TransportFrameDecoderWithBufferSupplier.java (65%) create mode 100644 client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/protocol/ReadData.java diff --git a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/buffer/FlinkNettyManagedBuffer.java b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/buffer/FlinkNettyManagedBuffer.java new file mode 100644 index 000000000..aeb736a39 --- /dev/null +++ b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/buffer/FlinkNettyManagedBuffer.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.plugin.flink.buffer; + +import java.io.InputStream; +import java.nio.ByteBuffer; + +import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf; +import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBufInputStream; + +import org.apache.celeborn.common.network.buffer.ManagedBuffer; + +public class FlinkNettyManagedBuffer extends ManagedBuffer { + private final ByteBuf buf; + + public FlinkNettyManagedBuffer(ByteBuf buf) { + super(); + this.buf = buf; + } + + @Override + public long size() { + return buf.readableBytes(); + } + + @Override + public ByteBuffer nioByteBuffer() { + return buf.nioBuffer(); + } + + @Override + public InputStream createInputStream() { + return new ByteBufInputStream(buf); + } + + @Override + public ManagedBuffer retain() { + buf.retain(); + return this; + } + + @Override + public ManagedBuffer release() { + buf.release(); + return this; + } + + @Override + public Object convertToNetty() { + return buf.duplicate().retain(); + } +} diff --git a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/FlinkTransportClientFactory.java b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/FlinkTransportClientFactory.java new file mode 100644 index 000000000..61abcb6bc --- /dev/null +++ b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/FlinkTransportClientFactory.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.plugin.flink.network; + +import java.io.IOException; +import java.util.function.Supplier; + +import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf; + +import org.apache.celeborn.common.network.TransportContext; +import org.apache.celeborn.common.network.client.TransportClient; +import org.apache.celeborn.common.network.client.TransportClientFactory; + +public class FlinkTransportClientFactory extends TransportClientFactory { + public FlinkTransportClientFactory(TransportContext context) { + super(context); + } + + public TransportClient createClient( + String remoteHost, int remotePort, int partitionId, Supplier supplier) + throws IOException, InterruptedException { + return createClient( + remoteHost, remotePort, partitionId, new TransportFrameDecoderWithBufferSupplier(supplier)); + } +} diff --git a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/MessageDecoderExt.java b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/MessageDecoderExt.java new file mode 100644 index 000000000..fbee2e4fd --- /dev/null +++ b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/MessageDecoderExt.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.plugin.flink.network; + +import java.nio.charset.StandardCharsets; + +import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf; + +import org.apache.celeborn.common.network.buffer.NettyManagedBuffer; +import org.apache.celeborn.common.network.protocol.*; +import org.apache.celeborn.plugin.flink.buffer.FlinkNettyManagedBuffer; +import org.apache.celeborn.plugin.flink.protocol.ReadData; + +public class MessageDecoderExt { + public static Message decode(Message.Type type, ByteBuf in, boolean decodeBody) { + long requestId; + // cannot use actual class decode method because common module cannot refer flink shaded netty. + switch (type) { + case RPC_REQUEST: + requestId = in.readLong(); + in.readInt(); + if (decodeBody) { + return new RpcRequest(requestId, new FlinkNettyManagedBuffer(in)); + } else { + return new RpcRequest(requestId, NettyManagedBuffer.EmptyBuffer); + } + + case RPC_RESPONSE: + requestId = in.readLong(); + in.readInt(); + if (decodeBody) { + return new RpcResponse(requestId, new FlinkNettyManagedBuffer(in)); + } else { + return new RpcResponse(requestId, NettyManagedBuffer.EmptyBuffer); + } + + case RPC_FAILURE: + requestId = in.readLong(); + int length = in.readInt(); + byte[] bytes = new byte[length]; + in.readBytes(bytes); + String errorString = new String(bytes, StandardCharsets.UTF_8); + return new RpcFailure(requestId, errorString); + + case ONE_WAY_MESSAGE: + in.readInt(); + if (decodeBody) { + return new OneWayMessage(new FlinkNettyManagedBuffer(in)); + } else { + return new OneWayMessage(NettyManagedBuffer.EmptyBuffer); + } + + case READ_ADD_CREDIT: + long streamId = in.readLong(); + int credit = in.readInt(); + return new ReadAddCredit(streamId, credit); + + case READ_DATA: + streamId = in.readLong(); + int backlog = in.readInt(); + long offset = in.readLong(); + return new ReadData(streamId, backlog, offset); + + case BACKLOG_ANNOUNCEMENT: + streamId = in.readLong(); + backlog = in.readInt(); + return new BacklogAnnouncement(streamId, backlog); + + default: + throw new IllegalArgumentException("Unexpected message type: " + type); + } + } +} diff --git a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/ReadClientHandler.java b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/ReadClientHandler.java new file mode 100644 index 000000000..fc0dfaccb --- /dev/null +++ b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/ReadClientHandler.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.plugin.flink.network; + +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Consumer; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.celeborn.common.network.client.TransportClient; +import org.apache.celeborn.common.network.protocol.BacklogAnnouncement; +import org.apache.celeborn.common.network.protocol.RequestMessage; +import org.apache.celeborn.common.network.server.BaseMessageHandler; +import org.apache.celeborn.plugin.flink.protocol.ReadData; + +public class ReadClientHandler extends BaseMessageHandler { + private static Logger logger = LoggerFactory.getLogger(ReadClientHandler.class); + private ConcurrentHashMap> streamHandlers = + new ConcurrentHashMap<>(); + private ConcurrentHashMap streamClients = new ConcurrentHashMap<>(); + + public void registerHandler( + long streamId, Consumer handle, TransportClient client) { + streamHandlers.put(streamId, handle); + streamClients.put(streamId, client); + } + + public void removeHandler(long streamId) { + streamHandlers.remove(streamId); + streamClients.remove(streamId); + } + + @Override + public void receive(TransportClient client, RequestMessage msg) { + long streamId = 0; + switch (msg.type()) { + case READ_DATA: + ReadData readData = (ReadData) msg; + streamId = readData.getStreamId(); + if (streamHandlers.containsKey(streamId)) { + logger.debug( + "received streamId: {}, readData size:{}", + streamId, + readData.getFlinkBuffer().readableBytes()); + streamHandlers.get(streamId).accept(msg); + } else { + logger.warn("Unexpected streamId received: {}", streamId); + } + break; + case BACKLOG_ANNOUNCEMENT: + BacklogAnnouncement backlogAnnouncement = (BacklogAnnouncement) msg; + streamId = backlogAnnouncement.getStreamId(); + if (streamHandlers.containsKey(streamId)) { + logger.debug( + "received streamId: {}, backlog: {}", streamId, backlogAnnouncement.getBacklog()); + streamHandlers.get(streamId).accept(msg); + } else { + logger.warn("Unexpected streamId received: {}", streamId); + } + break; + case ONE_WAY_MESSAGE: + // ignore it. + break; + default: + logger.error("Unexpected msg type {} content {}", msg.type(), msg); + } + } + + @Override + public boolean checkRegistered() { + return true; + } + + @Override + public void channelInactive(TransportClient client) { + streamClients.forEach( + (streamId, savedClient) -> { + if (savedClient == client) { + logger.warn("Client {} is lost, remove related stream {}", savedClient, streamId); + } + }); + } + + @Override + public void exceptionCaught(Throwable cause, TransportClient client) { + logger.warn("exception caught {}", client.getSocketAddress(), cause); + } +} diff --git a/common/src/main/java/org/apache/celeborn/common/network/util/TransportFrameDecoderWithBufferSupplier.java b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/TransportFrameDecoderWithBufferSupplier.java similarity index 65% rename from common/src/main/java/org/apache/celeborn/common/network/util/TransportFrameDecoderWithBufferSupplier.java rename to client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/TransportFrameDecoderWithBufferSupplier.java index 9d6d6bac1..604366ebf 100644 --- a/common/src/main/java/org/apache/celeborn/common/network/util/TransportFrameDecoderWithBufferSupplier.java +++ b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/TransportFrameDecoderWithBufferSupplier.java @@ -15,57 +15,43 @@ * limitations under the License. */ -package org.apache.celeborn.common.network.util; +package org.apache.celeborn.plugin.flink.network; -import java.util.function.Function; import java.util.function.Supplier; -import io.netty.buffer.ByteBuf; -import io.netty.buffer.CompositeByteBuf; -import io.netty.buffer.Unpooled; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; +import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf; +import org.apache.flink.shaded.netty4.io.netty.buffer.Unpooled; import org.apache.celeborn.common.network.protocol.Message; +import org.apache.celeborn.common.network.util.FrameDecoder; +import org.apache.celeborn.plugin.flink.protocol.ReadData; public class TransportFrameDecoderWithBufferSupplier extends ChannelInboundHandlerAdapter implements FrameDecoder { - private final Function> bufferSuppliers; + private final Supplier bufferSupplier; private int msgSize = -1; private int bodySize = -1; private Message.Type curType = Message.Type.UNKNOWN_TYPE; private ByteBuf headerBuf = Unpooled.buffer(HEADER_SIZE, HEADER_SIZE); - private CompositeByteBuf bodyBuf = null; + private io.netty.buffer.CompositeByteBuf bodyBuf = null; private ByteBuf externalBuf = null; private final ByteBuf msgBuf = Unpooled.buffer(8); private Message curMsg = null; - public TransportFrameDecoderWithBufferSupplier() { - this.bufferSuppliers = - new Function>() { - @Override - public Supplier apply(Integer size) { - return new Supplier() { - @Override - public ByteBuf get() { - return Unpooled.buffer(size); - } - }; - } - }; + public TransportFrameDecoderWithBufferSupplier(Supplier bufferSupplier) { + this.bufferSupplier = bufferSupplier; } - public TransportFrameDecoderWithBufferSupplier( - Function> bufferSuppliers) { - this.bufferSuppliers = bufferSuppliers; - } - - private void copyByteBuf(ByteBuf source, ByteBuf target, int targetSize) { + private void copyByteBuf(io.netty.buffer.ByteBuf source, ByteBuf target, int targetSize) { int bytes = Math.min(source.readableBytes(), targetSize - target.readableBytes()); - target.writeBytes(source, bytes); + for (int i = 0; i < bytes; i++) { + target.writeByte(source.readByte()); + } } - private void decodeHeader(ByteBuf buf, ChannelHandlerContext ctx) { + private void decodeHeader(io.netty.buffer.ByteBuf buf, ChannelHandlerContext ctx) { copyByteBuf(buf, headerBuf, HEADER_SIZE); if (!headerBuf.isWritable()) { msgSize = headerBuf.readInt(); @@ -73,18 +59,20 @@ public class TransportFrameDecoderWithBufferSupplier extends ChannelInboundHandl msgBuf.capacity(msgSize); } msgBuf.clear(); - curType = Message.Type.decode(headerBuf); + curType = Message.Type.decode(headerBuf.nioBuffer()); + // type byte is read + headerBuf.readByte(); bodySize = headerBuf.readInt(); decodeMsg(buf, ctx); } } - private void decodeMsg(ByteBuf buf, ChannelHandlerContext ctx) { + private void decodeMsg(io.netty.buffer.ByteBuf buf, ChannelHandlerContext ctx) { if (msgBuf.readableBytes() < msgSize) { copyByteBuf(buf, msgBuf, msgSize); } if (msgBuf.readableBytes() == msgSize) { - curMsg = Message.decode(curType, msgBuf, false); + curMsg = MessageDecoderExt.decode(curType, msgBuf, false); if (bodySize <= 0) { ctx.fireChannelRead(curMsg); clear(); @@ -92,11 +80,12 @@ public class TransportFrameDecoderWithBufferSupplier extends ChannelInboundHandl } } - private ByteBuf decodeBody(ByteBuf buf, ChannelHandlerContext ctx) { + private io.netty.buffer.ByteBuf decodeBody( + io.netty.buffer.ByteBuf buf, ChannelHandlerContext ctx) { if (bodyBuf == null) { if (buf.readableBytes() >= bodySize) { - ByteBuf body = buf.retain().readSlice(bodySize); - curMsg.setBody(body); + io.netty.buffer.ByteBuf body = buf.retain().readSlice(bodySize); + curMsg.setBody(body.nioBuffer()); ctx.fireChannelRead(curMsg); clear(); return buf; @@ -105,7 +94,7 @@ public class TransportFrameDecoderWithBufferSupplier extends ChannelInboundHandl } } int remaining = bodySize - bodyBuf.readableBytes(); - ByteBuf next; + io.netty.buffer.ByteBuf next; if (remaining >= buf.readableBytes()) { next = buf; buf = null; @@ -114,20 +103,25 @@ public class TransportFrameDecoderWithBufferSupplier extends ChannelInboundHandl } bodyBuf.addComponent(next).writerIndex(bodyBuf.writerIndex() + next.readableBytes()); if (bodyBuf.readableBytes() == bodySize) { - curMsg.setBody(bodyBuf); + curMsg.setBody(bodyBuf.nioBuffer()); ctx.fireChannelRead(curMsg); clear(); } return buf; } - private ByteBuf decodeBodyCopyOut(ByteBuf buf, ChannelHandlerContext ctx) { + private io.netty.buffer.ByteBuf decodeBodyCopyOut( + io.netty.buffer.ByteBuf buf, ChannelHandlerContext ctx) { if (externalBuf == null) { - externalBuf = bufferSuppliers.apply(bodySize).get(); + externalBuf = bufferSupplier.get(); } copyByteBuf(buf, externalBuf, bodySize); if (externalBuf.readableBytes() == bodySize) { - curMsg.setBody(externalBuf); + if (curMsg instanceof ReadData) { + ((ReadData) curMsg).setFlinkBuffer(externalBuf); + } else { + curMsg.setBody(externalBuf.nioBuffer()); + } ctx.fireChannelRead(curMsg); clear(); } @@ -135,24 +129,24 @@ public class TransportFrameDecoderWithBufferSupplier extends ChannelInboundHandl } public void channelRead(ChannelHandlerContext ctx, Object data) { - ByteBuf buf = (ByteBuf) data; + io.netty.buffer.ByteBuf nettyBuf = (io.netty.buffer.ByteBuf) data; try { - while (buf != null && buf.isReadable()) { + while (nettyBuf != null && nettyBuf.isReadable()) { if (headerBuf.isWritable()) { - decodeHeader(buf, ctx); + decodeHeader(nettyBuf, ctx); } else if (curMsg == null) { - decodeMsg(buf, ctx); + decodeMsg(nettyBuf, ctx); } else if (bodySize > 0) { if (curMsg.needCopyOut()) { - buf = decodeBodyCopyOut(buf, ctx); + nettyBuf = decodeBodyCopyOut(nettyBuf, ctx); } else { - buf = decodeBody(buf, ctx); + nettyBuf = decodeBody(nettyBuf, ctx); } } } } finally { - if (buf != null) { - buf.release(); + if (nettyBuf != null) { + nettyBuf.release(); } } } diff --git a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/protocol/ReadData.java b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/protocol/ReadData.java new file mode 100644 index 000000000..0932b0949 --- /dev/null +++ b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/protocol/ReadData.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.plugin.flink.protocol; + +import java.util.Objects; + +import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf; + +import org.apache.celeborn.common.network.protocol.RequestMessage; + +public final class ReadData extends RequestMessage { + private final long streamId; + private final int backlog; + private final long offset; + private ByteBuf flinkBuffer; + + public ReadData(long streamId, int backlog, long offset) { + this.streamId = streamId; + this.backlog = backlog; + this.offset = offset; + } + + @Override + public int encodedLength() { + return 8 + 4 + 8; + } + + // This method will not be called because ReadData won't be created at flink client. + @Override + public void encode(io.netty.buffer.ByteBuf buf) { + buf.writeLong(streamId); + buf.writeInt(backlog); + buf.writeLong(offset); + } + + public long getStreamId() { + return streamId; + } + + public int getBacklog() { + return backlog; + } + + public long getOffset() { + return offset; + } + + @Override + public Type type() { + return Type.READ_DATA; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ReadData readData = (ReadData) o; + return streamId == readData.streamId + && backlog == readData.backlog + && offset == readData.offset; + } + + @Override + public int hashCode() { + return Objects.hash(streamId, backlog, offset); + } + + @Override + public String toString() { + return "ReadData{" + + "streamId=" + + streamId + + ", backlog=" + + backlog + + ", offset=" + + offset + + '}'; + } + + public ByteBuf getFlinkBuffer() { + return flinkBuffer; + } + + public void setFlinkBuffer(ByteBuf flinkBuffer) { + this.flinkBuffer = flinkBuffer; + } +} diff --git a/common/src/main/java/org/apache/celeborn/common/network/TransportContext.java b/common/src/main/java/org/apache/celeborn/common/network/TransportContext.java index 69268de80..126432933 100644 --- a/common/src/main/java/org/apache/celeborn/common/network/TransportContext.java +++ b/common/src/main/java/org/apache/celeborn/common/network/TransportContext.java @@ -18,6 +18,7 @@ package org.apache.celeborn.common.network; import io.netty.channel.Channel; +import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.socket.SocketChannel; import io.netty.handler.timeout.IdleStateHandler; import org.slf4j.Logger; @@ -29,8 +30,8 @@ import org.apache.celeborn.common.network.client.TransportResponseHandler; import org.apache.celeborn.common.network.protocol.MessageEncoder; import org.apache.celeborn.common.network.server.*; import org.apache.celeborn.common.network.util.FrameDecoder; -import org.apache.celeborn.common.network.util.NettyUtils; import org.apache.celeborn.common.network.util.TransportConf; +import org.apache.celeborn.common.network.util.TransportFrameDecoder; /** * Contains the context to create a {@link TransportServer}, {@link TransportClientFactory}, and to @@ -94,6 +95,11 @@ public class TransportContext { } public TransportChannelHandler initializePipeline(SocketChannel channel) { + return initializePipeline(channel, new TransportFrameDecoder()); + } + + public TransportChannelHandler initializePipeline( + SocketChannel channel, ChannelInboundHandlerAdapter decoder) { try { if (channelsLimiter != null) { channel.pipeline().addLast("limiter", channelsLimiter); @@ -102,7 +108,7 @@ public class TransportContext { channel .pipeline() .addLast("encoder", ENCODER) - .addLast(FrameDecoder.HANDLER_NAME, NettyUtils.createFrameDecoder(conf)) + .addLast(FrameDecoder.HANDLER_NAME, decoder) .addLast( "idleStateHandler", new IdleStateHandler(0, 0, conf.connectionTimeoutMs() / 1000)) .addLast("handler", channelHandler); @@ -126,4 +132,8 @@ public class TransportContext { public TransportConf getConf() { return conf; } + + public BaseMessageHandler getMsgHandler() { + return msgHandler; + } } diff --git a/common/src/main/java/org/apache/celeborn/common/network/buffer/NettyManagedBuffer.java b/common/src/main/java/org/apache/celeborn/common/network/buffer/NettyManagedBuffer.java index 760e8bed3..d3be0a6d9 100644 --- a/common/src/main/java/org/apache/celeborn/common/network/buffer/NettyManagedBuffer.java +++ b/common/src/main/java/org/apache/celeborn/common/network/buffer/NettyManagedBuffer.java @@ -35,6 +35,10 @@ public class NettyManagedBuffer extends ManagedBuffer { this.buf = buf; } + public NettyManagedBuffer(ByteBuffer buffer) { + this.buf = Unpooled.wrappedBuffer(buffer); + } + public ByteBuf getBuf() { return buf.duplicate(); } diff --git a/common/src/main/java/org/apache/celeborn/common/network/client/TransportClientFactory.java b/common/src/main/java/org/apache/celeborn/common/network/client/TransportClientFactory.java index 8e1a00cc0..de8e61de8 100644 --- a/common/src/main/java/org/apache/celeborn/common/network/client/TransportClientFactory.java +++ b/common/src/main/java/org/apache/celeborn/common/network/client/TransportClientFactory.java @@ -110,6 +110,12 @@ public class TransportClientFactory implements Closeable { */ public TransportClient createClient(String remoteHost, int remotePort, int partitionId) throws IOException, InterruptedException { + return createClient(remoteHost, remotePort, partitionId, new TransportFrameDecoder()); + } + + public TransportClient createClient( + String remoteHost, int remotePort, int partitionId, ChannelInboundHandlerAdapter decoder) + throws IOException, InterruptedException { // Get connection from the connection pool first. // If it is not found or not active, create a new one. // Use unresolved address here to avoid DNS resolution each time we creates a client. @@ -166,7 +172,7 @@ public class TransportClientFactory implements Closeable { logger.info("Found inactive connection to {}, creating a new one.", resolvedAddress); } } - clientPool.clients[clientIndex] = internalCreateClient(resolvedAddress); + clientPool.clients[clientIndex] = internalCreateClient(resolvedAddress, decoder); return clientPool.clients[clientIndex]; } } @@ -182,7 +188,8 @@ public class TransportClientFactory implements Closeable { * *

As with {@link #createClient(String, int)}, this method is blocking. */ - private TransportClient internalCreateClient(InetSocketAddress address) + private TransportClient internalCreateClient( + InetSocketAddress address, ChannelInboundHandlerAdapter decoder) throws IOException, InterruptedException { Bootstrap bootstrap = new Bootstrap(); bootstrap @@ -209,7 +216,7 @@ public class TransportClientFactory implements Closeable { new ChannelInitializer() { @Override public void initChannel(SocketChannel ch) { - TransportChannelHandler clientHandler = context.initializePipeline(ch); + TransportChannelHandler clientHandler = context.initializePipeline(ch, decoder); clientRef.set(clientHandler.getClient()); channelRef.set(ch); } @@ -252,4 +259,8 @@ public class TransportClientFactory implements Closeable { workerGroup.shutdownGracefully(); } } + + public TransportContext getContext() { + return context; + } } diff --git a/common/src/main/java/org/apache/celeborn/common/network/protocol/Message.java b/common/src/main/java/org/apache/celeborn/common/network/protocol/Message.java index a361c9739..b15453a15 100644 --- a/common/src/main/java/org/apache/celeborn/common/network/protocol/Message.java +++ b/common/src/main/java/org/apache/celeborn/common/network/protocol/Message.java @@ -50,6 +50,10 @@ public abstract class Message implements Encodable { this.body = new NettyManagedBuffer(buf); } + public void setBody(ByteBuffer buf) { + this.body = new NettyManagedBuffer(buf); + } + /** Whether the body should be copied out in frame decoder. */ public boolean needCopyOut() { return false; @@ -89,7 +93,6 @@ public abstract class Message implements Encodable { READ_DATA(17), OPEN_STREAM_WITH_CREDIT(18), BACKLOG_ANNOUNCEMENT(19); - private final byte id; Type(int id) { @@ -111,6 +114,11 @@ public abstract class Message implements Encodable { buf.writeByte(id); } + public static Type decode(ByteBuffer buffer) { + ByteBuf buf = Unpooled.wrappedBuffer(buffer); + return decode(buf); + } + public static Type decode(ByteBuf buf) { byte id = buf.readByte(); switch (id) { diff --git a/common/src/main/java/org/apache/celeborn/common/network/util/NettyUtils.java b/common/src/main/java/org/apache/celeborn/common/network/util/NettyUtils.java index 5fdb90c21..9a1223d2f 100644 --- a/common/src/main/java/org/apache/celeborn/common/network/util/NettyUtils.java +++ b/common/src/main/java/org/apache/celeborn/common/network/util/NettyUtils.java @@ -21,7 +21,6 @@ import java.util.concurrent.ThreadFactory; import io.netty.buffer.PooledByteBufAllocator; import io.netty.channel.Channel; -import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.EventLoopGroup; import io.netty.channel.ServerChannel; import io.netty.channel.epoll.EpollEventLoopGroup; @@ -78,20 +77,6 @@ public class NettyUtils { } } - /** - * Creates a LengthFieldBasedFrameDecoder where the first 8 bytes are the length of the frame. - * This is used before all decoders. - */ - public static ChannelInboundHandlerAdapter createFrameDecoder(TransportConf conf) { - if (conf.decoderMode().equals("default")) { - return new TransportFrameDecoder(); - } else if (conf.decoderMode().equals("supplier")) { - return new TransportFrameDecoderWithBufferSupplier(); - } else { - return new TransportFrameDecoder(); - } - } - /** Returns the remote address on the channel or "<unknown remote>" if none exists. */ public static String getRemoteAddress(Channel channel) { if (channel != null && channel.remoteAddress() != null) { diff --git a/common/src/main/java/org/apache/celeborn/common/network/util/TransportConf.java b/common/src/main/java/org/apache/celeborn/common/network/util/TransportConf.java index 29d40f2ae..2971096d7 100644 --- a/common/src/main/java/org/apache/celeborn/common/network/util/TransportConf.java +++ b/common/src/main/java/org/apache/celeborn/common/network/util/TransportConf.java @@ -141,10 +141,6 @@ public class TransportConf { return conf.networkIoMaxChunksBeingTransferred(module); } - public String decoderMode() { - return conf.networkIoDecoderMode(module); - } - public long pushDataTimeoutCheckIntervalMs() { return conf.pushTimeoutCheckInterval(); } diff --git a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala index 64c4d1023..59ae4bcd2 100644 --- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala +++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala @@ -461,11 +461,6 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se getSizeAsBytes(key, MAX_CHUNKS_BEING_TRANSFERRED.defaultValueString) } - def networkIoDecoderMode(module: String): String = { - val key = NETWORK_IO_DECODER_MODE.key.replace("", module) - get(key, NETWORK_IO_DECODER_MODE.defaultValue.get) - } - // ////////////////////////////////////////////////////// // Master // // ////////////////////////////////////////////////////// @@ -1089,15 +1084,6 @@ object CelebornConf extends Logging { .checkValues(Set("NIO", "EPOLL")) .createWithDefault("NIO") - val NETWORK_IO_DECODER_MODE: ConfigEntry[String] = - buildConf("celeborn..decoder.mode") - .categories("network") - .doc("Netty TransportFrameDecoder implementation, available options: default, supplier.") - .stringConf - .transform(_.toLowerCase) - .checkValues(Set("default", "supplier")) - .createWithDefault("default") - val NETWORK_IO_PREFER_DIRECT_BUFS: ConfigEntry[Boolean] = buildConf("celeborn..io.preferDirectBufs") .categories("network") diff --git a/docs/configuration/network.md b/docs/configuration/network.md index d6b01c62d..8fb4fef81 100644 --- a/docs/configuration/network.md +++ b/docs/configuration/network.md @@ -19,7 +19,6 @@ license: | | Key | Default | Description | Since | | --- | ------- | ----------- | ----- | -| celeborn.<module>.decoder.mode | default | Netty TransportFrameDecoder implementation, available options: default, supplier. | | | celeborn.<module>.io.backLog | 0 | Requested maximum length of the queue of incoming connections. Default 0 for no backlog. | | | celeborn.<module>.io.clientThreads | 0 | Number of threads used in the client thread pool. Default to 0, which is 2x#cores. | | | celeborn.<module>.io.connectTimeout | <value of celeborn.network.connect.timeout> | Socket connect timeout. | |