From 3ff8812cdd43d7c1295ab99de40e9cbbd2ffd8d2 Mon Sep 17 00:00:00 2001 From: Mridul Muralidharan Date: Mon, 1 Apr 2024 19:59:44 +0800 Subject: [PATCH] [CELEBORN-1348] Update infrastructure for SSL communication ### What changes were proposed in this pull request? Update infrastructure for SSL support. Please see #2416 for the consolidated PR with all the changes for reference. ### Why are the changes needed? At a high level, the changes are: * `ManagedBuffer.convertToNettyForSsl`, to support SSL encryption. * Add `EncryptedMessageWithHeader`, which is used to wrap the message and body, for use with SSL. * `SslMessageEncoder` is an encoder for SSL ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? The overall PR #2416 (and this PR as well) passes all tests, and this PR includes relevant subset of tests. Closes #2427 from mridulm/update-infra-for-ssl. Authored-by: Mridul Muralidharan Signed-off-by: SteNicholas --- LICENSE | 3 + .../flink/buffer/FlinkNettyManagedBuffer.java | 6 + .../buffer/FileSegmentManagedBuffer.java | 7 + .../common/network/buffer/ManagedBuffer.java | 11 ++ .../network/buffer/NettyManagedBuffer.java | 5 + .../network/buffer/NioManagedBuffer.java | 5 + .../protocol/EncryptedMessageWithHeader.java | 149 ++++++++++++++++ .../network/protocol/SslMessageEncoder.java | 105 ++++++++++++ .../common/network/TestManagedBuffer.java | 5 + .../EncryptedMessageWithHeaderSuiteJ.java | 160 ++++++++++++++++++ 10 files changed, 456 insertions(+) create mode 100644 common/src/main/java/org/apache/celeborn/common/network/protocol/EncryptedMessageWithHeader.java create mode 100644 common/src/main/java/org/apache/celeborn/common/network/protocol/SslMessageEncoder.java create mode 100644 common/src/test/java/org/apache/celeborn/common/network/protocol/EncryptedMessageWithHeaderSuiteJ.java diff --git a/LICENSE b/LICENSE index 76555a026..d5f68e4d7 100644 --- a/LICENSE +++ b/LICENSE @@ -212,11 +212,14 @@ Apache License 2.0 Apache Spark ./client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java ./client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java +./common/src/main/java/org/apache/celeborn/common/network/protocol/EncryptedMessageWithHeader.java +./common/src/main/java/org/apache/celeborn/common/network/protocol/SslMessageEncoder.java ./common/src/main/java/org/apache/celeborn/common/network/ssl/ReloadingX509TrustManager.java ./common/src/main/java/org/apache/celeborn/common/network/util/NettyLogger.java ./common/src/main/java/org/apache/celeborn/common/unsafe/Platform.java ./common/src/main/java/org/apache/celeborn/common/util/JavaUtils.java ./common/src/main/scala/org/apache/celeborn/common/util/SignalUtils.scala +./common/src/test/java/org/apache/celeborn/common/network/protocol/EncryptedMessageWithHeaderSuiteJ.java ./common/src/test/java/org/apache/celeborn/common/network/ssl/ReloadingX509TrustManagerSuiteJ.java ./common/src/test/java/org/apache/celeborn/common/network/ssl/SslSampleConfigs.java ./worker/src/main/java/org/apache/celeborn/service/deploy/worker/shuffledb/DB.java diff --git a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/buffer/FlinkNettyManagedBuffer.java b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/buffer/FlinkNettyManagedBuffer.java index aeb736a39..e3add9117 100644 --- a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/buffer/FlinkNettyManagedBuffer.java +++ b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/buffer/FlinkNettyManagedBuffer.java @@ -17,6 +17,7 @@ package org.apache.celeborn.plugin.flink.buffer; +import java.io.IOException; import java.io.InputStream; import java.nio.ByteBuffer; @@ -64,4 +65,9 @@ public class FlinkNettyManagedBuffer extends ManagedBuffer { public Object convertToNetty() { return buf.duplicate().retain(); } + + @Override + public Object convertToNettyForSsl() throws IOException { + return buf.duplicate().retain(); + } } diff --git a/common/src/main/java/org/apache/celeborn/common/network/buffer/FileSegmentManagedBuffer.java b/common/src/main/java/org/apache/celeborn/common/network/buffer/FileSegmentManagedBuffer.java index 6af9e4305..5d11e8780 100644 --- a/common/src/main/java/org/apache/celeborn/common/network/buffer/FileSegmentManagedBuffer.java +++ b/common/src/main/java/org/apache/celeborn/common/network/buffer/FileSegmentManagedBuffer.java @@ -24,6 +24,7 @@ import java.nio.file.StandardOpenOption; import com.google.common.io.ByteStreams; import io.netty.channel.DefaultFileRegion; +import io.netty.handler.stream.ChunkedStream; import org.apache.commons.lang3.builder.ToStringBuilder; import org.apache.commons.lang3.builder.ToStringStyle; @@ -132,6 +133,12 @@ public final class FileSegmentManagedBuffer extends ManagedBuffer { } } + @Override + public Object convertToNettyForSsl() throws IOException { + // Cannot use zero-copy with SSL + return new ChunkedStream(createInputStream(), conf.maxSslEncryptedBlockSize()); + } + public File getFile() { return file; } diff --git a/common/src/main/java/org/apache/celeborn/common/network/buffer/ManagedBuffer.java b/common/src/main/java/org/apache/celeborn/common/network/buffer/ManagedBuffer.java index ce320d9d7..9ab05781b 100644 --- a/common/src/main/java/org/apache/celeborn/common/network/buffer/ManagedBuffer.java +++ b/common/src/main/java/org/apache/celeborn/common/network/buffer/ManagedBuffer.java @@ -71,4 +71,15 @@ public abstract class ManagedBuffer { * the caller will be responsible for releasing this new reference. */ public abstract Object convertToNetty() throws IOException; + + /** + * Convert the buffer into a Netty object, used to write the data out with SSL encryption, which + * cannot use {@link io.netty.channel.FileRegion}. The return value is either a {@link + * io.netty.buffer.ByteBuf}, a {@link io.netty.handler.stream.ChunkedStream}, or a {@link + * java.io.InputStream}. + * + *

If this method returns a ByteBuf, then that buffer's reference count will be incremented and + * the caller will be responsible for releasing this new reference. + */ + public abstract Object convertToNettyForSsl() throws IOException; } diff --git a/common/src/main/java/org/apache/celeborn/common/network/buffer/NettyManagedBuffer.java b/common/src/main/java/org/apache/celeborn/common/network/buffer/NettyManagedBuffer.java index 60cf8625b..0528c8c74 100644 --- a/common/src/main/java/org/apache/celeborn/common/network/buffer/NettyManagedBuffer.java +++ b/common/src/main/java/org/apache/celeborn/common/network/buffer/NettyManagedBuffer.java @@ -76,6 +76,11 @@ public class NettyManagedBuffer extends ManagedBuffer { return buf.duplicate().retain(); } + @Override + public Object convertToNettyForSsl() throws IOException { + return buf.duplicate().retain(); + } + @Override public String toString() { return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE) diff --git a/common/src/main/java/org/apache/celeborn/common/network/buffer/NioManagedBuffer.java b/common/src/main/java/org/apache/celeborn/common/network/buffer/NioManagedBuffer.java index b14cb1f8e..97c31ef2c 100644 --- a/common/src/main/java/org/apache/celeborn/common/network/buffer/NioManagedBuffer.java +++ b/common/src/main/java/org/apache/celeborn/common/network/buffer/NioManagedBuffer.java @@ -64,6 +64,11 @@ public class NioManagedBuffer extends ManagedBuffer { return Unpooled.wrappedBuffer(buf); } + @Override + public Object convertToNettyForSsl() throws IOException { + return Unpooled.wrappedBuffer(buf); + } + @Override public String toString() { return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE) diff --git a/common/src/main/java/org/apache/celeborn/common/network/protocol/EncryptedMessageWithHeader.java b/common/src/main/java/org/apache/celeborn/common/network/protocol/EncryptedMessageWithHeader.java new file mode 100644 index 000000000..df2ab1a92 --- /dev/null +++ b/common/src/main/java/org/apache/celeborn/common/network/protocol/EncryptedMessageWithHeader.java @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.common.network.protocol; + +import java.io.EOFException; +import java.io.InputStream; + +import javax.annotation.Nullable; + +import com.google.common.base.Preconditions; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.stream.ChunkedInput; +import io.netty.handler.stream.ChunkedStream; + +import org.apache.celeborn.common.network.buffer.ManagedBuffer; + +/** + * A wrapper message that holds two separate pieces (a header and a body). + * + *

The header must be a ByteBuf, while the body can be any InputStream or ChunkedStream Based on + * common/network-common/org.apache.spark.network.protocol.EncryptedMessageWithHeader + */ +public class EncryptedMessageWithHeader implements ChunkedInput { + + @Nullable private final ManagedBuffer managedBuffer; + private final ByteBuf header; + private final int headerLength; + private final Object body; + private final long bodyLength; + private long totalBytesTransferred; + + /** + * Construct a new EncryptedMessageWithHeader. + * + * @param managedBuffer the {@link ManagedBuffer} that the message body came from. This needs to + * be passed in so that the buffer can be freed when this message is deallocated. Ownership of + * the caller's reference to this buffer is transferred to this class, so if the caller wants + * to continue to use the ManagedBuffer in other messages then they will need to call retain() + * on it before passing it to this constructor. + * @param header the message header. + * @param body the message body. + * @param bodyLength the length of the message body, in bytes. + */ + public EncryptedMessageWithHeader( + @Nullable ManagedBuffer managedBuffer, ByteBuf header, Object body, long bodyLength) { + Preconditions.checkArgument( + body instanceof InputStream || body instanceof ChunkedStream, + "Body must be an InputStream or a ChunkedStream."); + this.managedBuffer = managedBuffer; + this.header = header; + this.headerLength = header.readableBytes(); + this.body = body; + this.bodyLength = bodyLength; + this.totalBytesTransferred = 0; + } + + @Override + public ByteBuf readChunk(ChannelHandlerContext ctx) throws Exception { + return readChunk(ctx.alloc()); + } + + @Override + public ByteBuf readChunk(ByteBufAllocator allocator) throws Exception { + if (isEndOfInput()) { + return null; + } + + if (totalBytesTransferred < headerLength) { + totalBytesTransferred += headerLength; + return header.retain(); + } else if (body instanceof InputStream) { + InputStream stream = (InputStream) body; + int available = stream.available(); + if (available <= 0) { + available = (int) (length() - totalBytesTransferred); + } else { + available = (int) Math.min(available, length() - totalBytesTransferred); + } + ByteBuf buffer = allocator.buffer(available); + int toRead = Math.min(available, buffer.writableBytes()); + int read = buffer.writeBytes(stream, toRead); + if (read >= 0) { + totalBytesTransferred += read; + return buffer; + } else { + throw new EOFException("Unable to read bytes from InputStream"); + } + } else if (body instanceof ChunkedStream) { + ChunkedStream stream = (ChunkedStream) body; + long old = stream.transferredBytes(); + ByteBuf buffer = stream.readChunk(allocator); + long read = stream.transferredBytes() - old; + if (read >= 0) { + totalBytesTransferred += read; + assert (totalBytesTransferred <= length()); + return buffer; + } else { + throw new EOFException("Unable to read bytes from ChunkedStream"); + } + } else { + return null; + } + } + + @Override + public long length() { + return headerLength + bodyLength; + } + + @Override + public long progress() { + return totalBytesTransferred; + } + + @Override + public boolean isEndOfInput() throws Exception { + return (headerLength + bodyLength) == totalBytesTransferred; + } + + @Override + public void close() throws Exception { + header.release(); + if (managedBuffer != null) { + managedBuffer.release(); + } + if (body instanceof InputStream) { + ((InputStream) body).close(); + } else if (body instanceof ChunkedStream) { + ((ChunkedStream) body).close(); + } + } +} diff --git a/common/src/main/java/org/apache/celeborn/common/network/protocol/SslMessageEncoder.java b/common/src/main/java/org/apache/celeborn/common/network/protocol/SslMessageEncoder.java new file mode 100644 index 000000000..508b6a13d --- /dev/null +++ b/common/src/main/java/org/apache/celeborn/common/network/protocol/SslMessageEncoder.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.common.network.protocol; + +import java.io.InputStream; +import java.util.List; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.MessageToMessageEncoder; +import io.netty.handler.stream.ChunkedStream; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Encoder used by the server side to encode secure (SSL) server-to-client responses. This encoder + * is stateless so it is safe to be shared by multiple threads. Based on + * common/network-common/org.apache.spark.network.protocol.SslMessageEncoder + */ +@ChannelHandler.Sharable +public final class SslMessageEncoder extends MessageToMessageEncoder { + + private static final Logger logger = LoggerFactory.getLogger(SslMessageEncoder.class); + public static final SslMessageEncoder INSTANCE = new SslMessageEncoder(); + + private SslMessageEncoder() {} + + /** + * Encodes a Message by invoking its encode() method. For non-data messages, we will add one + * ByteBuf to 'out' containing the total frame length, the message type, and the message itself. + * In the case of a ChunkFetchSuccess, we will also add the ManagedBuffer corresponding to the + * data to 'out'. + */ + @Override + public void encode(ChannelHandlerContext ctx, Message in, List out) throws Exception { + Object body = null; + int bodyLength = 0; + + // If the message has a body, take it out... + // For SSL, zero-copy transfer will not work, so we will check if + // the body is an InputStream, and if so, use an EncryptedMessageWithHeader + // to wrap the header+body appropriately (for thread safety). + if (in.body() != null) { + try { + bodyLength = (int) in.body().size(); + body = in.body().convertToNettyForSsl(); + } catch (Exception e) { + in.body().release(); + if (in instanceof ResponseMessage) { + ResponseMessage resp = (ResponseMessage) in; + // Re-encode this message as a failure response. + String error = e.getMessage() != null ? e.getMessage() : "null"; + logger.error( + String.format("Error processing %s for client %s", in, ctx.channel().remoteAddress()), + e); + encode(ctx, resp.createFailureResponse(error), out); + } else { + throw e; + } + return; + } + } + + Message.Type msgType = in.type(); + // message size, message type size, body size, message encoded length + int headerLength = 4 + msgType.encodedLength() + 4 + in.encodedLength(); + ByteBuf header = ctx.alloc().heapBuffer(headerLength); + header.writeInt(in.encodedLength()); + msgType.encode(header); + header.writeInt(bodyLength); + in.encode(header); + assert header.writableBytes() == 0; + + if (body != null && bodyLength > 0) { + if (body instanceof ByteBuf) { + out.add(Unpooled.wrappedBuffer(header, (ByteBuf) body)); + } else if (body instanceof InputStream || body instanceof ChunkedStream) { + // For now, assume the InputStream is doing proper chunking. + out.add(new EncryptedMessageWithHeader(in.body(), header, body, bodyLength)); + } else { + throw new IllegalArgumentException( + "Body must be a ByteBuf, ChunkedStream or an InputStream"); + } + } else { + out.add(header); + } + } +} diff --git a/common/src/test/java/org/apache/celeborn/common/network/TestManagedBuffer.java b/common/src/test/java/org/apache/celeborn/common/network/TestManagedBuffer.java index b5f196fe2..ad3cc4521 100644 --- a/common/src/test/java/org/apache/celeborn/common/network/TestManagedBuffer.java +++ b/common/src/test/java/org/apache/celeborn/common/network/TestManagedBuffer.java @@ -79,6 +79,11 @@ public class TestManagedBuffer extends ManagedBuffer { return underlying.convertToNetty(); } + @Override + public Object convertToNettyForSsl() throws IOException { + return underlying.convertToNettyForSsl(); + } + @Override public int hashCode() { return underlying.hashCode(); diff --git a/common/src/test/java/org/apache/celeborn/common/network/protocol/EncryptedMessageWithHeaderSuiteJ.java b/common/src/test/java/org/apache/celeborn/common/network/protocol/EncryptedMessageWithHeaderSuiteJ.java new file mode 100644 index 000000000..0fbf7e9c9 --- /dev/null +++ b/common/src/test/java/org/apache/celeborn/common/network/protocol/EncryptedMessageWithHeaderSuiteJ.java @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.common.network.protocol; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +import java.io.ByteArrayInputStream; +import java.io.InputStream; +import java.util.Random; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.netty.handler.stream.ChunkedStream; +import org.junit.Test; + +import org.apache.celeborn.common.network.buffer.ManagedBuffer; +import org.apache.celeborn.common.network.buffer.NettyManagedBuffer; + +/* + * Based on common/network-common/org.apache.spark.network.protocol.EncryptedMessageWithHeaderSuite + */ +public class EncryptedMessageWithHeaderSuiteJ { + + // Tests the case where the body is an input stream and that we manage the refcounts of the + // buffer properly + @Test + public void testInputStreamBodyFromManagedBuffer() throws Exception { + byte[] randomData = new byte[128]; + new Random().nextBytes(randomData); + ByteBuf sourceBuffer = Unpooled.copiedBuffer(randomData); + InputStream body = new ByteArrayInputStream(sourceBuffer.array()); + ByteBuf header = Unpooled.copyLong(42); + + long expectedHeaderValue = header.getLong(header.readerIndex()); + assertEquals(1, header.refCnt()); + assertEquals(1, sourceBuffer.refCnt()); + ManagedBuffer managedBuf = new NettyManagedBuffer(sourceBuffer); + + EncryptedMessageWithHeader msg = + new EncryptedMessageWithHeader(managedBuf, header, body, managedBuf.size()); + ByteBufAllocator allocator = ByteBufAllocator.DEFAULT; + + // First read should just read the header + ByteBuf headerResult = msg.readChunk(allocator); + assertEquals(header.capacity(), headerResult.readableBytes()); + assertEquals(expectedHeaderValue, headerResult.readLong()); + assertEquals(header.capacity(), msg.progress()); + assertFalse(msg.isEndOfInput()); + + // Second read should read the body + ByteBuf bodyResult = msg.readChunk(allocator); + assertEquals(randomData.length + header.capacity(), msg.progress()); + assertTrue(msg.isEndOfInput()); + + // Validate we read it all + assertEquals(bodyResult.readableBytes(), randomData.length); + for (int i = 0; i < randomData.length; i++) { + assertEquals(bodyResult.readByte(), randomData[i]); + } + + // Closing the message should release the source buffer + msg.close(); + assertEquals(0, sourceBuffer.refCnt()); + + // The header still has a reference we got + assertEquals(1, header.refCnt()); + headerResult.release(); + assertEquals(0, header.refCnt()); + } + + // Tests the case where the body is a chunked stream and that we are fine when there is no + // input managed buffer + @Test + public void testChunkedStream() throws Exception { + int bodyLength = 129; + int chunkSize = 64; + byte[] randomData = new byte[bodyLength]; + new Random().nextBytes(randomData); + InputStream inputStream = new ByteArrayInputStream(randomData); + ChunkedStream body = new ChunkedStream(inputStream, chunkSize); + ByteBuf header = Unpooled.copyLong(42); + + long expectedHeaderValue = header.getLong(header.readerIndex()); + assertEquals(1, header.refCnt()); + + EncryptedMessageWithHeader msg = new EncryptedMessageWithHeader(null, header, body, bodyLength); + ByteBufAllocator allocator = ByteBufAllocator.DEFAULT; + + // First read should just read the header + ByteBuf headerResult = msg.readChunk(allocator); + assertEquals(header.capacity(), headerResult.readableBytes()); + assertEquals(expectedHeaderValue, headerResult.readLong()); + assertEquals(header.capacity(), msg.progress()); + assertFalse(msg.isEndOfInput()); + + // Next 2 reads should read full buffers + int readIndex = 0; + for (int i = 1; i <= 2; i++) { + ByteBuf bodyResult = msg.readChunk(allocator); + assertEquals(header.capacity() + (i * chunkSize), msg.progress()); + assertFalse(msg.isEndOfInput()); + + // Validate we read data correctly + assertEquals(bodyResult.readableBytes(), chunkSize); + assert (bodyResult.readableBytes() < (randomData.length - readIndex)); + while (bodyResult.readableBytes() > 0) { + assertEquals(bodyResult.readByte(), randomData[readIndex++]); + } + } + + // Last read should be partial + ByteBuf bodyResult = msg.readChunk(allocator); + assertEquals(header.capacity() + bodyLength, msg.progress()); + assertTrue(msg.isEndOfInput()); + + // Validate we read the byte properly + assertEquals(bodyResult.readableBytes(), 1); + assertEquals(bodyResult.readByte(), randomData[readIndex]); + + // Closing the message should close the input stream + msg.close(); + assertTrue(body.isEndOfInput()); + + // The header still has a reference we got + assertEquals(1, header.refCnt()); + headerResult.release(); + assertEquals(0, header.refCnt()); + } + + @Test + public void testByteBufIsNotSupported() throws Exception { + // Validate that ByteBufs are not supported. This test can be updated + // when we add support for them + ByteBuf header = Unpooled.copyLong(42); + assertThrows( + IllegalArgumentException.class, + () -> { + EncryptedMessageWithHeader msg = new EncryptedMessageWithHeader(null, header, header, 4); + }); + } +}