diff --git a/LICENSE b/LICENSE
index 76555a026..d5f68e4d7 100644
--- a/LICENSE
+++ b/LICENSE
@@ -212,11 +212,14 @@ Apache License 2.0
Apache Spark
./client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
./client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
+./common/src/main/java/org/apache/celeborn/common/network/protocol/EncryptedMessageWithHeader.java
+./common/src/main/java/org/apache/celeborn/common/network/protocol/SslMessageEncoder.java
./common/src/main/java/org/apache/celeborn/common/network/ssl/ReloadingX509TrustManager.java
./common/src/main/java/org/apache/celeborn/common/network/util/NettyLogger.java
./common/src/main/java/org/apache/celeborn/common/unsafe/Platform.java
./common/src/main/java/org/apache/celeborn/common/util/JavaUtils.java
./common/src/main/scala/org/apache/celeborn/common/util/SignalUtils.scala
+./common/src/test/java/org/apache/celeborn/common/network/protocol/EncryptedMessageWithHeaderSuiteJ.java
./common/src/test/java/org/apache/celeborn/common/network/ssl/ReloadingX509TrustManagerSuiteJ.java
./common/src/test/java/org/apache/celeborn/common/network/ssl/SslSampleConfigs.java
./worker/src/main/java/org/apache/celeborn/service/deploy/worker/shuffledb/DB.java
diff --git a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/buffer/FlinkNettyManagedBuffer.java b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/buffer/FlinkNettyManagedBuffer.java
index aeb736a39..e3add9117 100644
--- a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/buffer/FlinkNettyManagedBuffer.java
+++ b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/buffer/FlinkNettyManagedBuffer.java
@@ -17,6 +17,7 @@
package org.apache.celeborn.plugin.flink.buffer;
+import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
@@ -64,4 +65,9 @@ public class FlinkNettyManagedBuffer extends ManagedBuffer {
public Object convertToNetty() {
return buf.duplicate().retain();
}
+
+ @Override
+ public Object convertToNettyForSsl() throws IOException {
+ return buf.duplicate().retain();
+ }
}
diff --git a/common/src/main/java/org/apache/celeborn/common/network/buffer/FileSegmentManagedBuffer.java b/common/src/main/java/org/apache/celeborn/common/network/buffer/FileSegmentManagedBuffer.java
index 6af9e4305..5d11e8780 100644
--- a/common/src/main/java/org/apache/celeborn/common/network/buffer/FileSegmentManagedBuffer.java
+++ b/common/src/main/java/org/apache/celeborn/common/network/buffer/FileSegmentManagedBuffer.java
@@ -24,6 +24,7 @@ import java.nio.file.StandardOpenOption;
import com.google.common.io.ByteStreams;
import io.netty.channel.DefaultFileRegion;
+import io.netty.handler.stream.ChunkedStream;
import org.apache.commons.lang3.builder.ToStringBuilder;
import org.apache.commons.lang3.builder.ToStringStyle;
@@ -132,6 +133,12 @@ public final class FileSegmentManagedBuffer extends ManagedBuffer {
}
}
+ @Override
+ public Object convertToNettyForSsl() throws IOException {
+ // Cannot use zero-copy with SSL
+ return new ChunkedStream(createInputStream(), conf.maxSslEncryptedBlockSize());
+ }
+
public File getFile() {
return file;
}
diff --git a/common/src/main/java/org/apache/celeborn/common/network/buffer/ManagedBuffer.java b/common/src/main/java/org/apache/celeborn/common/network/buffer/ManagedBuffer.java
index ce320d9d7..9ab05781b 100644
--- a/common/src/main/java/org/apache/celeborn/common/network/buffer/ManagedBuffer.java
+++ b/common/src/main/java/org/apache/celeborn/common/network/buffer/ManagedBuffer.java
@@ -71,4 +71,15 @@ public abstract class ManagedBuffer {
* the caller will be responsible for releasing this new reference.
*/
public abstract Object convertToNetty() throws IOException;
+
+ /**
+ * Convert the buffer into a Netty object, used to write the data out with SSL encryption, which
+ * cannot use {@link io.netty.channel.FileRegion}. The return value is either a {@link
+ * io.netty.buffer.ByteBuf}, a {@link io.netty.handler.stream.ChunkedStream}, or a {@link
+ * java.io.InputStream}.
+ *
+ *
If this method returns a ByteBuf, then that buffer's reference count will be incremented and
+ * the caller will be responsible for releasing this new reference.
+ */
+ public abstract Object convertToNettyForSsl() throws IOException;
}
diff --git a/common/src/main/java/org/apache/celeborn/common/network/buffer/NettyManagedBuffer.java b/common/src/main/java/org/apache/celeborn/common/network/buffer/NettyManagedBuffer.java
index 60cf8625b..0528c8c74 100644
--- a/common/src/main/java/org/apache/celeborn/common/network/buffer/NettyManagedBuffer.java
+++ b/common/src/main/java/org/apache/celeborn/common/network/buffer/NettyManagedBuffer.java
@@ -76,6 +76,11 @@ public class NettyManagedBuffer extends ManagedBuffer {
return buf.duplicate().retain();
}
+ @Override
+ public Object convertToNettyForSsl() throws IOException {
+ return buf.duplicate().retain();
+ }
+
@Override
public String toString() {
return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE)
diff --git a/common/src/main/java/org/apache/celeborn/common/network/buffer/NioManagedBuffer.java b/common/src/main/java/org/apache/celeborn/common/network/buffer/NioManagedBuffer.java
index b14cb1f8e..97c31ef2c 100644
--- a/common/src/main/java/org/apache/celeborn/common/network/buffer/NioManagedBuffer.java
+++ b/common/src/main/java/org/apache/celeborn/common/network/buffer/NioManagedBuffer.java
@@ -64,6 +64,11 @@ public class NioManagedBuffer extends ManagedBuffer {
return Unpooled.wrappedBuffer(buf);
}
+ @Override
+ public Object convertToNettyForSsl() throws IOException {
+ return Unpooled.wrappedBuffer(buf);
+ }
+
@Override
public String toString() {
return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE)
diff --git a/common/src/main/java/org/apache/celeborn/common/network/protocol/EncryptedMessageWithHeader.java b/common/src/main/java/org/apache/celeborn/common/network/protocol/EncryptedMessageWithHeader.java
new file mode 100644
index 000000000..df2ab1a92
--- /dev/null
+++ b/common/src/main/java/org/apache/celeborn/common/network/protocol/EncryptedMessageWithHeader.java
@@ -0,0 +1,149 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.network.protocol;
+
+import java.io.EOFException;
+import java.io.InputStream;
+
+import javax.annotation.Nullable;
+
+import com.google.common.base.Preconditions;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.ByteBufAllocator;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.handler.stream.ChunkedInput;
+import io.netty.handler.stream.ChunkedStream;
+
+import org.apache.celeborn.common.network.buffer.ManagedBuffer;
+
+/**
+ * A wrapper message that holds two separate pieces (a header and a body).
+ *
+ *
The header must be a ByteBuf, while the body can be any InputStream or ChunkedStream Based on
+ * common/network-common/org.apache.spark.network.protocol.EncryptedMessageWithHeader
+ */
+public class EncryptedMessageWithHeader implements ChunkedInput {
+
+ @Nullable private final ManagedBuffer managedBuffer;
+ private final ByteBuf header;
+ private final int headerLength;
+ private final Object body;
+ private final long bodyLength;
+ private long totalBytesTransferred;
+
+ /**
+ * Construct a new EncryptedMessageWithHeader.
+ *
+ * @param managedBuffer the {@link ManagedBuffer} that the message body came from. This needs to
+ * be passed in so that the buffer can be freed when this message is deallocated. Ownership of
+ * the caller's reference to this buffer is transferred to this class, so if the caller wants
+ * to continue to use the ManagedBuffer in other messages then they will need to call retain()
+ * on it before passing it to this constructor.
+ * @param header the message header.
+ * @param body the message body.
+ * @param bodyLength the length of the message body, in bytes.
+ */
+ public EncryptedMessageWithHeader(
+ @Nullable ManagedBuffer managedBuffer, ByteBuf header, Object body, long bodyLength) {
+ Preconditions.checkArgument(
+ body instanceof InputStream || body instanceof ChunkedStream,
+ "Body must be an InputStream or a ChunkedStream.");
+ this.managedBuffer = managedBuffer;
+ this.header = header;
+ this.headerLength = header.readableBytes();
+ this.body = body;
+ this.bodyLength = bodyLength;
+ this.totalBytesTransferred = 0;
+ }
+
+ @Override
+ public ByteBuf readChunk(ChannelHandlerContext ctx) throws Exception {
+ return readChunk(ctx.alloc());
+ }
+
+ @Override
+ public ByteBuf readChunk(ByteBufAllocator allocator) throws Exception {
+ if (isEndOfInput()) {
+ return null;
+ }
+
+ if (totalBytesTransferred < headerLength) {
+ totalBytesTransferred += headerLength;
+ return header.retain();
+ } else if (body instanceof InputStream) {
+ InputStream stream = (InputStream) body;
+ int available = stream.available();
+ if (available <= 0) {
+ available = (int) (length() - totalBytesTransferred);
+ } else {
+ available = (int) Math.min(available, length() - totalBytesTransferred);
+ }
+ ByteBuf buffer = allocator.buffer(available);
+ int toRead = Math.min(available, buffer.writableBytes());
+ int read = buffer.writeBytes(stream, toRead);
+ if (read >= 0) {
+ totalBytesTransferred += read;
+ return buffer;
+ } else {
+ throw new EOFException("Unable to read bytes from InputStream");
+ }
+ } else if (body instanceof ChunkedStream) {
+ ChunkedStream stream = (ChunkedStream) body;
+ long old = stream.transferredBytes();
+ ByteBuf buffer = stream.readChunk(allocator);
+ long read = stream.transferredBytes() - old;
+ if (read >= 0) {
+ totalBytesTransferred += read;
+ assert (totalBytesTransferred <= length());
+ return buffer;
+ } else {
+ throw new EOFException("Unable to read bytes from ChunkedStream");
+ }
+ } else {
+ return null;
+ }
+ }
+
+ @Override
+ public long length() {
+ return headerLength + bodyLength;
+ }
+
+ @Override
+ public long progress() {
+ return totalBytesTransferred;
+ }
+
+ @Override
+ public boolean isEndOfInput() throws Exception {
+ return (headerLength + bodyLength) == totalBytesTransferred;
+ }
+
+ @Override
+ public void close() throws Exception {
+ header.release();
+ if (managedBuffer != null) {
+ managedBuffer.release();
+ }
+ if (body instanceof InputStream) {
+ ((InputStream) body).close();
+ } else if (body instanceof ChunkedStream) {
+ ((ChunkedStream) body).close();
+ }
+ }
+}
diff --git a/common/src/main/java/org/apache/celeborn/common/network/protocol/SslMessageEncoder.java b/common/src/main/java/org/apache/celeborn/common/network/protocol/SslMessageEncoder.java
new file mode 100644
index 000000000..508b6a13d
--- /dev/null
+++ b/common/src/main/java/org/apache/celeborn/common/network/protocol/SslMessageEncoder.java
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.common.network.protocol;
+
+import java.io.InputStream;
+import java.util.List;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.ChannelHandler;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.handler.codec.MessageToMessageEncoder;
+import io.netty.handler.stream.ChunkedStream;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Encoder used by the server side to encode secure (SSL) server-to-client responses. This encoder
+ * is stateless so it is safe to be shared by multiple threads. Based on
+ * common/network-common/org.apache.spark.network.protocol.SslMessageEncoder
+ */
+@ChannelHandler.Sharable
+public final class SslMessageEncoder extends MessageToMessageEncoder {
+
+ private static final Logger logger = LoggerFactory.getLogger(SslMessageEncoder.class);
+ public static final SslMessageEncoder INSTANCE = new SslMessageEncoder();
+
+ private SslMessageEncoder() {}
+
+ /**
+ * Encodes a Message by invoking its encode() method. For non-data messages, we will add one
+ * ByteBuf to 'out' containing the total frame length, the message type, and the message itself.
+ * In the case of a ChunkFetchSuccess, we will also add the ManagedBuffer corresponding to the
+ * data to 'out'.
+ */
+ @Override
+ public void encode(ChannelHandlerContext ctx, Message in, List