From ffdaef98c38b2d1494945f98dc09d7b20d3b29ac Mon Sep 17 00:00:00 2001 From: Jray <1075860716@qq.com> Date: Fri, 29 Aug 2025 18:58:22 +0800 Subject: [PATCH] [CELEBORN-2097] Support Zstd Compression in CppClient ### What changes were proposed in this pull request? This PR adds support for zstd compression in CppClient. ### Why are the changes needed? To support writing to Celeborn with CppClient. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? By compilation and UTs. Closes #3454 from Jraaay/feat/cpp_client_zstd_compression. Authored-by: Jray <1075860716@qq.com> Signed-off-by: SteNicholas --- cpp/celeborn/client/CMakeLists.txt | 3 +- cpp/celeborn/client/compress/Compressor.cpp | 5 +- .../client/compress/ZstdCompressor.cpp | 76 +++++++++++++++++++ cpp/celeborn/client/compress/ZstdCompressor.h | 50 ++++++++++++ cpp/celeborn/client/tests/CMakeLists.txt | 3 +- .../client/tests/ZstdCompressorTest.cpp | 76 +++++++++++++++++++ cpp/celeborn/conf/CelebornConf.cpp | 6 ++ cpp/celeborn/conf/CelebornConf.h | 5 ++ 8 files changed, 220 insertions(+), 4 deletions(-) create mode 100644 cpp/celeborn/client/compress/ZstdCompressor.cpp create mode 100644 cpp/celeborn/client/compress/ZstdCompressor.h create mode 100644 cpp/celeborn/client/tests/ZstdCompressorTest.cpp diff --git a/cpp/celeborn/client/CMakeLists.txt b/cpp/celeborn/client/CMakeLists.txt index c5534a3a8..2586f6855 100644 --- a/cpp/celeborn/client/CMakeLists.txt +++ b/cpp/celeborn/client/CMakeLists.txt @@ -21,7 +21,8 @@ add_library( compress/Lz4Decompressor.cpp compress/ZstdDecompressor.cpp compress/Compressor.cpp - compress/Lz4Compressor.cpp) + compress/Lz4Compressor.cpp + compress/ZstdCompressor.cpp) target_include_directories(client PUBLIC ${CMAKE_BINARY_DIR}) diff --git a/cpp/celeborn/client/compress/Compressor.cpp b/cpp/celeborn/client/compress/Compressor.cpp index 849b1ff0d..cde6cad5a 100644 --- a/cpp/celeborn/client/compress/Compressor.cpp +++ b/cpp/celeborn/client/compress/Compressor.cpp @@ -18,6 +18,7 @@ #include #include "celeborn/client/compress/Lz4Compressor.h" +#include "celeborn/client/compress/ZstdCompressor.h" #include "celeborn/utils/Exceptions.h" namespace celeborn { @@ -31,8 +32,8 @@ std::unique_ptr Compressor::createCompressor( case protocol::CompressionCodec::LZ4: return std::make_unique(); case protocol::CompressionCodec::ZSTD: - // TODO: impl zstd - CELEBORN_FAIL("Compression codec ZSTD is not supported."); + return std::make_unique( + conf.shuffleCompressionZstdCompressLevel()); default: CELEBORN_FAIL("Unknown compression codec."); } diff --git a/cpp/celeborn/client/compress/ZstdCompressor.cpp b/cpp/celeborn/client/compress/ZstdCompressor.cpp new file mode 100644 index 000000000..3ce248644 --- /dev/null +++ b/cpp/celeborn/client/compress/ZstdCompressor.cpp @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "celeborn/client/compress/ZstdCompressor.h" + +namespace celeborn { +namespace client { +namespace compress { + +ZstdCompressor::ZstdCompressor(const int compressionLevel) + : compressionLevel_(compressionLevel) {} + +size_t ZstdCompressor::compress( + const uint8_t* src, + const int srcOffset, + const int srcLength, + uint8_t* dst, + const int dstOffset) { + const auto srcPtr = src + srcOffset; + const auto dstPtr = dst + dstOffset; + const auto dstDataPtr = dstPtr + kHeaderLength; + + uLong check = crc32(0L, Z_NULL, 0); + check = crc32(check, srcPtr, srcLength); + + std::copy_n(kMagic, kMagicLength, dstPtr); + + size_t compressedLength = ZSTD_compress( + dstDataPtr, + ZSTD_compressBound(srcLength), + srcPtr, + srcLength, + compressionLevel_); + + int compressionMethod; + if (ZSTD_isError(compressedLength) || + compressedLength >= static_cast(srcLength)) { + compressionMethod = kCompressionMethodRaw; + compressedLength = srcLength; + std::copy_n(srcPtr, srcLength, dstDataPtr); + } else { + compressionMethod = kCompressionMethodZstd; + } + + dstPtr[kMagicLength] = static_cast(compressionMethod); + writeIntLE(compressedLength, dstPtr, kMagicLength + 1); + writeIntLE(srcLength, dstPtr, kMagicLength + 5); + writeIntLE(static_cast(check), dstPtr, kMagicLength + 9); + + return kHeaderLength + compressedLength; +} + +size_t ZstdCompressor::getDstCapacity(const int length) { + return ZSTD_compressBound(length) + kHeaderLength; +} + +} // namespace compress +} // namespace client +} // namespace celeborn diff --git a/cpp/celeborn/client/compress/ZstdCompressor.h b/cpp/celeborn/client/compress/ZstdCompressor.h new file mode 100644 index 000000000..2fd8aa02b --- /dev/null +++ b/cpp/celeborn/client/compress/ZstdCompressor.h @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "celeborn/client/compress/Compressor.h" +#include "celeborn/client/compress/ZstdTrait.h" + +namespace celeborn { +namespace client { +namespace compress { + +class ZstdCompressor final : public Compressor, ZstdTrait { + public: + explicit ZstdCompressor(int compressionLevel); + ~ZstdCompressor() override = default; + + size_t compress( + const uint8_t* src, + int srcOffset, + int srcLength, + uint8_t* dst, + int dstOffset) override; + + size_t getDstCapacity(int length) override; + + ZstdCompressor(const ZstdCompressor&) = delete; + ZstdCompressor& operator=(const ZstdCompressor&) = delete; + + private: + const int compressionLevel_; +}; + +} // namespace compress +} // namespace client +} // namespace celeborn diff --git a/cpp/celeborn/client/tests/CMakeLists.txt b/cpp/celeborn/client/tests/CMakeLists.txt index 9ac740474..d8a98e2b6 100644 --- a/cpp/celeborn/client/tests/CMakeLists.txt +++ b/cpp/celeborn/client/tests/CMakeLists.txt @@ -18,7 +18,8 @@ add_executable( WorkerPartitionReaderTest.cpp Lz4DecompressorTest.cpp ZstdDecompressorTest.cpp - Lz4CompressorTest.cpp) + Lz4CompressorTest.cpp + ZstdCompressorTest.cpp) add_test(NAME celeborn_client_test COMMAND celeborn_client_test) diff --git a/cpp/celeborn/client/tests/ZstdCompressorTest.cpp b/cpp/celeborn/client/tests/ZstdCompressorTest.cpp new file mode 100644 index 000000000..c4cf9ce7c --- /dev/null +++ b/cpp/celeborn/client/tests/ZstdCompressorTest.cpp @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "celeborn/client/compress/ZstdCompressor.h" +#include "client/compress/ZstdDecompressor.h" + +using namespace celeborn; +using namespace celeborn::client; +using namespace celeborn::protocol; + +TEST(ZstdCompressorTest, CompressWithZstd) { + for (int compressionLevel = -5; compressionLevel <= 22; compressionLevel++) { + compress::ZstdCompressor compressor(compressionLevel); + const std::string toCompressData = + "Helloooooooooooo Celeborn!!!!!!!!!!!!!!"; + + const auto maxLength = compressor.getDstCapacity(toCompressData.size()); + std::vector compressedData(maxLength); + compressor.compress( + reinterpret_cast(toCompressData.data()), + 0, + toCompressData.size(), + compressedData.data(), + 0); + + compress::ZstdDecompressor decompressor; + const auto oriLength = decompressor.getOriginalLen(compressedData.data()); + std::vector decompressedData(oriLength + 1); + decompressedData[oriLength] = '\0'; + const bool success = decompressor.decompress( + compressedData.data(), decompressedData.data(), 0); + EXPECT_TRUE(success); + EXPECT_EQ(reinterpret_cast(decompressedData.data()), toCompressData); + } +} + +TEST(ZstdCompressorTest, CompressWithRaw) { + for (int compressionLevel = -5; compressionLevel <= 22; compressionLevel++) { + compress::ZstdCompressor compressor(compressionLevel); + const std::string toCompressData = "Hello Celeborn!"; + + const auto maxLength = compressor.getDstCapacity(toCompressData.size()); + std::vector compressedData(maxLength); + compressor.compress( + reinterpret_cast(toCompressData.data()), + 0, + toCompressData.size(), + compressedData.data(), + 0); + + compress::ZstdDecompressor decompressor; + const auto oriLength = decompressor.getOriginalLen(compressedData.data()); + std::vector decompressedData(oriLength + 1); + decompressedData[oriLength] = '\0'; + const bool success = decompressor.decompress( + compressedData.data(), decompressedData.data(), 0); + EXPECT_TRUE(success); + EXPECT_EQ(reinterpret_cast(decompressedData.data()), toCompressData); + } +} diff --git a/cpp/celeborn/conf/CelebornConf.cpp b/cpp/celeborn/conf/CelebornConf.cpp index 4f85c19a0..e21d39032 100644 --- a/cpp/celeborn/conf/CelebornConf.cpp +++ b/cpp/celeborn/conf/CelebornConf.cpp @@ -143,6 +143,7 @@ const std::unordered_map> STR_PROP( kShuffleCompressionCodec, protocol::toString(protocol::CompressionCodec::NONE)), + NUM_PROP(kShuffleCompressionZstdCompressLevel, 1), // NUM_PROP(kNumExample, 50'000), // BOOL_PROP(kBoolExample, false), }; @@ -210,5 +211,10 @@ protocol::CompressionCodec CelebornConf::shuffleCompressionCodec() const { return protocol::toCompressionCodec( optionalProperty(kShuffleCompressionCodec).value()); } + +int CelebornConf::shuffleCompressionZstdCompressLevel() const { + return std::stoi( + optionalProperty(kShuffleCompressionZstdCompressLevel).value()); +} } // namespace conf } // namespace celeborn diff --git a/cpp/celeborn/conf/CelebornConf.h b/cpp/celeborn/conf/CelebornConf.h index 783bc96e1..5aa3c6f9e 100644 --- a/cpp/celeborn/conf/CelebornConf.h +++ b/cpp/celeborn/conf/CelebornConf.h @@ -64,6 +64,9 @@ class CelebornConf : public BaseConf { static constexpr std::string_view kShuffleCompressionCodec{ "celeborn.client.shuffle.compression.codec"}; + static constexpr std::string_view kShuffleCompressionZstdCompressLevel{ + "celeborn.client.shuffle.compression.zstd.level"}; + CelebornConf(); CelebornConf(const std::string& filename); @@ -89,6 +92,8 @@ class CelebornConf : public BaseConf { int clientFetchMaxReqsInFlight() const; protocol::CompressionCodec shuffleCompressionCodec() const; + + int shuffleCompressionZstdCompressLevel() const; }; } // namespace conf } // namespace celeborn