[CELEBORN-283] Derive network layer for flink plugin. (#1222)
This commit is contained in:
parent
5236df68af
commit
3aacede5f8
@ -0,0 +1,67 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.celeborn.plugin.flink.buffer;
|
||||
|
||||
import java.io.InputStream;
|
||||
import java.nio.ByteBuffer;
|
||||
|
||||
import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf;
|
||||
import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBufInputStream;
|
||||
|
||||
import org.apache.celeborn.common.network.buffer.ManagedBuffer;
|
||||
|
||||
public class FlinkNettyManagedBuffer extends ManagedBuffer {
|
||||
private final ByteBuf buf;
|
||||
|
||||
public FlinkNettyManagedBuffer(ByteBuf buf) {
|
||||
super();
|
||||
this.buf = buf;
|
||||
}
|
||||
|
||||
@Override
|
||||
public long size() {
|
||||
return buf.readableBytes();
|
||||
}
|
||||
|
||||
@Override
|
||||
public ByteBuffer nioByteBuffer() {
|
||||
return buf.nioBuffer();
|
||||
}
|
||||
|
||||
@Override
|
||||
public InputStream createInputStream() {
|
||||
return new ByteBufInputStream(buf);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ManagedBuffer retain() {
|
||||
buf.retain();
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ManagedBuffer release() {
|
||||
buf.release();
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object convertToNetty() {
|
||||
return buf.duplicate().retain();
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,40 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.celeborn.plugin.flink.network;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.function.Supplier;
|
||||
|
||||
import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf;
|
||||
|
||||
import org.apache.celeborn.common.network.TransportContext;
|
||||
import org.apache.celeborn.common.network.client.TransportClient;
|
||||
import org.apache.celeborn.common.network.client.TransportClientFactory;
|
||||
|
||||
public class FlinkTransportClientFactory extends TransportClientFactory {
|
||||
public FlinkTransportClientFactory(TransportContext context) {
|
||||
super(context);
|
||||
}
|
||||
|
||||
public TransportClient createClient(
|
||||
String remoteHost, int remotePort, int partitionId, Supplier<ByteBuf> supplier)
|
||||
throws IOException, InterruptedException {
|
||||
return createClient(
|
||||
remoteHost, remotePort, partitionId, new TransportFrameDecoderWithBufferSupplier(supplier));
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,88 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.celeborn.plugin.flink.network;
|
||||
|
||||
import java.nio.charset.StandardCharsets;
|
||||
|
||||
import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf;
|
||||
|
||||
import org.apache.celeborn.common.network.buffer.NettyManagedBuffer;
|
||||
import org.apache.celeborn.common.network.protocol.*;
|
||||
import org.apache.celeborn.plugin.flink.buffer.FlinkNettyManagedBuffer;
|
||||
import org.apache.celeborn.plugin.flink.protocol.ReadData;
|
||||
|
||||
public class MessageDecoderExt {
|
||||
public static Message decode(Message.Type type, ByteBuf in, boolean decodeBody) {
|
||||
long requestId;
|
||||
// cannot use actual class decode method because common module cannot refer flink shaded netty.
|
||||
switch (type) {
|
||||
case RPC_REQUEST:
|
||||
requestId = in.readLong();
|
||||
in.readInt();
|
||||
if (decodeBody) {
|
||||
return new RpcRequest(requestId, new FlinkNettyManagedBuffer(in));
|
||||
} else {
|
||||
return new RpcRequest(requestId, NettyManagedBuffer.EmptyBuffer);
|
||||
}
|
||||
|
||||
case RPC_RESPONSE:
|
||||
requestId = in.readLong();
|
||||
in.readInt();
|
||||
if (decodeBody) {
|
||||
return new RpcResponse(requestId, new FlinkNettyManagedBuffer(in));
|
||||
} else {
|
||||
return new RpcResponse(requestId, NettyManagedBuffer.EmptyBuffer);
|
||||
}
|
||||
|
||||
case RPC_FAILURE:
|
||||
requestId = in.readLong();
|
||||
int length = in.readInt();
|
||||
byte[] bytes = new byte[length];
|
||||
in.readBytes(bytes);
|
||||
String errorString = new String(bytes, StandardCharsets.UTF_8);
|
||||
return new RpcFailure(requestId, errorString);
|
||||
|
||||
case ONE_WAY_MESSAGE:
|
||||
in.readInt();
|
||||
if (decodeBody) {
|
||||
return new OneWayMessage(new FlinkNettyManagedBuffer(in));
|
||||
} else {
|
||||
return new OneWayMessage(NettyManagedBuffer.EmptyBuffer);
|
||||
}
|
||||
|
||||
case READ_ADD_CREDIT:
|
||||
long streamId = in.readLong();
|
||||
int credit = in.readInt();
|
||||
return new ReadAddCredit(streamId, credit);
|
||||
|
||||
case READ_DATA:
|
||||
streamId = in.readLong();
|
||||
int backlog = in.readInt();
|
||||
long offset = in.readLong();
|
||||
return new ReadData(streamId, backlog, offset);
|
||||
|
||||
case BACKLOG_ANNOUNCEMENT:
|
||||
streamId = in.readLong();
|
||||
backlog = in.readInt();
|
||||
return new BacklogAnnouncement(streamId, backlog);
|
||||
|
||||
default:
|
||||
throw new IllegalArgumentException("Unexpected message type: " + type);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,104 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.celeborn.plugin.flink.network;
|
||||
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import org.apache.celeborn.common.network.client.TransportClient;
|
||||
import org.apache.celeborn.common.network.protocol.BacklogAnnouncement;
|
||||
import org.apache.celeborn.common.network.protocol.RequestMessage;
|
||||
import org.apache.celeborn.common.network.server.BaseMessageHandler;
|
||||
import org.apache.celeborn.plugin.flink.protocol.ReadData;
|
||||
|
||||
public class ReadClientHandler extends BaseMessageHandler {
|
||||
private static Logger logger = LoggerFactory.getLogger(ReadClientHandler.class);
|
||||
private ConcurrentHashMap<Long, Consumer<RequestMessage>> streamHandlers =
|
||||
new ConcurrentHashMap<>();
|
||||
private ConcurrentHashMap<Long, TransportClient> streamClients = new ConcurrentHashMap<>();
|
||||
|
||||
public void registerHandler(
|
||||
long streamId, Consumer<RequestMessage> handle, TransportClient client) {
|
||||
streamHandlers.put(streamId, handle);
|
||||
streamClients.put(streamId, client);
|
||||
}
|
||||
|
||||
public void removeHandler(long streamId) {
|
||||
streamHandlers.remove(streamId);
|
||||
streamClients.remove(streamId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void receive(TransportClient client, RequestMessage msg) {
|
||||
long streamId = 0;
|
||||
switch (msg.type()) {
|
||||
case READ_DATA:
|
||||
ReadData readData = (ReadData) msg;
|
||||
streamId = readData.getStreamId();
|
||||
if (streamHandlers.containsKey(streamId)) {
|
||||
logger.debug(
|
||||
"received streamId: {}, readData size:{}",
|
||||
streamId,
|
||||
readData.getFlinkBuffer().readableBytes());
|
||||
streamHandlers.get(streamId).accept(msg);
|
||||
} else {
|
||||
logger.warn("Unexpected streamId received: {}", streamId);
|
||||
}
|
||||
break;
|
||||
case BACKLOG_ANNOUNCEMENT:
|
||||
BacklogAnnouncement backlogAnnouncement = (BacklogAnnouncement) msg;
|
||||
streamId = backlogAnnouncement.getStreamId();
|
||||
if (streamHandlers.containsKey(streamId)) {
|
||||
logger.debug(
|
||||
"received streamId: {}, backlog: {}", streamId, backlogAnnouncement.getBacklog());
|
||||
streamHandlers.get(streamId).accept(msg);
|
||||
} else {
|
||||
logger.warn("Unexpected streamId received: {}", streamId);
|
||||
}
|
||||
break;
|
||||
case ONE_WAY_MESSAGE:
|
||||
// ignore it.
|
||||
break;
|
||||
default:
|
||||
logger.error("Unexpected msg type {} content {}", msg.type(), msg);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean checkRegistered() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void channelInactive(TransportClient client) {
|
||||
streamClients.forEach(
|
||||
(streamId, savedClient) -> {
|
||||
if (savedClient == client) {
|
||||
logger.warn("Client {} is lost, remove related stream {}", savedClient, streamId);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@Override
|
||||
public void exceptionCaught(Throwable cause, TransportClient client) {
|
||||
logger.warn("exception caught {}", client.getSocketAddress(), cause);
|
||||
}
|
||||
}
|
||||
@ -15,57 +15,43 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.celeborn.common.network.util;
|
||||
package org.apache.celeborn.plugin.flink.network;
|
||||
|
||||
import java.util.function.Function;
|
||||
import java.util.function.Supplier;
|
||||
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.buffer.CompositeByteBuf;
|
||||
import io.netty.buffer.Unpooled;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||
import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf;
|
||||
import org.apache.flink.shaded.netty4.io.netty.buffer.Unpooled;
|
||||
|
||||
import org.apache.celeborn.common.network.protocol.Message;
|
||||
import org.apache.celeborn.common.network.util.FrameDecoder;
|
||||
import org.apache.celeborn.plugin.flink.protocol.ReadData;
|
||||
|
||||
public class TransportFrameDecoderWithBufferSupplier extends ChannelInboundHandlerAdapter
|
||||
implements FrameDecoder {
|
||||
private final Function<Integer, Supplier<ByteBuf>> bufferSuppliers;
|
||||
private final Supplier<ByteBuf> bufferSupplier;
|
||||
private int msgSize = -1;
|
||||
private int bodySize = -1;
|
||||
private Message.Type curType = Message.Type.UNKNOWN_TYPE;
|
||||
private ByteBuf headerBuf = Unpooled.buffer(HEADER_SIZE, HEADER_SIZE);
|
||||
private CompositeByteBuf bodyBuf = null;
|
||||
private io.netty.buffer.CompositeByteBuf bodyBuf = null;
|
||||
private ByteBuf externalBuf = null;
|
||||
private final ByteBuf msgBuf = Unpooled.buffer(8);
|
||||
private Message curMsg = null;
|
||||
|
||||
public TransportFrameDecoderWithBufferSupplier() {
|
||||
this.bufferSuppliers =
|
||||
new Function<Integer, Supplier<ByteBuf>>() {
|
||||
@Override
|
||||
public Supplier<ByteBuf> apply(Integer size) {
|
||||
return new Supplier<ByteBuf>() {
|
||||
@Override
|
||||
public ByteBuf get() {
|
||||
return Unpooled.buffer(size);
|
||||
}
|
||||
};
|
||||
}
|
||||
};
|
||||
public TransportFrameDecoderWithBufferSupplier(Supplier<ByteBuf> bufferSupplier) {
|
||||
this.bufferSupplier = bufferSupplier;
|
||||
}
|
||||
|
||||
public TransportFrameDecoderWithBufferSupplier(
|
||||
Function<Integer, Supplier<ByteBuf>> bufferSuppliers) {
|
||||
this.bufferSuppliers = bufferSuppliers;
|
||||
}
|
||||
|
||||
private void copyByteBuf(ByteBuf source, ByteBuf target, int targetSize) {
|
||||
private void copyByteBuf(io.netty.buffer.ByteBuf source, ByteBuf target, int targetSize) {
|
||||
int bytes = Math.min(source.readableBytes(), targetSize - target.readableBytes());
|
||||
target.writeBytes(source, bytes);
|
||||
for (int i = 0; i < bytes; i++) {
|
||||
target.writeByte(source.readByte());
|
||||
}
|
||||
}
|
||||
|
||||
private void decodeHeader(ByteBuf buf, ChannelHandlerContext ctx) {
|
||||
private void decodeHeader(io.netty.buffer.ByteBuf buf, ChannelHandlerContext ctx) {
|
||||
copyByteBuf(buf, headerBuf, HEADER_SIZE);
|
||||
if (!headerBuf.isWritable()) {
|
||||
msgSize = headerBuf.readInt();
|
||||
@ -73,18 +59,20 @@ public class TransportFrameDecoderWithBufferSupplier extends ChannelInboundHandl
|
||||
msgBuf.capacity(msgSize);
|
||||
}
|
||||
msgBuf.clear();
|
||||
curType = Message.Type.decode(headerBuf);
|
||||
curType = Message.Type.decode(headerBuf.nioBuffer());
|
||||
// type byte is read
|
||||
headerBuf.readByte();
|
||||
bodySize = headerBuf.readInt();
|
||||
decodeMsg(buf, ctx);
|
||||
}
|
||||
}
|
||||
|
||||
private void decodeMsg(ByteBuf buf, ChannelHandlerContext ctx) {
|
||||
private void decodeMsg(io.netty.buffer.ByteBuf buf, ChannelHandlerContext ctx) {
|
||||
if (msgBuf.readableBytes() < msgSize) {
|
||||
copyByteBuf(buf, msgBuf, msgSize);
|
||||
}
|
||||
if (msgBuf.readableBytes() == msgSize) {
|
||||
curMsg = Message.decode(curType, msgBuf, false);
|
||||
curMsg = MessageDecoderExt.decode(curType, msgBuf, false);
|
||||
if (bodySize <= 0) {
|
||||
ctx.fireChannelRead(curMsg);
|
||||
clear();
|
||||
@ -92,11 +80,12 @@ public class TransportFrameDecoderWithBufferSupplier extends ChannelInboundHandl
|
||||
}
|
||||
}
|
||||
|
||||
private ByteBuf decodeBody(ByteBuf buf, ChannelHandlerContext ctx) {
|
||||
private io.netty.buffer.ByteBuf decodeBody(
|
||||
io.netty.buffer.ByteBuf buf, ChannelHandlerContext ctx) {
|
||||
if (bodyBuf == null) {
|
||||
if (buf.readableBytes() >= bodySize) {
|
||||
ByteBuf body = buf.retain().readSlice(bodySize);
|
||||
curMsg.setBody(body);
|
||||
io.netty.buffer.ByteBuf body = buf.retain().readSlice(bodySize);
|
||||
curMsg.setBody(body.nioBuffer());
|
||||
ctx.fireChannelRead(curMsg);
|
||||
clear();
|
||||
return buf;
|
||||
@ -105,7 +94,7 @@ public class TransportFrameDecoderWithBufferSupplier extends ChannelInboundHandl
|
||||
}
|
||||
}
|
||||
int remaining = bodySize - bodyBuf.readableBytes();
|
||||
ByteBuf next;
|
||||
io.netty.buffer.ByteBuf next;
|
||||
if (remaining >= buf.readableBytes()) {
|
||||
next = buf;
|
||||
buf = null;
|
||||
@ -114,20 +103,25 @@ public class TransportFrameDecoderWithBufferSupplier extends ChannelInboundHandl
|
||||
}
|
||||
bodyBuf.addComponent(next).writerIndex(bodyBuf.writerIndex() + next.readableBytes());
|
||||
if (bodyBuf.readableBytes() == bodySize) {
|
||||
curMsg.setBody(bodyBuf);
|
||||
curMsg.setBody(bodyBuf.nioBuffer());
|
||||
ctx.fireChannelRead(curMsg);
|
||||
clear();
|
||||
}
|
||||
return buf;
|
||||
}
|
||||
|
||||
private ByteBuf decodeBodyCopyOut(ByteBuf buf, ChannelHandlerContext ctx) {
|
||||
private io.netty.buffer.ByteBuf decodeBodyCopyOut(
|
||||
io.netty.buffer.ByteBuf buf, ChannelHandlerContext ctx) {
|
||||
if (externalBuf == null) {
|
||||
externalBuf = bufferSuppliers.apply(bodySize).get();
|
||||
externalBuf = bufferSupplier.get();
|
||||
}
|
||||
copyByteBuf(buf, externalBuf, bodySize);
|
||||
if (externalBuf.readableBytes() == bodySize) {
|
||||
curMsg.setBody(externalBuf);
|
||||
if (curMsg instanceof ReadData) {
|
||||
((ReadData) curMsg).setFlinkBuffer(externalBuf);
|
||||
} else {
|
||||
curMsg.setBody(externalBuf.nioBuffer());
|
||||
}
|
||||
ctx.fireChannelRead(curMsg);
|
||||
clear();
|
||||
}
|
||||
@ -135,24 +129,24 @@ public class TransportFrameDecoderWithBufferSupplier extends ChannelInboundHandl
|
||||
}
|
||||
|
||||
public void channelRead(ChannelHandlerContext ctx, Object data) {
|
||||
ByteBuf buf = (ByteBuf) data;
|
||||
io.netty.buffer.ByteBuf nettyBuf = (io.netty.buffer.ByteBuf) data;
|
||||
try {
|
||||
while (buf != null && buf.isReadable()) {
|
||||
while (nettyBuf != null && nettyBuf.isReadable()) {
|
||||
if (headerBuf.isWritable()) {
|
||||
decodeHeader(buf, ctx);
|
||||
decodeHeader(nettyBuf, ctx);
|
||||
} else if (curMsg == null) {
|
||||
decodeMsg(buf, ctx);
|
||||
decodeMsg(nettyBuf, ctx);
|
||||
} else if (bodySize > 0) {
|
||||
if (curMsg.needCopyOut()) {
|
||||
buf = decodeBodyCopyOut(buf, ctx);
|
||||
nettyBuf = decodeBodyCopyOut(nettyBuf, ctx);
|
||||
} else {
|
||||
buf = decodeBody(buf, ctx);
|
||||
nettyBuf = decodeBody(nettyBuf, ctx);
|
||||
}
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
if (buf != null) {
|
||||
buf.release();
|
||||
if (nettyBuf != null) {
|
||||
nettyBuf.release();
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,102 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.celeborn.plugin.flink.protocol;
|
||||
|
||||
import java.util.Objects;
|
||||
|
||||
import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf;
|
||||
|
||||
import org.apache.celeborn.common.network.protocol.RequestMessage;
|
||||
|
||||
public final class ReadData extends RequestMessage {
|
||||
private final long streamId;
|
||||
private final int backlog;
|
||||
private final long offset;
|
||||
private ByteBuf flinkBuffer;
|
||||
|
||||
public ReadData(long streamId, int backlog, long offset) {
|
||||
this.streamId = streamId;
|
||||
this.backlog = backlog;
|
||||
this.offset = offset;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int encodedLength() {
|
||||
return 8 + 4 + 8;
|
||||
}
|
||||
|
||||
// This method will not be called because ReadData won't be created at flink client.
|
||||
@Override
|
||||
public void encode(io.netty.buffer.ByteBuf buf) {
|
||||
buf.writeLong(streamId);
|
||||
buf.writeInt(backlog);
|
||||
buf.writeLong(offset);
|
||||
}
|
||||
|
||||
public long getStreamId() {
|
||||
return streamId;
|
||||
}
|
||||
|
||||
public int getBacklog() {
|
||||
return backlog;
|
||||
}
|
||||
|
||||
public long getOffset() {
|
||||
return offset;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Type type() {
|
||||
return Type.READ_DATA;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
ReadData readData = (ReadData) o;
|
||||
return streamId == readData.streamId
|
||||
&& backlog == readData.backlog
|
||||
&& offset == readData.offset;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(streamId, backlog, offset);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "ReadData{"
|
||||
+ "streamId="
|
||||
+ streamId
|
||||
+ ", backlog="
|
||||
+ backlog
|
||||
+ ", offset="
|
||||
+ offset
|
||||
+ '}';
|
||||
}
|
||||
|
||||
public ByteBuf getFlinkBuffer() {
|
||||
return flinkBuffer;
|
||||
}
|
||||
|
||||
public void setFlinkBuffer(ByteBuf flinkBuffer) {
|
||||
this.flinkBuffer = flinkBuffer;
|
||||
}
|
||||
}
|
||||
@ -18,6 +18,7 @@
|
||||
package org.apache.celeborn.common.network;
|
||||
|
||||
import io.netty.channel.Channel;
|
||||
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||
import io.netty.channel.socket.SocketChannel;
|
||||
import io.netty.handler.timeout.IdleStateHandler;
|
||||
import org.slf4j.Logger;
|
||||
@ -29,8 +30,8 @@ import org.apache.celeborn.common.network.client.TransportResponseHandler;
|
||||
import org.apache.celeborn.common.network.protocol.MessageEncoder;
|
||||
import org.apache.celeborn.common.network.server.*;
|
||||
import org.apache.celeborn.common.network.util.FrameDecoder;
|
||||
import org.apache.celeborn.common.network.util.NettyUtils;
|
||||
import org.apache.celeborn.common.network.util.TransportConf;
|
||||
import org.apache.celeborn.common.network.util.TransportFrameDecoder;
|
||||
|
||||
/**
|
||||
* Contains the context to create a {@link TransportServer}, {@link TransportClientFactory}, and to
|
||||
@ -94,6 +95,11 @@ public class TransportContext {
|
||||
}
|
||||
|
||||
public TransportChannelHandler initializePipeline(SocketChannel channel) {
|
||||
return initializePipeline(channel, new TransportFrameDecoder());
|
||||
}
|
||||
|
||||
public TransportChannelHandler initializePipeline(
|
||||
SocketChannel channel, ChannelInboundHandlerAdapter decoder) {
|
||||
try {
|
||||
if (channelsLimiter != null) {
|
||||
channel.pipeline().addLast("limiter", channelsLimiter);
|
||||
@ -102,7 +108,7 @@ public class TransportContext {
|
||||
channel
|
||||
.pipeline()
|
||||
.addLast("encoder", ENCODER)
|
||||
.addLast(FrameDecoder.HANDLER_NAME, NettyUtils.createFrameDecoder(conf))
|
||||
.addLast(FrameDecoder.HANDLER_NAME, decoder)
|
||||
.addLast(
|
||||
"idleStateHandler", new IdleStateHandler(0, 0, conf.connectionTimeoutMs() / 1000))
|
||||
.addLast("handler", channelHandler);
|
||||
@ -126,4 +132,8 @@ public class TransportContext {
|
||||
public TransportConf getConf() {
|
||||
return conf;
|
||||
}
|
||||
|
||||
public BaseMessageHandler getMsgHandler() {
|
||||
return msgHandler;
|
||||
}
|
||||
}
|
||||
|
||||
@ -35,6 +35,10 @@ public class NettyManagedBuffer extends ManagedBuffer {
|
||||
this.buf = buf;
|
||||
}
|
||||
|
||||
public NettyManagedBuffer(ByteBuffer buffer) {
|
||||
this.buf = Unpooled.wrappedBuffer(buffer);
|
||||
}
|
||||
|
||||
public ByteBuf getBuf() {
|
||||
return buf.duplicate();
|
||||
}
|
||||
|
||||
@ -110,6 +110,12 @@ public class TransportClientFactory implements Closeable {
|
||||
*/
|
||||
public TransportClient createClient(String remoteHost, int remotePort, int partitionId)
|
||||
throws IOException, InterruptedException {
|
||||
return createClient(remoteHost, remotePort, partitionId, new TransportFrameDecoder());
|
||||
}
|
||||
|
||||
public TransportClient createClient(
|
||||
String remoteHost, int remotePort, int partitionId, ChannelInboundHandlerAdapter decoder)
|
||||
throws IOException, InterruptedException {
|
||||
// Get connection from the connection pool first.
|
||||
// If it is not found or not active, create a new one.
|
||||
// Use unresolved address here to avoid DNS resolution each time we creates a client.
|
||||
@ -166,7 +172,7 @@ public class TransportClientFactory implements Closeable {
|
||||
logger.info("Found inactive connection to {}, creating a new one.", resolvedAddress);
|
||||
}
|
||||
}
|
||||
clientPool.clients[clientIndex] = internalCreateClient(resolvedAddress);
|
||||
clientPool.clients[clientIndex] = internalCreateClient(resolvedAddress, decoder);
|
||||
return clientPool.clients[clientIndex];
|
||||
}
|
||||
}
|
||||
@ -182,7 +188,8 @@ public class TransportClientFactory implements Closeable {
|
||||
*
|
||||
* <p>As with {@link #createClient(String, int)}, this method is blocking.
|
||||
*/
|
||||
private TransportClient internalCreateClient(InetSocketAddress address)
|
||||
private TransportClient internalCreateClient(
|
||||
InetSocketAddress address, ChannelInboundHandlerAdapter decoder)
|
||||
throws IOException, InterruptedException {
|
||||
Bootstrap bootstrap = new Bootstrap();
|
||||
bootstrap
|
||||
@ -209,7 +216,7 @@ public class TransportClientFactory implements Closeable {
|
||||
new ChannelInitializer<SocketChannel>() {
|
||||
@Override
|
||||
public void initChannel(SocketChannel ch) {
|
||||
TransportChannelHandler clientHandler = context.initializePipeline(ch);
|
||||
TransportChannelHandler clientHandler = context.initializePipeline(ch, decoder);
|
||||
clientRef.set(clientHandler.getClient());
|
||||
channelRef.set(ch);
|
||||
}
|
||||
@ -252,4 +259,8 @@ public class TransportClientFactory implements Closeable {
|
||||
workerGroup.shutdownGracefully();
|
||||
}
|
||||
}
|
||||
|
||||
public TransportContext getContext() {
|
||||
return context;
|
||||
}
|
||||
}
|
||||
|
||||
@ -50,6 +50,10 @@ public abstract class Message implements Encodable {
|
||||
this.body = new NettyManagedBuffer(buf);
|
||||
}
|
||||
|
||||
public void setBody(ByteBuffer buf) {
|
||||
this.body = new NettyManagedBuffer(buf);
|
||||
}
|
||||
|
||||
/** Whether the body should be copied out in frame decoder. */
|
||||
public boolean needCopyOut() {
|
||||
return false;
|
||||
@ -89,7 +93,6 @@ public abstract class Message implements Encodable {
|
||||
READ_DATA(17),
|
||||
OPEN_STREAM_WITH_CREDIT(18),
|
||||
BACKLOG_ANNOUNCEMENT(19);
|
||||
|
||||
private final byte id;
|
||||
|
||||
Type(int id) {
|
||||
@ -111,6 +114,11 @@ public abstract class Message implements Encodable {
|
||||
buf.writeByte(id);
|
||||
}
|
||||
|
||||
public static Type decode(ByteBuffer buffer) {
|
||||
ByteBuf buf = Unpooled.wrappedBuffer(buffer);
|
||||
return decode(buf);
|
||||
}
|
||||
|
||||
public static Type decode(ByteBuf buf) {
|
||||
byte id = buf.readByte();
|
||||
switch (id) {
|
||||
|
||||
@ -21,7 +21,6 @@ import java.util.concurrent.ThreadFactory;
|
||||
|
||||
import io.netty.buffer.PooledByteBufAllocator;
|
||||
import io.netty.channel.Channel;
|
||||
import io.netty.channel.ChannelInboundHandlerAdapter;
|
||||
import io.netty.channel.EventLoopGroup;
|
||||
import io.netty.channel.ServerChannel;
|
||||
import io.netty.channel.epoll.EpollEventLoopGroup;
|
||||
@ -78,20 +77,6 @@ public class NettyUtils {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a LengthFieldBasedFrameDecoder where the first 8 bytes are the length of the frame.
|
||||
* This is used before all decoders.
|
||||
*/
|
||||
public static ChannelInboundHandlerAdapter createFrameDecoder(TransportConf conf) {
|
||||
if (conf.decoderMode().equals("default")) {
|
||||
return new TransportFrameDecoder();
|
||||
} else if (conf.decoderMode().equals("supplier")) {
|
||||
return new TransportFrameDecoderWithBufferSupplier();
|
||||
} else {
|
||||
return new TransportFrameDecoder();
|
||||
}
|
||||
}
|
||||
|
||||
/** Returns the remote address on the channel or "<unknown remote>" if none exists. */
|
||||
public static String getRemoteAddress(Channel channel) {
|
||||
if (channel != null && channel.remoteAddress() != null) {
|
||||
|
||||
@ -141,10 +141,6 @@ public class TransportConf {
|
||||
return conf.networkIoMaxChunksBeingTransferred(module);
|
||||
}
|
||||
|
||||
public String decoderMode() {
|
||||
return conf.networkIoDecoderMode(module);
|
||||
}
|
||||
|
||||
public long pushDataTimeoutCheckIntervalMs() {
|
||||
return conf.pushTimeoutCheckInterval();
|
||||
}
|
||||
|
||||
@ -461,11 +461,6 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se
|
||||
getSizeAsBytes(key, MAX_CHUNKS_BEING_TRANSFERRED.defaultValueString)
|
||||
}
|
||||
|
||||
def networkIoDecoderMode(module: String): String = {
|
||||
val key = NETWORK_IO_DECODER_MODE.key.replace("<module>", module)
|
||||
get(key, NETWORK_IO_DECODER_MODE.defaultValue.get)
|
||||
}
|
||||
|
||||
// //////////////////////////////////////////////////////
|
||||
// Master //
|
||||
// //////////////////////////////////////////////////////
|
||||
@ -1089,15 +1084,6 @@ object CelebornConf extends Logging {
|
||||
.checkValues(Set("NIO", "EPOLL"))
|
||||
.createWithDefault("NIO")
|
||||
|
||||
val NETWORK_IO_DECODER_MODE: ConfigEntry[String] =
|
||||
buildConf("celeborn.<module>.decoder.mode")
|
||||
.categories("network")
|
||||
.doc("Netty TransportFrameDecoder implementation, available options: default, supplier.")
|
||||
.stringConf
|
||||
.transform(_.toLowerCase)
|
||||
.checkValues(Set("default", "supplier"))
|
||||
.createWithDefault("default")
|
||||
|
||||
val NETWORK_IO_PREFER_DIRECT_BUFS: ConfigEntry[Boolean] =
|
||||
buildConf("celeborn.<module>.io.preferDirectBufs")
|
||||
.categories("network")
|
||||
|
||||
@ -19,7 +19,6 @@ license: |
|
||||
<!--begin-include-->
|
||||
| Key | Default | Description | Since |
|
||||
| --- | ------- | ----------- | ----- |
|
||||
| celeborn.<module>.decoder.mode | default | Netty TransportFrameDecoder implementation, available options: default, supplier. | |
|
||||
| celeborn.<module>.io.backLog | 0 | Requested maximum length of the queue of incoming connections. Default 0 for no backlog. | |
|
||||
| celeborn.<module>.io.clientThreads | 0 | Number of threads used in the client thread pool. Default to 0, which is 2x#cores. | |
|
||||
| celeborn.<module>.io.connectTimeout | <value of celeborn.network.connect.timeout> | Socket connect timeout. | |
|
||||
|
||||
Loading…
Reference in New Issue
Block a user