From cde33d953b36a9ab23b92bf2b379c71b6f3fe3ab Mon Sep 17 00:00:00 2001 From: Gaurav Mittal Date: Sat, 28 Jun 2025 09:19:57 +0800 Subject: [PATCH] [CELEBORN-894] End to End Integrity Checks ### What changes were proposed in this pull request? Design doc - https://docs.google.com/document/d/1YqK0kua-5rMufJw57kEIrHHGbLnAF9iXM5GdDweMzzg/edit?tab=t.0#heading=h.n5ldma432qnd - End to End integrity checks provide additional confidence that Celeborn is producing complete as well as correct data - The checks are hidden behind a client side config that is false by default. Provides users optionality to enable these if required on a per app basis - Only compatible with Spark at the moment - No support for Flink (can be considered in future) - No support for Columnar Shuffle (can be considered in future) Writer - Whenever a mapper completes, it reports crc32 and bytes written on a per partition basis to the driver Driver - Driver aggregates the mapper reports - and computes aggregated CRC32 and bytes written on per partitionID basis Reader - Each CelebornInputStream will report (int shuffleId, int partitionId, int startMapIndex, int endMapIndex, int crc32, long bytes) to driver when it finished reading all data on the stream - On every report - Driver will aggregate the CRC32 and bytesRead for the partitionID - Driver will aggregate mapRange to determine when all sub-paritions have been read for partitionID have been read - It will then compare the aggregated CRC32 and bytes read with the expected CRC32 and bytes written for the partition - There is special handling for skewhandlingwithoutMapRangeSplit scenario as well - In this case, we report the number of sub-partitions and index of the sub-partition instead of startMapIndex and endMapIndex There is separate handling for skew handling with and without map range split As a follow up, I will do another PR that will harden up the checks and perform additional checks to add book keeping that every CelebornInputStream makes the required checks ### Why are the changes needed? https://issues.apache.org/jira/browse/CELEBORN-894 Note: I am putting up this PR even though some tests are failing, since I want to get some early feedback on the code changes. ### Does this PR introduce _any_ user-facing change? Not sure how to answer this. A new client side config is available to enable the checks if required ### How was this patch tested? Unit tests + Integration tests Closes #3261 from gauravkm/gaurav/e2e_checks_v3. Lead-authored-by: Gaurav Mittal Co-authored-by: Gaurav Mittal Co-authored-by: Fei Wang Signed-off-by: Shuang --- .../plugin/flink/RemoteShuffleOutputGate.java | 2 +- .../flink/RemoteShuffleOutputGateSuiteJ.java | 2 +- .../tiered/CelebornTierProducerAgent.java | 7 +- .../tiered/CelebornTierProducerAgent.java | 7 +- .../mapred/CelebornSortBasedPusher.java | 2 +- .../celeborn/HashBasedShuffleWriter.java | 2 +- .../celeborn/SortBasedShuffleWriter.java | 2 +- .../celeborn/HashBasedShuffleWriter.java | 2 +- .../celeborn/SortBasedShuffleWriter.java | 2 +- .../celeborn/client/CelebornTezWriter.java | 2 +- .../celeborn/client/DummyShuffleClient.java | 10 +- .../apache/celeborn/client/ShuffleClient.java | 14 +- .../celeborn/client/ShuffleClientImpl.java | 70 +++- .../client/read/CelebornInputStream.java | 67 +++- .../celeborn/client/CommitManager.scala | 33 +- .../celeborn/client/LifecycleManager.scala | 81 ++++- .../client/commit/CommitHandler.scala | 20 +- ...LegacySkewHandlingPartitionValidator.scala | 189 +++++++++++ .../commit/MapPartitionCommitHandler.scala | 25 +- .../PartitionCompletenessValidator.scala | 86 +++++ .../commit/ReducePartitionCommitHandler.scala | 99 +++++- ...SkewHandlingWithoutMapRangeValidator.scala | 109 +++++++ .../celeborn/client/ShuffleClientSuiteJ.java | 92 +++++- .../client/WithShuffleClientSuite.scala | 12 +- .../commit/PartitionValidatorTest.scala | 301 ++++++++++++++++++ ...HandlingWithoutMapRangeValidatorTest.scala | 251 +++++++++++++++ .../apache/celeborn/common/CelebornCRC32.java | 91 ++++++ .../celeborn/common/CommitMetadata.java | 82 +++++ .../common/protocol/message/StatusCode.java | 3 +- .../celeborn/common/write/PushState.java | 34 ++ common/src/main/proto/TransportMessages.proto | 19 ++ .../apache/celeborn/common/CelebornConf.scala | 10 + .../protocol/message/ControlMessages.scala | 66 +++- .../celeborn/common/CelebornCRC32Test.java | 71 +++++ .../celeborn/common/CommitMetadataTest.java | 72 +++++ .../celeborn/common/util/UtilsSuite.scala | 16 +- docs/configuration/client.md | 1 + .../LifecycleManagerCommitFilesSuite.scala | 20 +- .../LifecycleManagerReserveSlotsSuite.scala | 2 +- .../spark/CelebornIntegrityCheckSuite.scala | 158 +++++++++ ...ShuffleReaderGetHookForCorruptedData.scala | 252 +++++++++++++++ .../celeborn/tests/spark/SkewJoinSuite.scala | 122 +++---- .../cluster/JavaReadCppWriteTestBase.scala | 2 +- .../cluster/LocalReadByChunkOffsetsTest.scala | 2 +- .../cluster/PushMergedDataSplitSuite.scala | 2 +- .../deploy/cluster/ReadWriteTestBase.scala | 2 +- .../cluster/ReadWriteTestWithFailures.scala | 2 +- 47 files changed, 2375 insertions(+), 143 deletions(-) create mode 100644 client/src/main/scala/org/apache/celeborn/client/commit/LegacySkewHandlingPartitionValidator.scala create mode 100644 client/src/main/scala/org/apache/celeborn/client/commit/PartitionCompletenessValidator.scala create mode 100644 client/src/main/scala/org/apache/celeborn/client/commit/SkewHandlingWithoutMapRangeValidator.scala create mode 100644 client/src/test/scala/org/apache/celeborn/client/commit/PartitionValidatorTest.scala create mode 100644 client/src/test/scala/org/apache/celeborn/client/commit/SkewHandlingWithoutMapRangeValidatorTest.scala create mode 100644 common/src/main/java/org/apache/celeborn/common/CelebornCRC32.java create mode 100644 common/src/main/java/org/apache/celeborn/common/CommitMetadata.java create mode 100644 common/src/test/java/org/apache/celeborn/common/CelebornCRC32Test.java create mode 100644 common/src/test/java/org/apache/celeborn/common/CommitMetadataTest.java create mode 100644 tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornIntegrityCheckSuite.scala create mode 100644 tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/ShuffleReaderGetHookForCorruptedData.scala diff --git a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleOutputGate.java b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleOutputGate.java index 2630617a1..43a58ebc1 100644 --- a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleOutputGate.java +++ b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleOutputGate.java @@ -197,7 +197,7 @@ public class RemoteShuffleOutputGate { /** Indicates the writing/spilling is finished. */ public void finish() throws InterruptedException, IOException { flinkShuffleClient.mapPartitionMapperEnd( - shuffleId, mapId, attemptId, numMappers, partitionLocation.getId()); + shuffleId, mapId, attemptId, numMappers, numSubs, partitionLocation.getId()); } /** Close the transportation gate. */ diff --git a/client-flink/common/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleOutputGateSuiteJ.java b/client-flink/common/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleOutputGateSuiteJ.java index 3210125e1..56a2c5457 100644 --- a/client-flink/common/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleOutputGateSuiteJ.java +++ b/client-flink/common/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleOutputGateSuiteJ.java @@ -107,7 +107,7 @@ public class RemoteShuffleOutputGateSuiteJ { doNothing() .when(remoteShuffleOutputGate.flinkShuffleClient) - .mapperEnd(anyInt(), anyInt(), anyInt(), anyInt()); + .mapperEnd(anyInt(), anyInt(), anyInt(), anyInt(), anyInt()); remoteShuffleOutputGate.finish(); doNothing().when(remoteShuffleOutputGate.flinkShuffleClient).shutdown(); diff --git a/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierProducerAgent.java b/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierProducerAgent.java index 8ddb40f18..3b10d93f5 100644 --- a/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierProducerAgent.java +++ b/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierProducerAgent.java @@ -234,7 +234,12 @@ public class CelebornTierProducerAgent implements TierProducerAgent { try { if (hasRegisteredShuffle && partitionLocation != null) { flinkShuffleClient.mapPartitionMapperEnd( - shuffleId, mapId, attemptId, numPartitions, partitionLocation.getId()); + shuffleId, + mapId, + attemptId, + numPartitions, + numSubPartitions, + partitionLocation.getId()); } } catch (Exception e) { Utils.rethrowAsRuntimeException(e); diff --git a/client-flink/flink-2.0/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierProducerAgent.java b/client-flink/flink-2.0/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierProducerAgent.java index 8ddb40f18..3b10d93f5 100644 --- a/client-flink/flink-2.0/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierProducerAgent.java +++ b/client-flink/flink-2.0/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierProducerAgent.java @@ -234,7 +234,12 @@ public class CelebornTierProducerAgent implements TierProducerAgent { try { if (hasRegisteredShuffle && partitionLocation != null) { flinkShuffleClient.mapPartitionMapperEnd( - shuffleId, mapId, attemptId, numPartitions, partitionLocation.getId()); + shuffleId, + mapId, + attemptId, + numPartitions, + numSubPartitions, + partitionLocation.getId()); } } catch (Exception e) { Utils.rethrowAsRuntimeException(e); diff --git a/client-mr/mr/src/main/java/org/apache/hadoop/mapred/CelebornSortBasedPusher.java b/client-mr/mr/src/main/java/org/apache/hadoop/mapred/CelebornSortBasedPusher.java index 69868197e..25354bf62 100644 --- a/client-mr/mr/src/main/java/org/apache/hadoop/mapred/CelebornSortBasedPusher.java +++ b/client-mr/mr/src/main/java/org/apache/hadoop/mapred/CelebornSortBasedPusher.java @@ -316,7 +316,7 @@ public class CelebornSortBasedPusher extends OutputStream { mapId, attempt, numMappers); - shuffleClient.mapperEnd(0, mapId, attempt, numMappers); + shuffleClient.mapperEnd(0, mapId, attempt, numMappers, numReducers); } catch (IOException e) { exception.compareAndSet(null, e); } diff --git a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java index 456da7e9a..7121ccdf9 100644 --- a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java +++ b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java @@ -369,7 +369,7 @@ public class HashBasedShuffleWriter extends ShuffleWriter { sendOffsets = null; long waitStartTime = System.nanoTime(); - shuffleClient.mapperEnd(shuffleId, mapId, encodedAttemptId, numMappers); + shuffleClient.mapperEnd(shuffleId, mapId, encodedAttemptId, numMappers, numPartitions); writeMetrics.incWriteTime(System.nanoTime() - waitStartTime); BlockManagerId bmId = SparkEnv.get().blockManager().shuffleServerId(); diff --git a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java index 2d65b6859..9ba908ade 100644 --- a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java +++ b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java @@ -316,7 +316,7 @@ public class SortBasedShuffleWriter extends ShuffleWriter { updateMapStatus(); long waitStartTime = System.nanoTime(); - shuffleClient.mapperEnd(shuffleId, mapId, encodedAttemptId, numMappers); + shuffleClient.mapperEnd(shuffleId, mapId, encodedAttemptId, numMappers, numPartitions); writeMetrics.incWriteTime(System.nanoTime() - waitStartTime); } diff --git a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java index c423a97ce..d5a0fdf22 100644 --- a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java +++ b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java @@ -378,7 +378,7 @@ public class HashBasedShuffleWriter extends ShuffleWriter { updateRecordsWrittenMetrics(); long waitStartTime = System.nanoTime(); - shuffleClient.mapperEnd(shuffleId, mapId, encodedAttemptId, numMappers); + shuffleClient.mapperEnd(shuffleId, mapId, encodedAttemptId, numMappers, numPartitions); writeMetrics.incWriteTime(System.nanoTime() - waitStartTime); BlockManagerId bmId = SparkEnv.get().blockManager().shuffleServerId(); diff --git a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java index 3346deb2a..b4ff4d158 100644 --- a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java +++ b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java @@ -381,7 +381,7 @@ public class SortBasedShuffleWriter extends ShuffleWriter { writeMetrics.incRecordsWritten(tmpRecordsWritten); long waitStartTime = System.nanoTime(); - shuffleClient.mapperEnd(shuffleId, mapId, encodedAttemptId, numMappers); + shuffleClient.mapperEnd(shuffleId, mapId, encodedAttemptId, numMappers, numPartitions); writeMetrics.incWriteTime(System.nanoTime() - waitStartTime); } diff --git a/client-tez/tez/src/main/java/org/apache/celeborn/client/CelebornTezWriter.java b/client-tez/tez/src/main/java/org/apache/celeborn/client/CelebornTezWriter.java index e80de0504..5bafebc90 100644 --- a/client-tez/tez/src/main/java/org/apache/celeborn/client/CelebornTezWriter.java +++ b/client-tez/tez/src/main/java/org/apache/celeborn/client/CelebornTezWriter.java @@ -120,7 +120,7 @@ public class CelebornTezWriter { try { dataPusher.waitOnTermination(); shuffleClient.pushMergedData(shuffleId, mapId, attemptNumber); - shuffleClient.mapperEnd(shuffleId, mapId, attemptNumber, numMappers); + shuffleClient.mapperEnd(shuffleId, mapId, attemptNumber, numMappers, numPartitions); } catch (InterruptedException e) { throw new IOInterruptedException(e); } diff --git a/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java b/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java index c79cb6eef..69cc3cd6f 100644 --- a/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java +++ b/client/src/main/java/org/apache/celeborn/client/DummyShuffleClient.java @@ -115,11 +115,17 @@ public class DummyShuffleClient extends ShuffleClient { public void pushMergedData(int shuffleId, int mapId, int attemptId) {} @Override - public void mapperEnd(int shuffleId, int mapId, int attemptId, int numMappers) {} + public void mapperEnd( + int shuffleId, int mapId, int attemptId, int numMappers, int numPartitions) {} + + @Override + public void readReducerPartitionEnd( + int shuffleId, int partitionId, int startMapIndex, int endMapIndex, int crc32, long bytes) + throws IOException {} @Override public void mapPartitionMapperEnd( - int shuffleId, int mapId, int attemptId, int numMappers, int partitionId) + int shuffleId, int mapId, int attemptId, int numMappers, int numPartitions, int partitionId) throws IOException {} @Override diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java index 3131820c0..d37ec6442 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java @@ -203,13 +203,19 @@ public abstract class ShuffleClient { public abstract void pushMergedData(int shuffleId, int mapId, int attemptId) throws IOException; - // Report partition locations written by the completed map task of ReducePartition Shuffle Type - public abstract void mapperEnd(int shuffleId, int mapId, int attemptId, int numMappers) + // Report partition locations written by the completed map task of ReducePartition Shuffle Type. + public abstract void mapperEnd( + int shuffleId, int mapId, int attemptId, int numMappers, int numPartitions) throws IOException; - // Report partition locations written by the completed map task of MapPartition Shuffle Type + public abstract void readReducerPartitionEnd( + int shuffleId, int partitionId, int startMapIndex, int endMapIndex, int crc32, long bytes) + throws IOException; + + // Report partition locations written by the completed map task of MapPartition Shuffle Type. public abstract void mapPartitionMapperEnd( - int shuffleId, int mapId, int attemptId, int numMappers, int partitionId) throws IOException; + int shuffleId, int mapId, int attemptId, int numMappers, int numPartitions, int partitionId) + throws IOException; // Cleanup states of the map task public abstract void cleanup(int shuffleId, int mapId, int attemptId); diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java index d9a1e67d6..a8685e9b4 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java @@ -50,10 +50,7 @@ import org.apache.celeborn.common.identity.UserIdentifier; import org.apache.celeborn.common.metrics.source.Role; import org.apache.celeborn.common.network.TransportContext; import org.apache.celeborn.common.network.buffer.NettyManagedBuffer; -import org.apache.celeborn.common.network.client.RpcResponseCallback; -import org.apache.celeborn.common.network.client.TransportClient; -import org.apache.celeborn.common.network.client.TransportClientBootstrap; -import org.apache.celeborn.common.network.client.TransportClientFactory; +import org.apache.celeborn.common.network.client.*; import org.apache.celeborn.common.network.protocol.*; import org.apache.celeborn.common.network.protocol.SerdeVersion; import org.apache.celeborn.common.network.sasl.SaslClientBootstrap; @@ -122,6 +119,8 @@ public class ShuffleClientImpl extends ShuffleClient { private final boolean pushExcludeWorkerOnFailureEnabled; private final boolean shuffleCompressionEnabled; + private final boolean shuffleIntegrityCheckEnabled; + private final Set pushExcludedWorkers = ConcurrentHashMap.newKeySet(); private final ConcurrentHashMap fetchExcludedWorkers = JavaUtils.newConcurrentHashMap(); @@ -203,6 +202,7 @@ public class ShuffleClientImpl extends ShuffleClient { shuffleCompressionEnabled = !conf.shuffleCompressionCodec().equals(CompressionCodec.NONE); pushReplicateEnabled = conf.clientPushReplicateEnabled(); fetchExcludeWorkerOnFailureEnabled = conf.clientFetchExcludeWorkerOnFailureEnabled(); + shuffleIntegrityCheckEnabled = conf.clientShuffleIntegrityCheckEnabled(); if (conf.clientPushReplicateEnabled()) { pushDataTimeout = conf.pushDataTimeoutMs() * 2; } else { @@ -648,6 +648,35 @@ public class ShuffleClientImpl extends ShuffleClient { }); } + @Override + public void readReducerPartitionEnd( + int shuffleId, + int partitionId, + int startMapIndex, + int endMapIndex, + int crc32, + long bytesWritten) + throws IOException { + PbReadReducerPartitionEnd pbReadReducerPartitionEnd = + PbReadReducerPartitionEnd.newBuilder() + .setShuffleId(shuffleId) + .setPartitionId(partitionId) + .setStartMaxIndex(startMapIndex) + .setEndMapIndex(endMapIndex) + .setCrc32(crc32) + .setBytesWritten(bytesWritten) + .build(); + + PbReadReducerPartitionEndResponse pbReducerPartitionEndResponse = + lifecycleManagerRef.askSync( + pbReadReducerPartitionEnd, + conf.clientRpcRegisterShuffleAskTimeout(), + ClassTag$.MODULE$.apply(PbReadReducerPartitionEndResponse.class)); + if (pbReducerPartitionEndResponse.getStatus() != StatusCode.SUCCESS.getValue()) { + throw new CelebornIOException(pbReducerPartitionEndResponse.getErrorMsg()); + } + } + @Override public boolean reportShuffleFetchFailure(int appShuffleId, int shuffleId, long taskId) { PbReportShuffleFetchFailure pbReportShuffleFetchFailure = @@ -1012,6 +1041,12 @@ public class ShuffleClientImpl extends ShuffleClient { // increment batchId final int nextBatchId = pushState.nextBatchId(); + // Track commit metadata if shuffle compression and integrity check are enabled and this request + // is not for pushing metadata itself. + if (shuffleIntegrityCheckEnabled) { + pushState.addDataWithOffsetAndLength(partitionId, data, offset, length); + } + if (shuffleCompressionEnabled && !skipCompress) { // compress data final Compressor compressor = compressorThreadLocal.get(); @@ -1725,19 +1760,25 @@ public class ShuffleClientImpl extends ShuffleClient { } @Override - public void mapperEnd(int shuffleId, int mapId, int attemptId, int numMappers) + public void mapperEnd(int shuffleId, int mapId, int attemptId, int numMappers, int numPartitions) throws IOException { - mapEndInternal(shuffleId, mapId, attemptId, numMappers, -1); + mapEndInternal(shuffleId, mapId, attemptId, numMappers, numPartitions, -1); } @Override public void mapPartitionMapperEnd( - int shuffleId, int mapId, int attemptId, int numMappers, int partitionId) throws IOException { - mapEndInternal(shuffleId, mapId, attemptId, numMappers, partitionId); + int shuffleId, int mapId, int attemptId, int numMappers, int numPartitions, int partitionId) + throws IOException { + mapEndInternal(shuffleId, mapId, attemptId, numMappers, numPartitions, partitionId); } private void mapEndInternal( - int shuffleId, int mapId, int attemptId, int numMappers, Integer partitionId) + int shuffleId, + int mapId, + int attemptId, + int numMappers, + int numPartitions, + Integer partitionId) throws IOException { final String mapKey = Utils.makeMapKey(shuffleId, mapId, attemptId); PushState pushState = getPushState(mapKey); @@ -1745,6 +1786,12 @@ public class ShuffleClientImpl extends ShuffleClient { try { limitZeroInFlight(mapKey, pushState); + // send CRC32 and num bytes per partition if e2e checks are enabled + int[] crc32PerPartition = + pushState.getCRC32PerPartition(shuffleIntegrityCheckEnabled, numPartitions); + long[] bytesPerPartition = + pushState.getBytesWrittenPerPartition(shuffleIntegrityCheckEnabled, numPartitions); + MapperEndResponse response = lifecycleManagerRef.askSync( new MapperEnd( @@ -1753,7 +1800,10 @@ public class ShuffleClientImpl extends ShuffleClient { attemptId, numMappers, partitionId, - pushState.getFailedBatches()), + pushState.getFailedBatches(), + numPartitions, + crc32PerPartition, + bytesPerPartition), rpcMaxRetries, rpcRetryWait, ClassTag$.MODULE$.apply(MapperEndResponse.class)); diff --git a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java index ad91fc381..fb5b626f8 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java +++ b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java @@ -40,6 +40,7 @@ import org.apache.celeborn.client.ShuffleClient; import org.apache.celeborn.client.compress.Decompressor; import org.apache.celeborn.client.read.checkpoint.PartitionReaderCheckpointMetadata; import org.apache.celeborn.common.CelebornConf; +import org.apache.celeborn.common.CommitMetadata; import org.apache.celeborn.common.exception.CelebornIOException; import org.apache.celeborn.common.network.client.TransportClient; import org.apache.celeborn.common.network.client.TransportClientFactory; @@ -103,7 +104,9 @@ public abstract class CelebornInputStream extends InputStream { exceptionMaker, true, metricsCallback, - needDecompress); + needDecompress, + startMapIndex, + endMapIndex); } else { return new CelebornInputStreamImpl( conf, @@ -126,7 +129,9 @@ public abstract class CelebornInputStream extends InputStream { exceptionMaker, false, metricsCallback, - needDecompress); + needDecompress, + -1, + -1); } } } @@ -168,6 +173,8 @@ public abstract class CelebornInputStream extends InputStream { private final CelebornConf conf; private final TransportClientFactory clientFactory; private final String shuffleKey; + private final int numberOfSubPartitions; + private final int currentIndexOfSubPartition; private ArrayList locations; private ArrayList streamHandlers; private int[] attempts; @@ -206,6 +213,7 @@ public abstract class CelebornInputStream extends InputStream { private final String localHostAddress; private boolean shouldDecompress; + private boolean shuffleIntegrityCheckEnabled; private long fetchExcludedWorkerExpireTimeout; private ConcurrentHashMap fetchExcludedWorkers; @@ -216,6 +224,8 @@ public abstract class CelebornInputStream extends InputStream { private int partitionId; private ExceptionMaker exceptionMaker; private boolean closed = false; + private boolean integrityChecked = false; + private final CommitMetadata aggregatedActualCommitMetadata = new CommitMetadata(); private final boolean readSkewPartitionWithoutMapRange; @@ -238,7 +248,9 @@ public abstract class CelebornInputStream extends InputStream { ExceptionMaker exceptionMaker, boolean splitSkewPartitionWithoutMapRange, MetricsCallback metricsCallback, - boolean needDecompress) + boolean needDecompress, + int numberOfSubPartitions, + int currentIndexOfSubPartition) throws IOException { this( conf, @@ -261,7 +273,9 @@ public abstract class CelebornInputStream extends InputStream { exceptionMaker, splitSkewPartitionWithoutMapRange, metricsCallback, - needDecompress); + needDecompress, + numberOfSubPartitions, + currentIndexOfSubPartition); } CelebornInputStreamImpl( @@ -285,7 +299,9 @@ public abstract class CelebornInputStream extends InputStream { ExceptionMaker exceptionMaker, boolean readSkewPartitionWithoutMapRange, MetricsCallback metricsCallback, - boolean needDecompress) + boolean needDecompress, + int numberOfSubPartitions, + int currentIndexOfSubPartition) throws IOException { this.conf = conf; this.clientFactory = clientFactory; @@ -305,9 +321,12 @@ public abstract class CelebornInputStream extends InputStream { this.localHostAddress = Utils.localHostName(conf); this.shouldDecompress = !conf.shuffleCompressionCodec().equals(CompressionCodec.NONE) && needDecompress; + this.shuffleIntegrityCheckEnabled = conf.clientShuffleIntegrityCheckEnabled(); this.fetchExcludedWorkerExpireTimeout = conf.clientFetchExcludedWorkerExpireTimeout(); this.failedBatches = failedBatchSet; this.readSkewPartitionWithoutMapRange = readSkewPartitionWithoutMapRange; + this.numberOfSubPartitions = numberOfSubPartitions; + this.currentIndexOfSubPartition = currentIndexOfSubPartition; this.fetchExcludedWorkers = fetchExcludedWorkers; if (conf.clientPushReplicateEnabled()) { @@ -721,6 +740,36 @@ public abstract class CelebornInputStream extends InputStream { } } + void validateIntegrity() throws IOException { + if (integrityChecked || !shuffleIntegrityCheckEnabled) { + return; + } + + if (readSkewPartitionWithoutMapRange) { + shuffleClient.readReducerPartitionEnd( + shuffleId, + partitionId, + numberOfSubPartitions, + currentIndexOfSubPartition, + aggregatedActualCommitMetadata.getChecksum(), + aggregatedActualCommitMetadata.getBytes()); + } else { + shuffleClient.readReducerPartitionEnd( + shuffleId, + partitionId, + startMapIndex, + endMapIndex, + aggregatedActualCommitMetadata.getChecksum(), + aggregatedActualCommitMetadata.getBytes()); + } + logger.info( + "reducerPartitionEnd successful for shuffleId{}, partitionId{}. actual CommitMetadata: {}", + shuffleId, + partitionId, + aggregatedActualCommitMetadata); + integrityChecked = true; + } + private boolean moveToNextChunk() throws IOException { if (currentChunk != null) { currentChunk.release(); @@ -760,6 +809,7 @@ public abstract class CelebornInputStream extends InputStream { firstChunk = false; } if (currentChunk == null) { + validateIntegrity(); close(); return false; } @@ -821,6 +871,9 @@ public abstract class CelebornInputStream extends InputStream { } else { limit = size; } + if (shuffleIntegrityCheckEnabled) { + aggregatedActualCommitMetadata.addDataWithOffsetAndLength(rawDataBuf, 0, limit); + } position = 0; hasData = true; break; @@ -835,6 +888,10 @@ public abstract class CelebornInputStream extends InputStream { } } + if (!hasData) { + validateIntegrity(); + // TODO(gaurav): consider closing the stream + } return hasData; } catch (LZ4Exception | ZstdException | IOException e) { logger.error( diff --git a/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala b/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala index 8e2e43c89..1a71d3ccd 100644 --- a/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala +++ b/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala @@ -31,7 +31,7 @@ import org.apache.celeborn.client.CommitManager.CommittedPartitionInfo import org.apache.celeborn.client.LifecycleManager.ShuffleFailedWorkers import org.apache.celeborn.client.commit.{CommitFilesParam, CommitHandler, MapPartitionCommitHandler, ReducePartitionCommitHandler} import org.apache.celeborn.client.listener.{WorkersStatus, WorkerStatusListener} -import org.apache.celeborn.common.CelebornConf +import org.apache.celeborn.common.{CelebornConf, CommitMetadata} import org.apache.celeborn.common.internal.Logging import org.apache.celeborn.common.meta.WorkerInfo import org.apache.celeborn.common.network.protocol.SerdeVersion @@ -178,7 +178,8 @@ class CommitManager(appUniqueId: String, val conf: CelebornConf, lifecycleManage def registerShuffle( shuffleId: Int, numMappers: Int, - isSegmentGranularityVisible: Boolean): Unit = { + isSegmentGranularityVisible: Boolean, + numPartitions: Int): Unit = { committedPartitionInfo.put( shuffleId, ShuffleCommittedInfo( @@ -198,7 +199,8 @@ class CommitManager(appUniqueId: String, val conf: CelebornConf, lifecycleManage getCommitHandler(shuffleId).registerShuffle( shuffleId, numMappers, - isSegmentGranularityVisible); + isSegmentGranularityVisible, + numPartitions) } def isSegmentGranularityVisible(shuffleId: Int): Boolean = { @@ -219,8 +221,10 @@ class CommitManager(appUniqueId: String, val conf: CelebornConf, lifecycleManage attemptId: Int, numMappers: Int, partitionId: Int = -1, - pushFailedBatches: util.Map[String, LocationPushFailedBatches] = Collections.emptyMap()) - : (Boolean, Boolean) = { + pushFailedBatches: util.Map[String, LocationPushFailedBatches] = Collections.emptyMap(), + numPartitions: Int = -1, + crc32PerPartition: Array[Int] = new Array[Int](0), + bytesWrittenPerPartition: Array[Long] = new Array[Long](0)): (Boolean, Boolean) = { getCommitHandler(shuffleId).finishMapperAttempt( shuffleId, mapId, @@ -228,7 +232,10 @@ class CommitManager(appUniqueId: String, val conf: CelebornConf, lifecycleManage numMappers, partitionId, pushFailedBatches, - r => lifecycleManager.workerStatusTracker.recordWorkerFailure(r)) + r => lifecycleManager.workerStatusTracker.recordWorkerFailure(r), + numPartitions, + crc32PerPartition, + bytesWrittenPerPartition) } def releasePartitionResource(shuffleId: Int, partitionId: Int): Unit = { @@ -350,4 +357,18 @@ class CommitManager(appUniqueId: String, val conf: CelebornConf, lifecycleManage } } } + + def finishPartition( + shuffleId: Int, + partitionId: Int, + startMapIndex: Int, + endMapIndex: Int, + actualCommitMetadata: CommitMetadata): (Boolean, String) = { + getCommitHandler(shuffleId).finishPartition( + shuffleId, + partitionId, + startMapIndex, + endMapIndex, + actualCommitMetadata) + } } diff --git a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala index 9063ce7b8..f51831661 100644 --- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala +++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala @@ -21,7 +21,7 @@ import java.lang.{Byte => JByte} import java.nio.ByteBuffer import java.security.SecureRandom import java.util -import java.util.{function, List => JList} +import java.util.{function, Collections, List => JList} import java.util.concurrent._ import java.util.concurrent.atomic.{AtomicInteger, LongAdder} import java.util.function.{BiConsumer, BiFunction, Consumer} @@ -40,7 +40,7 @@ import org.roaringbitmap.RoaringBitmap import org.apache.celeborn.client.LifecycleManager.{ShuffleAllocatedWorkers, ShuffleFailedWorkers} import org.apache.celeborn.client.listener.WorkerStatusListener -import org.apache.celeborn.common.CelebornConf +import org.apache.celeborn.common.{CelebornConf, CommitMetadata} import org.apache.celeborn.common.CelebornConf.ACTIVE_STORAGE_TYPES import org.apache.celeborn.common.client.MasterClient import org.apache.celeborn.common.identity.{IdentityProvider, UserIdentifier} @@ -423,13 +423,31 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends oldPartition, isSegmentGranularityVisible = commitManager.isSegmentGranularityVisible(shuffleId)) - case MapperEnd(shuffleId, mapId, attemptId, numMappers, partitionId, pushFailedBatch) => + case MapperEnd( + shuffleId, + mapId, + attemptId, + numMappers, + partitionId, + pushFailedBatch, + numPartitions, + crc32PerPartition, + bytesWrittenPerPartition) => logTrace(s"Received MapperEnd TaskEnd request, " + s"${Utils.makeMapKey(shuffleId, mapId, attemptId)}") val partitionType = getPartitionType(shuffleId) partitionType match { case PartitionType.REDUCE => - handleMapperEnd(context, shuffleId, mapId, attemptId, numMappers, pushFailedBatch) + handleMapperEnd( + context, + shuffleId, + mapId, + attemptId, + numMappers, + pushFailedBatch, + numPartitions, + crc32PerPartition, + bytesWrittenPerPartition) case PartitionType.MAP => handleMapPartitionEnd( context, @@ -442,6 +460,22 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends throw new UnsupportedOperationException(s"Not support $partitionType yet") } + case pb: ReadReducerPartitionEnd => + val partitionType = getPartitionType(pb.shuffleId) + partitionType match { + case PartitionType.REDUCE => + handleReducerPartitionEnd( + context, + pb.shuffleId, + pb.partitionId, + pb.startMapIndex, + pb.endMapIndex, + pb.crc32, + pb.bytesWritten) + case _ => + throw new UnsupportedOperationException(s"Not support $partitionType yet") + } + case GetReducerFileGroup( shuffleId: Int, isSegmentGranularityVisible: Boolean, @@ -493,6 +527,32 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends } } + private def handleReducerPartitionEnd( + context: RpcCallContext, + shuffleId: Int, + partitionId: Int, + startMapIndex: Int, + endMapIndex: Int, + crc32: Int, + bytesWritten: Long): Unit = { + val (isValid, errorMessage) = commitManager.finishPartition( + shuffleId, + partitionId, + startMapIndex, + endMapIndex, + new CommitMetadata(crc32, bytesWritten)) + var response: PbReadReducerPartitionEndResponse = null + if (isValid) { + response = + PbReadReducerPartitionEndResponse.newBuilder().setStatus( + StatusCode.SUCCESS.getValue).build() + } else { + response = PbReadReducerPartitionEndResponse.newBuilder().setStatus( + +StatusCode.READ_REDUCER_PARTITION_END_FAILED.getValue).setErrorMsg(errorMessage).build() + } + context.reply(response) + } + def setupEndpoints( workers: util.Set[WorkerInfo], shuffleId: Int, @@ -771,7 +831,8 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends commitManager.registerShuffle( shuffleId, numMappers, - isSegmentGranularityVisible) + isSegmentGranularityVisible, + numPartitions) // Fifth, reply the allocated partition location to ShuffleClient. logInfo(s"Handle RegisterShuffle Success for $shuffleId.") @@ -840,7 +901,10 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends mapId: Int, attemptId: Int, numMappers: Int, - pushFailedBatches: util.Map[String, LocationPushFailedBatches]): Unit = { + pushFailedBatches: util.Map[String, LocationPushFailedBatches], + numPartitions: Int, + crc32PerPartition: Array[Int], + bytesWrittenPerPartition: Array[Long]): Unit = { val (mapperAttemptFinishedSuccess, allMapperFinished) = commitManager.finishMapperAttempt( @@ -848,7 +912,10 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends mapId, attemptId, numMappers, - pushFailedBatches = pushFailedBatches) + pushFailedBatches = pushFailedBatches, + numPartitions = numPartitions, + crc32PerPartition = crc32PerPartition, + bytesWrittenPerPartition = bytesWrittenPerPartition) if (mapperAttemptFinishedSuccess && allMapperFinished) { // last mapper finished. call mapper end logInfo(s"Last MapperEnd, call StageEnd with shuffleKey:" + diff --git a/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala b/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala index 4a8b542f8..794a4b2dd 100644 --- a/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala +++ b/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala @@ -31,7 +31,7 @@ import scala.concurrent.duration.Duration import org.apache.celeborn.client.{ShuffleCommittedInfo, WorkerStatusTracker} import org.apache.celeborn.client.CommitManager.CommittedPartitionInfo import org.apache.celeborn.client.LifecycleManager.{ShuffleFailedWorkers, ShuffleFileGroups, ShufflePushFailedBatches} -import org.apache.celeborn.common.CelebornConf +import org.apache.celeborn.common.{CelebornConf, CommitMetadata} import org.apache.celeborn.common.internal.Logging import org.apache.celeborn.common.meta.{ShufflePartitionLocationInfo, WorkerInfo} import org.apache.celeborn.common.network.protocol.SerdeVersion @@ -216,12 +216,16 @@ abstract class CommitHandler( numMappers: Int, partitionId: Int, pushFailedBatches: util.Map[String, LocationPushFailedBatches], - recordWorkerFailure: ShuffleFailedWorkers => Unit): (Boolean, Boolean) + recordWorkerFailure: ShuffleFailedWorkers => Unit, + numPartitions: Int, + crc32PerPartition: Array[Int], + bytesWrittenPerPartition: Array[Long]): (Boolean, Boolean) def registerShuffle( shuffleId: Int, numMappers: Int, - isSegmentGranularityVisible: Boolean): Unit = { + isSegmentGranularityVisible: Boolean, + numPartitions: Int): Unit = { // TODO: if isSegmentGranularityVisible is set to true, it is necessary to handle the pending // get partition request of downstream reduce task here, in scenarios which support // downstream task start early before the upstream task, e.g. flink hybrid shuffle. @@ -422,6 +426,16 @@ abstract class CommitHandler( } } + /** + * Invoked when a reduce partition finishes reading data to perform end to end integrity check validation + */ + def finishPartition( + shuffleId: Int, + partitionId: Int, + startMapIndex: Int, + endMapIndex: Int, + actualCommitMetadata: CommitMetadata): (Boolean, String) + def parallelCommitFiles( shuffleId: Int, allocatedWorkers: util.Map[String, ShufflePartitionLocationInfo], diff --git a/client/src/main/scala/org/apache/celeborn/client/commit/LegacySkewHandlingPartitionValidator.scala b/client/src/main/scala/org/apache/celeborn/client/commit/LegacySkewHandlingPartitionValidator.scala new file mode 100644 index 000000000..12028a477 --- /dev/null +++ b/client/src/main/scala/org/apache/celeborn/client/commit/LegacySkewHandlingPartitionValidator.scala @@ -0,0 +1,189 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.client.commit + +import java.util +import java.util.Comparator +import java.util.function.BiFunction + +import com.google.common.base.Preconditions.{checkArgument, checkState} + +import org.apache.celeborn.common.CommitMetadata +import org.apache.celeborn.common.internal.Logging +import org.apache.celeborn.common.util.JavaUtils + +class LegacySkewHandlingPartitionValidator extends AbstractPartitionCompletenessValidator + with Logging { + private val subRangeToCommitMetadataPerReducer = { + JavaUtils.newConcurrentHashMap[Int, java.util.TreeMap[(Int, Int), CommitMetadata]]() + } + private val partitionToSubPartitionCount = JavaUtils.newConcurrentHashMap[Int, Int]() + private val currentCommitMetadataForReducer = + JavaUtils.newConcurrentHashMap[Int, CommitMetadata]() + private val currentTotalMapIdCountForReducer = JavaUtils.newConcurrentHashMap[Int, Int]() + private val comparator: java.util.Comparator[(Int, Int)] = new Comparator[(Int, Int)] { + override def compare(o1: (Int, Int), o2: (Int, Int)): Int = { + val comparator = Integer.compare(o1._1, o2._1) + if (comparator != 0) + comparator + else + Integer.compare(o1._2, o2._2) + } + } + private val mapCountMergeBiFunction = new BiFunction[Int, Int, Int] { + override def apply(t: Int, u: Int): Int = + Integer.sum(t, u) + } + private val metadataMergeBiFunction: BiFunction[CommitMetadata, CommitMetadata, CommitMetadata] = + new BiFunction[CommitMetadata, CommitMetadata, CommitMetadata] { + override def apply( + existing: CommitMetadata, + incoming: CommitMetadata): CommitMetadata = { + if (existing == null) { + if (incoming != null) { + return new CommitMetadata(incoming.getChecksum, incoming.getBytes) + } else { + return incoming + } + } + if (incoming == null) { + return existing + } + existing.addCommitData(incoming) + existing + } + } + + private def checkOverlappingRange( + treeMap: java.util.TreeMap[(Int, Int), CommitMetadata], + rangeKey: (Int, Int)): (Boolean, ((Int, Int), CommitMetadata)) = { + val floorEntry: util.Map.Entry[(Int, Int), CommitMetadata] = + treeMap.floorEntry(rangeKey) + val ceilingEntry: util.Map.Entry[(Int, Int), CommitMetadata] = + treeMap.ceilingEntry(rangeKey) + + if (floorEntry != null) { + if (rangeKey._1 < floorEntry.getKey._2) { + return (true, ((floorEntry.getKey._1, floorEntry.getKey._2), floorEntry.getValue)) + } + } + if (ceilingEntry != null) { + if (rangeKey._2 > ceilingEntry.getKey._1) { + return (true, ((ceilingEntry.getKey._1, ceilingEntry.getKey._2), ceilingEntry.getValue)) + } + } + (false, ((0, 0), new CommitMetadata())) + } + + override def processSubPartition( + partitionId: Int, + startMapIndex: Int, + endMapIndex: Int, + actualCommitMetadata: CommitMetadata, + expectedTotalMapperCount: Int): (Boolean, String) = { + checkArgument( + startMapIndex < endMapIndex, + "startMapIndex %s must be less than endMapIndex %s", + startMapIndex, + endMapIndex) + logDebug( + s"Validate partition invoked for partitionId $partitionId startMapIndex $startMapIndex endMapIndex $endMapIndex") + partitionToSubPartitionCount.put(partitionId, expectedTotalMapperCount) + val subRangeToCommitMetadataMap = subRangeToCommitMetadataPerReducer.computeIfAbsent( + partitionId, + new java.util.function.Function[Int, util.TreeMap[(Int, Int), CommitMetadata]] { + override def apply(key: Int): util.TreeMap[(Int, Int), CommitMetadata] = + new util.TreeMap[(Int, Int), CommitMetadata](comparator) + }) + val rangeKey = (startMapIndex, endMapIndex) + subRangeToCommitMetadataMap.synchronized { + val existingMetadata = subRangeToCommitMetadataMap.get(rangeKey) + if (existingMetadata == null) { + val (isOverlapping, overlappingEntry) = + checkOverlappingRange(subRangeToCommitMetadataMap, rangeKey) + if (isOverlapping) { + val errorMessage = s"Encountered overlapping map range for partitionId: $partitionId " + + s" while processing range with startMapIndex: $startMapIndex and endMapIndex: $endMapIndex " + + s"existing range map: $subRangeToCommitMetadataMap " + + s"overlapped with Entry((startMapIndex, endMapIndex), count): $overlappingEntry" + logError(errorMessage) + return (false, errorMessage) + } + + // Process new range + subRangeToCommitMetadataMap.put(rangeKey, actualCommitMetadata) + currentCommitMetadataForReducer.merge( + partitionId, + actualCommitMetadata, + metadataMergeBiFunction) + currentTotalMapIdCountForReducer.merge( + partitionId, + endMapIndex - startMapIndex, + mapCountMergeBiFunction) + } else if (existingMetadata != actualCommitMetadata) { + val errorMessage = s"Commit Metadata for partition: $partitionId " + + s"not matching for sub-partition with startMapIndex: $startMapIndex endMapIndex: $endMapIndex " + + s"previous count: $existingMetadata new count: $actualCommitMetadata" + logError(errorMessage) + return (false, errorMessage) + } + + validateState(partitionId, startMapIndex, endMapIndex) + val sumOfMapRanges: Int = currentTotalMapIdCountForReducer.get(partitionId) + val currentCommitMetadata: CommitMetadata = + currentCommitMetadataForReducer.get(partitionId) + + if (sumOfMapRanges > expectedTotalMapperCount) { + val errorMsg = s"AQE Partition $partitionId failed validation check " + + s"while processing startMapIndex: $startMapIndex endMapIndex: $endMapIndex " + + s"ActualCommitMetadata $currentCommitMetadata > ExpectedTotalMapperCount $expectedTotalMapperCount" + logError(errorMsg) + return (false, errorMsg) + } + } + (true, "") + } + + private def validateState(partitionId: Int, startMapIndex: Int, endMapIndex: Int): Unit = { + if (!currentTotalMapIdCountForReducer.containsKey(partitionId)) { + checkState(!currentCommitMetadataForReducer.containsKey( + partitionId, + "mapper total count missing while processing partitionId %s startMapIndex %s endMapIndex %s", + partitionId, + startMapIndex, + endMapIndex)) + currentTotalMapIdCountForReducer.put(partitionId, 0) + currentCommitMetadataForReducer.put(partitionId, new CommitMetadata()) + } + checkState( + currentCommitMetadataForReducer.containsKey(partitionId), + "mapper written count missing while processing partitionId %s startMapIndex %s endMapIndex %s", + partitionId, + startMapIndex, + endMapIndex) + } + + override def currentCommitMetadata(partitionId: Int): CommitMetadata = { + currentCommitMetadataForReducer.get(partitionId) + } + + override def isPartitionComplete(partitionId: Int): Boolean = { + val sumOfMapRanges: Int = currentTotalMapIdCountForReducer.get(partitionId) + sumOfMapRanges == partitionToSubPartitionCount.get(partitionId) + } +} diff --git a/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala b/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala index 715950531..490b9450a 100644 --- a/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala +++ b/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala @@ -20,15 +20,17 @@ package org.apache.celeborn.client.commit import java.util import java.util.Collections import java.util.concurrent.{ConcurrentHashMap, ThreadPoolExecutor} -import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.atomic.{AtomicInteger, AtomicIntegerArray} import scala.collection.JavaConverters._ import scala.collection.mutable +import org.roaringbitmap.RoaringBitmap + import org.apache.celeborn.client.{ShuffleCommittedInfo, WorkerStatusTracker} import org.apache.celeborn.client.CommitManager.CommittedPartitionInfo import org.apache.celeborn.client.LifecycleManager.{ShuffleAllocatedWorkers, ShuffleFailedWorkers} -import org.apache.celeborn.common.CelebornConf +import org.apache.celeborn.common.{CelebornConf, CommitMetadata} import org.apache.celeborn.common.internal.Logging import org.apache.celeborn.common.meta.{ShufflePartitionLocationInfo, WorkerInfo} import org.apache.celeborn.common.network.protocol.SerdeVersion @@ -187,7 +189,10 @@ class MapPartitionCommitHandler( numMappers: Int, partitionId: Int, pushFailedBatches: util.Map[String, LocationPushFailedBatches], - recordWorkerFailure: ShuffleFailedWorkers => Unit): (Boolean, Boolean) = { + recordWorkerFailure: ShuffleFailedWorkers => Unit, + numPartitions: Int, + crc32PerPartition: Array[Int], + bytesWrittenPerPartition: Array[Long]): (Boolean, Boolean) = { val inProcessingPartitionIds = inProcessMapPartitionEndIds.computeIfAbsent( shuffleId, @@ -222,8 +227,9 @@ class MapPartitionCommitHandler( override def registerShuffle( shuffleId: Int, numMappers: Int, - isSegmentGranularityVisible: Boolean): Unit = { - super.registerShuffle(shuffleId, numMappers, isSegmentGranularityVisible) + isSegmentGranularityVisible: Boolean, + numPartitions: Int): Unit = { + super.registerShuffle(shuffleId, numMappers, isSegmentGranularityVisible, numPartitions) shuffleIsSegmentGranularityVisible.put(shuffleId, isSegmentGranularityVisible) } @@ -231,6 +237,15 @@ class MapPartitionCommitHandler( shuffleIsSegmentGranularityVisible.get(shuffleId) } + override def finishPartition( + shuffleId: Int, + partitionId: Int, + startMapIndex: Int, + endMapIndex: Int, + actualCommitMetadata: CommitMetadata): (Boolean, String) = { + throw new UnsupportedOperationException() + } + override def handleGetReducerFileGroup( context: RpcCallContext, shuffleId: Int, diff --git a/client/src/main/scala/org/apache/celeborn/client/commit/PartitionCompletenessValidator.scala b/client/src/main/scala/org/apache/celeborn/client/commit/PartitionCompletenessValidator.scala new file mode 100644 index 000000000..93c3fd471 --- /dev/null +++ b/client/src/main/scala/org/apache/celeborn/client/commit/PartitionCompletenessValidator.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.client.commit + +import org.apache.celeborn.common.CommitMetadata +import org.apache.celeborn.common.internal.Logging + +abstract class AbstractPartitionCompletenessValidator { + def processSubPartition( + partitionId: Int, + startMapIndex: Int, + endMapIndex: Int, + actualCommitMetadata: CommitMetadata, + expectedTotalMapperCount: Int): (Boolean, String) + + def currentCommitMetadata(partitionId: Int): CommitMetadata + + def isPartitionComplete(partitionId: Int): Boolean +} + +class PartitionCompletenessValidator extends Logging { + + private val skewHandlingValidator: AbstractPartitionCompletenessValidator = + new SkewHandlingWithoutMapRangeValidator + private val legacySkewHandlingValidator: AbstractPartitionCompletenessValidator = + new LegacySkewHandlingPartitionValidator + + def validateSubPartition( + partitionId: Int, + startMapIndex: Int, + endMapIndex: Int, + actualCommitMetadata: CommitMetadata, + expectedCommitMetadata: CommitMetadata, + expectedTotalMapperCountForParent: Int, + skewPartitionHandlingWithoutMapRange: Boolean): (Boolean, String) = { + val validator = + if (skewPartitionHandlingWithoutMapRange) { + skewHandlingValidator + } else { + legacySkewHandlingValidator + } + + val (successfullyProcessed, error) = validator.processSubPartition( + partitionId, + startMapIndex, + endMapIndex, + actualCommitMetadata, + expectedTotalMapperCountForParent) + if (!successfullyProcessed) return (false, error) + + if (!validator.isPartitionComplete(partitionId)) { + return (true, "Partition is valid but still waiting for more data") + } + + val currentCommitMetadata = validator.currentCommitMetadata(partitionId) + if (!CommitMetadata.checkCommitMetadata(expectedCommitMetadata, currentCommitMetadata)) { + val errorMsg = + s"AQE Partition $partitionId failed validation check" + + s"while processing range startMapIndex: $startMapIndex endMapIndex: $endMapIndex" + + s"ExpectedCommitMetadata $expectedCommitMetadata, " + + s"ActualCommitMetadata $currentCommitMetadata, " + logError(errorMsg) + return (false, errorMsg) + } + logInfo( + s"AQE Partition $partitionId completed validation check , " + + s"expectedCommitMetadata $expectedCommitMetadata, " + + s"actualCommitMetadata $currentCommitMetadata") + (true, "Partition is complete") + } +} diff --git a/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala index 5fa4394ad..6d8ee3c2a 100644 --- a/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala +++ b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala @@ -25,13 +25,14 @@ import java.util.function import scala.collection.JavaConverters._ import scala.collection.mutable +import com.google.common.base.Preconditions.checkState import com.google.common.cache.{Cache, CacheBuilder} import com.google.common.collect.Sets import org.apache.celeborn.client.{ClientUtils, LifecycleManager, ShuffleCommittedInfo, WorkerStatusTracker} import org.apache.celeborn.client.CommitManager.CommittedPartitionInfo import org.apache.celeborn.client.LifecycleManager.{ShuffleAllocatedWorkers, ShuffleFailedWorkers} -import org.apache.celeborn.common.CelebornConf +import org.apache.celeborn.common.{CelebornConf, CommitMetadata} import org.apache.celeborn.common.internal.Logging import org.apache.celeborn.common.meta.ShufflePartitionLocationInfo import org.apache.celeborn.common.network.protocol.SerdeVersion @@ -82,6 +83,13 @@ class ReducePartitionCommitHandler( private val rpcCacheConcurrencyLevel = conf.clientRpcCacheConcurrencyLevel private val rpcCacheExpireTime = conf.clientRpcCacheExpireTime + private val shuffleIntegrityCheckEnabled = conf.clientShuffleIntegrityCheckEnabled + // partitionId-shuffleId -> number of mappers that have written to this reducer (partition + shuffle) + private val commitMetadataForReducer = + JavaUtils.newConcurrentHashMap[Integer, Array[CommitMetadata]] + private val skewPartitionCompletenessValidator = + JavaUtils.newConcurrentHashMap[Int, PartitionCompletenessValidator]() + private val getReducerFileGroupResponseBroadcastEnabled = conf.getReducerFileGroupBroadcastEnabled private val getReducerFileGroupResponseBroadcastMiniSize = conf.getReducerFileGroupBroadcastMiniSize @@ -153,6 +161,8 @@ class ReducePartitionCommitHandler( stageEndShuffleSet.remove(shuffleId) inProcessStageEndShuffleSet.remove(shuffleId) shuffleMapperAttempts.remove(shuffleId) + commitMetadataForReducer.remove(shuffleId) + skewPartitionCompletenessValidator.remove(shuffleId) super.removeExpiredShuffle(shuffleId) } @@ -268,16 +278,20 @@ class ReducePartitionCommitHandler( numMappers: Int, partitionId: Int, pushFailedBatches: util.Map[String, LocationPushFailedBatches], - recordWorkerFailure: ShuffleFailedWorkers => Unit): (Boolean, Boolean) = { - shuffleMapperAttempts.synchronized { + recordWorkerFailure: ShuffleFailedWorkers => Unit, + numPartitions: Int, + crc32PerPartition: Array[Int], + bytesWrittenPerPartition: Array[Long]): (Boolean, Boolean) = { + val (mapperAttemptFinishedSuccess, allMapperFinished) = shuffleMapperAttempts.synchronized { if (getMapperAttempts(shuffleId) == null) { logDebug(s"[handleMapperEnd] $shuffleId not registered, create one.") - initMapperAttempts(shuffleId, numMappers) + initMapperAttempts(shuffleId, numMappers, numPartitions) } val attempts = shuffleMapperAttempts.get(shuffleId) if (attempts(mapId) < 0) { attempts(mapId) = attemptId + if (null != pushFailedBatches && !pushFailedBatches.isEmpty) { val pushFailedBatchesMap = shufflePushFailedBatches.computeIfAbsent( shuffleId, @@ -296,24 +310,93 @@ class ReducePartitionCommitHandler( (false, false) } } + if (shuffleIntegrityCheckEnabled && mapperAttemptFinishedSuccess) { + val commitMetadataArray = commitMetadataForReducer.get(shuffleId) + checkState( + commitMetadataArray != null, + "commitMetadataArray can only be null if shuffleId %s is not registered!", + shuffleId) + for (i <- 0 until numPartitions) { + if (bytesWrittenPerPartition(i) != 0) { + commitMetadataArray(i).addCommitData( + crc32PerPartition(i), + bytesWrittenPerPartition(i)) + } + } + } + (mapperAttemptFinishedSuccess, allMapperFinished) } override def registerShuffle( shuffleId: Int, numMappers: Int, - isSegmentGranularityVisible: Boolean): Unit = { - super.registerShuffle(shuffleId, numMappers, isSegmentGranularityVisible) + isSegmentGranularityVisible: Boolean, + numPartitions: Int): Unit = { + super.registerShuffle(shuffleId, numMappers, isSegmentGranularityVisible, numPartitions) getReducerFileGroupRequest.put(shuffleId, new util.HashSet[MultiSerdeVersionRpcContext]()) - initMapperAttempts(shuffleId, numMappers) + initMapperAttempts(shuffleId, numMappers, numPartitions) } - private def initMapperAttempts(shuffleId: Int, numMappers: Int): Unit = { + override def finishPartition( + shuffleId: Int, + partitionId: Int, + startMapIndex: Int, + endMapIndex: Int, + actualCommitMetadata: CommitMetadata): (Boolean, String) = { + logDebug(s"finish Partition call: shuffleId: $shuffleId, " + + s"partitionId: $partitionId, " + + s"startMapIndex: $startMapIndex " + + s"endMapIndex: $endMapIndex, " + + s"actualCommitMetadata: $actualCommitMetadata") + val map = commitMetadataForReducer.get(shuffleId) + checkState( + map != null, + "CommitMetadata map cannot be null for a registered shuffleId: %d", + shuffleId) + val expectedCommitMetadata = map(partitionId) + if (endMapIndex == Integer.MAX_VALUE) { + // complete partition available + val bool = CommitMetadata.checkCommitMetadata(actualCommitMetadata, expectedCommitMetadata) + var message = "" + if (!bool) { + message = + s"CommitMetadata mismatch for shuffleId: $shuffleId partitionId: $partitionId expected: $expectedCommitMetadata actual: $actualCommitMetadata" + } else { + logInfo( + s"CommitMetadata matched for shuffleID : $shuffleId, partitionId: $partitionId expected: $expectedCommitMetadata actual: $actualCommitMetadata") + } + return (bool, message) + } + + val splitSkewPartitionWithoutMapRange = + ClientUtils.readSkewPartitionWithoutMapRange(conf, startMapIndex, endMapIndex) + + val validator = skewPartitionCompletenessValidator.computeIfAbsent( + shuffleId, + new java.util.function.Function[Int, PartitionCompletenessValidator] { + override def apply(key: Int): PartitionCompletenessValidator = + new PartitionCompletenessValidator() + }) + validator.validateSubPartition( + partitionId, + startMapIndex, + endMapIndex, + actualCommitMetadata, + expectedCommitMetadata, + shuffleMapperAttempts.get(shuffleId).length, + splitSkewPartitionWithoutMapRange) + } + + private def initMapperAttempts(shuffleId: Int, numMappers: Int, numPartitions: Int): Unit = { shuffleMapperAttempts.synchronized { if (!shuffleMapperAttempts.containsKey(shuffleId)) { val attempts = new Array[Int](numMappers) 0 until numMappers foreach (idx => attempts(idx) = -1) shuffleMapperAttempts.put(shuffleId, attempts) } + if (shuffleIntegrityCheckEnabled) { + commitMetadataForReducer.put(shuffleId, Array.fill(numPartitions)(new CommitMetadata())) + } } } diff --git a/client/src/main/scala/org/apache/celeborn/client/commit/SkewHandlingWithoutMapRangeValidator.scala b/client/src/main/scala/org/apache/celeborn/client/commit/SkewHandlingWithoutMapRangeValidator.scala new file mode 100644 index 000000000..e2d7d8463 --- /dev/null +++ b/client/src/main/scala/org/apache/celeborn/client/commit/SkewHandlingWithoutMapRangeValidator.scala @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.client.commit + +import java.util + +import com.google.common.base.Preconditions.{checkArgument, checkState} + +import org.apache.celeborn.common.CommitMetadata +import org.apache.celeborn.common.util.JavaUtils + +class SkewHandlingWithoutMapRangeValidator extends AbstractPartitionCompletenessValidator { + + private val totalSubPartitionsProcessed = + JavaUtils.newConcurrentHashMap[Int, util.HashMap[Int, CommitMetadata]]() + private val partitionToSubPartitionCount = JavaUtils.newConcurrentHashMap[Int, Int]() + private val currentCommitMetadataForReducer = + JavaUtils.newConcurrentHashMap[Int, CommitMetadata]() + + override def processSubPartition( + partitionId: Int, + startMapIndex: Int, + endMapIndex: Int, + actualCommitMetadata: CommitMetadata, + expectedTotalMapperCount: Int): (Boolean, String) = { + checkArgument( + endMapIndex >= 0, + "index of sub-partition %s must be greater than or equal to 0", + endMapIndex) + checkArgument( + startMapIndex > endMapIndex, + "startMapIndex %s must be greater than endMapIndex %s", + startMapIndex, + endMapIndex) + totalSubPartitionsProcessed.synchronized { + if (totalSubPartitionsProcessed.containsKey(partitionId)) { + val currentSubPartitionCount = partitionToSubPartitionCount.getOrDefault(partitionId, -1) + checkState( + currentSubPartitionCount == startMapIndex, + "total subpartition count mismatch for partitionId %s existing %s new %s", + partitionId, + currentSubPartitionCount, + startMapIndex) + } else { + totalSubPartitionsProcessed.put(partitionId, new util.HashMap[Int, CommitMetadata]()) + partitionToSubPartitionCount.put(partitionId, startMapIndex) + } + val subPartitionsProcessed = totalSubPartitionsProcessed.get(partitionId) + if (subPartitionsProcessed.containsKey(endMapIndex)) { + // check if previous entry matches + val existingCommitMetadata = subPartitionsProcessed.get(endMapIndex) + if (existingCommitMetadata != actualCommitMetadata) { + return ( + false, + s"Mismatch in metadata for the same chunk range on retry: $endMapIndex existing: $existingCommitMetadata new: $actualCommitMetadata") + } + } + subPartitionsProcessed.put(endMapIndex, actualCommitMetadata) + val partitionProcessed = getTotalNumberOfSubPartitionsProcessed(partitionId) + checkState( + partitionProcessed <= startMapIndex, + "Number of sub-partitions processed %s should less than total number of sub-partitions %s", + partitionProcessed, + startMapIndex) + } + updateCommitMetadata(partitionId, actualCommitMetadata) + (true, "") + } + + private def updateCommitMetadata(partitionId: Int, actualCommitMetadata: CommitMetadata): Unit = { + val currentCommitMetadata = + currentCommitMetadataForReducer.computeIfAbsent( + partitionId, + new java.util.function.Function[Int, CommitMetadata] { + override def apply(partitionId: Int): CommitMetadata = { + new CommitMetadata() + } + }) + currentCommitMetadata.addCommitData(actualCommitMetadata) + } + + private def getTotalNumberOfSubPartitionsProcessed(partitionId: Int) = { + totalSubPartitionsProcessed.get(partitionId).size() + } + + override def currentCommitMetadata(partitionId: Int): CommitMetadata = { + currentCommitMetadataForReducer.get(partitionId) + } + + override def isPartitionComplete(partitionId: Int): Boolean = { + getTotalNumberOfSubPartitionsProcessed(partitionId) == partitionToSubPartitionCount.get( + partitionId) + } +} diff --git a/client/src/test/java/org/apache/celeborn/client/ShuffleClientSuiteJ.java b/client/src/test/java/org/apache/celeborn/client/ShuffleClientSuiteJ.java index 85cf0ba10..acca92b4b 100644 --- a/client/src/test/java/org/apache/celeborn/client/ShuffleClientSuiteJ.java +++ b/client/src/test/java/org/apache/celeborn/client/ShuffleClientSuiteJ.java @@ -18,8 +18,7 @@ package org.apache.celeborn.client; import static org.junit.Assert.*; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.*; import static org.mockito.BDDMockito.willAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -34,6 +33,9 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicReference; +import scala.reflect.ClassTag; +import scala.reflect.ClassTag$; + import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; import io.netty.util.concurrent.Future; @@ -41,6 +43,7 @@ import io.netty.util.concurrent.GenericFutureListener; import org.apache.commons.lang3.RandomStringUtils; import org.junit.Assert; import org.junit.Test; +import org.mockito.ArgumentCaptor; import org.apache.celeborn.client.compress.Compressor; import org.apache.celeborn.common.CelebornConf; @@ -51,6 +54,8 @@ import org.apache.celeborn.common.network.client.TransportClientFactory; import org.apache.celeborn.common.network.protocol.SerdeVersion; import org.apache.celeborn.common.protocol.CompressionCodec; import org.apache.celeborn.common.protocol.PartitionLocation; +import org.apache.celeborn.common.protocol.PbReadReducerPartitionEnd; +import org.apache.celeborn.common.protocol.PbReadReducerPartitionEndResponse; import org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse$; import org.apache.celeborn.common.protocol.message.ControlMessages.RegisterShuffleResponse$; import org.apache.celeborn.common.protocol.message.StatusCode; @@ -68,6 +73,7 @@ public class ShuffleClientSuiteJ { private static final int TEST_SHUFFLE_ID = 1; private static final int TEST_ATTEMPT_ID = 0; private static final int TEST_REDUCRE_ID = 0; + private static final int TEST_MAP_ID = 0; private static final int PRIMARY_RPC_PORT = 1234; private static final int PRIMARY_PUSH_PORT = 1235; @@ -621,4 +627,86 @@ public class ShuffleClientSuiteJ { Exception exception = exceptionRef.get(); Assert.assertTrue(exception.getCause() instanceof TimeoutException); } + + @Test + public void testSuccessfulReadReducePartitionEnd() throws IOException { + CelebornConf conf = new CelebornConf(); + shuffleClient = + new ShuffleClientImpl(TEST_APPLICATION_ID, conf, new UserIdentifier("mock", "mock")); + shuffleClient.setupLifecycleManagerRef(endpointRef); + + ClassTag classTag = + ClassTag$.MODULE$.apply(PbReadReducerPartitionEndResponse.class); + PbReadReducerPartitionEndResponse mockResponse = + PbReadReducerPartitionEndResponse.newBuilder() + .setStatus(StatusCode.SUCCESS.getValue()) + .build(); + when(endpointRef.askSync(any(PbReadReducerPartitionEnd.class), any(), eq(classTag))) + .thenReturn(mockResponse); + + shuffleClient.readReducerPartitionEnd(1, 2, 3, 4, 12345, 10000L); + } + + @Test + public void testFailedReadReducerPartitionEnd() throws IOException { + CelebornConf conf = new CelebornConf(); + shuffleClient = + new ShuffleClientImpl(TEST_APPLICATION_ID, conf, new UserIdentifier("mock", "mock")); + shuffleClient.setupLifecycleManagerRef(endpointRef); + ClassTag classTag = + ClassTag$.MODULE$.apply(PbReadReducerPartitionEndResponse.class); + + String errorMsg = "Test error message"; + PbReadReducerPartitionEndResponse mockResponse = + PbReadReducerPartitionEndResponse.newBuilder() + .setStatus(StatusCode.READ_REDUCER_PARTITION_END_FAILED.getValue()) + .setErrorMsg(errorMsg) + .build(); + + when(endpointRef.askSync(any(PbReadReducerPartitionEnd.class), any(), eq(classTag))) + .thenReturn(mockResponse); + + try { + shuffleClient.readReducerPartitionEnd(1, 2, 3, 4, 12345, 10000L); + } catch (CelebornIOException e) { + Assert.assertEquals(errorMsg, e.getMessage()); + } + } + + @Test + public void testCorrectParametersPassedInRequest() throws IOException { + CelebornConf conf = new CelebornConf(); + shuffleClient = + new ShuffleClientImpl(TEST_APPLICATION_ID, conf, new UserIdentifier("mock", "mock")); + shuffleClient.setupLifecycleManagerRef(endpointRef); + + int shuffleId = 123; + int partitionId = 456; + int startMapIndex = 0; + int endMapIndex = 10; + int crc32 = 98765; + long bytesWritten = 54321L; + + PbReadReducerPartitionEndResponse mockResponse = + PbReadReducerPartitionEndResponse.newBuilder() + .setStatus(StatusCode.SUCCESS.getValue()) + .build(); + ArgumentCaptor requestCaptor = + ArgumentCaptor.forClass(PbReadReducerPartitionEnd.class); + ClassTag classTag = + ClassTag$.MODULE$.apply(PbReadReducerPartitionEndResponse.class); + + when(endpointRef.askSync(requestCaptor.capture(), any(), eq(classTag))) + .thenReturn(mockResponse); + + shuffleClient.readReducerPartitionEnd( + shuffleId, partitionId, startMapIndex, endMapIndex, crc32, bytesWritten); + PbReadReducerPartitionEnd capturedRequest = requestCaptor.getValue(); + assertEquals(shuffleId, capturedRequest.getShuffleId()); + assertEquals(partitionId, capturedRequest.getPartitionId()); + assertEquals(startMapIndex, capturedRequest.getStartMaxIndex()); + assertEquals(endMapIndex, capturedRequest.getEndMapIndex()); + assertEquals(crc32, capturedRequest.getCrc32()); + assertEquals(bytesWritten, capturedRequest.getBytesWritten()); + } } diff --git a/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala b/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala index 9f2199ced..263a2a9a6 100644 --- a/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala +++ b/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala @@ -197,21 +197,29 @@ trait WithShuffleClientSuite extends CelebornFunSuite { } private def registerAndFinishPartition(shuffleId: Int): Unit = { + val numPartitions = 9 shuffleClient.registerMapPartitionTask(shuffleId, numMappers, mapId, attemptId, 1) shuffleClient.registerMapPartitionTask(shuffleId, numMappers, mapId + 1, attemptId, 2) shuffleClient.registerMapPartitionTask(shuffleId, numMappers, mapId + 2, attemptId, 3) // task number incr to numMappers + 1 shuffleClient.registerMapPartitionTask(shuffleId, numMappers, mapId, attemptId + 1, 9) - shuffleClient.mapPartitionMapperEnd(shuffleId, mapId, attemptId, numMappers, 1) + shuffleClient.mapPartitionMapperEnd(shuffleId, mapId, attemptId, numMappers, numPartitions, 1) // another attempt shuffleClient.mapPartitionMapperEnd( shuffleId, mapId, attemptId + 1, numMappers, + numPartitions, 9) // another mapper - shuffleClient.mapPartitionMapperEnd(shuffleId, mapId + 1, attemptId, numMappers, mapId + 1) + shuffleClient.mapPartitionMapperEnd( + shuffleId, + mapId + 1, + attemptId, + numMappers, + numPartitions, + mapId + 1) } } diff --git a/client/src/test/scala/org/apache/celeborn/client/commit/PartitionValidatorTest.scala b/client/src/test/scala/org/apache/celeborn/client/commit/PartitionValidatorTest.scala new file mode 100644 index 000000000..d2a5aded7 --- /dev/null +++ b/client/src/test/scala/org/apache/celeborn/client/commit/PartitionValidatorTest.scala @@ -0,0 +1,301 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.client.commit + +import org.scalatest.matchers.must.Matchers.{be, include} +import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper + +import org.apache.celeborn.CelebornFunSuite +import org.apache.celeborn.common.CommitMetadata + +class PartitionValidatorTest extends CelebornFunSuite { + + var validator: PartitionCompletenessValidator = _ + + var mockCommitMetadata: CommitMetadata = new CommitMetadata() + test("AQEPartitionCompletenessValidator should validate a new sub-partition correctly when there are no overlapping ranges") { + validator = new PartitionCompletenessValidator + val (isValid, message) = + validator.validateSubPartition(1, 0, 10, mockCommitMetadata, mockCommitMetadata, 20, false) + + isValid shouldBe (true) + message shouldBe ("Partition is valid but still waiting for more data") + } + + test("AQEPartitionCompletenessValidator should fail validation for overlapping map ranges") { + validator = new PartitionCompletenessValidator + validator.validateSubPartition( + 1, + 0, + 10, + mockCommitMetadata, + mockCommitMetadata, + 20, + false + ) // First call should add the range + val (isValid, message) = + validator.validateSubPartition( + 1, + 5, + 15, + mockCommitMetadata, + mockCommitMetadata, + 20, + false + ) // This overlaps + + isValid shouldBe (false) + message should include("Encountered overlapping map range for partitionId: 1") + } + + test( + "AQEPartitionCompletenessValidator should fail validation for overlapping map ranges- case 2") { + validator = new PartitionCompletenessValidator + validator.validateSubPartition( + 1, + 0, + 1, + mockCommitMetadata, + mockCommitMetadata, + 20, + false + ) // First call should add the range + validator.validateSubPartition( + 1, + 2, + 3, + mockCommitMetadata, + mockCommitMetadata, + 20, + false + ) // First call should add the range + val (isValid, message) = + validator.validateSubPartition( + 1, + 1, + 2, + mockCommitMetadata, + mockCommitMetadata, + 20, + false + ) // This overlaps + + isValid shouldBe (true) + } + + test("AQEPartitionCompletenessValidator should fail validation for overlapping map ranges - edge cases") { + validator = new PartitionCompletenessValidator + validator.validateSubPartition( + 1, + 0, + 10, + mockCommitMetadata, + mockCommitMetadata, + 20, + false + ) // First call should add the range + val (isValid, message) = + validator.validateSubPartition( + 1, + 0, + 5, + mockCommitMetadata, + mockCommitMetadata, + 20, + false + ) // This overlaps + + isValid shouldBe (false) + message should include("Encountered overlapping map range for partitionId: 1") + } + + test("AQEPartitionCompletenessValidator should fail validation for one map range subsuming another map range") { + validator = new PartitionCompletenessValidator + validator.validateSubPartition( + 1, + 5, + 10, + mockCommitMetadata, + mockCommitMetadata, + 20, + false + ) // First call should add the range + val (isValid, message) = + validator.validateSubPartition( + 1, + 0, + 15, + mockCommitMetadata, + mockCommitMetadata, + 20, + false + ) // This overlaps + + isValid shouldBe (false) + message should include("Encountered overlapping map range for partitionId: 1") + } + + test("AQEPartitionCompletenessValidator should fail validation for one map range subsuming another map range - 2") { + validator = new PartitionCompletenessValidator + validator.validateSubPartition( + 1, + 0, + 15, + mockCommitMetadata, + mockCommitMetadata, + 20, + false + ) // First call should add the range + val (isValid, message) = + validator.validateSubPartition( + 1, + 5, + 10, + mockCommitMetadata, + mockCommitMetadata, + 20, + false + ) // This overlaps + + isValid shouldBe (false) + message should include("Encountered overlapping map range for partitionId: 1") + } + + test( + "AQEPartitionCompletenessValidator should fail validation if commit metadata does not match") { + val expectedCommitMetadata = new CommitMetadata() + val actuaCommitMetadata = new CommitMetadata() + validator = new PartitionCompletenessValidator + validator.validateSubPartition( + 1, + 0, + 10, + actuaCommitMetadata, + expectedCommitMetadata, + 20, + false + ) // Write first partition + + // Second one with a different count + val (isValid, message) = + validator.validateSubPartition( + 1, + 0, + 10, + new CommitMetadata(3, 3), + expectedCommitMetadata, + 20, + false) + + isValid should be(false) + message should include("Commit Metadata for partition: 1 not matching for sub-partition") + } + + test("pass validation if written counts are correct after all updates") { + validator = new PartitionCompletenessValidator + val expectedCommitMetadata = new CommitMetadata(0, 10) + validator.validateSubPartition( + 1, + 0, + 10, + new CommitMetadata(0, 5), + expectedCommitMetadata, + 20, + false) + val (isValid, message) = validator.validateSubPartition( + 1, + 10, + 20, + new CommitMetadata(0, 5), + expectedCommitMetadata, + 20, + false) + + isValid should be(true) + message should be("Partition is complete") + } + + test("handle multiple partitions correctly") { + validator = new PartitionCompletenessValidator + // Testing with multiple partitions to check isolation + val expectedCommitMetadataForPartition1 = new CommitMetadata(0, 10) + val expectedCommitMetadataForPartition2 = new CommitMetadata(0, 2) + validator.validateSubPartition( + 1, + 0, + 10, + new CommitMetadata(0, 5), + expectedCommitMetadataForPartition1, + 20, + false + ) // Validate partition 1 + validator.validateSubPartition( + 2, + 0, + 5, + new CommitMetadata(0, 2), + expectedCommitMetadataForPartition2, + 10, + false + ) // Validate partition 2 + + val (isValid1, message1) = validator.validateSubPartition( + 1, + 0, + 10, + new CommitMetadata(0, 5), + expectedCommitMetadataForPartition1, + 20, + false) + val (isValid2, message2) = validator.validateSubPartition( + 2, + 0, + 5, + new CommitMetadata(0, 2), + expectedCommitMetadataForPartition2, + 10, + false) + + isValid1 should be(true) + isValid2 should be(true) + message1 should be("Partition is valid but still waiting for more data") + message2 should be("Partition is valid but still waiting for more data") + + val (isValid3, message3) = validator.validateSubPartition( + 1, + 10, + 20, + new CommitMetadata(0, 5), + expectedCommitMetadataForPartition1, + 20, + false) + val (isValid4, message4) = validator.validateSubPartition( + 2, + 5, + 10, + new CommitMetadata(), + expectedCommitMetadataForPartition2, + 10, + false) + isValid3 should be(true) + isValid4 should be(true) + message3 should be("Partition is complete") + message4 should be("Partition is complete") + + } +} diff --git a/client/src/test/scala/org/apache/celeborn/client/commit/SkewHandlingWithoutMapRangeValidatorTest.scala b/client/src/test/scala/org/apache/celeborn/client/commit/SkewHandlingWithoutMapRangeValidatorTest.scala new file mode 100644 index 000000000..e087ae93f --- /dev/null +++ b/client/src/test/scala/org/apache/celeborn/client/commit/SkewHandlingWithoutMapRangeValidatorTest.scala @@ -0,0 +1,251 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.client.commit + +import org.scalatest.matchers.must.Matchers.{be, include} +import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper + +import org.apache.celeborn.CelebornFunSuite +import org.apache.celeborn.common.CommitMetadata + +class SkewHandlingWithoutMapRangeValidatorTest extends CelebornFunSuite { + + private var validator: SkewHandlingWithoutMapRangeValidator = _ + + test("testProcessSubPartitionFirstTime") { + // Setup + validator = new SkewHandlingWithoutMapRangeValidator + val partitionId = 1 + val startMapIndex = 10 + val endMapIndex = 0 + val metadata = new CommitMetadata + val expectedMapperCount = 20 + + // Execute + val (success, message) = validator.processSubPartition( + partitionId, + startMapIndex, + endMapIndex, + metadata, + expectedMapperCount) + + success shouldBe true + message shouldBe "" + validator.isPartitionComplete(partitionId) shouldBe false + + // Verify current metadata + val currentMetadata = validator.currentCommitMetadata(partitionId) + currentMetadata shouldNot be(null) + currentMetadata shouldBe metadata + } + + test("testProcessMultipleSubPartitions") { + // Setup + validator = new SkewHandlingWithoutMapRangeValidator + val partitionId = 1 + val startMapIndex = 10 + val metadata1 = new CommitMetadata(5, 1000) + val metadata2 = new CommitMetadata(3, 500) + + // Process first sub-partition + validator.processSubPartition(partitionId, startMapIndex, 0, metadata1, startMapIndex) + + // Process second sub-partition + val (success, message) = validator.processSubPartition( + partitionId, + startMapIndex, + 1, + metadata2, + startMapIndex) + + // Verify + success shouldBe true + message shouldBe "" + validator.isPartitionComplete(partitionId) shouldBe false + + // Verify current metadata + val currentMetadata = validator.currentCommitMetadata(partitionId) + currentMetadata shouldNot be(null) + metadata1.addCommitData(metadata2) + currentMetadata shouldBe metadata1 + + } + + test("testPartitionComplete") { + // Setup - processing all sub-partitions from 0 to 9 for a total of 10 + validator = new SkewHandlingWithoutMapRangeValidator + val partitionId = 1 + val startMapIndex = 10 + + // Process all sub-partitions + for (i <- 0 until startMapIndex) { + val metadata = new CommitMetadata() + validator.processSubPartition(partitionId, startMapIndex, i, metadata, startMapIndex) + } + + validator.isPartitionComplete(partitionId) shouldBe true + } + + test("testInvalidStartEndMapIndex") { + // Setup - invalid case where startMapIndex <= endMapIndex + validator = new SkewHandlingWithoutMapRangeValidator + val partitionId = 1 + val startMapIndex = 5 + val endMapIndex = 5 // Equal, which should fail + val metadata = new CommitMetadata() + + // This should throw IllegalArgumentException + intercept[IllegalArgumentException] { + validator.processSubPartition(partitionId, startMapIndex, endMapIndex, metadata, 20) + } + } + + test("testMismatchSubPartitionCount") { + // Setup + validator = new SkewHandlingWithoutMapRangeValidator + val partitionId = 1 + val metadata = new CommitMetadata() + + // First call with startMapIndex = 10 + validator.processSubPartition(partitionId, 10, 0, metadata, 10) + + // Second call with different startMapIndex, which should fail + intercept[IllegalStateException] { + validator.processSubPartition(partitionId, 12, 1, metadata, 12) + } + } + + test("testDuplicateSubPartitionWithSameMetadata") { + // Setup + validator = new SkewHandlingWithoutMapRangeValidator + val partitionId = 1 + val startMapIndex = 10 + val endMapIndex = 5 + val metadata = new CommitMetadata(3, 100) + + // Process sub-partition first time + validator.processSubPartition(partitionId, startMapIndex, endMapIndex, metadata, startMapIndex) + + // Process same sub-partition again with identical metadata + val (success, message) = validator.processSubPartition( + partitionId, + startMapIndex, + endMapIndex, + metadata, + startMapIndex) + success shouldBe true + message shouldBe "" + validator.isPartitionComplete(partitionId) shouldBe false + } + + test("testDuplicateSubPartitionWithDifferentMetadata") { + // Setup + validator = new SkewHandlingWithoutMapRangeValidator + val partitionId = 1 + val startMapIndex = 10 + val endMapIndex = 5 + val metadata1 = new CommitMetadata(3, 100) + + val metadata2 = new CommitMetadata(4, 200) + + // Process sub-partition first time + validator.processSubPartition(partitionId, startMapIndex, endMapIndex, metadata1, startMapIndex) + + // Process same sub-partition again with different metadata + val (success, message) = validator.processSubPartition( + partitionId, + startMapIndex, + endMapIndex, + metadata2, + startMapIndex) + success shouldBe false + message should include("Mismatch in metadata") + validator.isPartitionComplete(partitionId) shouldBe false + } + + test("testProcessTooManySubPartitions") { + // Setup + validator = new SkewHandlingWithoutMapRangeValidator + val partitionId = 1 + val startMapIndex = 10 + + // Process all sub-partitions + for (i <- 0 until startMapIndex) { + val metadata = new CommitMetadata() + validator.processSubPartition(partitionId, startMapIndex, i, metadata, startMapIndex) + } + + // Try to process one more sub-partition, which should exceed the total count + val extraMetadata = new CommitMetadata() + // This should throw IllegalStateException + intercept[IllegalArgumentException] { + validator.processSubPartition( + partitionId, + startMapIndex, + startMapIndex, + extraMetadata, + startMapIndex) + } + } + + test("multiple complete partitions") { + // Setup - we'll use 3 different partitions + validator = new SkewHandlingWithoutMapRangeValidator + val partition1 = 1 + val partition2 = 2 + val partition3 = 3 + + // Different sizes for each partition + val subPartitions1 = 5 + val subPartitions2 = 8 + val subPartitions3 = 3 + + // Process all sub-partitions for partition1 + val expectedMetadata1 = new CommitMetadata() + for (i <- 0 until subPartitions1) { + val metadata = new CommitMetadata(i, 100 + i) + validator.processSubPartition(partition1, subPartitions1, i, metadata, subPartitions1) + expectedMetadata1.addCommitData(metadata) + } + + // Process all sub-partitions for partition2 + val expectedMetadata2 = new CommitMetadata() + for (i <- 0 until subPartitions2) { + val metadata = new CommitMetadata(i + 1, 200 + i) + validator.processSubPartition(partition2, subPartitions2, i, metadata, subPartitions2) + expectedMetadata2.addCommitData(metadata) + } + + // Process only some sub-partitions for partition3 + val expectedMetadata3 = new CommitMetadata() + for (i <- 0 until subPartitions3 - 1) { // Deliberately leave one out + val metadata = new CommitMetadata(i + 2, 300 + i) + validator.processSubPartition(partition3, subPartitions3, i, metadata, subPartitions3) + expectedMetadata3.addCommitData(metadata) + } + + // Verify completion status for each partition + validator.isPartitionComplete(partition1) shouldBe true + validator.isPartitionComplete(partition2) shouldBe true + validator.isPartitionComplete(partition3) shouldBe false + + validator.currentCommitMetadata(partition1) shouldBe expectedMetadata1 + validator.currentCommitMetadata(partition2) shouldBe expectedMetadata2 + validator.currentCommitMetadata(partition3) shouldBe expectedMetadata3 + } +} diff --git a/common/src/main/java/org/apache/celeborn/common/CelebornCRC32.java b/common/src/main/java/org/apache/celeborn/common/CelebornCRC32.java new file mode 100644 index 000000000..c4c71efa6 --- /dev/null +++ b/common/src/main/java/org/apache/celeborn/common/CelebornCRC32.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.common; + +import java.util.Objects; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.zip.CRC32; + +public class CelebornCRC32 { + + private final AtomicInteger current; + + CelebornCRC32(int i) { + this.current = new AtomicInteger(i); + } + + CelebornCRC32() { + this(0); + } + + static int compute(byte[] bytes) { + CRC32 hashFunction = new CRC32(); + hashFunction.update(bytes); + return (int) hashFunction.getValue(); + } + + static int compute(byte[] bytes, int offset, int length) { + CRC32 hashFunction = new CRC32(); + hashFunction.update(bytes, offset, length); + return (int) hashFunction.getValue(); + } + + static int combine(int first, int second) { + first = + (((byte) second + (byte) first) & 0xFF) + | ((((byte) (second >> 8) + (byte) (first >> 8)) & 0xFF) << 8) + | ((((byte) (second >> 16) + (byte) (first >> 16)) & 0xFF) << 16) + | (((byte) (second >> 24) + (byte) (first >> 24)) << 24); + return first; + } + + public void addChecksum(int checksum) { + while (true) { + int val = current.get(); + int newVal = combine(checksum, val); + if (current.compareAndSet(val, newVal)) { + break; + } + } + } + + void addData(byte[] bytes, int offset, int length) { + addChecksum(compute(bytes, offset, length)); + } + + int get() { + return current.get(); + } + + @Override + public String toString() { + return "CelebornCRC32{" + "current=" + current + '}'; + } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) return false; + CelebornCRC32 that = (CelebornCRC32) o; + return current.get() == that.current.get(); + } + + @Override + public int hashCode() { + return Objects.hashCode(current); + } +} diff --git a/common/src/main/java/org/apache/celeborn/common/CommitMetadata.java b/common/src/main/java/org/apache/celeborn/common/CommitMetadata.java new file mode 100644 index 000000000..39b84883c --- /dev/null +++ b/common/src/main/java/org/apache/celeborn/common/CommitMetadata.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.common; + +import java.util.Objects; +import java.util.concurrent.atomic.AtomicLong; + +public class CommitMetadata { + + private final AtomicLong bytes; + private final CelebornCRC32 crc; + + public CommitMetadata() { + this.bytes = new AtomicLong(); + this.crc = new CelebornCRC32(); + } + + public CommitMetadata(int checksum, long numBytes) { + this.bytes = new AtomicLong(numBytes); + this.crc = new CelebornCRC32(checksum); + } + + public void addDataWithOffsetAndLength(byte[] rawDataBuf, int offset, int length) { + this.bytes.addAndGet(length); + this.crc.addData(rawDataBuf, offset, length); + } + + public void addCommitData(CommitMetadata commitMetadata) { + addCommitData(commitMetadata.getChecksum(), commitMetadata.getBytes()); + } + + public void addCommitData(int checksum, long numBytes) { + this.bytes.addAndGet(numBytes); + this.crc.addChecksum(checksum); + } + + public int getChecksum() { + return crc.get(); + } + + public long getBytes() { + return bytes.get(); + } + + public static boolean checkCommitMetadata(CommitMetadata expected, CommitMetadata actual) { + boolean bytesMatch = expected.getBytes() == actual.getBytes(); + boolean checksumsMatch = expected.getChecksum() == actual.getChecksum(); + return bytesMatch && checksumsMatch; + } + + @Override + public String toString() { + return "CommitMetadata{" + "bytes=" + bytes.get() + ", crc=" + crc + '}'; + } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) return false; + CommitMetadata that = (CommitMetadata) o; + return bytes.get() == that.bytes.get() && Objects.equals(crc, that.crc); + } + + @Override + public int hashCode() { + return Objects.hash(bytes, crc); + } +} diff --git a/common/src/main/java/org/apache/celeborn/common/protocol/message/StatusCode.java b/common/src/main/java/org/apache/celeborn/common/protocol/message/StatusCode.java index 46c59ab52..96b53766a 100644 --- a/common/src/main/java/org/apache/celeborn/common/protocol/message/StatusCode.java +++ b/common/src/main/java/org/apache/celeborn/common/protocol/message/StatusCode.java @@ -92,7 +92,8 @@ public enum StatusCode { SEGMENT_START_FAIL_REPLICA(52), SEGMENT_START_FAIL_PRIMARY(53), NO_SPLIT(54), - WORKER_UNRESPONSIVE(55); + WORKER_UNRESPONSIVE(55), + READ_REDUCER_PARTITION_END_FAILED(56); private final byte value; diff --git a/common/src/main/java/org/apache/celeborn/common/write/PushState.java b/common/src/main/java/org/apache/celeborn/common/write/PushState.java index afa22bb8a..be40ec572 100644 --- a/common/src/main/java/org/apache/celeborn/common/write/PushState.java +++ b/common/src/main/java/org/apache/celeborn/common/write/PushState.java @@ -25,6 +25,7 @@ import java.util.concurrent.atomic.AtomicReference; import org.apache.commons.lang3.tuple.Pair; import org.apache.celeborn.common.CelebornConf; +import org.apache.celeborn.common.CommitMetadata; import org.apache.celeborn.common.protocol.PartitionLocation; import org.apache.celeborn.common.util.JavaUtils; @@ -33,6 +34,9 @@ public class PushState { private final int pushBufferMaxSize; public AtomicReference exception = new AtomicReference<>(); private final InFlightRequestTracker inFlightRequestTracker; + // partition id -> CommitMetadata + private final ConcurrentHashMap commitMetadataMap = + new ConcurrentHashMap<>(); private final Map failedBatchMap; @@ -102,4 +106,34 @@ public class PushState { public Map getFailedBatches() { return this.failedBatchMap; } + + public int[] getCRC32PerPartition(boolean shuffleIntegrityCheckEnabled, int numPartitions) { + if (!shuffleIntegrityCheckEnabled) { + return new int[0]; + } + + int[] crc32PerPartition = new int[numPartitions]; + for (Map.Entry entry : commitMetadataMap.entrySet()) { + crc32PerPartition[entry.getKey()] = entry.getValue().getChecksum(); + } + return crc32PerPartition; + } + + public long[] getBytesWrittenPerPartition( + boolean shuffleIntegrityCheckEnabled, int numPartitions) { + if (!shuffleIntegrityCheckEnabled) { + return new long[0]; + } + long[] bytesWrittenPerPartition = new long[numPartitions]; + for (Map.Entry entry : commitMetadataMap.entrySet()) { + bytesWrittenPerPartition[entry.getKey()] = entry.getValue().getBytes(); + } + return bytesWrittenPerPartition; + } + + public void addDataWithOffsetAndLength(int partitionId, byte[] data, int offset, int length) { + CommitMetadata commitMetadata = + commitMetadataMap.computeIfAbsent(partitionId, id -> new CommitMetadata()); + commitMetadata.addDataWithOffsetAndLength(data, offset, length); + } } diff --git a/common/src/main/proto/TransportMessages.proto b/common/src/main/proto/TransportMessages.proto index 9ecb05051..6051b5901 100644 --- a/common/src/main/proto/TransportMessages.proto +++ b/common/src/main/proto/TransportMessages.proto @@ -114,6 +114,8 @@ enum MessageType { PUSH_MERGED_DATA_SPLIT_PARTITION_INFO = 91; GET_STAGE_END = 92; GET_STAGE_END_RESPONSE = 93; + READ_REDUCER_PARTITION_END = 94; + READ_REDUCER_PARTITION_END_RESPONSE = 95; } enum StreamType { @@ -360,6 +362,9 @@ message PbMapperEnd { int32 numMappers = 4; int32 partitionId = 5; map pushFailureBatches= 6; + int32 numPartitions = 7; + repeated int32 crc32PerPartition = 8; + repeated int64 bytesWrittenPerPartition = 9; } message PbLocationPushFailedBatches { @@ -920,3 +925,17 @@ message PbGetStageEndResponse { message PbChunkOffsets { repeated int64 chunkOffset = 1; } + +message PbReadReducerPartitionEnd { + int32 shuffleId = 1; + int32 partitionId = 2; + int32 startMaxIndex = 3; + int32 endMapIndex = 4; + int32 crc32 = 5; + int64 bytesWritten = 6; +} + +message PbReadReducerPartitionEndResponse { + int32 status = 1; + string errorMsg = 2; +} diff --git a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala index 20f6d09cd..c73832259 100644 --- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala +++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala @@ -944,6 +944,8 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se get(CLIENT_EXCLUDE_PEER_WORKER_ON_FAILURE_ENABLED) def clientMrMaxPushData: Long = get(CLIENT_MR_PUSH_DATA_MAX) def clientApplicationUUIDSuffixEnabled: Boolean = get(CLIENT_APPLICATION_UUID_SUFFIX_ENABLED) + def clientShuffleIntegrityCheckEnabled: Boolean = + get(CLIENT_SHUFFLE_INTEGRITY_CHECK_ENABLED) def appUniqueIdWithUUIDSuffix(appId: String): String = { if (clientApplicationUUIDSuffixEnabled) { @@ -5358,6 +5360,14 @@ object CelebornConf extends Logging { .bytesConf(ByteUnit.BYTE) .createWithDefaultString("512k") + val CLIENT_SHUFFLE_INTEGRITY_CHECK_ENABLED: ConfigEntry[Boolean] = + buildConf("celeborn.client.shuffle.integrityCheck.enabled") + .categories("client") + .version("0.6.1") + .doc("When `true`, enables end-to-end integrity checks for Spark workloads.") + .booleanConf + .createWithDefault(false) + val SPARK_SHUFFLE_WRITER_MODE: ConfigEntry[String] = buildConf("celeborn.client.spark.shuffle.writer") .withAlternative("celeborn.shuffle.writer") diff --git a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala index 04c91e03b..7fcb0fd2f 100644 --- a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala +++ b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala @@ -19,9 +19,11 @@ package org.apache.celeborn.common.protocol.message import java.util import java.util.{Collections, UUID} +import java.util.concurrent.atomic.AtomicIntegerArray import scala.collection.JavaConverters._ +import com.google.common.base.Preconditions.checkState import com.google.protobuf.ByteString import org.roaringbitmap.RoaringBitmap @@ -274,11 +276,25 @@ object ControlMessages extends Logging { attemptId: Int, numMappers: Int, partitionId: Int, - failedBatchSet: util.Map[String, LocationPushFailedBatches]) + failedBatchSet: util.Map[String, LocationPushFailedBatches], + numPartitions: Int, + crc32PerPartition: Array[Int], + bytesWrittenPerPartition: Array[Long]) + extends MasterMessage + + case class ReadReducerPartitionEnd( + shuffleId: Int, + partitionId: Int, + startMapIndex: Int, + endMapIndex: Int, + crc32: Int, + bytesWritten: Long) extends MasterMessage case class MapperEndResponse(status: StatusCode) extends MasterMessage + case class ReadReducerPartitionEndResponse(status: StatusCode) extends MasterMessage + case class GetReducerFileGroup( shuffleId: Int, isSegmentGranularityVisible: Boolean, @@ -606,6 +622,12 @@ object ControlMessages extends Logging { case pb: PbReviseLostShufflesResponse => new TransportMessage(MessageType.REVISE_LOST_SHUFFLES_RESPONSE, pb.toByteArray) + case pb: PbReadReducerPartitionEnd => + new TransportMessage(MessageType.READ_REDUCER_PARTITION_END, pb.toByteArray) + + case pb: PbReadReducerPartitionEndResponse => + new TransportMessage(MessageType.READ_REDUCER_PARTITION_END_RESPONSE, pb.toByteArray) + case pb: PbReportBarrierStageAttemptFailure => new TransportMessage(MessageType.REPORT_BARRIER_STAGE_ATTEMPT_FAILURE, pb.toByteArray) @@ -739,7 +761,16 @@ object ControlMessages extends Logging { case pb: PbChangeLocationResponse => new TransportMessage(MessageType.CHANGE_LOCATION_RESPONSE, pb.toByteArray) - case MapperEnd(shuffleId, mapId, attemptId, numMappers, partitionId, pushFailedBatch) => + case MapperEnd( + shuffleId, + mapId, + attemptId, + numMappers, + partitionId, + pushFailedBatch, + numPartitions, + crc32PerPartition, + bytesWrittenPerPartition) => val pushFailedMap = pushFailedBatch.asScala.map { case (k, v) => val resultValue = PbSerDeUtils.toPbLocationPushFailedBatches(v) (k, resultValue) @@ -751,6 +782,10 @@ object ControlMessages extends Logging { .setNumMappers(numMappers) .setPartitionId(partitionId) .putAllPushFailureBatches(pushFailedMap) + .setNumPartitions(numPartitions) + .addAllCrc32PerPartition(crc32PerPartition.map(Integer.valueOf).toSeq.asJava) + .addAllBytesWrittenPerPartition(bytesWrittenPerPartition.map( + java.lang.Long.valueOf).toSeq.asJava) .build().toByteArray new TransportMessage(MessageType.MAPPER_END, payload) @@ -1176,6 +1211,15 @@ object ControlMessages extends Logging { case MAPPER_END_VALUE => val pbMapperEnd = PbMapperEnd.parseFrom(message.getPayload) + val partitionCount = pbMapperEnd.getCrc32PerPartitionCount + checkState(partitionCount == pbMapperEnd.getBytesWrittenPerPartitionCount) + val crc32Array = new Array[Int](partitionCount) + val bytesWrittenPerPartitionArray = + new Array[Long](pbMapperEnd.getBytesWrittenPerPartitionCount) + for (i <- 0 until partitionCount) { + crc32Array(i) = pbMapperEnd.getCrc32PerPartition(i) + bytesWrittenPerPartitionArray(i) = pbMapperEnd.getBytesWrittenPerPartition(i) + } MapperEnd( pbMapperEnd.getShuffleId, pbMapperEnd.getMapId, @@ -1185,7 +1229,23 @@ object ControlMessages extends Logging { pbMapperEnd.getPushFailureBatchesMap.asScala.map { case (partitionId, pushFailedBatchSet) => (partitionId, PbSerDeUtils.fromPbLocationPushFailedBatches(pushFailedBatchSet)) - }.toMap.asJava) + }.toMap.asJava, + pbMapperEnd.getNumPartitions, + crc32Array, + bytesWrittenPerPartitionArray) + + case READ_REDUCER_PARTITION_END_VALUE => + val pbReadReducerPartitionEnd = PbReadReducerPartitionEnd.parseFrom(message.getPayload) + ReadReducerPartitionEnd( + pbReadReducerPartitionEnd.getShuffleId, + pbReadReducerPartitionEnd.getPartitionId, + pbReadReducerPartitionEnd.getStartMaxIndex, + pbReadReducerPartitionEnd.getEndMapIndex, + pbReadReducerPartitionEnd.getCrc32, + pbReadReducerPartitionEnd.getBytesWritten) + + case READ_REDUCER_PARTITION_END_RESPONSE_VALUE => + PbReadReducerPartitionEndResponse.parseFrom(message.getPayload) case MAPPER_END_RESPONSE_VALUE => val pbMapperEndResponse = PbMapperEndResponse.parseFrom(message.getPayload) diff --git a/common/src/test/java/org/apache/celeborn/common/CelebornCRC32Test.java b/common/src/test/java/org/apache/celeborn/common/CelebornCRC32Test.java new file mode 100644 index 000000000..296e37d9c --- /dev/null +++ b/common/src/test/java/org/apache/celeborn/common/CelebornCRC32Test.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.common; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; + +import org.junit.Test; + +// Test data is generated from https://crccalc.com/ using "CRC-32/ISO-HDLC". +public class CelebornCRC32Test { + + @Test + public void testCompute() { + byte[] data = "test".getBytes(); + int checksum = CelebornCRC32.compute(data); + assertEquals(3632233996L, checksum & 0xFFFFFFFFL); + } + + @Test + public void testComputeWithOffset() { + byte[] data = "testdata".getBytes(); + int checksum = CelebornCRC32.compute(data, 4, 4); + assertEquals(2918445923L, checksum & 0xFFFFFFFFL); + } + + @Test + public void testCombine() { + int first = 123456789; + int second = 987654321; + int combined = CelebornCRC32.combine(first, second); + assertNotEquals(first, combined); + assertNotEquals(second, combined); + } + + @Test + public void testAddChecksum() { + CelebornCRC32 crc = new CelebornCRC32(); + crc.addChecksum(123456789); + assertEquals(123456789, crc.get()); + } + + @Test + public void testAddDataWithOffset() { + CelebornCRC32 crc = new CelebornCRC32(); + byte[] data = "testdata".getBytes(); + crc.addData(data, 4, 4); + assertEquals(2918445923L, crc.get() & 0xFFFFFFFFL); + } + + @Test + public void testToString() { + CelebornCRC32 crc = new CelebornCRC32(123456789); + assertEquals("CelebornCRC32{current=123456789}", crc.toString()); + } +} diff --git a/common/src/test/java/org/apache/celeborn/common/CommitMetadataTest.java b/common/src/test/java/org/apache/celeborn/common/CommitMetadataTest.java new file mode 100644 index 000000000..3451d386f --- /dev/null +++ b/common/src/test/java/org/apache/celeborn/common/CommitMetadataTest.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.common; + +import static org.junit.Assert.assertEquals; + +import org.junit.Assert; +import org.junit.Test; + +public class CommitMetadataTest { + + @Test + public void testAddDataWithOffsetAndLength() { + CommitMetadata metadata = new CommitMetadata(); + byte[] data = "testdata".getBytes(); + metadata.addDataWithOffsetAndLength(data, 4, 4); + assertEquals(4, metadata.getBytes()); + assertEquals(CelebornCRC32.compute(data, 4, 4), metadata.getChecksum()); + } + + @Test + public void testAddCommitData() { + CommitMetadata metadata1 = new CommitMetadata(); + byte[] data1 = "test".getBytes(); + metadata1.addDataWithOffsetAndLength(data1, 0, data1.length); + + CommitMetadata metadata2 = new CommitMetadata(); + byte[] data2 = "data".getBytes(); + metadata2.addDataWithOffsetAndLength(data2, 0, data2.length); + + metadata1.addCommitData(metadata2); + + // Verify that the metadata1 now contains the combined data and checksum of data1 and data2 + assertEquals(data1.length + data2.length, metadata1.getBytes()); + assertEquals( + CelebornCRC32.combine(CelebornCRC32.compute(data1), CelebornCRC32.compute(data2)), + metadata1.getChecksum()); + } + + @Test + public void testCheckCommitMetadata() { + CommitMetadata expected = new CommitMetadata(CelebornCRC32.compute("testdata".getBytes()), 8); + CommitMetadata actualMatching = + new CommitMetadata(CelebornCRC32.compute("testdata".getBytes()), 8); + CommitMetadata actualNonMatchingBytesOnly = + new CommitMetadata(CelebornCRC32.compute("testdata".getBytes()), 12); + CommitMetadata actualNonMatchingChecksumOnly = + new CommitMetadata(CelebornCRC32.compute("foo".getBytes()), 8); + CommitMetadata actualNonMatching = + new CommitMetadata(CelebornCRC32.compute("bar".getBytes()), 16); + + Assert.assertTrue(CommitMetadata.checkCommitMetadata(expected, actualMatching)); + Assert.assertFalse(CommitMetadata.checkCommitMetadata(expected, actualNonMatchingBytesOnly)); + Assert.assertFalse(CommitMetadata.checkCommitMetadata(expected, actualNonMatchingChecksumOnly)); + Assert.assertFalse(CommitMetadata.checkCommitMetadata(expected, actualNonMatching)); + } +} diff --git a/common/src/test/scala/org/apache/celeborn/common/util/UtilsSuite.scala b/common/src/test/scala/org/apache/celeborn/common/util/UtilsSuite.scala index a9097e197..7cadaf07e 100644 --- a/common/src/test/scala/org/apache/celeborn/common/util/UtilsSuite.scala +++ b/common/src/test/scala/org/apache/celeborn/common/util/UtilsSuite.scala @@ -20,6 +20,9 @@ package org.apache.celeborn.common.util import java.util import java.util.Collections +import org.scalatest.matchers.must.Matchers.contain +import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper + import org.apache.celeborn.CelebornFunSuite import org.apache.celeborn.common.CelebornConf import org.apache.celeborn.common.client.{MasterEndpointResolver, StaticMasterEndpointResolver} @@ -145,10 +148,19 @@ class UtilsSuite extends CelebornFunSuite { } test("MapperEnd class convert with pb") { - val mapperEnd = MapperEnd(1, 1, 1, 2, 1, Collections.emptyMap()) + val mapperEnd = + MapperEnd(1, 1, 1, 2, 1, Collections.emptyMap(), 1, Array.emptyIntArray, Array.emptyLongArray) val mapperEndTrans = Utils.fromTransportMessage(Utils.toTransportMessage(mapperEnd)).asInstanceOf[MapperEnd] - assert(mapperEnd == mapperEndTrans) + assert(mapperEnd.shuffleId == mapperEndTrans.shuffleId) + assert(mapperEnd.mapId == mapperEndTrans.mapId) + assert(mapperEnd.attemptId == mapperEndTrans.attemptId) + assert(mapperEnd.numMappers == mapperEndTrans.numMappers) + assert(mapperEnd.partitionId == mapperEndTrans.partitionId) + assert(mapperEnd.failedBatchSet == mapperEndTrans.failedBatchSet) + assert(mapperEnd.numPartitions == mapperEndTrans.numPartitions) + mapperEnd.crc32PerPartition.array should contain theSameElementsInOrderAs mapperEndTrans.crc32PerPartition + mapperEnd.bytesWrittenPerPartition.array should contain theSameElementsInOrderAs mapperEndTrans.bytesWrittenPerPartition } test("validate HDFS compatible fs path") { diff --git a/docs/configuration/client.md b/docs/configuration/client.md index 5c4cbf006..96dc7945d 100644 --- a/docs/configuration/client.md +++ b/docs/configuration/client.md @@ -105,6 +105,7 @@ license: | | celeborn.client.shuffle.dynamicResourceEnabled | false | false | When enabled, the ChangePartitionManager will obtain candidate workers from the availableWorkers pool during heartbeats when worker resource change. | 0.6.0 | | | celeborn.client.shuffle.dynamicResourceFactor | 0.5 | false | The ChangePartitionManager will check whether (unavailable workers / shuffle allocated workers) is more than the factor before obtaining candidate workers from the requestSlots RPC response when `celeborn.client.shuffle.dynamicResourceEnabled` set true | 0.6.0 | | | celeborn.client.shuffle.expired.checkInterval | 60s | false | Interval for client to check expired shuffles. | 0.3.0 | celeborn.shuffle.expired.checkInterval | +| celeborn.client.shuffle.integrityCheck.enabled | false | false | When `true`, enables end-to-end integrity checks for Spark workloads. | 0.6.1 | | | celeborn.client.shuffle.manager.port | 0 | false | Port used by the LifecycleManager on the Driver. | 0.3.0 | celeborn.shuffle.manager.port | | celeborn.client.shuffle.partition.type | REDUCE | false | Type of shuffle's partition. | 0.3.0 | celeborn.shuffle.partition.type | | celeborn.client.shuffle.partitionSplit.mode | SOFT | false | soft: the shuffle file size might be larger than split threshold. hard: the shuffle file size will be limited to split threshold. | 0.3.0 | celeborn.shuffle.partitionSplit.mode | diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerCommitFilesSuite.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerCommitFilesSuite.scala index 39b54b5f6..fa505e29f 100644 --- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerCommitFilesSuite.scala +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerCommitFilesSuite.scala @@ -70,8 +70,9 @@ class LifecycleManagerCommitFilesSuite extends WithShuffleClientSuite with MiniC res.workerResource, updateEpoch = false) - lifecycleManager.commitManager.registerShuffle(shuffleId, 1, false) - 0 until 10 foreach { partitionId => + val numPartitions = 10 + lifecycleManager.commitManager.registerShuffle(shuffleId, 1, false, numPartitions) + 0 until numPartitions foreach { partitionId => lifecycleManager.commitManager.finishMapperAttempt(shuffleId, 0, 0, 1, partitionId) } @@ -126,8 +127,9 @@ class LifecycleManagerCommitFilesSuite extends WithShuffleClientSuite with MiniC res.workerResource, updateEpoch = false) - lifecycleManager.commitManager.registerShuffle(shuffleId, 1, false) - 0 until 10 foreach { partitionId => + val numPartitions = 10 + lifecycleManager.commitManager.registerShuffle(shuffleId, 1, false, numPartitions) + 0 until numPartitions foreach { partitionId => lifecycleManager.commitManager.finishMapperAttempt(shuffleId, 0, 0, 1, partitionId) } @@ -196,8 +198,9 @@ class LifecycleManagerCommitFilesSuite extends WithShuffleClientSuite with MiniC res.workerResource, updateEpoch = false) - lifecycleManager.commitManager.registerShuffle(shuffleId, 1, false) - 0 until 1000 foreach { partitionId => + val numPartitions = 1000 + lifecycleManager.commitManager.registerShuffle(shuffleId, 1, false, numPartitions) + 0 until numPartitions foreach { partitionId => lifecycleManager.commitManager.finishMapperAttempt(shuffleId, 0, 0, 1, partitionId) } @@ -256,13 +259,14 @@ class LifecycleManagerCommitFilesSuite extends WithShuffleClientSuite with MiniC res.workerResource, updateEpoch = false) - lifecycleManager.commitManager.registerShuffle(shuffleId, 1, false) + val numPartitions = 3 + lifecycleManager.commitManager.registerShuffle(shuffleId, 1, false, numPartitions) val buffer = "hello world".getBytes(StandardCharsets.UTF_8) var bufferLength = -1 - 0 until 3 foreach { partitionId => + 0 until numPartitions foreach { partitionId => bufferLength = shuffleClient.pushData(shuffleId, 0, 0, partitionId, buffer, 0, buffer.length, 1, 3) lifecycleManager.commitManager.finishMapperAttempt(shuffleId, 0, 0, 1, partitionId) diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerReserveSlotsSuite.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerReserveSlotsSuite.scala index ee9317a5e..1626f2226 100644 --- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerReserveSlotsSuite.scala +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerReserveSlotsSuite.scala @@ -157,7 +157,7 @@ class LifecycleManagerReserveSlotsSuite extends AnyFunSuite // push merged data, we expect that partition(0) will be split, while partition(1) will not be split shuffleClient1.pushMergedData(SHUFFLE_ID, MAP_ID, ATTEMPT_ID) - shuffleClient1.mapperEnd(SHUFFLE_ID, MAP_ID, ATTEMPT_ID, MAP_NUM) + shuffleClient1.mapperEnd(SHUFFLE_ID, MAP_ID, ATTEMPT_ID, MAP_NUM, PARTITION_NUM) // partition(1) will not be split assert(partitionLocationMap1.get(partitions(1)).getEpoch == 0) diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornIntegrityCheckSuite.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornIntegrityCheckSuite.scala new file mode 100644 index 000000000..ef382d75c --- /dev/null +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/CelebornIntegrityCheckSuite.scala @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.tests.spark + +import org.apache.spark._ +import org.apache.spark.shuffle.celeborn.{SparkUtils, TestCelebornShuffleManager} +import org.apache.spark.sql.SparkSession +import org.scalatest.BeforeAndAfterEach +import org.scalatest.funsuite.AnyFunSuite +import org.slf4j.{Logger, LoggerFactory} + +import org.apache.celeborn.client.ShuffleClient +import org.apache.celeborn.common.CelebornConf +import org.apache.celeborn.common.protocol.{CompressionCodec, ShuffleMode} + +class CelebornIntegrityCheckSuite extends AnyFunSuite + with SparkTestBase + with BeforeAndAfterEach { + + private val logger = LoggerFactory.getLogger(classOf[CelebornIntegrityCheckSuite]) + + override def beforeEach(): Unit = { + ShuffleClient.reset() + } + + override def afterEach(): Unit = { + System.gc() + } + + test("celeborn spark integration test - corrupted data, no integrity check - app succeeds but data is different") { + if (Spark3OrNewer) { + val sparkConf = new SparkConf().setAppName("rss-demo").setMaster("local[2,3]") + val sparkSession = SparkSession.builder() + .config(updateSparkConf(sparkConf, ShuffleMode.HASH)) + .config("spark.sql.shuffle.partitions", 2) + .config( + "spark.shuffle.manager", + "org.apache.spark.shuffle.celeborn.TestCelebornShuffleManager") + .config( + s"spark.${CelebornConf.SHUFFLE_COMPRESSION_CODEC.key}", + CompressionCodec.NONE.toString) + .getOrCreate() + + // Introduce Data Corruption in single bit in 1 partition location file + val celebornConf = SparkUtils.fromSparkConf(sparkSession.sparkContext.getConf) + val hook = new ShuffleReaderGetHookForCorruptedData(celebornConf, workerDirs) + TestCelebornShuffleManager.registerReaderGetHook(hook) + + val value = Range(1, 10000).mkString(",") + val tuples = sparkSession.sparkContext.parallelize(1 to 1000, 2) + .map { i => (i, value) }.groupByKey(16).collect() + + assert(tuples.length == 1000) + + try { + for (elem <- tuples) { + assert(elem._2.mkString(",").equals(value)) + } + } catch { + case e: Throwable => + e.getMessage.contains("elem._2.mkString(\",\").equals(value) was false") + } finally { + sparkSession.stop() + } + } + } + + test("celeborn spark integration test - corrupted data, integrity checks enabled, no stage rerun - app fails") { + if (Spark3OrNewer) { + val sparkConf = new SparkConf().setAppName("rss-demo").setMaster("local[2,3]") + val sparkSession = SparkSession.builder() + .config(updateSparkConf(sparkConf, ShuffleMode.HASH)) + .config("spark.sql.shuffle.partitions", 2) + .config( + s"spark.${CelebornConf.SHUFFLE_COMPRESSION_CODEC.key}", + CompressionCodec.NONE.toString) + .config(s"spark.${CelebornConf.CLIENT_STAGE_RERUN_ENABLED.key}", "false") + .config("spark.celeborn.client.shuffle.integrityCheck.enabled", "true") + .config( + "spark.shuffle.manager", + "org.apache.spark.shuffle.celeborn.TestCelebornShuffleManager") + .getOrCreate() + + // Introduce Data Corruption in single bit in 1 partition location file + val celebornConf = SparkUtils.fromSparkConf(sparkSession.sparkContext.getConf) + val hook = new ShuffleReaderGetHookForCorruptedData(celebornConf, workerDirs) + TestCelebornShuffleManager.registerReaderGetHook(hook) + + try { + val value = Range(1, 10000).mkString(",") + val tuples = sparkSession.sparkContext.parallelize(1 to 1000, 2) + .map { i => (i, value) }.groupByKey(16).collect() + fail("App should abort prior to this step and throw an exception") + } catch { + // verify that the app fails + case e: Throwable => { + logger.error("Expected exception, logging the full exception", e) + assert(e.getMessage.contains("Job aborted")) + assert(e.getMessage.contains("CommitMetadata mismatch")) + } + } finally { + sparkSession.stop() + } + } + } + + test("celeborn spark integration test - corrupted data, integrity checks enabled, stage rerun enabled - app succeeds") { + if (Spark3OrNewer) { + val sparkConf = new SparkConf().setAppName("rss-demo").setMaster("local[2,3]") + val sparkSession = SparkSession.builder() + .config(updateSparkConf(sparkConf, ShuffleMode.HASH)) + .config("spark.sql.shuffle.partitions", 2) + .config( + s"spark.${CelebornConf.SHUFFLE_COMPRESSION_CODEC.key}", + CompressionCodec.NONE.toString) + .config(s"spark.${CelebornConf.CLIENT_STAGE_RERUN_ENABLED.key}", "true") + .config("spark.celeborn.client.shuffle.integrityCheck.enabled", "true") + .config( + "spark.shuffle.manager", + "org.apache.spark.shuffle.celeborn.TestCelebornShuffleManager") + .getOrCreate() + + // Introduce Data Corruption in single bit in 1 partition location file + val celebornConf = SparkUtils.fromSparkConf(sparkSession.sparkContext.getConf) + val hook = new ShuffleReaderGetHookForCorruptedData(celebornConf, workerDirs) + TestCelebornShuffleManager.registerReaderGetHook(hook) + + try { + val value = Range(1, 10000).mkString(",") + val tuples = sparkSession.sparkContext.parallelize(1 to 1000, 2) + .map { i => (i, value) }.groupByKey(16).collect() + + // verify result + assert(tuples.length == 1000) + for (elem <- tuples) { + assert(elem._2.mkString(",").equals(value)) + } + } finally { + sparkSession.stop() + } + } + } +} diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/ShuffleReaderGetHookForCorruptedData.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/ShuffleReaderGetHookForCorruptedData.scala new file mode 100644 index 000000000..99ab8e3aa --- /dev/null +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/ShuffleReaderGetHookForCorruptedData.scala @@ -0,0 +1,252 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.tests.spark + +import java.io.{File, RandomAccessFile} +import java.util.Random +import java.util.concurrent.atomic.AtomicBoolean + +import scala.util.control.Breaks.{break, breakable} + +import org.apache.spark.TaskContext +import org.apache.spark.shuffle.ShuffleHandle +import org.apache.spark.shuffle.celeborn.{CelebornShuffleHandle, ShuffleManagerHook, SparkCommonUtils, SparkUtils} +import org.slf4j.LoggerFactory + +import org.apache.celeborn.client.ShuffleClient +import org.apache.celeborn.common.CelebornConf +import org.apache.celeborn.common.unsafe.Platform + +class ShuffleReaderGetHookForCorruptedData( + conf: CelebornConf, + workerDirs: Seq[String], + shuffleIdToBeModified: Seq[Int] = Seq(), + triggerStageId: Option[Int] = None) + extends ShuffleManagerHook { + + private val logger = LoggerFactory.getLogger(classOf[ShuffleReaderGetHookForCorruptedData]) + var executed: AtomicBoolean = new AtomicBoolean(false) + val lock = new Object + var corruptedCount = 0 + val random = new Random() + + private def modifyDataFileWithSingleBitFlip(appUniqueId: String, celebornShuffleId: Int): Unit = { + if (corruptedCount > 0) { + return + } + + val minFileSize = 16 // Minimum file size to be considered for corruption + + // Find all potential data files across all worker directories + val dataFiles = + workerDirs.flatMap(dir => { + val shuffleDir = + new File(s"$dir/celeborn-worker/shuffle_data/$appUniqueId/$celebornShuffleId") + if (shuffleDir.exists()) { + shuffleDir.listFiles().filter(_.isFile).filter(_.length() >= minFileSize) + } else { + Array.empty[File] + } + }) + + if (dataFiles.isEmpty) { + logger.error( + s"No suitable data files found for appUniqueId=$appUniqueId, shuffleId=$celebornShuffleId") + return + } + + // Sort files by size (descending) to prioritize larger files that are more likely to have data + val sortedFiles = dataFiles.sortBy(-_.length()) + logger.info(s"Found ${sortedFiles.length} data files for corruption testing.") + + // Try to corrupt files one by one until successful + breakable { + for (file <- sortedFiles) { + if (tryCorruptFile(file)) { + corruptedCount = 1 + break + } + } + } + // If we couldn't corrupt any file through the normal process, use the fallback method + if (corruptedCount == 0 && sortedFiles.nonEmpty) { + logger.warn( + "Could not find a valid data section in any file. Using safer fallback corruption method.") + val fileToCorrupt = sortedFiles.head // Take the largest file for fallback + if (fallbackCorruption(fileToCorrupt)) { + corruptedCount = 1 + } + } + } + + /** + * Try to corrupt a specific file by finding and corrupting a valid data section. + * @return true if corruption was successful, false otherwise + */ + private def tryCorruptFile(file: File): Boolean = { + + val random = new Random() + logger.info(s"Attempting to corrupt file: ${file.getPath()}, size: ${file.length()} bytes") + val raf = new RandomAccessFile(file, "rw") + try { + val fileSize = raf.length() + + // Find a valid data section in the file + var position: Long = 0 + val headerSize = 16 + + breakable { + while (position + headerSize <= fileSize) { + raf.seek(position) + val sizeBuf = new Array[Byte](headerSize) + raf.readFully(sizeBuf) + + // Parse header + val mapId: Int = Platform.getInt(sizeBuf, Platform.BYTE_ARRAY_OFFSET) + val attemptId: Int = Platform.getInt(sizeBuf, Platform.BYTE_ARRAY_OFFSET + 4) + val batchId: Int = Platform.getInt(sizeBuf, Platform.BYTE_ARRAY_OFFSET + 8) + val dataSize: Int = Platform.getInt(sizeBuf, Platform.BYTE_ARRAY_OFFSET + 12) + + logger.info(s"Found header at position $position: mapId=$mapId, attemptId=$attemptId, " + + s"batchId=$batchId, dataSize=$dataSize") + + // Validate dataSize - should be positive and within file bounds + if (dataSize > 0 && position + headerSize + dataSize <= fileSize) { + val dataStartPosition = position + headerSize + + // Choose a random position within the data section + val offsetInData = random.nextInt(dataSize) + val corruptPosition = dataStartPosition + offsetInData + + // Read the byte at the position we want to corrupt + raf.seek(corruptPosition) + val originalByte = raf.readByte() + + // Flip one bit in the byte (avoid the sign bit for numeric values) + val bitToFlip = 1 << random.nextInt(7) // bits 0-6 only, avoiding the sign bit + val corruptedByte = (originalByte ^ bitToFlip).toByte + + // Write back the corrupted byte + raf.seek(corruptPosition) + raf.writeByte(corruptedByte) + + logger.info(s"Successfully corrupted byte in file ${file.getName()} at position " + + s"$corruptPosition: ${originalByte & 0xFF} -> ${corruptedByte & 0xFF} (flipped bit $bitToFlip)") + + return true // Corruption successful + } + + // Skip to next record: header size + data size (even if data size is 0) + position += headerSize + Math.max(0, dataSize) + } + } + + false // No valid data section found + } catch { + case e: Exception => + logger.error(s"Error while attempting to corrupt file ${file.getPath()}", e) + false + } finally { + raf.close() + } + } + + /** + * Fallback method for corruption when we can't find a valid data section. + * This method simply corrupts a byte in the middle of the file. + */ + private def fallbackCorruption(file: File): Boolean = { + val raf = new RandomAccessFile(file, "rw") + try { + val fileSize = raf.length() + + // Skip first 64 bytes to avoid headers, and corrupt somewhere in the middle of the file + val safeStartPosition = Math.min(64, fileSize / 4) + val safeEndPosition = Math.max(safeStartPosition + 1, fileSize - 16) + + // Ensure we have a valid range + if (safeEndPosition <= safeStartPosition) { + logger.error(s"File ${file.getName()} too small for safe corruption: $fileSize bytes") + return false + } + + val corruptPosition = safeStartPosition + + random.nextInt((safeEndPosition - safeStartPosition).toInt) + + raf.seek(corruptPosition) + val originalByte = raf.readByte() + + val bitToFlip = 1 << random.nextInt(7) + val corruptedByte = (originalByte ^ bitToFlip).toByte + + raf.seek(corruptPosition) + raf.writeByte(corruptedByte) + + logger.info(s"Used fallback corruption approach on file ${file.getName()}. " + + s"Corrupted byte at position $corruptPosition: ${originalByte & 0xFF} -> ${corruptedByte & 0xFF}") + + true + } catch { + case e: Exception => + logger.error(s"Error during fallback corruption of file ${file.getPath()}", e) + false + } finally { + raf.close() + } + } + + override def exec( + handle: ShuffleHandle, + startPartition: Int, + endPartition: Int, + context: TaskContext): Unit = { + if (executed.get()) { + return + } + lock.synchronized { + handle match { + case h: CelebornShuffleHandle[_, _, _] => { + val appUniqueId = h.appUniqueId + val shuffleClient = ShuffleClient.get( + h.appUniqueId, + h.lifecycleManagerHost, + h.lifecycleManagerPort, + conf, + h.userIdentifier, + h.extension) + val celebornShuffleId = SparkUtils.celebornShuffleId(shuffleClient, h, context, false) + val appShuffleIdentifier = + SparkCommonUtils.encodeAppShuffleIdentifier(handle.shuffleId, context) + val Array(_, stageId, _) = appShuffleIdentifier.split('-') + if (triggerStageId.isEmpty || triggerStageId.get == stageId.toInt) { + if (shuffleIdToBeModified.isEmpty) { + modifyDataFileWithSingleBitFlip(appUniqueId, celebornShuffleId) + } else { + shuffleIdToBeModified.foreach { shuffleId => + modifyDataFileWithSingleBitFlip(appUniqueId, shuffleId) + } + } + executed.set(true) + } + } + case x => throw new RuntimeException(s"unexpected, only support RssShuffleHandle here," + + s" but get ${x.getClass.getCanonicalName}") + } + } + } +} diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SkewJoinSuite.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SkewJoinSuite.scala index f1a5d5e77..854bd0e27 100644 --- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SkewJoinSuite.scala +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SkewJoinSuite.scala @@ -49,70 +49,74 @@ class SkewJoinSuite extends AnyFunSuite CompressionCodec.values.foreach { codec => Seq(false, true).foreach { enabled => - test( - s"celeborn spark integration test - skew join - with $codec - with client skew $enabled") { - val sparkConf = new SparkConf().setAppName("celeborn-demo") - .setMaster("local[2]") - .set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true") - .set("spark.sql.adaptive.skewJoin.enabled", "true") - .set("spark.sql.adaptive.skewJoin.skewedPartitionThresholdInBytes", "16MB") - .set("spark.sql.adaptive.advisoryPartitionSizeInBytes", "12MB") - .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") - .set("spark.sql.adaptive.autoBroadcastJoinThreshold", "-1") - .set(SQLConf.PARQUET_COMPRESSION.key, "gzip") - .set(s"spark.${CelebornConf.SHUFFLE_COMPRESSION_CODEC.key}", codec.name) - .set(s"spark.${CelebornConf.SHUFFLE_RANGE_READ_FILTER_ENABLED.key}", "true") - .set( - s"spark.${CelebornConf.CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ_ENABLED.key}", - s"$enabled") + Seq(false, true).foreach { integrityChecksEnabled => + test( + s"celeborn spark integration test - skew join - with $codec - with client skew $enabled - with integrity checks $integrityChecksEnabled") { + val sparkConf = new SparkConf().setAppName("celeborn-demo") + .setMaster("local[2]") + .set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true") + .set("spark.sql.adaptive.skewJoin.enabled", "true") + .set("spark.sql.adaptive.skewJoin.skewedPartitionThresholdInBytes", "16MB") + .set("spark.sql.adaptive.advisoryPartitionSizeInBytes", "12MB") + .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + .set("spark.sql.adaptive.autoBroadcastJoinThreshold", "-1") + .set(SQLConf.PARQUET_COMPRESSION.key, "gzip") + .set(s"spark.${CelebornConf.SHUFFLE_COMPRESSION_CODEC.key}", codec.name) + .set(s"spark.${CelebornConf.SHUFFLE_RANGE_READ_FILTER_ENABLED.key}", "true") + .set( + s"spark.${CelebornConf.CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ_ENABLED.key}", + s"$enabled") + .set( + s"spark.${CelebornConf.CLIENT_SHUFFLE_INTEGRITY_CHECK_ENABLED.key}", + s"$integrityChecksEnabled") - enableCeleborn(sparkConf) + enableCeleborn(sparkConf) - val sparkSession = SparkSession.builder().config(sparkConf).getOrCreate() - if (sparkSession.version.startsWith("3")) { - import sparkSession.implicits._ - val df = sparkSession.sparkContext.parallelize(1 to 120000, 8) - .map(i => { - val random = new Random() - val oriKey = random.nextInt(64) - val key = if (oriKey < 32) 1 else oriKey - val fas = random.nextInt(1200000) - val fa = Range(fas, fas + 100).mkString(",") - val fbs = random.nextInt(1200000) - val fb = Range(fbs, fbs + 100).mkString(",") - val fcs = random.nextInt(1200000) - val fc = Range(fcs, fcs + 100).mkString(",") - val fds = random.nextInt(1200000) - val fd = Range(fds, fds + 100).mkString(",") + val sparkSession = SparkSession.builder().config(sparkConf).getOrCreate() + if (sparkSession.version.startsWith("3")) { + import sparkSession.implicits._ + val df = sparkSession.sparkContext.parallelize(1 to 120000, 8) + .map(i => { + val random = new Random() + val oriKey = random.nextInt(64) + val key = if (oriKey < 32) 1 else oriKey + val fas = random.nextInt(1200000) + val fa = Range(fas, fas + 100).mkString(",") + val fbs = random.nextInt(1200000) + val fb = Range(fbs, fbs + 100).mkString(",") + val fcs = random.nextInt(1200000) + val fc = Range(fcs, fcs + 100).mkString(",") + val fds = random.nextInt(1200000) + val fd = Range(fds, fds + 100).mkString(",") - (key, fa, fb, fc, fd) - }) - .toDF("fa", "f1", "f2", "f3", "f4") - df.createOrReplaceTempView("view1") - val df2 = sparkSession.sparkContext.parallelize(1 to 8, 8) - .map(i => { - val random = new Random() - val oriKey = random.nextInt(64) - val key = if (oriKey < 32) 1 else oriKey - val fas = random.nextInt(1200000) - val fa = Range(fas, fas + 100).mkString(",") - val fbs = random.nextInt(1200000) - val fb = Range(fbs, fbs + 100).mkString(",") - val fcs = random.nextInt(1200000) - val fc = Range(fcs, fcs + 100).mkString(",") - val fds = random.nextInt(1200000) - val fd = Range(fds, fds + 100).mkString(",") - (key, fa, fb, fc, fd) - }) - .toDF("fb", "f6", "f7", "f8", "f9") - df2.createOrReplaceTempView("view2") - sparkSession.sql("drop table if exists fres") - sparkSession.sql("create table fres using parquet as select * from view1 a inner join view2 b on a.fa=b.fb where a.fa=1 ") - sparkSession.sql("drop table fres") - sparkSession.stop() + (key, fa, fb, fc, fd) + }) + .toDF("fa", "f1", "f2", "f3", "f4") + df.createOrReplaceTempView("view1") + val df2 = sparkSession.sparkContext.parallelize(1 to 8, 8) + .map(i => { + val random = new Random() + val oriKey = random.nextInt(64) + val key = if (oriKey < 32) 1 else oriKey + val fas = random.nextInt(1200000) + val fa = Range(fas, fas + 100).mkString(",") + val fbs = random.nextInt(1200000) + val fb = Range(fbs, fbs + 100).mkString(",") + val fcs = random.nextInt(1200000) + val fc = Range(fcs, fcs + 100).mkString(",") + val fds = random.nextInt(1200000) + val fd = Range(fds, fds + 100).mkString(",") + (key, fa, fb, fc, fd) + }) + .toDF("fb", "f6", "f7", "f8", "f9") + df2.createOrReplaceTempView("view2") + sparkSession.sql("drop table if exists fres") + sparkSession.sql("create table fres using parquet as select * from view1 a inner join view2 b on a.fa=b.fb where a.fa=1 ") + sparkSession.sql("drop table fres") + sparkSession.stop() + } } } } } - } diff --git a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaReadCppWriteTestBase.scala b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaReadCppWriteTestBase.scala index a57ffebe8..997aa1f12 100644 --- a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaReadCppWriteTestBase.scala +++ b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/JavaReadCppWriteTestBase.scala @@ -109,7 +109,7 @@ trait JavaReadCppWriteTestBase extends AnyFunSuite } } shuffleClient.pushMergedData(shuffleId, mapId, attemptId) - shuffleClient.mapperEnd(shuffleId, mapId, attemptId, numMappers) + shuffleClient.mapperEnd(shuffleId, mapId, attemptId, numMappers, numPartitions) } // Launch cpp reader to read data, calculate result and write to specific result file. diff --git a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/LocalReadByChunkOffsetsTest.scala b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/LocalReadByChunkOffsetsTest.scala index 37f78efc4..24f745857 100644 --- a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/LocalReadByChunkOffsetsTest.scala +++ b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/LocalReadByChunkOffsetsTest.scala @@ -116,7 +116,7 @@ class LocalReadByChunkOffsetsTest extends AnyFunSuite shuffleClient.pushMergedData(1, 0, 0) Thread.sleep(1000) - shuffleClient.mapperEnd(1, 0, 0, 1) + shuffleClient.mapperEnd(1, 0, 0, 1, 0) val metricsCallback = new MetricsCallback { override def incBytesRead(bytesWritten: Long): Unit = {} diff --git a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/PushMergedDataSplitSuite.scala b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/PushMergedDataSplitSuite.scala index cdd0e758e..f095bb575 100644 --- a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/PushMergedDataSplitSuite.scala +++ b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/PushMergedDataSplitSuite.scala @@ -153,7 +153,7 @@ class PushMergedDataSplitSuite extends AnyFunSuite // push merged data, we expect that partition(0) will be split, while partition(1) will not be split shuffleClient.pushMergedData(SHUFFLE_ID, MAP_ID, ATTEMPT_ID) - shuffleClient.mapperEnd(SHUFFLE_ID, MAP_ID, ATTEMPT_ID, MAP_NUM) + shuffleClient.mapperEnd(SHUFFLE_ID, MAP_ID, ATTEMPT_ID, MAP_NUM, PARTITION_NUM) assert( partitionLocationMap.get(partitions(1)).getEpoch == 0 ) // means partition(1) will not be split diff --git a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestBase.scala b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestBase.scala index 1cb0a9fdf..dec30f8c6 100644 --- a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestBase.scala +++ b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestBase.scala @@ -98,7 +98,7 @@ trait ReadWriteTestBase extends AnyFunSuite shuffleClient.pushMergedData(1, 0, 0) Thread.sleep(1000) - shuffleClient.mapperEnd(1, 0, 0, 1) + shuffleClient.mapperEnd(1, 0, 0, 1, 1) val metricsCallback = new MetricsCallback { override def incBytesRead(bytesWritten: Long): Unit = {} diff --git a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestWithFailures.scala b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestWithFailures.scala index afd7b589a..063838067 100644 --- a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestWithFailures.scala +++ b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestWithFailures.scala @@ -92,7 +92,7 @@ class ReadWriteTestWithFailures extends AnyFunSuite shuffleClient.pushMergedData(1, 0, 0) Thread.sleep(1000) - shuffleClient.mapperEnd(1, 0, 0, 1) + shuffleClient.mapperEnd(1, 0, 0, 1, 1) var duplicateBytesRead = new AtomicLong(0) val metricsCallback = new MetricsCallback {