[CELEBORN-2097] Support Zstd Compression in CppClient

### What changes were proposed in this pull request?
This PR adds support for zstd compression in CppClient.

### Why are the changes needed?
To support writing to Celeborn with CppClient.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
By compilation and UTs.

Closes #3454 from Jraaay/feat/cpp_client_zstd_compression.

Authored-by: Jray <1075860716@qq.com>
Signed-off-by: SteNicholas <programgeek@163.com>
This commit is contained in:
Jray 2025-08-29 18:58:22 +08:00 committed by SteNicholas
parent 185890381b
commit ffdaef98c3
8 changed files with 220 additions and 4 deletions

View File

@ -21,7 +21,8 @@ add_library(
compress/Lz4Decompressor.cpp
compress/ZstdDecompressor.cpp
compress/Compressor.cpp
compress/Lz4Compressor.cpp)
compress/Lz4Compressor.cpp
compress/ZstdCompressor.cpp)
target_include_directories(client PUBLIC ${CMAKE_BINARY_DIR})

View File

@ -18,6 +18,7 @@
#include <stdexcept>
#include "celeborn/client/compress/Lz4Compressor.h"
#include "celeborn/client/compress/ZstdCompressor.h"
#include "celeborn/utils/Exceptions.h"
namespace celeborn {
@ -31,8 +32,8 @@ std::unique_ptr<Compressor> Compressor::createCompressor(
case protocol::CompressionCodec::LZ4:
return std::make_unique<Lz4Compressor>();
case protocol::CompressionCodec::ZSTD:
// TODO: impl zstd
CELEBORN_FAIL("Compression codec ZSTD is not supported.");
return std::make_unique<ZstdCompressor>(
conf.shuffleCompressionZstdCompressLevel());
default:
CELEBORN_FAIL("Unknown compression codec.");
}

View File

@ -0,0 +1,76 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <zlib.h>
#include <zstd.h>
#include "celeborn/client/compress/ZstdCompressor.h"
namespace celeborn {
namespace client {
namespace compress {
ZstdCompressor::ZstdCompressor(const int compressionLevel)
: compressionLevel_(compressionLevel) {}
size_t ZstdCompressor::compress(
const uint8_t* src,
const int srcOffset,
const int srcLength,
uint8_t* dst,
const int dstOffset) {
const auto srcPtr = src + srcOffset;
const auto dstPtr = dst + dstOffset;
const auto dstDataPtr = dstPtr + kHeaderLength;
uLong check = crc32(0L, Z_NULL, 0);
check = crc32(check, srcPtr, srcLength);
std::copy_n(kMagic, kMagicLength, dstPtr);
size_t compressedLength = ZSTD_compress(
dstDataPtr,
ZSTD_compressBound(srcLength),
srcPtr,
srcLength,
compressionLevel_);
int compressionMethod;
if (ZSTD_isError(compressedLength) ||
compressedLength >= static_cast<size_t>(srcLength)) {
compressionMethod = kCompressionMethodRaw;
compressedLength = srcLength;
std::copy_n(srcPtr, srcLength, dstDataPtr);
} else {
compressionMethod = kCompressionMethodZstd;
}
dstPtr[kMagicLength] = static_cast<uint8_t>(compressionMethod);
writeIntLE(compressedLength, dstPtr, kMagicLength + 1);
writeIntLE(srcLength, dstPtr, kMagicLength + 5);
writeIntLE(static_cast<int>(check), dstPtr, kMagicLength + 9);
return kHeaderLength + compressedLength;
}
size_t ZstdCompressor::getDstCapacity(const int length) {
return ZSTD_compressBound(length) + kHeaderLength;
}
} // namespace compress
} // namespace client
} // namespace celeborn

View File

@ -0,0 +1,50 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "celeborn/client/compress/Compressor.h"
#include "celeborn/client/compress/ZstdTrait.h"
namespace celeborn {
namespace client {
namespace compress {
class ZstdCompressor final : public Compressor, ZstdTrait {
public:
explicit ZstdCompressor(int compressionLevel);
~ZstdCompressor() override = default;
size_t compress(
const uint8_t* src,
int srcOffset,
int srcLength,
uint8_t* dst,
int dstOffset) override;
size_t getDstCapacity(int length) override;
ZstdCompressor(const ZstdCompressor&) = delete;
ZstdCompressor& operator=(const ZstdCompressor&) = delete;
private:
const int compressionLevel_;
};
} // namespace compress
} // namespace client
} // namespace celeborn

View File

@ -18,7 +18,8 @@ add_executable(
WorkerPartitionReaderTest.cpp
Lz4DecompressorTest.cpp
ZstdDecompressorTest.cpp
Lz4CompressorTest.cpp)
Lz4CompressorTest.cpp
ZstdCompressorTest.cpp)
add_test(NAME celeborn_client_test COMMAND celeborn_client_test)

View File

@ -0,0 +1,76 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <gtest/gtest.h>
#include "celeborn/client/compress/ZstdCompressor.h"
#include "client/compress/ZstdDecompressor.h"
using namespace celeborn;
using namespace celeborn::client;
using namespace celeborn::protocol;
TEST(ZstdCompressorTest, CompressWithZstd) {
for (int compressionLevel = -5; compressionLevel <= 22; compressionLevel++) {
compress::ZstdCompressor compressor(compressionLevel);
const std::string toCompressData =
"Helloooooooooooo Celeborn!!!!!!!!!!!!!!";
const auto maxLength = compressor.getDstCapacity(toCompressData.size());
std::vector<uint8_t> compressedData(maxLength);
compressor.compress(
reinterpret_cast<const uint8_t*>(toCompressData.data()),
0,
toCompressData.size(),
compressedData.data(),
0);
compress::ZstdDecompressor decompressor;
const auto oriLength = decompressor.getOriginalLen(compressedData.data());
std::vector<uint8_t> decompressedData(oriLength + 1);
decompressedData[oriLength] = '\0';
const bool success = decompressor.decompress(
compressedData.data(), decompressedData.data(), 0);
EXPECT_TRUE(success);
EXPECT_EQ(reinterpret_cast<char*>(decompressedData.data()), toCompressData);
}
}
TEST(ZstdCompressorTest, CompressWithRaw) {
for (int compressionLevel = -5; compressionLevel <= 22; compressionLevel++) {
compress::ZstdCompressor compressor(compressionLevel);
const std::string toCompressData = "Hello Celeborn!";
const auto maxLength = compressor.getDstCapacity(toCompressData.size());
std::vector<uint8_t> compressedData(maxLength);
compressor.compress(
reinterpret_cast<const uint8_t*>(toCompressData.data()),
0,
toCompressData.size(),
compressedData.data(),
0);
compress::ZstdDecompressor decompressor;
const auto oriLength = decompressor.getOriginalLen(compressedData.data());
std::vector<uint8_t> decompressedData(oriLength + 1);
decompressedData[oriLength] = '\0';
const bool success = decompressor.decompress(
compressedData.data(), decompressedData.data(), 0);
EXPECT_TRUE(success);
EXPECT_EQ(reinterpret_cast<char*>(decompressedData.data()), toCompressData);
}
}

View File

@ -143,6 +143,7 @@ const std::unordered_map<std::string, folly::Optional<std::string>>
STR_PROP(
kShuffleCompressionCodec,
protocol::toString(protocol::CompressionCodec::NONE)),
NUM_PROP(kShuffleCompressionZstdCompressLevel, 1),
// NUM_PROP(kNumExample, 50'000),
// BOOL_PROP(kBoolExample, false),
};
@ -210,5 +211,10 @@ protocol::CompressionCodec CelebornConf::shuffleCompressionCodec() const {
return protocol::toCompressionCodec(
optionalProperty(kShuffleCompressionCodec).value());
}
int CelebornConf::shuffleCompressionZstdCompressLevel() const {
return std::stoi(
optionalProperty(kShuffleCompressionZstdCompressLevel).value());
}
} // namespace conf
} // namespace celeborn

View File

@ -64,6 +64,9 @@ class CelebornConf : public BaseConf {
static constexpr std::string_view kShuffleCompressionCodec{
"celeborn.client.shuffle.compression.codec"};
static constexpr std::string_view kShuffleCompressionZstdCompressLevel{
"celeborn.client.shuffle.compression.zstd.level"};
CelebornConf();
CelebornConf(const std::string& filename);
@ -89,6 +92,8 @@ class CelebornConf : public BaseConf {
int clientFetchMaxReqsInFlight() const;
protocol::CompressionCodec shuffleCompressionCodec() const;
int shuffleCompressionZstdCompressLevel() const;
};
} // namespace conf
} // namespace celeborn