[CELEBORN-894] End to End Integrity Checks

### What changes were proposed in this pull request?
Design doc - https://docs.google.com/document/d/1YqK0kua-5rMufJw57kEIrHHGbLnAF9iXM5GdDweMzzg/edit?tab=t.0#heading=h.n5ldma432qnd

- End to End integrity checks provide additional confidence that Celeborn is producing complete as well as correct data
- The checks are hidden behind a client side config that is false by default. Provides users optionality to enable these if required on a per app basis
- Only compatible with Spark at the moment
- No support for Flink (can be considered in future)
- No support for Columnar Shuffle (can be considered in future)

Writer
- Whenever a mapper completes, it reports crc32 and bytes written on a per partition basis to the driver

Driver
- Driver aggregates the mapper reports - and computes aggregated CRC32 and bytes written on per partitionID basis

Reader
- Each CelebornInputStream will report (int shuffleId, int partitionId, int startMapIndex, int endMapIndex, int crc32, long bytes) to driver when it finished reading all data on the stream
- On every report
  - Driver will aggregate the CRC32 and bytesRead for the partitionID
  - Driver will aggregate mapRange to determine when all sub-paritions have been read for partitionID have been read
  - It will then compare the aggregated CRC32 and bytes read with the expected CRC32 and bytes written for the partition
  - There is special handling for skewhandlingwithoutMapRangeSplit scenario as well
  - In this case, we report the number of sub-partitions and index of the sub-partition instead of startMapIndex and endMapIndex

There is separate handling for skew handling with and without map range split

As a follow up, I will do another PR that will harden up the checks and perform additional checks to add book keeping that every CelebornInputStream makes the required checks

### Why are the changes needed?
https://issues.apache.org/jira/browse/CELEBORN-894

Note: I am putting up this PR even though some tests are failing, since I want to get some early feedback on the code changes.

### Does this PR introduce _any_ user-facing change?
Not sure how to answer this. A new client side config is available to enable the checks if required

### How was this patch tested?
Unit tests + Integration tests

Closes #3261 from gauravkm/gaurav/e2e_checks_v3.

Lead-authored-by: Gaurav Mittal <gaurav@stripe.com>
Co-authored-by: Gaurav Mittal <gauravkm@gmail.com>
Co-authored-by: Fei Wang <cn.feiwang@gmail.com>
Signed-off-by: Shuang <lvshuang.xjs@alibaba-inc.com>
This commit is contained in:
Gaurav Mittal 2025-06-28 09:19:57 +08:00 committed by Shuang
parent 7a0eee332a
commit cde33d953b
47 changed files with 2375 additions and 143 deletions

View File

@ -197,7 +197,7 @@ public class RemoteShuffleOutputGate {
/** Indicates the writing/spilling is finished. */
public void finish() throws InterruptedException, IOException {
flinkShuffleClient.mapPartitionMapperEnd(
shuffleId, mapId, attemptId, numMappers, partitionLocation.getId());
shuffleId, mapId, attemptId, numMappers, numSubs, partitionLocation.getId());
}
/** Close the transportation gate. */

View File

@ -107,7 +107,7 @@ public class RemoteShuffleOutputGateSuiteJ {
doNothing()
.when(remoteShuffleOutputGate.flinkShuffleClient)
.mapperEnd(anyInt(), anyInt(), anyInt(), anyInt());
.mapperEnd(anyInt(), anyInt(), anyInt(), anyInt(), anyInt());
remoteShuffleOutputGate.finish();
doNothing().when(remoteShuffleOutputGate.flinkShuffleClient).shutdown();

View File

@ -234,7 +234,12 @@ public class CelebornTierProducerAgent implements TierProducerAgent {
try {
if (hasRegisteredShuffle && partitionLocation != null) {
flinkShuffleClient.mapPartitionMapperEnd(
shuffleId, mapId, attemptId, numPartitions, partitionLocation.getId());
shuffleId,
mapId,
attemptId,
numPartitions,
numSubPartitions,
partitionLocation.getId());
}
} catch (Exception e) {
Utils.rethrowAsRuntimeException(e);

View File

@ -234,7 +234,12 @@ public class CelebornTierProducerAgent implements TierProducerAgent {
try {
if (hasRegisteredShuffle && partitionLocation != null) {
flinkShuffleClient.mapPartitionMapperEnd(
shuffleId, mapId, attemptId, numPartitions, partitionLocation.getId());
shuffleId,
mapId,
attemptId,
numPartitions,
numSubPartitions,
partitionLocation.getId());
}
} catch (Exception e) {
Utils.rethrowAsRuntimeException(e);

View File

@ -316,7 +316,7 @@ public class CelebornSortBasedPusher<K, V> extends OutputStream {
mapId,
attempt,
numMappers);
shuffleClient.mapperEnd(0, mapId, attempt, numMappers);
shuffleClient.mapperEnd(0, mapId, attempt, numMappers, numReducers);
} catch (IOException e) {
exception.compareAndSet(null, e);
}

View File

@ -369,7 +369,7 @@ public class HashBasedShuffleWriter<K, V, C> extends ShuffleWriter<K, V> {
sendOffsets = null;
long waitStartTime = System.nanoTime();
shuffleClient.mapperEnd(shuffleId, mapId, encodedAttemptId, numMappers);
shuffleClient.mapperEnd(shuffleId, mapId, encodedAttemptId, numMappers, numPartitions);
writeMetrics.incWriteTime(System.nanoTime() - waitStartTime);
BlockManagerId bmId = SparkEnv.get().blockManager().shuffleServerId();

View File

@ -316,7 +316,7 @@ public class SortBasedShuffleWriter<K, V, C> extends ShuffleWriter<K, V> {
updateMapStatus();
long waitStartTime = System.nanoTime();
shuffleClient.mapperEnd(shuffleId, mapId, encodedAttemptId, numMappers);
shuffleClient.mapperEnd(shuffleId, mapId, encodedAttemptId, numMappers, numPartitions);
writeMetrics.incWriteTime(System.nanoTime() - waitStartTime);
}

View File

@ -378,7 +378,7 @@ public class HashBasedShuffleWriter<K, V, C> extends ShuffleWriter<K, V> {
updateRecordsWrittenMetrics();
long waitStartTime = System.nanoTime();
shuffleClient.mapperEnd(shuffleId, mapId, encodedAttemptId, numMappers);
shuffleClient.mapperEnd(shuffleId, mapId, encodedAttemptId, numMappers, numPartitions);
writeMetrics.incWriteTime(System.nanoTime() - waitStartTime);
BlockManagerId bmId = SparkEnv.get().blockManager().shuffleServerId();

View File

@ -381,7 +381,7 @@ public class SortBasedShuffleWriter<K, V, C> extends ShuffleWriter<K, V> {
writeMetrics.incRecordsWritten(tmpRecordsWritten);
long waitStartTime = System.nanoTime();
shuffleClient.mapperEnd(shuffleId, mapId, encodedAttemptId, numMappers);
shuffleClient.mapperEnd(shuffleId, mapId, encodedAttemptId, numMappers, numPartitions);
writeMetrics.incWriteTime(System.nanoTime() - waitStartTime);
}

View File

@ -120,7 +120,7 @@ public class CelebornTezWriter {
try {
dataPusher.waitOnTermination();
shuffleClient.pushMergedData(shuffleId, mapId, attemptNumber);
shuffleClient.mapperEnd(shuffleId, mapId, attemptNumber, numMappers);
shuffleClient.mapperEnd(shuffleId, mapId, attemptNumber, numMappers, numPartitions);
} catch (InterruptedException e) {
throw new IOInterruptedException(e);
}

View File

@ -115,11 +115,17 @@ public class DummyShuffleClient extends ShuffleClient {
public void pushMergedData(int shuffleId, int mapId, int attemptId) {}
@Override
public void mapperEnd(int shuffleId, int mapId, int attemptId, int numMappers) {}
public void mapperEnd(
int shuffleId, int mapId, int attemptId, int numMappers, int numPartitions) {}
@Override
public void readReducerPartitionEnd(
int shuffleId, int partitionId, int startMapIndex, int endMapIndex, int crc32, long bytes)
throws IOException {}
@Override
public void mapPartitionMapperEnd(
int shuffleId, int mapId, int attemptId, int numMappers, int partitionId)
int shuffleId, int mapId, int attemptId, int numMappers, int numPartitions, int partitionId)
throws IOException {}
@Override

View File

@ -203,13 +203,19 @@ public abstract class ShuffleClient {
public abstract void pushMergedData(int shuffleId, int mapId, int attemptId) throws IOException;
// Report partition locations written by the completed map task of ReducePartition Shuffle Type
public abstract void mapperEnd(int shuffleId, int mapId, int attemptId, int numMappers)
// Report partition locations written by the completed map task of ReducePartition Shuffle Type.
public abstract void mapperEnd(
int shuffleId, int mapId, int attemptId, int numMappers, int numPartitions)
throws IOException;
// Report partition locations written by the completed map task of MapPartition Shuffle Type
public abstract void readReducerPartitionEnd(
int shuffleId, int partitionId, int startMapIndex, int endMapIndex, int crc32, long bytes)
throws IOException;
// Report partition locations written by the completed map task of MapPartition Shuffle Type.
public abstract void mapPartitionMapperEnd(
int shuffleId, int mapId, int attemptId, int numMappers, int partitionId) throws IOException;
int shuffleId, int mapId, int attemptId, int numMappers, int numPartitions, int partitionId)
throws IOException;
// Cleanup states of the map task
public abstract void cleanup(int shuffleId, int mapId, int attemptId);

View File

@ -50,10 +50,7 @@ import org.apache.celeborn.common.identity.UserIdentifier;
import org.apache.celeborn.common.metrics.source.Role;
import org.apache.celeborn.common.network.TransportContext;
import org.apache.celeborn.common.network.buffer.NettyManagedBuffer;
import org.apache.celeborn.common.network.client.RpcResponseCallback;
import org.apache.celeborn.common.network.client.TransportClient;
import org.apache.celeborn.common.network.client.TransportClientBootstrap;
import org.apache.celeborn.common.network.client.TransportClientFactory;
import org.apache.celeborn.common.network.client.*;
import org.apache.celeborn.common.network.protocol.*;
import org.apache.celeborn.common.network.protocol.SerdeVersion;
import org.apache.celeborn.common.network.sasl.SaslClientBootstrap;
@ -122,6 +119,8 @@ public class ShuffleClientImpl extends ShuffleClient {
private final boolean pushExcludeWorkerOnFailureEnabled;
private final boolean shuffleCompressionEnabled;
private final boolean shuffleIntegrityCheckEnabled;
private final Set<String> pushExcludedWorkers = ConcurrentHashMap.newKeySet();
private final ConcurrentHashMap<String, Long> fetchExcludedWorkers =
JavaUtils.newConcurrentHashMap();
@ -203,6 +202,7 @@ public class ShuffleClientImpl extends ShuffleClient {
shuffleCompressionEnabled = !conf.shuffleCompressionCodec().equals(CompressionCodec.NONE);
pushReplicateEnabled = conf.clientPushReplicateEnabled();
fetchExcludeWorkerOnFailureEnabled = conf.clientFetchExcludeWorkerOnFailureEnabled();
shuffleIntegrityCheckEnabled = conf.clientShuffleIntegrityCheckEnabled();
if (conf.clientPushReplicateEnabled()) {
pushDataTimeout = conf.pushDataTimeoutMs() * 2;
} else {
@ -648,6 +648,35 @@ public class ShuffleClientImpl extends ShuffleClient {
});
}
@Override
public void readReducerPartitionEnd(
int shuffleId,
int partitionId,
int startMapIndex,
int endMapIndex,
int crc32,
long bytesWritten)
throws IOException {
PbReadReducerPartitionEnd pbReadReducerPartitionEnd =
PbReadReducerPartitionEnd.newBuilder()
.setShuffleId(shuffleId)
.setPartitionId(partitionId)
.setStartMaxIndex(startMapIndex)
.setEndMapIndex(endMapIndex)
.setCrc32(crc32)
.setBytesWritten(bytesWritten)
.build();
PbReadReducerPartitionEndResponse pbReducerPartitionEndResponse =
lifecycleManagerRef.askSync(
pbReadReducerPartitionEnd,
conf.clientRpcRegisterShuffleAskTimeout(),
ClassTag$.MODULE$.apply(PbReadReducerPartitionEndResponse.class));
if (pbReducerPartitionEndResponse.getStatus() != StatusCode.SUCCESS.getValue()) {
throw new CelebornIOException(pbReducerPartitionEndResponse.getErrorMsg());
}
}
@Override
public boolean reportShuffleFetchFailure(int appShuffleId, int shuffleId, long taskId) {
PbReportShuffleFetchFailure pbReportShuffleFetchFailure =
@ -1012,6 +1041,12 @@ public class ShuffleClientImpl extends ShuffleClient {
// increment batchId
final int nextBatchId = pushState.nextBatchId();
// Track commit metadata if shuffle compression and integrity check are enabled and this request
// is not for pushing metadata itself.
if (shuffleIntegrityCheckEnabled) {
pushState.addDataWithOffsetAndLength(partitionId, data, offset, length);
}
if (shuffleCompressionEnabled && !skipCompress) {
// compress data
final Compressor compressor = compressorThreadLocal.get();
@ -1725,19 +1760,25 @@ public class ShuffleClientImpl extends ShuffleClient {
}
@Override
public void mapperEnd(int shuffleId, int mapId, int attemptId, int numMappers)
public void mapperEnd(int shuffleId, int mapId, int attemptId, int numMappers, int numPartitions)
throws IOException {
mapEndInternal(shuffleId, mapId, attemptId, numMappers, -1);
mapEndInternal(shuffleId, mapId, attemptId, numMappers, numPartitions, -1);
}
@Override
public void mapPartitionMapperEnd(
int shuffleId, int mapId, int attemptId, int numMappers, int partitionId) throws IOException {
mapEndInternal(shuffleId, mapId, attemptId, numMappers, partitionId);
int shuffleId, int mapId, int attemptId, int numMappers, int numPartitions, int partitionId)
throws IOException {
mapEndInternal(shuffleId, mapId, attemptId, numMappers, numPartitions, partitionId);
}
private void mapEndInternal(
int shuffleId, int mapId, int attemptId, int numMappers, Integer partitionId)
int shuffleId,
int mapId,
int attemptId,
int numMappers,
int numPartitions,
Integer partitionId)
throws IOException {
final String mapKey = Utils.makeMapKey(shuffleId, mapId, attemptId);
PushState pushState = getPushState(mapKey);
@ -1745,6 +1786,12 @@ public class ShuffleClientImpl extends ShuffleClient {
try {
limitZeroInFlight(mapKey, pushState);
// send CRC32 and num bytes per partition if e2e checks are enabled
int[] crc32PerPartition =
pushState.getCRC32PerPartition(shuffleIntegrityCheckEnabled, numPartitions);
long[] bytesPerPartition =
pushState.getBytesWrittenPerPartition(shuffleIntegrityCheckEnabled, numPartitions);
MapperEndResponse response =
lifecycleManagerRef.askSync(
new MapperEnd(
@ -1753,7 +1800,10 @@ public class ShuffleClientImpl extends ShuffleClient {
attemptId,
numMappers,
partitionId,
pushState.getFailedBatches()),
pushState.getFailedBatches(),
numPartitions,
crc32PerPartition,
bytesPerPartition),
rpcMaxRetries,
rpcRetryWait,
ClassTag$.MODULE$.apply(MapperEndResponse.class));

View File

@ -40,6 +40,7 @@ import org.apache.celeborn.client.ShuffleClient;
import org.apache.celeborn.client.compress.Decompressor;
import org.apache.celeborn.client.read.checkpoint.PartitionReaderCheckpointMetadata;
import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.CommitMetadata;
import org.apache.celeborn.common.exception.CelebornIOException;
import org.apache.celeborn.common.network.client.TransportClient;
import org.apache.celeborn.common.network.client.TransportClientFactory;
@ -103,7 +104,9 @@ public abstract class CelebornInputStream extends InputStream {
exceptionMaker,
true,
metricsCallback,
needDecompress);
needDecompress,
startMapIndex,
endMapIndex);
} else {
return new CelebornInputStreamImpl(
conf,
@ -126,7 +129,9 @@ public abstract class CelebornInputStream extends InputStream {
exceptionMaker,
false,
metricsCallback,
needDecompress);
needDecompress,
-1,
-1);
}
}
}
@ -168,6 +173,8 @@ public abstract class CelebornInputStream extends InputStream {
private final CelebornConf conf;
private final TransportClientFactory clientFactory;
private final String shuffleKey;
private final int numberOfSubPartitions;
private final int currentIndexOfSubPartition;
private ArrayList<PartitionLocation> locations;
private ArrayList<PbStreamHandler> streamHandlers;
private int[] attempts;
@ -206,6 +213,7 @@ public abstract class CelebornInputStream extends InputStream {
private final String localHostAddress;
private boolean shouldDecompress;
private boolean shuffleIntegrityCheckEnabled;
private long fetchExcludedWorkerExpireTimeout;
private ConcurrentHashMap<String, Long> fetchExcludedWorkers;
@ -216,6 +224,8 @@ public abstract class CelebornInputStream extends InputStream {
private int partitionId;
private ExceptionMaker exceptionMaker;
private boolean closed = false;
private boolean integrityChecked = false;
private final CommitMetadata aggregatedActualCommitMetadata = new CommitMetadata();
private final boolean readSkewPartitionWithoutMapRange;
@ -238,7 +248,9 @@ public abstract class CelebornInputStream extends InputStream {
ExceptionMaker exceptionMaker,
boolean splitSkewPartitionWithoutMapRange,
MetricsCallback metricsCallback,
boolean needDecompress)
boolean needDecompress,
int numberOfSubPartitions,
int currentIndexOfSubPartition)
throws IOException {
this(
conf,
@ -261,7 +273,9 @@ public abstract class CelebornInputStream extends InputStream {
exceptionMaker,
splitSkewPartitionWithoutMapRange,
metricsCallback,
needDecompress);
needDecompress,
numberOfSubPartitions,
currentIndexOfSubPartition);
}
CelebornInputStreamImpl(
@ -285,7 +299,9 @@ public abstract class CelebornInputStream extends InputStream {
ExceptionMaker exceptionMaker,
boolean readSkewPartitionWithoutMapRange,
MetricsCallback metricsCallback,
boolean needDecompress)
boolean needDecompress,
int numberOfSubPartitions,
int currentIndexOfSubPartition)
throws IOException {
this.conf = conf;
this.clientFactory = clientFactory;
@ -305,9 +321,12 @@ public abstract class CelebornInputStream extends InputStream {
this.localHostAddress = Utils.localHostName(conf);
this.shouldDecompress =
!conf.shuffleCompressionCodec().equals(CompressionCodec.NONE) && needDecompress;
this.shuffleIntegrityCheckEnabled = conf.clientShuffleIntegrityCheckEnabled();
this.fetchExcludedWorkerExpireTimeout = conf.clientFetchExcludedWorkerExpireTimeout();
this.failedBatches = failedBatchSet;
this.readSkewPartitionWithoutMapRange = readSkewPartitionWithoutMapRange;
this.numberOfSubPartitions = numberOfSubPartitions;
this.currentIndexOfSubPartition = currentIndexOfSubPartition;
this.fetchExcludedWorkers = fetchExcludedWorkers;
if (conf.clientPushReplicateEnabled()) {
@ -721,6 +740,36 @@ public abstract class CelebornInputStream extends InputStream {
}
}
void validateIntegrity() throws IOException {
if (integrityChecked || !shuffleIntegrityCheckEnabled) {
return;
}
if (readSkewPartitionWithoutMapRange) {
shuffleClient.readReducerPartitionEnd(
shuffleId,
partitionId,
numberOfSubPartitions,
currentIndexOfSubPartition,
aggregatedActualCommitMetadata.getChecksum(),
aggregatedActualCommitMetadata.getBytes());
} else {
shuffleClient.readReducerPartitionEnd(
shuffleId,
partitionId,
startMapIndex,
endMapIndex,
aggregatedActualCommitMetadata.getChecksum(),
aggregatedActualCommitMetadata.getBytes());
}
logger.info(
"reducerPartitionEnd successful for shuffleId{}, partitionId{}. actual CommitMetadata: {}",
shuffleId,
partitionId,
aggregatedActualCommitMetadata);
integrityChecked = true;
}
private boolean moveToNextChunk() throws IOException {
if (currentChunk != null) {
currentChunk.release();
@ -760,6 +809,7 @@ public abstract class CelebornInputStream extends InputStream {
firstChunk = false;
}
if (currentChunk == null) {
validateIntegrity();
close();
return false;
}
@ -821,6 +871,9 @@ public abstract class CelebornInputStream extends InputStream {
} else {
limit = size;
}
if (shuffleIntegrityCheckEnabled) {
aggregatedActualCommitMetadata.addDataWithOffsetAndLength(rawDataBuf, 0, limit);
}
position = 0;
hasData = true;
break;
@ -835,6 +888,10 @@ public abstract class CelebornInputStream extends InputStream {
}
}
if (!hasData) {
validateIntegrity();
// TODO(gaurav): consider closing the stream
}
return hasData;
} catch (LZ4Exception | ZstdException | IOException e) {
logger.error(

View File

@ -31,7 +31,7 @@ import org.apache.celeborn.client.CommitManager.CommittedPartitionInfo
import org.apache.celeborn.client.LifecycleManager.ShuffleFailedWorkers
import org.apache.celeborn.client.commit.{CommitFilesParam, CommitHandler, MapPartitionCommitHandler, ReducePartitionCommitHandler}
import org.apache.celeborn.client.listener.{WorkersStatus, WorkerStatusListener}
import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.{CelebornConf, CommitMetadata}
import org.apache.celeborn.common.internal.Logging
import org.apache.celeborn.common.meta.WorkerInfo
import org.apache.celeborn.common.network.protocol.SerdeVersion
@ -178,7 +178,8 @@ class CommitManager(appUniqueId: String, val conf: CelebornConf, lifecycleManage
def registerShuffle(
shuffleId: Int,
numMappers: Int,
isSegmentGranularityVisible: Boolean): Unit = {
isSegmentGranularityVisible: Boolean,
numPartitions: Int): Unit = {
committedPartitionInfo.put(
shuffleId,
ShuffleCommittedInfo(
@ -198,7 +199,8 @@ class CommitManager(appUniqueId: String, val conf: CelebornConf, lifecycleManage
getCommitHandler(shuffleId).registerShuffle(
shuffleId,
numMappers,
isSegmentGranularityVisible);
isSegmentGranularityVisible,
numPartitions)
}
def isSegmentGranularityVisible(shuffleId: Int): Boolean = {
@ -219,8 +221,10 @@ class CommitManager(appUniqueId: String, val conf: CelebornConf, lifecycleManage
attemptId: Int,
numMappers: Int,
partitionId: Int = -1,
pushFailedBatches: util.Map[String, LocationPushFailedBatches] = Collections.emptyMap())
: (Boolean, Boolean) = {
pushFailedBatches: util.Map[String, LocationPushFailedBatches] = Collections.emptyMap(),
numPartitions: Int = -1,
crc32PerPartition: Array[Int] = new Array[Int](0),
bytesWrittenPerPartition: Array[Long] = new Array[Long](0)): (Boolean, Boolean) = {
getCommitHandler(shuffleId).finishMapperAttempt(
shuffleId,
mapId,
@ -228,7 +232,10 @@ class CommitManager(appUniqueId: String, val conf: CelebornConf, lifecycleManage
numMappers,
partitionId,
pushFailedBatches,
r => lifecycleManager.workerStatusTracker.recordWorkerFailure(r))
r => lifecycleManager.workerStatusTracker.recordWorkerFailure(r),
numPartitions,
crc32PerPartition,
bytesWrittenPerPartition)
}
def releasePartitionResource(shuffleId: Int, partitionId: Int): Unit = {
@ -350,4 +357,18 @@ class CommitManager(appUniqueId: String, val conf: CelebornConf, lifecycleManage
}
}
}
def finishPartition(
shuffleId: Int,
partitionId: Int,
startMapIndex: Int,
endMapIndex: Int,
actualCommitMetadata: CommitMetadata): (Boolean, String) = {
getCommitHandler(shuffleId).finishPartition(
shuffleId,
partitionId,
startMapIndex,
endMapIndex,
actualCommitMetadata)
}
}

View File

@ -21,7 +21,7 @@ import java.lang.{Byte => JByte}
import java.nio.ByteBuffer
import java.security.SecureRandom
import java.util
import java.util.{function, List => JList}
import java.util.{function, Collections, List => JList}
import java.util.concurrent._
import java.util.concurrent.atomic.{AtomicInteger, LongAdder}
import java.util.function.{BiConsumer, BiFunction, Consumer}
@ -40,7 +40,7 @@ import org.roaringbitmap.RoaringBitmap
import org.apache.celeborn.client.LifecycleManager.{ShuffleAllocatedWorkers, ShuffleFailedWorkers}
import org.apache.celeborn.client.listener.WorkerStatusListener
import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.{CelebornConf, CommitMetadata}
import org.apache.celeborn.common.CelebornConf.ACTIVE_STORAGE_TYPES
import org.apache.celeborn.common.client.MasterClient
import org.apache.celeborn.common.identity.{IdentityProvider, UserIdentifier}
@ -423,13 +423,31 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
oldPartition,
isSegmentGranularityVisible = commitManager.isSegmentGranularityVisible(shuffleId))
case MapperEnd(shuffleId, mapId, attemptId, numMappers, partitionId, pushFailedBatch) =>
case MapperEnd(
shuffleId,
mapId,
attemptId,
numMappers,
partitionId,
pushFailedBatch,
numPartitions,
crc32PerPartition,
bytesWrittenPerPartition) =>
logTrace(s"Received MapperEnd TaskEnd request, " +
s"${Utils.makeMapKey(shuffleId, mapId, attemptId)}")
val partitionType = getPartitionType(shuffleId)
partitionType match {
case PartitionType.REDUCE =>
handleMapperEnd(context, shuffleId, mapId, attemptId, numMappers, pushFailedBatch)
handleMapperEnd(
context,
shuffleId,
mapId,
attemptId,
numMappers,
pushFailedBatch,
numPartitions,
crc32PerPartition,
bytesWrittenPerPartition)
case PartitionType.MAP =>
handleMapPartitionEnd(
context,
@ -442,6 +460,22 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
throw new UnsupportedOperationException(s"Not support $partitionType yet")
}
case pb: ReadReducerPartitionEnd =>
val partitionType = getPartitionType(pb.shuffleId)
partitionType match {
case PartitionType.REDUCE =>
handleReducerPartitionEnd(
context,
pb.shuffleId,
pb.partitionId,
pb.startMapIndex,
pb.endMapIndex,
pb.crc32,
pb.bytesWritten)
case _ =>
throw new UnsupportedOperationException(s"Not support $partitionType yet")
}
case GetReducerFileGroup(
shuffleId: Int,
isSegmentGranularityVisible: Boolean,
@ -493,6 +527,32 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
}
}
private def handleReducerPartitionEnd(
context: RpcCallContext,
shuffleId: Int,
partitionId: Int,
startMapIndex: Int,
endMapIndex: Int,
crc32: Int,
bytesWritten: Long): Unit = {
val (isValid, errorMessage) = commitManager.finishPartition(
shuffleId,
partitionId,
startMapIndex,
endMapIndex,
new CommitMetadata(crc32, bytesWritten))
var response: PbReadReducerPartitionEndResponse = null
if (isValid) {
response =
PbReadReducerPartitionEndResponse.newBuilder().setStatus(
StatusCode.SUCCESS.getValue).build()
} else {
response = PbReadReducerPartitionEndResponse.newBuilder().setStatus(
+StatusCode.READ_REDUCER_PARTITION_END_FAILED.getValue).setErrorMsg(errorMessage).build()
}
context.reply(response)
}
def setupEndpoints(
workers: util.Set[WorkerInfo],
shuffleId: Int,
@ -771,7 +831,8 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
commitManager.registerShuffle(
shuffleId,
numMappers,
isSegmentGranularityVisible)
isSegmentGranularityVisible,
numPartitions)
// Fifth, reply the allocated partition location to ShuffleClient.
logInfo(s"Handle RegisterShuffle Success for $shuffleId.")
@ -840,7 +901,10 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
mapId: Int,
attemptId: Int,
numMappers: Int,
pushFailedBatches: util.Map[String, LocationPushFailedBatches]): Unit = {
pushFailedBatches: util.Map[String, LocationPushFailedBatches],
numPartitions: Int,
crc32PerPartition: Array[Int],
bytesWrittenPerPartition: Array[Long]): Unit = {
val (mapperAttemptFinishedSuccess, allMapperFinished) =
commitManager.finishMapperAttempt(
@ -848,7 +912,10 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
mapId,
attemptId,
numMappers,
pushFailedBatches = pushFailedBatches)
pushFailedBatches = pushFailedBatches,
numPartitions = numPartitions,
crc32PerPartition = crc32PerPartition,
bytesWrittenPerPartition = bytesWrittenPerPartition)
if (mapperAttemptFinishedSuccess && allMapperFinished) {
// last mapper finished. call mapper end
logInfo(s"Last MapperEnd, call StageEnd with shuffleKey:" +

View File

@ -31,7 +31,7 @@ import scala.concurrent.duration.Duration
import org.apache.celeborn.client.{ShuffleCommittedInfo, WorkerStatusTracker}
import org.apache.celeborn.client.CommitManager.CommittedPartitionInfo
import org.apache.celeborn.client.LifecycleManager.{ShuffleFailedWorkers, ShuffleFileGroups, ShufflePushFailedBatches}
import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.{CelebornConf, CommitMetadata}
import org.apache.celeborn.common.internal.Logging
import org.apache.celeborn.common.meta.{ShufflePartitionLocationInfo, WorkerInfo}
import org.apache.celeborn.common.network.protocol.SerdeVersion
@ -216,12 +216,16 @@ abstract class CommitHandler(
numMappers: Int,
partitionId: Int,
pushFailedBatches: util.Map[String, LocationPushFailedBatches],
recordWorkerFailure: ShuffleFailedWorkers => Unit): (Boolean, Boolean)
recordWorkerFailure: ShuffleFailedWorkers => Unit,
numPartitions: Int,
crc32PerPartition: Array[Int],
bytesWrittenPerPartition: Array[Long]): (Boolean, Boolean)
def registerShuffle(
shuffleId: Int,
numMappers: Int,
isSegmentGranularityVisible: Boolean): Unit = {
isSegmentGranularityVisible: Boolean,
numPartitions: Int): Unit = {
// TODO: if isSegmentGranularityVisible is set to true, it is necessary to handle the pending
// get partition request of downstream reduce task here, in scenarios which support
// downstream task start early before the upstream task, e.g. flink hybrid shuffle.
@ -422,6 +426,16 @@ abstract class CommitHandler(
}
}
/**
* Invoked when a reduce partition finishes reading data to perform end to end integrity check validation
*/
def finishPartition(
shuffleId: Int,
partitionId: Int,
startMapIndex: Int,
endMapIndex: Int,
actualCommitMetadata: CommitMetadata): (Boolean, String)
def parallelCommitFiles(
shuffleId: Int,
allocatedWorkers: util.Map[String, ShufflePartitionLocationInfo],

View File

@ -0,0 +1,189 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.celeborn.client.commit
import java.util
import java.util.Comparator
import java.util.function.BiFunction
import com.google.common.base.Preconditions.{checkArgument, checkState}
import org.apache.celeborn.common.CommitMetadata
import org.apache.celeborn.common.internal.Logging
import org.apache.celeborn.common.util.JavaUtils
class LegacySkewHandlingPartitionValidator extends AbstractPartitionCompletenessValidator
with Logging {
private val subRangeToCommitMetadataPerReducer = {
JavaUtils.newConcurrentHashMap[Int, java.util.TreeMap[(Int, Int), CommitMetadata]]()
}
private val partitionToSubPartitionCount = JavaUtils.newConcurrentHashMap[Int, Int]()
private val currentCommitMetadataForReducer =
JavaUtils.newConcurrentHashMap[Int, CommitMetadata]()
private val currentTotalMapIdCountForReducer = JavaUtils.newConcurrentHashMap[Int, Int]()
private val comparator: java.util.Comparator[(Int, Int)] = new Comparator[(Int, Int)] {
override def compare(o1: (Int, Int), o2: (Int, Int)): Int = {
val comparator = Integer.compare(o1._1, o2._1)
if (comparator != 0)
comparator
else
Integer.compare(o1._2, o2._2)
}
}
private val mapCountMergeBiFunction = new BiFunction[Int, Int, Int] {
override def apply(t: Int, u: Int): Int =
Integer.sum(t, u)
}
private val metadataMergeBiFunction: BiFunction[CommitMetadata, CommitMetadata, CommitMetadata] =
new BiFunction[CommitMetadata, CommitMetadata, CommitMetadata] {
override def apply(
existing: CommitMetadata,
incoming: CommitMetadata): CommitMetadata = {
if (existing == null) {
if (incoming != null) {
return new CommitMetadata(incoming.getChecksum, incoming.getBytes)
} else {
return incoming
}
}
if (incoming == null) {
return existing
}
existing.addCommitData(incoming)
existing
}
}
private def checkOverlappingRange(
treeMap: java.util.TreeMap[(Int, Int), CommitMetadata],
rangeKey: (Int, Int)): (Boolean, ((Int, Int), CommitMetadata)) = {
val floorEntry: util.Map.Entry[(Int, Int), CommitMetadata] =
treeMap.floorEntry(rangeKey)
val ceilingEntry: util.Map.Entry[(Int, Int), CommitMetadata] =
treeMap.ceilingEntry(rangeKey)
if (floorEntry != null) {
if (rangeKey._1 < floorEntry.getKey._2) {
return (true, ((floorEntry.getKey._1, floorEntry.getKey._2), floorEntry.getValue))
}
}
if (ceilingEntry != null) {
if (rangeKey._2 > ceilingEntry.getKey._1) {
return (true, ((ceilingEntry.getKey._1, ceilingEntry.getKey._2), ceilingEntry.getValue))
}
}
(false, ((0, 0), new CommitMetadata()))
}
override def processSubPartition(
partitionId: Int,
startMapIndex: Int,
endMapIndex: Int,
actualCommitMetadata: CommitMetadata,
expectedTotalMapperCount: Int): (Boolean, String) = {
checkArgument(
startMapIndex < endMapIndex,
"startMapIndex %s must be less than endMapIndex %s",
startMapIndex,
endMapIndex)
logDebug(
s"Validate partition invoked for partitionId $partitionId startMapIndex $startMapIndex endMapIndex $endMapIndex")
partitionToSubPartitionCount.put(partitionId, expectedTotalMapperCount)
val subRangeToCommitMetadataMap = subRangeToCommitMetadataPerReducer.computeIfAbsent(
partitionId,
new java.util.function.Function[Int, util.TreeMap[(Int, Int), CommitMetadata]] {
override def apply(key: Int): util.TreeMap[(Int, Int), CommitMetadata] =
new util.TreeMap[(Int, Int), CommitMetadata](comparator)
})
val rangeKey = (startMapIndex, endMapIndex)
subRangeToCommitMetadataMap.synchronized {
val existingMetadata = subRangeToCommitMetadataMap.get(rangeKey)
if (existingMetadata == null) {
val (isOverlapping, overlappingEntry) =
checkOverlappingRange(subRangeToCommitMetadataMap, rangeKey)
if (isOverlapping) {
val errorMessage = s"Encountered overlapping map range for partitionId: $partitionId " +
s" while processing range with startMapIndex: $startMapIndex and endMapIndex: $endMapIndex " +
s"existing range map: $subRangeToCommitMetadataMap " +
s"overlapped with Entry((startMapIndex, endMapIndex), count): $overlappingEntry"
logError(errorMessage)
return (false, errorMessage)
}
// Process new range
subRangeToCommitMetadataMap.put(rangeKey, actualCommitMetadata)
currentCommitMetadataForReducer.merge(
partitionId,
actualCommitMetadata,
metadataMergeBiFunction)
currentTotalMapIdCountForReducer.merge(
partitionId,
endMapIndex - startMapIndex,
mapCountMergeBiFunction)
} else if (existingMetadata != actualCommitMetadata) {
val errorMessage = s"Commit Metadata for partition: $partitionId " +
s"not matching for sub-partition with startMapIndex: $startMapIndex endMapIndex: $endMapIndex " +
s"previous count: $existingMetadata new count: $actualCommitMetadata"
logError(errorMessage)
return (false, errorMessage)
}
validateState(partitionId, startMapIndex, endMapIndex)
val sumOfMapRanges: Int = currentTotalMapIdCountForReducer.get(partitionId)
val currentCommitMetadata: CommitMetadata =
currentCommitMetadataForReducer.get(partitionId)
if (sumOfMapRanges > expectedTotalMapperCount) {
val errorMsg = s"AQE Partition $partitionId failed validation check " +
s"while processing startMapIndex: $startMapIndex endMapIndex: $endMapIndex " +
s"ActualCommitMetadata $currentCommitMetadata > ExpectedTotalMapperCount $expectedTotalMapperCount"
logError(errorMsg)
return (false, errorMsg)
}
}
(true, "")
}
private def validateState(partitionId: Int, startMapIndex: Int, endMapIndex: Int): Unit = {
if (!currentTotalMapIdCountForReducer.containsKey(partitionId)) {
checkState(!currentCommitMetadataForReducer.containsKey(
partitionId,
"mapper total count missing while processing partitionId %s startMapIndex %s endMapIndex %s",
partitionId,
startMapIndex,
endMapIndex))
currentTotalMapIdCountForReducer.put(partitionId, 0)
currentCommitMetadataForReducer.put(partitionId, new CommitMetadata())
}
checkState(
currentCommitMetadataForReducer.containsKey(partitionId),
"mapper written count missing while processing partitionId %s startMapIndex %s endMapIndex %s",
partitionId,
startMapIndex,
endMapIndex)
}
override def currentCommitMetadata(partitionId: Int): CommitMetadata = {
currentCommitMetadataForReducer.get(partitionId)
}
override def isPartitionComplete(partitionId: Int): Boolean = {
val sumOfMapRanges: Int = currentTotalMapIdCountForReducer.get(partitionId)
sumOfMapRanges == partitionToSubPartitionCount.get(partitionId)
}
}

View File

@ -20,15 +20,17 @@ package org.apache.celeborn.client.commit
import java.util
import java.util.Collections
import java.util.concurrent.{ConcurrentHashMap, ThreadPoolExecutor}
import java.util.concurrent.atomic.AtomicInteger
import java.util.concurrent.atomic.{AtomicInteger, AtomicIntegerArray}
import scala.collection.JavaConverters._
import scala.collection.mutable
import org.roaringbitmap.RoaringBitmap
import org.apache.celeborn.client.{ShuffleCommittedInfo, WorkerStatusTracker}
import org.apache.celeborn.client.CommitManager.CommittedPartitionInfo
import org.apache.celeborn.client.LifecycleManager.{ShuffleAllocatedWorkers, ShuffleFailedWorkers}
import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.{CelebornConf, CommitMetadata}
import org.apache.celeborn.common.internal.Logging
import org.apache.celeborn.common.meta.{ShufflePartitionLocationInfo, WorkerInfo}
import org.apache.celeborn.common.network.protocol.SerdeVersion
@ -187,7 +189,10 @@ class MapPartitionCommitHandler(
numMappers: Int,
partitionId: Int,
pushFailedBatches: util.Map[String, LocationPushFailedBatches],
recordWorkerFailure: ShuffleFailedWorkers => Unit): (Boolean, Boolean) = {
recordWorkerFailure: ShuffleFailedWorkers => Unit,
numPartitions: Int,
crc32PerPartition: Array[Int],
bytesWrittenPerPartition: Array[Long]): (Boolean, Boolean) = {
val inProcessingPartitionIds =
inProcessMapPartitionEndIds.computeIfAbsent(
shuffleId,
@ -222,8 +227,9 @@ class MapPartitionCommitHandler(
override def registerShuffle(
shuffleId: Int,
numMappers: Int,
isSegmentGranularityVisible: Boolean): Unit = {
super.registerShuffle(shuffleId, numMappers, isSegmentGranularityVisible)
isSegmentGranularityVisible: Boolean,
numPartitions: Int): Unit = {
super.registerShuffle(shuffleId, numMappers, isSegmentGranularityVisible, numPartitions)
shuffleIsSegmentGranularityVisible.put(shuffleId, isSegmentGranularityVisible)
}
@ -231,6 +237,15 @@ class MapPartitionCommitHandler(
shuffleIsSegmentGranularityVisible.get(shuffleId)
}
override def finishPartition(
shuffleId: Int,
partitionId: Int,
startMapIndex: Int,
endMapIndex: Int,
actualCommitMetadata: CommitMetadata): (Boolean, String) = {
throw new UnsupportedOperationException()
}
override def handleGetReducerFileGroup(
context: RpcCallContext,
shuffleId: Int,

View File

@ -0,0 +1,86 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.celeborn.client.commit
import org.apache.celeborn.common.CommitMetadata
import org.apache.celeborn.common.internal.Logging
abstract class AbstractPartitionCompletenessValidator {
def processSubPartition(
partitionId: Int,
startMapIndex: Int,
endMapIndex: Int,
actualCommitMetadata: CommitMetadata,
expectedTotalMapperCount: Int): (Boolean, String)
def currentCommitMetadata(partitionId: Int): CommitMetadata
def isPartitionComplete(partitionId: Int): Boolean
}
class PartitionCompletenessValidator extends Logging {
private val skewHandlingValidator: AbstractPartitionCompletenessValidator =
new SkewHandlingWithoutMapRangeValidator
private val legacySkewHandlingValidator: AbstractPartitionCompletenessValidator =
new LegacySkewHandlingPartitionValidator
def validateSubPartition(
partitionId: Int,
startMapIndex: Int,
endMapIndex: Int,
actualCommitMetadata: CommitMetadata,
expectedCommitMetadata: CommitMetadata,
expectedTotalMapperCountForParent: Int,
skewPartitionHandlingWithoutMapRange: Boolean): (Boolean, String) = {
val validator =
if (skewPartitionHandlingWithoutMapRange) {
skewHandlingValidator
} else {
legacySkewHandlingValidator
}
val (successfullyProcessed, error) = validator.processSubPartition(
partitionId,
startMapIndex,
endMapIndex,
actualCommitMetadata,
expectedTotalMapperCountForParent)
if (!successfullyProcessed) return (false, error)
if (!validator.isPartitionComplete(partitionId)) {
return (true, "Partition is valid but still waiting for more data")
}
val currentCommitMetadata = validator.currentCommitMetadata(partitionId)
if (!CommitMetadata.checkCommitMetadata(expectedCommitMetadata, currentCommitMetadata)) {
val errorMsg =
s"AQE Partition $partitionId failed validation check" +
s"while processing range startMapIndex: $startMapIndex endMapIndex: $endMapIndex" +
s"ExpectedCommitMetadata $expectedCommitMetadata, " +
s"ActualCommitMetadata $currentCommitMetadata, "
logError(errorMsg)
return (false, errorMsg)
}
logInfo(
s"AQE Partition $partitionId completed validation check , " +
s"expectedCommitMetadata $expectedCommitMetadata, " +
s"actualCommitMetadata $currentCommitMetadata")
(true, "Partition is complete")
}
}

View File

@ -25,13 +25,14 @@ import java.util.function
import scala.collection.JavaConverters._
import scala.collection.mutable
import com.google.common.base.Preconditions.checkState
import com.google.common.cache.{Cache, CacheBuilder}
import com.google.common.collect.Sets
import org.apache.celeborn.client.{ClientUtils, LifecycleManager, ShuffleCommittedInfo, WorkerStatusTracker}
import org.apache.celeborn.client.CommitManager.CommittedPartitionInfo
import org.apache.celeborn.client.LifecycleManager.{ShuffleAllocatedWorkers, ShuffleFailedWorkers}
import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.{CelebornConf, CommitMetadata}
import org.apache.celeborn.common.internal.Logging
import org.apache.celeborn.common.meta.ShufflePartitionLocationInfo
import org.apache.celeborn.common.network.protocol.SerdeVersion
@ -82,6 +83,13 @@ class ReducePartitionCommitHandler(
private val rpcCacheConcurrencyLevel = conf.clientRpcCacheConcurrencyLevel
private val rpcCacheExpireTime = conf.clientRpcCacheExpireTime
private val shuffleIntegrityCheckEnabled = conf.clientShuffleIntegrityCheckEnabled
// partitionId-shuffleId -> number of mappers that have written to this reducer (partition + shuffle)
private val commitMetadataForReducer =
JavaUtils.newConcurrentHashMap[Integer, Array[CommitMetadata]]
private val skewPartitionCompletenessValidator =
JavaUtils.newConcurrentHashMap[Int, PartitionCompletenessValidator]()
private val getReducerFileGroupResponseBroadcastEnabled = conf.getReducerFileGroupBroadcastEnabled
private val getReducerFileGroupResponseBroadcastMiniSize =
conf.getReducerFileGroupBroadcastMiniSize
@ -153,6 +161,8 @@ class ReducePartitionCommitHandler(
stageEndShuffleSet.remove(shuffleId)
inProcessStageEndShuffleSet.remove(shuffleId)
shuffleMapperAttempts.remove(shuffleId)
commitMetadataForReducer.remove(shuffleId)
skewPartitionCompletenessValidator.remove(shuffleId)
super.removeExpiredShuffle(shuffleId)
}
@ -268,16 +278,20 @@ class ReducePartitionCommitHandler(
numMappers: Int,
partitionId: Int,
pushFailedBatches: util.Map[String, LocationPushFailedBatches],
recordWorkerFailure: ShuffleFailedWorkers => Unit): (Boolean, Boolean) = {
shuffleMapperAttempts.synchronized {
recordWorkerFailure: ShuffleFailedWorkers => Unit,
numPartitions: Int,
crc32PerPartition: Array[Int],
bytesWrittenPerPartition: Array[Long]): (Boolean, Boolean) = {
val (mapperAttemptFinishedSuccess, allMapperFinished) = shuffleMapperAttempts.synchronized {
if (getMapperAttempts(shuffleId) == null) {
logDebug(s"[handleMapperEnd] $shuffleId not registered, create one.")
initMapperAttempts(shuffleId, numMappers)
initMapperAttempts(shuffleId, numMappers, numPartitions)
}
val attempts = shuffleMapperAttempts.get(shuffleId)
if (attempts(mapId) < 0) {
attempts(mapId) = attemptId
if (null != pushFailedBatches && !pushFailedBatches.isEmpty) {
val pushFailedBatchesMap = shufflePushFailedBatches.computeIfAbsent(
shuffleId,
@ -296,24 +310,93 @@ class ReducePartitionCommitHandler(
(false, false)
}
}
if (shuffleIntegrityCheckEnabled && mapperAttemptFinishedSuccess) {
val commitMetadataArray = commitMetadataForReducer.get(shuffleId)
checkState(
commitMetadataArray != null,
"commitMetadataArray can only be null if shuffleId %s is not registered!",
shuffleId)
for (i <- 0 until numPartitions) {
if (bytesWrittenPerPartition(i) != 0) {
commitMetadataArray(i).addCommitData(
crc32PerPartition(i),
bytesWrittenPerPartition(i))
}
}
}
(mapperAttemptFinishedSuccess, allMapperFinished)
}
override def registerShuffle(
shuffleId: Int,
numMappers: Int,
isSegmentGranularityVisible: Boolean): Unit = {
super.registerShuffle(shuffleId, numMappers, isSegmentGranularityVisible)
isSegmentGranularityVisible: Boolean,
numPartitions: Int): Unit = {
super.registerShuffle(shuffleId, numMappers, isSegmentGranularityVisible, numPartitions)
getReducerFileGroupRequest.put(shuffleId, new util.HashSet[MultiSerdeVersionRpcContext]())
initMapperAttempts(shuffleId, numMappers)
initMapperAttempts(shuffleId, numMappers, numPartitions)
}
private def initMapperAttempts(shuffleId: Int, numMappers: Int): Unit = {
override def finishPartition(
shuffleId: Int,
partitionId: Int,
startMapIndex: Int,
endMapIndex: Int,
actualCommitMetadata: CommitMetadata): (Boolean, String) = {
logDebug(s"finish Partition call: shuffleId: $shuffleId, " +
s"partitionId: $partitionId, " +
s"startMapIndex: $startMapIndex " +
s"endMapIndex: $endMapIndex, " +
s"actualCommitMetadata: $actualCommitMetadata")
val map = commitMetadataForReducer.get(shuffleId)
checkState(
map != null,
"CommitMetadata map cannot be null for a registered shuffleId: %d",
shuffleId)
val expectedCommitMetadata = map(partitionId)
if (endMapIndex == Integer.MAX_VALUE) {
// complete partition available
val bool = CommitMetadata.checkCommitMetadata(actualCommitMetadata, expectedCommitMetadata)
var message = ""
if (!bool) {
message =
s"CommitMetadata mismatch for shuffleId: $shuffleId partitionId: $partitionId expected: $expectedCommitMetadata actual: $actualCommitMetadata"
} else {
logInfo(
s"CommitMetadata matched for shuffleID : $shuffleId, partitionId: $partitionId expected: $expectedCommitMetadata actual: $actualCommitMetadata")
}
return (bool, message)
}
val splitSkewPartitionWithoutMapRange =
ClientUtils.readSkewPartitionWithoutMapRange(conf, startMapIndex, endMapIndex)
val validator = skewPartitionCompletenessValidator.computeIfAbsent(
shuffleId,
new java.util.function.Function[Int, PartitionCompletenessValidator] {
override def apply(key: Int): PartitionCompletenessValidator =
new PartitionCompletenessValidator()
})
validator.validateSubPartition(
partitionId,
startMapIndex,
endMapIndex,
actualCommitMetadata,
expectedCommitMetadata,
shuffleMapperAttempts.get(shuffleId).length,
splitSkewPartitionWithoutMapRange)
}
private def initMapperAttempts(shuffleId: Int, numMappers: Int, numPartitions: Int): Unit = {
shuffleMapperAttempts.synchronized {
if (!shuffleMapperAttempts.containsKey(shuffleId)) {
val attempts = new Array[Int](numMappers)
0 until numMappers foreach (idx => attempts(idx) = -1)
shuffleMapperAttempts.put(shuffleId, attempts)
}
if (shuffleIntegrityCheckEnabled) {
commitMetadataForReducer.put(shuffleId, Array.fill(numPartitions)(new CommitMetadata()))
}
}
}

View File

@ -0,0 +1,109 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.celeborn.client.commit
import java.util
import com.google.common.base.Preconditions.{checkArgument, checkState}
import org.apache.celeborn.common.CommitMetadata
import org.apache.celeborn.common.util.JavaUtils
class SkewHandlingWithoutMapRangeValidator extends AbstractPartitionCompletenessValidator {
private val totalSubPartitionsProcessed =
JavaUtils.newConcurrentHashMap[Int, util.HashMap[Int, CommitMetadata]]()
private val partitionToSubPartitionCount = JavaUtils.newConcurrentHashMap[Int, Int]()
private val currentCommitMetadataForReducer =
JavaUtils.newConcurrentHashMap[Int, CommitMetadata]()
override def processSubPartition(
partitionId: Int,
startMapIndex: Int,
endMapIndex: Int,
actualCommitMetadata: CommitMetadata,
expectedTotalMapperCount: Int): (Boolean, String) = {
checkArgument(
endMapIndex >= 0,
"index of sub-partition %s must be greater than or equal to 0",
endMapIndex)
checkArgument(
startMapIndex > endMapIndex,
"startMapIndex %s must be greater than endMapIndex %s",
startMapIndex,
endMapIndex)
totalSubPartitionsProcessed.synchronized {
if (totalSubPartitionsProcessed.containsKey(partitionId)) {
val currentSubPartitionCount = partitionToSubPartitionCount.getOrDefault(partitionId, -1)
checkState(
currentSubPartitionCount == startMapIndex,
"total subpartition count mismatch for partitionId %s existing %s new %s",
partitionId,
currentSubPartitionCount,
startMapIndex)
} else {
totalSubPartitionsProcessed.put(partitionId, new util.HashMap[Int, CommitMetadata]())
partitionToSubPartitionCount.put(partitionId, startMapIndex)
}
val subPartitionsProcessed = totalSubPartitionsProcessed.get(partitionId)
if (subPartitionsProcessed.containsKey(endMapIndex)) {
// check if previous entry matches
val existingCommitMetadata = subPartitionsProcessed.get(endMapIndex)
if (existingCommitMetadata != actualCommitMetadata) {
return (
false,
s"Mismatch in metadata for the same chunk range on retry: $endMapIndex existing: $existingCommitMetadata new: $actualCommitMetadata")
}
}
subPartitionsProcessed.put(endMapIndex, actualCommitMetadata)
val partitionProcessed = getTotalNumberOfSubPartitionsProcessed(partitionId)
checkState(
partitionProcessed <= startMapIndex,
"Number of sub-partitions processed %s should less than total number of sub-partitions %s",
partitionProcessed,
startMapIndex)
}
updateCommitMetadata(partitionId, actualCommitMetadata)
(true, "")
}
private def updateCommitMetadata(partitionId: Int, actualCommitMetadata: CommitMetadata): Unit = {
val currentCommitMetadata =
currentCommitMetadataForReducer.computeIfAbsent(
partitionId,
new java.util.function.Function[Int, CommitMetadata] {
override def apply(partitionId: Int): CommitMetadata = {
new CommitMetadata()
}
})
currentCommitMetadata.addCommitData(actualCommitMetadata)
}
private def getTotalNumberOfSubPartitionsProcessed(partitionId: Int) = {
totalSubPartitionsProcessed.get(partitionId).size()
}
override def currentCommitMetadata(partitionId: Int): CommitMetadata = {
currentCommitMetadataForReducer.get(partitionId)
}
override def isPartitionComplete(partitionId: Int): Boolean = {
getTotalNumberOfSubPartitionsProcessed(partitionId) == partitionToSubPartitionCount.get(
partitionId)
}
}

View File

@ -18,8 +18,7 @@
package org.apache.celeborn.client;
import static org.junit.Assert.*;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.*;
import static org.mockito.BDDMockito.willAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
@ -34,6 +33,9 @@ import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicReference;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.util.concurrent.Future;
@ -41,6 +43,7 @@ import io.netty.util.concurrent.GenericFutureListener;
import org.apache.commons.lang3.RandomStringUtils;
import org.junit.Assert;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.apache.celeborn.client.compress.Compressor;
import org.apache.celeborn.common.CelebornConf;
@ -51,6 +54,8 @@ import org.apache.celeborn.common.network.client.TransportClientFactory;
import org.apache.celeborn.common.network.protocol.SerdeVersion;
import org.apache.celeborn.common.protocol.CompressionCodec;
import org.apache.celeborn.common.protocol.PartitionLocation;
import org.apache.celeborn.common.protocol.PbReadReducerPartitionEnd;
import org.apache.celeborn.common.protocol.PbReadReducerPartitionEndResponse;
import org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse$;
import org.apache.celeborn.common.protocol.message.ControlMessages.RegisterShuffleResponse$;
import org.apache.celeborn.common.protocol.message.StatusCode;
@ -68,6 +73,7 @@ public class ShuffleClientSuiteJ {
private static final int TEST_SHUFFLE_ID = 1;
private static final int TEST_ATTEMPT_ID = 0;
private static final int TEST_REDUCRE_ID = 0;
private static final int TEST_MAP_ID = 0;
private static final int PRIMARY_RPC_PORT = 1234;
private static final int PRIMARY_PUSH_PORT = 1235;
@ -621,4 +627,86 @@ public class ShuffleClientSuiteJ {
Exception exception = exceptionRef.get();
Assert.assertTrue(exception.getCause() instanceof TimeoutException);
}
@Test
public void testSuccessfulReadReducePartitionEnd() throws IOException {
CelebornConf conf = new CelebornConf();
shuffleClient =
new ShuffleClientImpl(TEST_APPLICATION_ID, conf, new UserIdentifier("mock", "mock"));
shuffleClient.setupLifecycleManagerRef(endpointRef);
ClassTag<PbReadReducerPartitionEndResponse> classTag =
ClassTag$.MODULE$.apply(PbReadReducerPartitionEndResponse.class);
PbReadReducerPartitionEndResponse mockResponse =
PbReadReducerPartitionEndResponse.newBuilder()
.setStatus(StatusCode.SUCCESS.getValue())
.build();
when(endpointRef.askSync(any(PbReadReducerPartitionEnd.class), any(), eq(classTag)))
.thenReturn(mockResponse);
shuffleClient.readReducerPartitionEnd(1, 2, 3, 4, 12345, 10000L);
}
@Test
public void testFailedReadReducerPartitionEnd() throws IOException {
CelebornConf conf = new CelebornConf();
shuffleClient =
new ShuffleClientImpl(TEST_APPLICATION_ID, conf, new UserIdentifier("mock", "mock"));
shuffleClient.setupLifecycleManagerRef(endpointRef);
ClassTag<PbReadReducerPartitionEndResponse> classTag =
ClassTag$.MODULE$.apply(PbReadReducerPartitionEndResponse.class);
String errorMsg = "Test error message";
PbReadReducerPartitionEndResponse mockResponse =
PbReadReducerPartitionEndResponse.newBuilder()
.setStatus(StatusCode.READ_REDUCER_PARTITION_END_FAILED.getValue())
.setErrorMsg(errorMsg)
.build();
when(endpointRef.askSync(any(PbReadReducerPartitionEnd.class), any(), eq(classTag)))
.thenReturn(mockResponse);
try {
shuffleClient.readReducerPartitionEnd(1, 2, 3, 4, 12345, 10000L);
} catch (CelebornIOException e) {
Assert.assertEquals(errorMsg, e.getMessage());
}
}
@Test
public void testCorrectParametersPassedInRequest() throws IOException {
CelebornConf conf = new CelebornConf();
shuffleClient =
new ShuffleClientImpl(TEST_APPLICATION_ID, conf, new UserIdentifier("mock", "mock"));
shuffleClient.setupLifecycleManagerRef(endpointRef);
int shuffleId = 123;
int partitionId = 456;
int startMapIndex = 0;
int endMapIndex = 10;
int crc32 = 98765;
long bytesWritten = 54321L;
PbReadReducerPartitionEndResponse mockResponse =
PbReadReducerPartitionEndResponse.newBuilder()
.setStatus(StatusCode.SUCCESS.getValue())
.build();
ArgumentCaptor<PbReadReducerPartitionEnd> requestCaptor =
ArgumentCaptor.forClass(PbReadReducerPartitionEnd.class);
ClassTag<PbReadReducerPartitionEndResponse> classTag =
ClassTag$.MODULE$.apply(PbReadReducerPartitionEndResponse.class);
when(endpointRef.askSync(requestCaptor.capture(), any(), eq(classTag)))
.thenReturn(mockResponse);
shuffleClient.readReducerPartitionEnd(
shuffleId, partitionId, startMapIndex, endMapIndex, crc32, bytesWritten);
PbReadReducerPartitionEnd capturedRequest = requestCaptor.getValue();
assertEquals(shuffleId, capturedRequest.getShuffleId());
assertEquals(partitionId, capturedRequest.getPartitionId());
assertEquals(startMapIndex, capturedRequest.getStartMaxIndex());
assertEquals(endMapIndex, capturedRequest.getEndMapIndex());
assertEquals(crc32, capturedRequest.getCrc32());
assertEquals(bytesWritten, capturedRequest.getBytesWritten());
}
}

View File

@ -197,21 +197,29 @@ trait WithShuffleClientSuite extends CelebornFunSuite {
}
private def registerAndFinishPartition(shuffleId: Int): Unit = {
val numPartitions = 9
shuffleClient.registerMapPartitionTask(shuffleId, numMappers, mapId, attemptId, 1)
shuffleClient.registerMapPartitionTask(shuffleId, numMappers, mapId + 1, attemptId, 2)
shuffleClient.registerMapPartitionTask(shuffleId, numMappers, mapId + 2, attemptId, 3)
// task number incr to numMappers + 1
shuffleClient.registerMapPartitionTask(shuffleId, numMappers, mapId, attemptId + 1, 9)
shuffleClient.mapPartitionMapperEnd(shuffleId, mapId, attemptId, numMappers, 1)
shuffleClient.mapPartitionMapperEnd(shuffleId, mapId, attemptId, numMappers, numPartitions, 1)
// another attempt
shuffleClient.mapPartitionMapperEnd(
shuffleId,
mapId,
attemptId + 1,
numMappers,
numPartitions,
9)
// another mapper
shuffleClient.mapPartitionMapperEnd(shuffleId, mapId + 1, attemptId, numMappers, mapId + 1)
shuffleClient.mapPartitionMapperEnd(
shuffleId,
mapId + 1,
attemptId,
numMappers,
numPartitions,
mapId + 1)
}
}

View File

@ -0,0 +1,301 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.celeborn.client.commit
import org.scalatest.matchers.must.Matchers.{be, include}
import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper
import org.apache.celeborn.CelebornFunSuite
import org.apache.celeborn.common.CommitMetadata
class PartitionValidatorTest extends CelebornFunSuite {
var validator: PartitionCompletenessValidator = _
var mockCommitMetadata: CommitMetadata = new CommitMetadata()
test("AQEPartitionCompletenessValidator should validate a new sub-partition correctly when there are no overlapping ranges") {
validator = new PartitionCompletenessValidator
val (isValid, message) =
validator.validateSubPartition(1, 0, 10, mockCommitMetadata, mockCommitMetadata, 20, false)
isValid shouldBe (true)
message shouldBe ("Partition is valid but still waiting for more data")
}
test("AQEPartitionCompletenessValidator should fail validation for overlapping map ranges") {
validator = new PartitionCompletenessValidator
validator.validateSubPartition(
1,
0,
10,
mockCommitMetadata,
mockCommitMetadata,
20,
false
) // First call should add the range
val (isValid, message) =
validator.validateSubPartition(
1,
5,
15,
mockCommitMetadata,
mockCommitMetadata,
20,
false
) // This overlaps
isValid shouldBe (false)
message should include("Encountered overlapping map range for partitionId: 1")
}
test(
"AQEPartitionCompletenessValidator should fail validation for overlapping map ranges- case 2") {
validator = new PartitionCompletenessValidator
validator.validateSubPartition(
1,
0,
1,
mockCommitMetadata,
mockCommitMetadata,
20,
false
) // First call should add the range
validator.validateSubPartition(
1,
2,
3,
mockCommitMetadata,
mockCommitMetadata,
20,
false
) // First call should add the range
val (isValid, message) =
validator.validateSubPartition(
1,
1,
2,
mockCommitMetadata,
mockCommitMetadata,
20,
false
) // This overlaps
isValid shouldBe (true)
}
test("AQEPartitionCompletenessValidator should fail validation for overlapping map ranges - edge cases") {
validator = new PartitionCompletenessValidator
validator.validateSubPartition(
1,
0,
10,
mockCommitMetadata,
mockCommitMetadata,
20,
false
) // First call should add the range
val (isValid, message) =
validator.validateSubPartition(
1,
0,
5,
mockCommitMetadata,
mockCommitMetadata,
20,
false
) // This overlaps
isValid shouldBe (false)
message should include("Encountered overlapping map range for partitionId: 1")
}
test("AQEPartitionCompletenessValidator should fail validation for one map range subsuming another map range") {
validator = new PartitionCompletenessValidator
validator.validateSubPartition(
1,
5,
10,
mockCommitMetadata,
mockCommitMetadata,
20,
false
) // First call should add the range
val (isValid, message) =
validator.validateSubPartition(
1,
0,
15,
mockCommitMetadata,
mockCommitMetadata,
20,
false
) // This overlaps
isValid shouldBe (false)
message should include("Encountered overlapping map range for partitionId: 1")
}
test("AQEPartitionCompletenessValidator should fail validation for one map range subsuming another map range - 2") {
validator = new PartitionCompletenessValidator
validator.validateSubPartition(
1,
0,
15,
mockCommitMetadata,
mockCommitMetadata,
20,
false
) // First call should add the range
val (isValid, message) =
validator.validateSubPartition(
1,
5,
10,
mockCommitMetadata,
mockCommitMetadata,
20,
false
) // This overlaps
isValid shouldBe (false)
message should include("Encountered overlapping map range for partitionId: 1")
}
test(
"AQEPartitionCompletenessValidator should fail validation if commit metadata does not match") {
val expectedCommitMetadata = new CommitMetadata()
val actuaCommitMetadata = new CommitMetadata()
validator = new PartitionCompletenessValidator
validator.validateSubPartition(
1,
0,
10,
actuaCommitMetadata,
expectedCommitMetadata,
20,
false
) // Write first partition
// Second one with a different count
val (isValid, message) =
validator.validateSubPartition(
1,
0,
10,
new CommitMetadata(3, 3),
expectedCommitMetadata,
20,
false)
isValid should be(false)
message should include("Commit Metadata for partition: 1 not matching for sub-partition")
}
test("pass validation if written counts are correct after all updates") {
validator = new PartitionCompletenessValidator
val expectedCommitMetadata = new CommitMetadata(0, 10)
validator.validateSubPartition(
1,
0,
10,
new CommitMetadata(0, 5),
expectedCommitMetadata,
20,
false)
val (isValid, message) = validator.validateSubPartition(
1,
10,
20,
new CommitMetadata(0, 5),
expectedCommitMetadata,
20,
false)
isValid should be(true)
message should be("Partition is complete")
}
test("handle multiple partitions correctly") {
validator = new PartitionCompletenessValidator
// Testing with multiple partitions to check isolation
val expectedCommitMetadataForPartition1 = new CommitMetadata(0, 10)
val expectedCommitMetadataForPartition2 = new CommitMetadata(0, 2)
validator.validateSubPartition(
1,
0,
10,
new CommitMetadata(0, 5),
expectedCommitMetadataForPartition1,
20,
false
) // Validate partition 1
validator.validateSubPartition(
2,
0,
5,
new CommitMetadata(0, 2),
expectedCommitMetadataForPartition2,
10,
false
) // Validate partition 2
val (isValid1, message1) = validator.validateSubPartition(
1,
0,
10,
new CommitMetadata(0, 5),
expectedCommitMetadataForPartition1,
20,
false)
val (isValid2, message2) = validator.validateSubPartition(
2,
0,
5,
new CommitMetadata(0, 2),
expectedCommitMetadataForPartition2,
10,
false)
isValid1 should be(true)
isValid2 should be(true)
message1 should be("Partition is valid but still waiting for more data")
message2 should be("Partition is valid but still waiting for more data")
val (isValid3, message3) = validator.validateSubPartition(
1,
10,
20,
new CommitMetadata(0, 5),
expectedCommitMetadataForPartition1,
20,
false)
val (isValid4, message4) = validator.validateSubPartition(
2,
5,
10,
new CommitMetadata(),
expectedCommitMetadataForPartition2,
10,
false)
isValid3 should be(true)
isValid4 should be(true)
message3 should be("Partition is complete")
message4 should be("Partition is complete")
}
}

View File

@ -0,0 +1,251 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.celeborn.client.commit
import org.scalatest.matchers.must.Matchers.{be, include}
import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper
import org.apache.celeborn.CelebornFunSuite
import org.apache.celeborn.common.CommitMetadata
class SkewHandlingWithoutMapRangeValidatorTest extends CelebornFunSuite {
private var validator: SkewHandlingWithoutMapRangeValidator = _
test("testProcessSubPartitionFirstTime") {
// Setup
validator = new SkewHandlingWithoutMapRangeValidator
val partitionId = 1
val startMapIndex = 10
val endMapIndex = 0
val metadata = new CommitMetadata
val expectedMapperCount = 20
// Execute
val (success, message) = validator.processSubPartition(
partitionId,
startMapIndex,
endMapIndex,
metadata,
expectedMapperCount)
success shouldBe true
message shouldBe ""
validator.isPartitionComplete(partitionId) shouldBe false
// Verify current metadata
val currentMetadata = validator.currentCommitMetadata(partitionId)
currentMetadata shouldNot be(null)
currentMetadata shouldBe metadata
}
test("testProcessMultipleSubPartitions") {
// Setup
validator = new SkewHandlingWithoutMapRangeValidator
val partitionId = 1
val startMapIndex = 10
val metadata1 = new CommitMetadata(5, 1000)
val metadata2 = new CommitMetadata(3, 500)
// Process first sub-partition
validator.processSubPartition(partitionId, startMapIndex, 0, metadata1, startMapIndex)
// Process second sub-partition
val (success, message) = validator.processSubPartition(
partitionId,
startMapIndex,
1,
metadata2,
startMapIndex)
// Verify
success shouldBe true
message shouldBe ""
validator.isPartitionComplete(partitionId) shouldBe false
// Verify current metadata
val currentMetadata = validator.currentCommitMetadata(partitionId)
currentMetadata shouldNot be(null)
metadata1.addCommitData(metadata2)
currentMetadata shouldBe metadata1
}
test("testPartitionComplete") {
// Setup - processing all sub-partitions from 0 to 9 for a total of 10
validator = new SkewHandlingWithoutMapRangeValidator
val partitionId = 1
val startMapIndex = 10
// Process all sub-partitions
for (i <- 0 until startMapIndex) {
val metadata = new CommitMetadata()
validator.processSubPartition(partitionId, startMapIndex, i, metadata, startMapIndex)
}
validator.isPartitionComplete(partitionId) shouldBe true
}
test("testInvalidStartEndMapIndex") {
// Setup - invalid case where startMapIndex <= endMapIndex
validator = new SkewHandlingWithoutMapRangeValidator
val partitionId = 1
val startMapIndex = 5
val endMapIndex = 5 // Equal, which should fail
val metadata = new CommitMetadata()
// This should throw IllegalArgumentException
intercept[IllegalArgumentException] {
validator.processSubPartition(partitionId, startMapIndex, endMapIndex, metadata, 20)
}
}
test("testMismatchSubPartitionCount") {
// Setup
validator = new SkewHandlingWithoutMapRangeValidator
val partitionId = 1
val metadata = new CommitMetadata()
// First call with startMapIndex = 10
validator.processSubPartition(partitionId, 10, 0, metadata, 10)
// Second call with different startMapIndex, which should fail
intercept[IllegalStateException] {
validator.processSubPartition(partitionId, 12, 1, metadata, 12)
}
}
test("testDuplicateSubPartitionWithSameMetadata") {
// Setup
validator = new SkewHandlingWithoutMapRangeValidator
val partitionId = 1
val startMapIndex = 10
val endMapIndex = 5
val metadata = new CommitMetadata(3, 100)
// Process sub-partition first time
validator.processSubPartition(partitionId, startMapIndex, endMapIndex, metadata, startMapIndex)
// Process same sub-partition again with identical metadata
val (success, message) = validator.processSubPartition(
partitionId,
startMapIndex,
endMapIndex,
metadata,
startMapIndex)
success shouldBe true
message shouldBe ""
validator.isPartitionComplete(partitionId) shouldBe false
}
test("testDuplicateSubPartitionWithDifferentMetadata") {
// Setup
validator = new SkewHandlingWithoutMapRangeValidator
val partitionId = 1
val startMapIndex = 10
val endMapIndex = 5
val metadata1 = new CommitMetadata(3, 100)
val metadata2 = new CommitMetadata(4, 200)
// Process sub-partition first time
validator.processSubPartition(partitionId, startMapIndex, endMapIndex, metadata1, startMapIndex)
// Process same sub-partition again with different metadata
val (success, message) = validator.processSubPartition(
partitionId,
startMapIndex,
endMapIndex,
metadata2,
startMapIndex)
success shouldBe false
message should include("Mismatch in metadata")
validator.isPartitionComplete(partitionId) shouldBe false
}
test("testProcessTooManySubPartitions") {
// Setup
validator = new SkewHandlingWithoutMapRangeValidator
val partitionId = 1
val startMapIndex = 10
// Process all sub-partitions
for (i <- 0 until startMapIndex) {
val metadata = new CommitMetadata()
validator.processSubPartition(partitionId, startMapIndex, i, metadata, startMapIndex)
}
// Try to process one more sub-partition, which should exceed the total count
val extraMetadata = new CommitMetadata()
// This should throw IllegalStateException
intercept[IllegalArgumentException] {
validator.processSubPartition(
partitionId,
startMapIndex,
startMapIndex,
extraMetadata,
startMapIndex)
}
}
test("multiple complete partitions") {
// Setup - we'll use 3 different partitions
validator = new SkewHandlingWithoutMapRangeValidator
val partition1 = 1
val partition2 = 2
val partition3 = 3
// Different sizes for each partition
val subPartitions1 = 5
val subPartitions2 = 8
val subPartitions3 = 3
// Process all sub-partitions for partition1
val expectedMetadata1 = new CommitMetadata()
for (i <- 0 until subPartitions1) {
val metadata = new CommitMetadata(i, 100 + i)
validator.processSubPartition(partition1, subPartitions1, i, metadata, subPartitions1)
expectedMetadata1.addCommitData(metadata)
}
// Process all sub-partitions for partition2
val expectedMetadata2 = new CommitMetadata()
for (i <- 0 until subPartitions2) {
val metadata = new CommitMetadata(i + 1, 200 + i)
validator.processSubPartition(partition2, subPartitions2, i, metadata, subPartitions2)
expectedMetadata2.addCommitData(metadata)
}
// Process only some sub-partitions for partition3
val expectedMetadata3 = new CommitMetadata()
for (i <- 0 until subPartitions3 - 1) { // Deliberately leave one out
val metadata = new CommitMetadata(i + 2, 300 + i)
validator.processSubPartition(partition3, subPartitions3, i, metadata, subPartitions3)
expectedMetadata3.addCommitData(metadata)
}
// Verify completion status for each partition
validator.isPartitionComplete(partition1) shouldBe true
validator.isPartitionComplete(partition2) shouldBe true
validator.isPartitionComplete(partition3) shouldBe false
validator.currentCommitMetadata(partition1) shouldBe expectedMetadata1
validator.currentCommitMetadata(partition2) shouldBe expectedMetadata2
validator.currentCommitMetadata(partition3) shouldBe expectedMetadata3
}
}

View File

@ -0,0 +1,91 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.celeborn.common;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.zip.CRC32;
public class CelebornCRC32 {
private final AtomicInteger current;
CelebornCRC32(int i) {
this.current = new AtomicInteger(i);
}
CelebornCRC32() {
this(0);
}
static int compute(byte[] bytes) {
CRC32 hashFunction = new CRC32();
hashFunction.update(bytes);
return (int) hashFunction.getValue();
}
static int compute(byte[] bytes, int offset, int length) {
CRC32 hashFunction = new CRC32();
hashFunction.update(bytes, offset, length);
return (int) hashFunction.getValue();
}
static int combine(int first, int second) {
first =
(((byte) second + (byte) first) & 0xFF)
| ((((byte) (second >> 8) + (byte) (first >> 8)) & 0xFF) << 8)
| ((((byte) (second >> 16) + (byte) (first >> 16)) & 0xFF) << 16)
| (((byte) (second >> 24) + (byte) (first >> 24)) << 24);
return first;
}
public void addChecksum(int checksum) {
while (true) {
int val = current.get();
int newVal = combine(checksum, val);
if (current.compareAndSet(val, newVal)) {
break;
}
}
}
void addData(byte[] bytes, int offset, int length) {
addChecksum(compute(bytes, offset, length));
}
int get() {
return current.get();
}
@Override
public String toString() {
return "CelebornCRC32{" + "current=" + current + '}';
}
@Override
public boolean equals(Object o) {
if (o == null || getClass() != o.getClass()) return false;
CelebornCRC32 that = (CelebornCRC32) o;
return current.get() == that.current.get();
}
@Override
public int hashCode() {
return Objects.hashCode(current);
}
}

View File

@ -0,0 +1,82 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.celeborn.common;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicLong;
public class CommitMetadata {
private final AtomicLong bytes;
private final CelebornCRC32 crc;
public CommitMetadata() {
this.bytes = new AtomicLong();
this.crc = new CelebornCRC32();
}
public CommitMetadata(int checksum, long numBytes) {
this.bytes = new AtomicLong(numBytes);
this.crc = new CelebornCRC32(checksum);
}
public void addDataWithOffsetAndLength(byte[] rawDataBuf, int offset, int length) {
this.bytes.addAndGet(length);
this.crc.addData(rawDataBuf, offset, length);
}
public void addCommitData(CommitMetadata commitMetadata) {
addCommitData(commitMetadata.getChecksum(), commitMetadata.getBytes());
}
public void addCommitData(int checksum, long numBytes) {
this.bytes.addAndGet(numBytes);
this.crc.addChecksum(checksum);
}
public int getChecksum() {
return crc.get();
}
public long getBytes() {
return bytes.get();
}
public static boolean checkCommitMetadata(CommitMetadata expected, CommitMetadata actual) {
boolean bytesMatch = expected.getBytes() == actual.getBytes();
boolean checksumsMatch = expected.getChecksum() == actual.getChecksum();
return bytesMatch && checksumsMatch;
}
@Override
public String toString() {
return "CommitMetadata{" + "bytes=" + bytes.get() + ", crc=" + crc + '}';
}
@Override
public boolean equals(Object o) {
if (o == null || getClass() != o.getClass()) return false;
CommitMetadata that = (CommitMetadata) o;
return bytes.get() == that.bytes.get() && Objects.equals(crc, that.crc);
}
@Override
public int hashCode() {
return Objects.hash(bytes, crc);
}
}

View File

@ -92,7 +92,8 @@ public enum StatusCode {
SEGMENT_START_FAIL_REPLICA(52),
SEGMENT_START_FAIL_PRIMARY(53),
NO_SPLIT(54),
WORKER_UNRESPONSIVE(55);
WORKER_UNRESPONSIVE(55),
READ_REDUCER_PARTITION_END_FAILED(56);
private final byte value;

View File

@ -25,6 +25,7 @@ import java.util.concurrent.atomic.AtomicReference;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.CommitMetadata;
import org.apache.celeborn.common.protocol.PartitionLocation;
import org.apache.celeborn.common.util.JavaUtils;
@ -33,6 +34,9 @@ public class PushState {
private final int pushBufferMaxSize;
public AtomicReference<IOException> exception = new AtomicReference<>();
private final InFlightRequestTracker inFlightRequestTracker;
// partition id -> CommitMetadata
private final ConcurrentHashMap<Integer, CommitMetadata> commitMetadataMap =
new ConcurrentHashMap<>();
private final Map<String, LocationPushFailedBatches> failedBatchMap;
@ -102,4 +106,34 @@ public class PushState {
public Map<String, LocationPushFailedBatches> getFailedBatches() {
return this.failedBatchMap;
}
public int[] getCRC32PerPartition(boolean shuffleIntegrityCheckEnabled, int numPartitions) {
if (!shuffleIntegrityCheckEnabled) {
return new int[0];
}
int[] crc32PerPartition = new int[numPartitions];
for (Map.Entry<Integer, CommitMetadata> entry : commitMetadataMap.entrySet()) {
crc32PerPartition[entry.getKey()] = entry.getValue().getChecksum();
}
return crc32PerPartition;
}
public long[] getBytesWrittenPerPartition(
boolean shuffleIntegrityCheckEnabled, int numPartitions) {
if (!shuffleIntegrityCheckEnabled) {
return new long[0];
}
long[] bytesWrittenPerPartition = new long[numPartitions];
for (Map.Entry<Integer, CommitMetadata> entry : commitMetadataMap.entrySet()) {
bytesWrittenPerPartition[entry.getKey()] = entry.getValue().getBytes();
}
return bytesWrittenPerPartition;
}
public void addDataWithOffsetAndLength(int partitionId, byte[] data, int offset, int length) {
CommitMetadata commitMetadata =
commitMetadataMap.computeIfAbsent(partitionId, id -> new CommitMetadata());
commitMetadata.addDataWithOffsetAndLength(data, offset, length);
}
}

View File

@ -114,6 +114,8 @@ enum MessageType {
PUSH_MERGED_DATA_SPLIT_PARTITION_INFO = 91;
GET_STAGE_END = 92;
GET_STAGE_END_RESPONSE = 93;
READ_REDUCER_PARTITION_END = 94;
READ_REDUCER_PARTITION_END_RESPONSE = 95;
}
enum StreamType {
@ -360,6 +362,9 @@ message PbMapperEnd {
int32 numMappers = 4;
int32 partitionId = 5;
map<string, PbLocationPushFailedBatches> pushFailureBatches= 6;
int32 numPartitions = 7;
repeated int32 crc32PerPartition = 8;
repeated int64 bytesWrittenPerPartition = 9;
}
message PbLocationPushFailedBatches {
@ -920,3 +925,17 @@ message PbGetStageEndResponse {
message PbChunkOffsets {
repeated int64 chunkOffset = 1;
}
message PbReadReducerPartitionEnd {
int32 shuffleId = 1;
int32 partitionId = 2;
int32 startMaxIndex = 3;
int32 endMapIndex = 4;
int32 crc32 = 5;
int64 bytesWritten = 6;
}
message PbReadReducerPartitionEndResponse {
int32 status = 1;
string errorMsg = 2;
}

View File

@ -944,6 +944,8 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se
get(CLIENT_EXCLUDE_PEER_WORKER_ON_FAILURE_ENABLED)
def clientMrMaxPushData: Long = get(CLIENT_MR_PUSH_DATA_MAX)
def clientApplicationUUIDSuffixEnabled: Boolean = get(CLIENT_APPLICATION_UUID_SUFFIX_ENABLED)
def clientShuffleIntegrityCheckEnabled: Boolean =
get(CLIENT_SHUFFLE_INTEGRITY_CHECK_ENABLED)
def appUniqueIdWithUUIDSuffix(appId: String): String = {
if (clientApplicationUUIDSuffixEnabled) {
@ -5358,6 +5360,14 @@ object CelebornConf extends Logging {
.bytesConf(ByteUnit.BYTE)
.createWithDefaultString("512k")
val CLIENT_SHUFFLE_INTEGRITY_CHECK_ENABLED: ConfigEntry[Boolean] =
buildConf("celeborn.client.shuffle.integrityCheck.enabled")
.categories("client")
.version("0.6.1")
.doc("When `true`, enables end-to-end integrity checks for Spark workloads.")
.booleanConf
.createWithDefault(false)
val SPARK_SHUFFLE_WRITER_MODE: ConfigEntry[String] =
buildConf("celeborn.client.spark.shuffle.writer")
.withAlternative("celeborn.shuffle.writer")

View File

@ -19,9 +19,11 @@ package org.apache.celeborn.common.protocol.message
import java.util
import java.util.{Collections, UUID}
import java.util.concurrent.atomic.AtomicIntegerArray
import scala.collection.JavaConverters._
import com.google.common.base.Preconditions.checkState
import com.google.protobuf.ByteString
import org.roaringbitmap.RoaringBitmap
@ -274,11 +276,25 @@ object ControlMessages extends Logging {
attemptId: Int,
numMappers: Int,
partitionId: Int,
failedBatchSet: util.Map[String, LocationPushFailedBatches])
failedBatchSet: util.Map[String, LocationPushFailedBatches],
numPartitions: Int,
crc32PerPartition: Array[Int],
bytesWrittenPerPartition: Array[Long])
extends MasterMessage
case class ReadReducerPartitionEnd(
shuffleId: Int,
partitionId: Int,
startMapIndex: Int,
endMapIndex: Int,
crc32: Int,
bytesWritten: Long)
extends MasterMessage
case class MapperEndResponse(status: StatusCode) extends MasterMessage
case class ReadReducerPartitionEndResponse(status: StatusCode) extends MasterMessage
case class GetReducerFileGroup(
shuffleId: Int,
isSegmentGranularityVisible: Boolean,
@ -606,6 +622,12 @@ object ControlMessages extends Logging {
case pb: PbReviseLostShufflesResponse =>
new TransportMessage(MessageType.REVISE_LOST_SHUFFLES_RESPONSE, pb.toByteArray)
case pb: PbReadReducerPartitionEnd =>
new TransportMessage(MessageType.READ_REDUCER_PARTITION_END, pb.toByteArray)
case pb: PbReadReducerPartitionEndResponse =>
new TransportMessage(MessageType.READ_REDUCER_PARTITION_END_RESPONSE, pb.toByteArray)
case pb: PbReportBarrierStageAttemptFailure =>
new TransportMessage(MessageType.REPORT_BARRIER_STAGE_ATTEMPT_FAILURE, pb.toByteArray)
@ -739,7 +761,16 @@ object ControlMessages extends Logging {
case pb: PbChangeLocationResponse =>
new TransportMessage(MessageType.CHANGE_LOCATION_RESPONSE, pb.toByteArray)
case MapperEnd(shuffleId, mapId, attemptId, numMappers, partitionId, pushFailedBatch) =>
case MapperEnd(
shuffleId,
mapId,
attemptId,
numMappers,
partitionId,
pushFailedBatch,
numPartitions,
crc32PerPartition,
bytesWrittenPerPartition) =>
val pushFailedMap = pushFailedBatch.asScala.map { case (k, v) =>
val resultValue = PbSerDeUtils.toPbLocationPushFailedBatches(v)
(k, resultValue)
@ -751,6 +782,10 @@ object ControlMessages extends Logging {
.setNumMappers(numMappers)
.setPartitionId(partitionId)
.putAllPushFailureBatches(pushFailedMap)
.setNumPartitions(numPartitions)
.addAllCrc32PerPartition(crc32PerPartition.map(Integer.valueOf).toSeq.asJava)
.addAllBytesWrittenPerPartition(bytesWrittenPerPartition.map(
java.lang.Long.valueOf).toSeq.asJava)
.build().toByteArray
new TransportMessage(MessageType.MAPPER_END, payload)
@ -1176,6 +1211,15 @@ object ControlMessages extends Logging {
case MAPPER_END_VALUE =>
val pbMapperEnd = PbMapperEnd.parseFrom(message.getPayload)
val partitionCount = pbMapperEnd.getCrc32PerPartitionCount
checkState(partitionCount == pbMapperEnd.getBytesWrittenPerPartitionCount)
val crc32Array = new Array[Int](partitionCount)
val bytesWrittenPerPartitionArray =
new Array[Long](pbMapperEnd.getBytesWrittenPerPartitionCount)
for (i <- 0 until partitionCount) {
crc32Array(i) = pbMapperEnd.getCrc32PerPartition(i)
bytesWrittenPerPartitionArray(i) = pbMapperEnd.getBytesWrittenPerPartition(i)
}
MapperEnd(
pbMapperEnd.getShuffleId,
pbMapperEnd.getMapId,
@ -1185,7 +1229,23 @@ object ControlMessages extends Logging {
pbMapperEnd.getPushFailureBatchesMap.asScala.map {
case (partitionId, pushFailedBatchSet) =>
(partitionId, PbSerDeUtils.fromPbLocationPushFailedBatches(pushFailedBatchSet))
}.toMap.asJava)
}.toMap.asJava,
pbMapperEnd.getNumPartitions,
crc32Array,
bytesWrittenPerPartitionArray)
case READ_REDUCER_PARTITION_END_VALUE =>
val pbReadReducerPartitionEnd = PbReadReducerPartitionEnd.parseFrom(message.getPayload)
ReadReducerPartitionEnd(
pbReadReducerPartitionEnd.getShuffleId,
pbReadReducerPartitionEnd.getPartitionId,
pbReadReducerPartitionEnd.getStartMaxIndex,
pbReadReducerPartitionEnd.getEndMapIndex,
pbReadReducerPartitionEnd.getCrc32,
pbReadReducerPartitionEnd.getBytesWritten)
case READ_REDUCER_PARTITION_END_RESPONSE_VALUE =>
PbReadReducerPartitionEndResponse.parseFrom(message.getPayload)
case MAPPER_END_RESPONSE_VALUE =>
val pbMapperEndResponse = PbMapperEndResponse.parseFrom(message.getPayload)

View File

@ -0,0 +1,71 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.celeborn.common;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotEquals;
import org.junit.Test;
// Test data is generated from https://crccalc.com/ using "CRC-32/ISO-HDLC".
public class CelebornCRC32Test {
@Test
public void testCompute() {
byte[] data = "test".getBytes();
int checksum = CelebornCRC32.compute(data);
assertEquals(3632233996L, checksum & 0xFFFFFFFFL);
}
@Test
public void testComputeWithOffset() {
byte[] data = "testdata".getBytes();
int checksum = CelebornCRC32.compute(data, 4, 4);
assertEquals(2918445923L, checksum & 0xFFFFFFFFL);
}
@Test
public void testCombine() {
int first = 123456789;
int second = 987654321;
int combined = CelebornCRC32.combine(first, second);
assertNotEquals(first, combined);
assertNotEquals(second, combined);
}
@Test
public void testAddChecksum() {
CelebornCRC32 crc = new CelebornCRC32();
crc.addChecksum(123456789);
assertEquals(123456789, crc.get());
}
@Test
public void testAddDataWithOffset() {
CelebornCRC32 crc = new CelebornCRC32();
byte[] data = "testdata".getBytes();
crc.addData(data, 4, 4);
assertEquals(2918445923L, crc.get() & 0xFFFFFFFFL);
}
@Test
public void testToString() {
CelebornCRC32 crc = new CelebornCRC32(123456789);
assertEquals("CelebornCRC32{current=123456789}", crc.toString());
}
}

View File

@ -0,0 +1,72 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.celeborn.common;
import static org.junit.Assert.assertEquals;
import org.junit.Assert;
import org.junit.Test;
public class CommitMetadataTest {
@Test
public void testAddDataWithOffsetAndLength() {
CommitMetadata metadata = new CommitMetadata();
byte[] data = "testdata".getBytes();
metadata.addDataWithOffsetAndLength(data, 4, 4);
assertEquals(4, metadata.getBytes());
assertEquals(CelebornCRC32.compute(data, 4, 4), metadata.getChecksum());
}
@Test
public void testAddCommitData() {
CommitMetadata metadata1 = new CommitMetadata();
byte[] data1 = "test".getBytes();
metadata1.addDataWithOffsetAndLength(data1, 0, data1.length);
CommitMetadata metadata2 = new CommitMetadata();
byte[] data2 = "data".getBytes();
metadata2.addDataWithOffsetAndLength(data2, 0, data2.length);
metadata1.addCommitData(metadata2);
// Verify that the metadata1 now contains the combined data and checksum of data1 and data2
assertEquals(data1.length + data2.length, metadata1.getBytes());
assertEquals(
CelebornCRC32.combine(CelebornCRC32.compute(data1), CelebornCRC32.compute(data2)),
metadata1.getChecksum());
}
@Test
public void testCheckCommitMetadata() {
CommitMetadata expected = new CommitMetadata(CelebornCRC32.compute("testdata".getBytes()), 8);
CommitMetadata actualMatching =
new CommitMetadata(CelebornCRC32.compute("testdata".getBytes()), 8);
CommitMetadata actualNonMatchingBytesOnly =
new CommitMetadata(CelebornCRC32.compute("testdata".getBytes()), 12);
CommitMetadata actualNonMatchingChecksumOnly =
new CommitMetadata(CelebornCRC32.compute("foo".getBytes()), 8);
CommitMetadata actualNonMatching =
new CommitMetadata(CelebornCRC32.compute("bar".getBytes()), 16);
Assert.assertTrue(CommitMetadata.checkCommitMetadata(expected, actualMatching));
Assert.assertFalse(CommitMetadata.checkCommitMetadata(expected, actualNonMatchingBytesOnly));
Assert.assertFalse(CommitMetadata.checkCommitMetadata(expected, actualNonMatchingChecksumOnly));
Assert.assertFalse(CommitMetadata.checkCommitMetadata(expected, actualNonMatching));
}
}

View File

@ -20,6 +20,9 @@ package org.apache.celeborn.common.util
import java.util
import java.util.Collections
import org.scalatest.matchers.must.Matchers.contain
import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper
import org.apache.celeborn.CelebornFunSuite
import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.client.{MasterEndpointResolver, StaticMasterEndpointResolver}
@ -145,10 +148,19 @@ class UtilsSuite extends CelebornFunSuite {
}
test("MapperEnd class convert with pb") {
val mapperEnd = MapperEnd(1, 1, 1, 2, 1, Collections.emptyMap())
val mapperEnd =
MapperEnd(1, 1, 1, 2, 1, Collections.emptyMap(), 1, Array.emptyIntArray, Array.emptyLongArray)
val mapperEndTrans =
Utils.fromTransportMessage(Utils.toTransportMessage(mapperEnd)).asInstanceOf[MapperEnd]
assert(mapperEnd == mapperEndTrans)
assert(mapperEnd.shuffleId == mapperEndTrans.shuffleId)
assert(mapperEnd.mapId == mapperEndTrans.mapId)
assert(mapperEnd.attemptId == mapperEndTrans.attemptId)
assert(mapperEnd.numMappers == mapperEndTrans.numMappers)
assert(mapperEnd.partitionId == mapperEndTrans.partitionId)
assert(mapperEnd.failedBatchSet == mapperEndTrans.failedBatchSet)
assert(mapperEnd.numPartitions == mapperEndTrans.numPartitions)
mapperEnd.crc32PerPartition.array should contain theSameElementsInOrderAs mapperEndTrans.crc32PerPartition
mapperEnd.bytesWrittenPerPartition.array should contain theSameElementsInOrderAs mapperEndTrans.bytesWrittenPerPartition
}
test("validate HDFS compatible fs path") {

View File

@ -105,6 +105,7 @@ license: |
| celeborn.client.shuffle.dynamicResourceEnabled | false | false | When enabled, the ChangePartitionManager will obtain candidate workers from the availableWorkers pool during heartbeats when worker resource change. | 0.6.0 | |
| celeborn.client.shuffle.dynamicResourceFactor | 0.5 | false | The ChangePartitionManager will check whether (unavailable workers / shuffle allocated workers) is more than the factor before obtaining candidate workers from the requestSlots RPC response when `celeborn.client.shuffle.dynamicResourceEnabled` set true | 0.6.0 | |
| celeborn.client.shuffle.expired.checkInterval | 60s | false | Interval for client to check expired shuffles. | 0.3.0 | celeborn.shuffle.expired.checkInterval |
| celeborn.client.shuffle.integrityCheck.enabled | false | false | When `true`, enables end-to-end integrity checks for Spark workloads. | 0.6.1 | |
| celeborn.client.shuffle.manager.port | 0 | false | Port used by the LifecycleManager on the Driver. | 0.3.0 | celeborn.shuffle.manager.port |
| celeborn.client.shuffle.partition.type | REDUCE | false | Type of shuffle's partition. | 0.3.0 | celeborn.shuffle.partition.type |
| celeborn.client.shuffle.partitionSplit.mode | SOFT | false | soft: the shuffle file size might be larger than split threshold. hard: the shuffle file size will be limited to split threshold. | 0.3.0 | celeborn.shuffle.partitionSplit.mode |

View File

@ -70,8 +70,9 @@ class LifecycleManagerCommitFilesSuite extends WithShuffleClientSuite with MiniC
res.workerResource,
updateEpoch = false)
lifecycleManager.commitManager.registerShuffle(shuffleId, 1, false)
0 until 10 foreach { partitionId =>
val numPartitions = 10
lifecycleManager.commitManager.registerShuffle(shuffleId, 1, false, numPartitions)
0 until numPartitions foreach { partitionId =>
lifecycleManager.commitManager.finishMapperAttempt(shuffleId, 0, 0, 1, partitionId)
}
@ -126,8 +127,9 @@ class LifecycleManagerCommitFilesSuite extends WithShuffleClientSuite with MiniC
res.workerResource,
updateEpoch = false)
lifecycleManager.commitManager.registerShuffle(shuffleId, 1, false)
0 until 10 foreach { partitionId =>
val numPartitions = 10
lifecycleManager.commitManager.registerShuffle(shuffleId, 1, false, numPartitions)
0 until numPartitions foreach { partitionId =>
lifecycleManager.commitManager.finishMapperAttempt(shuffleId, 0, 0, 1, partitionId)
}
@ -196,8 +198,9 @@ class LifecycleManagerCommitFilesSuite extends WithShuffleClientSuite with MiniC
res.workerResource,
updateEpoch = false)
lifecycleManager.commitManager.registerShuffle(shuffleId, 1, false)
0 until 1000 foreach { partitionId =>
val numPartitions = 1000
lifecycleManager.commitManager.registerShuffle(shuffleId, 1, false, numPartitions)
0 until numPartitions foreach { partitionId =>
lifecycleManager.commitManager.finishMapperAttempt(shuffleId, 0, 0, 1, partitionId)
}
@ -256,13 +259,14 @@ class LifecycleManagerCommitFilesSuite extends WithShuffleClientSuite with MiniC
res.workerResource,
updateEpoch = false)
lifecycleManager.commitManager.registerShuffle(shuffleId, 1, false)
val numPartitions = 3
lifecycleManager.commitManager.registerShuffle(shuffleId, 1, false, numPartitions)
val buffer = "hello world".getBytes(StandardCharsets.UTF_8)
var bufferLength = -1
0 until 3 foreach { partitionId =>
0 until numPartitions foreach { partitionId =>
bufferLength =
shuffleClient.pushData(shuffleId, 0, 0, partitionId, buffer, 0, buffer.length, 1, 3)
lifecycleManager.commitManager.finishMapperAttempt(shuffleId, 0, 0, 1, partitionId)

View File

@ -157,7 +157,7 @@ class LifecycleManagerReserveSlotsSuite extends AnyFunSuite
// push merged data, we expect that partition(0) will be split, while partition(1) will not be split
shuffleClient1.pushMergedData(SHUFFLE_ID, MAP_ID, ATTEMPT_ID)
shuffleClient1.mapperEnd(SHUFFLE_ID, MAP_ID, ATTEMPT_ID, MAP_NUM)
shuffleClient1.mapperEnd(SHUFFLE_ID, MAP_ID, ATTEMPT_ID, MAP_NUM, PARTITION_NUM)
// partition(1) will not be split
assert(partitionLocationMap1.get(partitions(1)).getEpoch == 0)

View File

@ -0,0 +1,158 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.celeborn.tests.spark
import org.apache.spark._
import org.apache.spark.shuffle.celeborn.{SparkUtils, TestCelebornShuffleManager}
import org.apache.spark.sql.SparkSession
import org.scalatest.BeforeAndAfterEach
import org.scalatest.funsuite.AnyFunSuite
import org.slf4j.{Logger, LoggerFactory}
import org.apache.celeborn.client.ShuffleClient
import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.protocol.{CompressionCodec, ShuffleMode}
class CelebornIntegrityCheckSuite extends AnyFunSuite
with SparkTestBase
with BeforeAndAfterEach {
private val logger = LoggerFactory.getLogger(classOf[CelebornIntegrityCheckSuite])
override def beforeEach(): Unit = {
ShuffleClient.reset()
}
override def afterEach(): Unit = {
System.gc()
}
test("celeborn spark integration test - corrupted data, no integrity check - app succeeds but data is different") {
if (Spark3OrNewer) {
val sparkConf = new SparkConf().setAppName("rss-demo").setMaster("local[2,3]")
val sparkSession = SparkSession.builder()
.config(updateSparkConf(sparkConf, ShuffleMode.HASH))
.config("spark.sql.shuffle.partitions", 2)
.config(
"spark.shuffle.manager",
"org.apache.spark.shuffle.celeborn.TestCelebornShuffleManager")
.config(
s"spark.${CelebornConf.SHUFFLE_COMPRESSION_CODEC.key}",
CompressionCodec.NONE.toString)
.getOrCreate()
// Introduce Data Corruption in single bit in 1 partition location file
val celebornConf = SparkUtils.fromSparkConf(sparkSession.sparkContext.getConf)
val hook = new ShuffleReaderGetHookForCorruptedData(celebornConf, workerDirs)
TestCelebornShuffleManager.registerReaderGetHook(hook)
val value = Range(1, 10000).mkString(",")
val tuples = sparkSession.sparkContext.parallelize(1 to 1000, 2)
.map { i => (i, value) }.groupByKey(16).collect()
assert(tuples.length == 1000)
try {
for (elem <- tuples) {
assert(elem._2.mkString(",").equals(value))
}
} catch {
case e: Throwable =>
e.getMessage.contains("elem._2.mkString(\",\").equals(value) was false")
} finally {
sparkSession.stop()
}
}
}
test("celeborn spark integration test - corrupted data, integrity checks enabled, no stage rerun - app fails") {
if (Spark3OrNewer) {
val sparkConf = new SparkConf().setAppName("rss-demo").setMaster("local[2,3]")
val sparkSession = SparkSession.builder()
.config(updateSparkConf(sparkConf, ShuffleMode.HASH))
.config("spark.sql.shuffle.partitions", 2)
.config(
s"spark.${CelebornConf.SHUFFLE_COMPRESSION_CODEC.key}",
CompressionCodec.NONE.toString)
.config(s"spark.${CelebornConf.CLIENT_STAGE_RERUN_ENABLED.key}", "false")
.config("spark.celeborn.client.shuffle.integrityCheck.enabled", "true")
.config(
"spark.shuffle.manager",
"org.apache.spark.shuffle.celeborn.TestCelebornShuffleManager")
.getOrCreate()
// Introduce Data Corruption in single bit in 1 partition location file
val celebornConf = SparkUtils.fromSparkConf(sparkSession.sparkContext.getConf)
val hook = new ShuffleReaderGetHookForCorruptedData(celebornConf, workerDirs)
TestCelebornShuffleManager.registerReaderGetHook(hook)
try {
val value = Range(1, 10000).mkString(",")
val tuples = sparkSession.sparkContext.parallelize(1 to 1000, 2)
.map { i => (i, value) }.groupByKey(16).collect()
fail("App should abort prior to this step and throw an exception")
} catch {
// verify that the app fails
case e: Throwable => {
logger.error("Expected exception, logging the full exception", e)
assert(e.getMessage.contains("Job aborted"))
assert(e.getMessage.contains("CommitMetadata mismatch"))
}
} finally {
sparkSession.stop()
}
}
}
test("celeborn spark integration test - corrupted data, integrity checks enabled, stage rerun enabled - app succeeds") {
if (Spark3OrNewer) {
val sparkConf = new SparkConf().setAppName("rss-demo").setMaster("local[2,3]")
val sparkSession = SparkSession.builder()
.config(updateSparkConf(sparkConf, ShuffleMode.HASH))
.config("spark.sql.shuffle.partitions", 2)
.config(
s"spark.${CelebornConf.SHUFFLE_COMPRESSION_CODEC.key}",
CompressionCodec.NONE.toString)
.config(s"spark.${CelebornConf.CLIENT_STAGE_RERUN_ENABLED.key}", "true")
.config("spark.celeborn.client.shuffle.integrityCheck.enabled", "true")
.config(
"spark.shuffle.manager",
"org.apache.spark.shuffle.celeborn.TestCelebornShuffleManager")
.getOrCreate()
// Introduce Data Corruption in single bit in 1 partition location file
val celebornConf = SparkUtils.fromSparkConf(sparkSession.sparkContext.getConf)
val hook = new ShuffleReaderGetHookForCorruptedData(celebornConf, workerDirs)
TestCelebornShuffleManager.registerReaderGetHook(hook)
try {
val value = Range(1, 10000).mkString(",")
val tuples = sparkSession.sparkContext.parallelize(1 to 1000, 2)
.map { i => (i, value) }.groupByKey(16).collect()
// verify result
assert(tuples.length == 1000)
for (elem <- tuples) {
assert(elem._2.mkString(",").equals(value))
}
} finally {
sparkSession.stop()
}
}
}
}

View File

@ -0,0 +1,252 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.celeborn.tests.spark
import java.io.{File, RandomAccessFile}
import java.util.Random
import java.util.concurrent.atomic.AtomicBoolean
import scala.util.control.Breaks.{break, breakable}
import org.apache.spark.TaskContext
import org.apache.spark.shuffle.ShuffleHandle
import org.apache.spark.shuffle.celeborn.{CelebornShuffleHandle, ShuffleManagerHook, SparkCommonUtils, SparkUtils}
import org.slf4j.LoggerFactory
import org.apache.celeborn.client.ShuffleClient
import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.unsafe.Platform
class ShuffleReaderGetHookForCorruptedData(
conf: CelebornConf,
workerDirs: Seq[String],
shuffleIdToBeModified: Seq[Int] = Seq(),
triggerStageId: Option[Int] = None)
extends ShuffleManagerHook {
private val logger = LoggerFactory.getLogger(classOf[ShuffleReaderGetHookForCorruptedData])
var executed: AtomicBoolean = new AtomicBoolean(false)
val lock = new Object
var corruptedCount = 0
val random = new Random()
private def modifyDataFileWithSingleBitFlip(appUniqueId: String, celebornShuffleId: Int): Unit = {
if (corruptedCount > 0) {
return
}
val minFileSize = 16 // Minimum file size to be considered for corruption
// Find all potential data files across all worker directories
val dataFiles =
workerDirs.flatMap(dir => {
val shuffleDir =
new File(s"$dir/celeborn-worker/shuffle_data/$appUniqueId/$celebornShuffleId")
if (shuffleDir.exists()) {
shuffleDir.listFiles().filter(_.isFile).filter(_.length() >= minFileSize)
} else {
Array.empty[File]
}
})
if (dataFiles.isEmpty) {
logger.error(
s"No suitable data files found for appUniqueId=$appUniqueId, shuffleId=$celebornShuffleId")
return
}
// Sort files by size (descending) to prioritize larger files that are more likely to have data
val sortedFiles = dataFiles.sortBy(-_.length())
logger.info(s"Found ${sortedFiles.length} data files for corruption testing.")
// Try to corrupt files one by one until successful
breakable {
for (file <- sortedFiles) {
if (tryCorruptFile(file)) {
corruptedCount = 1
break
}
}
}
// If we couldn't corrupt any file through the normal process, use the fallback method
if (corruptedCount == 0 && sortedFiles.nonEmpty) {
logger.warn(
"Could not find a valid data section in any file. Using safer fallback corruption method.")
val fileToCorrupt = sortedFiles.head // Take the largest file for fallback
if (fallbackCorruption(fileToCorrupt)) {
corruptedCount = 1
}
}
}
/**
* Try to corrupt a specific file by finding and corrupting a valid data section.
* @return true if corruption was successful, false otherwise
*/
private def tryCorruptFile(file: File): Boolean = {
val random = new Random()
logger.info(s"Attempting to corrupt file: ${file.getPath()}, size: ${file.length()} bytes")
val raf = new RandomAccessFile(file, "rw")
try {
val fileSize = raf.length()
// Find a valid data section in the file
var position: Long = 0
val headerSize = 16
breakable {
while (position + headerSize <= fileSize) {
raf.seek(position)
val sizeBuf = new Array[Byte](headerSize)
raf.readFully(sizeBuf)
// Parse header
val mapId: Int = Platform.getInt(sizeBuf, Platform.BYTE_ARRAY_OFFSET)
val attemptId: Int = Platform.getInt(sizeBuf, Platform.BYTE_ARRAY_OFFSET + 4)
val batchId: Int = Platform.getInt(sizeBuf, Platform.BYTE_ARRAY_OFFSET + 8)
val dataSize: Int = Platform.getInt(sizeBuf, Platform.BYTE_ARRAY_OFFSET + 12)
logger.info(s"Found header at position $position: mapId=$mapId, attemptId=$attemptId, " +
s"batchId=$batchId, dataSize=$dataSize")
// Validate dataSize - should be positive and within file bounds
if (dataSize > 0 && position + headerSize + dataSize <= fileSize) {
val dataStartPosition = position + headerSize
// Choose a random position within the data section
val offsetInData = random.nextInt(dataSize)
val corruptPosition = dataStartPosition + offsetInData
// Read the byte at the position we want to corrupt
raf.seek(corruptPosition)
val originalByte = raf.readByte()
// Flip one bit in the byte (avoid the sign bit for numeric values)
val bitToFlip = 1 << random.nextInt(7) // bits 0-6 only, avoiding the sign bit
val corruptedByte = (originalByte ^ bitToFlip).toByte
// Write back the corrupted byte
raf.seek(corruptPosition)
raf.writeByte(corruptedByte)
logger.info(s"Successfully corrupted byte in file ${file.getName()} at position " +
s"$corruptPosition: ${originalByte & 0xFF} -> ${corruptedByte & 0xFF} (flipped bit $bitToFlip)")
return true // Corruption successful
}
// Skip to next record: header size + data size (even if data size is 0)
position += headerSize + Math.max(0, dataSize)
}
}
false // No valid data section found
} catch {
case e: Exception =>
logger.error(s"Error while attempting to corrupt file ${file.getPath()}", e)
false
} finally {
raf.close()
}
}
/**
* Fallback method for corruption when we can't find a valid data section.
* This method simply corrupts a byte in the middle of the file.
*/
private def fallbackCorruption(file: File): Boolean = {
val raf = new RandomAccessFile(file, "rw")
try {
val fileSize = raf.length()
// Skip first 64 bytes to avoid headers, and corrupt somewhere in the middle of the file
val safeStartPosition = Math.min(64, fileSize / 4)
val safeEndPosition = Math.max(safeStartPosition + 1, fileSize - 16)
// Ensure we have a valid range
if (safeEndPosition <= safeStartPosition) {
logger.error(s"File ${file.getName()} too small for safe corruption: $fileSize bytes")
return false
}
val corruptPosition = safeStartPosition +
random.nextInt((safeEndPosition - safeStartPosition).toInt)
raf.seek(corruptPosition)
val originalByte = raf.readByte()
val bitToFlip = 1 << random.nextInt(7)
val corruptedByte = (originalByte ^ bitToFlip).toByte
raf.seek(corruptPosition)
raf.writeByte(corruptedByte)
logger.info(s"Used fallback corruption approach on file ${file.getName()}. " +
s"Corrupted byte at position $corruptPosition: ${originalByte & 0xFF} -> ${corruptedByte & 0xFF}")
true
} catch {
case e: Exception =>
logger.error(s"Error during fallback corruption of file ${file.getPath()}", e)
false
} finally {
raf.close()
}
}
override def exec(
handle: ShuffleHandle,
startPartition: Int,
endPartition: Int,
context: TaskContext): Unit = {
if (executed.get()) {
return
}
lock.synchronized {
handle match {
case h: CelebornShuffleHandle[_, _, _] => {
val appUniqueId = h.appUniqueId
val shuffleClient = ShuffleClient.get(
h.appUniqueId,
h.lifecycleManagerHost,
h.lifecycleManagerPort,
conf,
h.userIdentifier,
h.extension)
val celebornShuffleId = SparkUtils.celebornShuffleId(shuffleClient, h, context, false)
val appShuffleIdentifier =
SparkCommonUtils.encodeAppShuffleIdentifier(handle.shuffleId, context)
val Array(_, stageId, _) = appShuffleIdentifier.split('-')
if (triggerStageId.isEmpty || triggerStageId.get == stageId.toInt) {
if (shuffleIdToBeModified.isEmpty) {
modifyDataFileWithSingleBitFlip(appUniqueId, celebornShuffleId)
} else {
shuffleIdToBeModified.foreach { shuffleId =>
modifyDataFileWithSingleBitFlip(appUniqueId, shuffleId)
}
}
executed.set(true)
}
}
case x => throw new RuntimeException(s"unexpected, only support RssShuffleHandle here," +
s" but get ${x.getClass.getCanonicalName}")
}
}
}
}

View File

@ -49,70 +49,74 @@ class SkewJoinSuite extends AnyFunSuite
CompressionCodec.values.foreach { codec =>
Seq(false, true).foreach { enabled =>
test(
s"celeborn spark integration test - skew join - with $codec - with client skew $enabled") {
val sparkConf = new SparkConf().setAppName("celeborn-demo")
.setMaster("local[2]")
.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true")
.set("spark.sql.adaptive.skewJoin.enabled", "true")
.set("spark.sql.adaptive.skewJoin.skewedPartitionThresholdInBytes", "16MB")
.set("spark.sql.adaptive.advisoryPartitionSizeInBytes", "12MB")
.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
.set("spark.sql.adaptive.autoBroadcastJoinThreshold", "-1")
.set(SQLConf.PARQUET_COMPRESSION.key, "gzip")
.set(s"spark.${CelebornConf.SHUFFLE_COMPRESSION_CODEC.key}", codec.name)
.set(s"spark.${CelebornConf.SHUFFLE_RANGE_READ_FILTER_ENABLED.key}", "true")
.set(
s"spark.${CelebornConf.CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ_ENABLED.key}",
s"$enabled")
Seq(false, true).foreach { integrityChecksEnabled =>
test(
s"celeborn spark integration test - skew join - with $codec - with client skew $enabled - with integrity checks $integrityChecksEnabled") {
val sparkConf = new SparkConf().setAppName("celeborn-demo")
.setMaster("local[2]")
.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true")
.set("spark.sql.adaptive.skewJoin.enabled", "true")
.set("spark.sql.adaptive.skewJoin.skewedPartitionThresholdInBytes", "16MB")
.set("spark.sql.adaptive.advisoryPartitionSizeInBytes", "12MB")
.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
.set("spark.sql.adaptive.autoBroadcastJoinThreshold", "-1")
.set(SQLConf.PARQUET_COMPRESSION.key, "gzip")
.set(s"spark.${CelebornConf.SHUFFLE_COMPRESSION_CODEC.key}", codec.name)
.set(s"spark.${CelebornConf.SHUFFLE_RANGE_READ_FILTER_ENABLED.key}", "true")
.set(
s"spark.${CelebornConf.CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ_ENABLED.key}",
s"$enabled")
.set(
s"spark.${CelebornConf.CLIENT_SHUFFLE_INTEGRITY_CHECK_ENABLED.key}",
s"$integrityChecksEnabled")
enableCeleborn(sparkConf)
enableCeleborn(sparkConf)
val sparkSession = SparkSession.builder().config(sparkConf).getOrCreate()
if (sparkSession.version.startsWith("3")) {
import sparkSession.implicits._
val df = sparkSession.sparkContext.parallelize(1 to 120000, 8)
.map(i => {
val random = new Random()
val oriKey = random.nextInt(64)
val key = if (oriKey < 32) 1 else oriKey
val fas = random.nextInt(1200000)
val fa = Range(fas, fas + 100).mkString(",")
val fbs = random.nextInt(1200000)
val fb = Range(fbs, fbs + 100).mkString(",")
val fcs = random.nextInt(1200000)
val fc = Range(fcs, fcs + 100).mkString(",")
val fds = random.nextInt(1200000)
val fd = Range(fds, fds + 100).mkString(",")
val sparkSession = SparkSession.builder().config(sparkConf).getOrCreate()
if (sparkSession.version.startsWith("3")) {
import sparkSession.implicits._
val df = sparkSession.sparkContext.parallelize(1 to 120000, 8)
.map(i => {
val random = new Random()
val oriKey = random.nextInt(64)
val key = if (oriKey < 32) 1 else oriKey
val fas = random.nextInt(1200000)
val fa = Range(fas, fas + 100).mkString(",")
val fbs = random.nextInt(1200000)
val fb = Range(fbs, fbs + 100).mkString(",")
val fcs = random.nextInt(1200000)
val fc = Range(fcs, fcs + 100).mkString(",")
val fds = random.nextInt(1200000)
val fd = Range(fds, fds + 100).mkString(",")
(key, fa, fb, fc, fd)
})
.toDF("fa", "f1", "f2", "f3", "f4")
df.createOrReplaceTempView("view1")
val df2 = sparkSession.sparkContext.parallelize(1 to 8, 8)
.map(i => {
val random = new Random()
val oriKey = random.nextInt(64)
val key = if (oriKey < 32) 1 else oriKey
val fas = random.nextInt(1200000)
val fa = Range(fas, fas + 100).mkString(",")
val fbs = random.nextInt(1200000)
val fb = Range(fbs, fbs + 100).mkString(",")
val fcs = random.nextInt(1200000)
val fc = Range(fcs, fcs + 100).mkString(",")
val fds = random.nextInt(1200000)
val fd = Range(fds, fds + 100).mkString(",")
(key, fa, fb, fc, fd)
})
.toDF("fb", "f6", "f7", "f8", "f9")
df2.createOrReplaceTempView("view2")
sparkSession.sql("drop table if exists fres")
sparkSession.sql("create table fres using parquet as select * from view1 a inner join view2 b on a.fa=b.fb where a.fa=1 ")
sparkSession.sql("drop table fres")
sparkSession.stop()
(key, fa, fb, fc, fd)
})
.toDF("fa", "f1", "f2", "f3", "f4")
df.createOrReplaceTempView("view1")
val df2 = sparkSession.sparkContext.parallelize(1 to 8, 8)
.map(i => {
val random = new Random()
val oriKey = random.nextInt(64)
val key = if (oriKey < 32) 1 else oriKey
val fas = random.nextInt(1200000)
val fa = Range(fas, fas + 100).mkString(",")
val fbs = random.nextInt(1200000)
val fb = Range(fbs, fbs + 100).mkString(",")
val fcs = random.nextInt(1200000)
val fc = Range(fcs, fcs + 100).mkString(",")
val fds = random.nextInt(1200000)
val fd = Range(fds, fds + 100).mkString(",")
(key, fa, fb, fc, fd)
})
.toDF("fb", "f6", "f7", "f8", "f9")
df2.createOrReplaceTempView("view2")
sparkSession.sql("drop table if exists fres")
sparkSession.sql("create table fres using parquet as select * from view1 a inner join view2 b on a.fa=b.fb where a.fa=1 ")
sparkSession.sql("drop table fres")
sparkSession.stop()
}
}
}
}
}
}

View File

@ -109,7 +109,7 @@ trait JavaReadCppWriteTestBase extends AnyFunSuite
}
}
shuffleClient.pushMergedData(shuffleId, mapId, attemptId)
shuffleClient.mapperEnd(shuffleId, mapId, attemptId, numMappers)
shuffleClient.mapperEnd(shuffleId, mapId, attemptId, numMappers, numPartitions)
}
// Launch cpp reader to read data, calculate result and write to specific result file.

View File

@ -116,7 +116,7 @@ class LocalReadByChunkOffsetsTest extends AnyFunSuite
shuffleClient.pushMergedData(1, 0, 0)
Thread.sleep(1000)
shuffleClient.mapperEnd(1, 0, 0, 1)
shuffleClient.mapperEnd(1, 0, 0, 1, 0)
val metricsCallback = new MetricsCallback {
override def incBytesRead(bytesWritten: Long): Unit = {}

View File

@ -153,7 +153,7 @@ class PushMergedDataSplitSuite extends AnyFunSuite
// push merged data, we expect that partition(0) will be split, while partition(1) will not be split
shuffleClient.pushMergedData(SHUFFLE_ID, MAP_ID, ATTEMPT_ID)
shuffleClient.mapperEnd(SHUFFLE_ID, MAP_ID, ATTEMPT_ID, MAP_NUM)
shuffleClient.mapperEnd(SHUFFLE_ID, MAP_ID, ATTEMPT_ID, MAP_NUM, PARTITION_NUM)
assert(
partitionLocationMap.get(partitions(1)).getEpoch == 0
) // means partition(1) will not be split

View File

@ -98,7 +98,7 @@ trait ReadWriteTestBase extends AnyFunSuite
shuffleClient.pushMergedData(1, 0, 0)
Thread.sleep(1000)
shuffleClient.mapperEnd(1, 0, 0, 1)
shuffleClient.mapperEnd(1, 0, 0, 1, 1)
val metricsCallback = new MetricsCallback {
override def incBytesRead(bytesWritten: Long): Unit = {}

View File

@ -92,7 +92,7 @@ class ReadWriteTestWithFailures extends AnyFunSuite
shuffleClient.pushMergedData(1, 0, 0)
Thread.sleep(1000)
shuffleClient.mapperEnd(1, 0, 0, 1)
shuffleClient.mapperEnd(1, 0, 0, 1, 1)
var duplicateBytesRead = new AtomicLong(0)
val metricsCallback = new MetricsCallback {