[CELEBORN-894] End to End Integrity Checks
### What changes were proposed in this pull request? Design doc - https://docs.google.com/document/d/1YqK0kua-5rMufJw57kEIrHHGbLnAF9iXM5GdDweMzzg/edit?tab=t.0#heading=h.n5ldma432qnd - End to End integrity checks provide additional confidence that Celeborn is producing complete as well as correct data - The checks are hidden behind a client side config that is false by default. Provides users optionality to enable these if required on a per app basis - Only compatible with Spark at the moment - No support for Flink (can be considered in future) - No support for Columnar Shuffle (can be considered in future) Writer - Whenever a mapper completes, it reports crc32 and bytes written on a per partition basis to the driver Driver - Driver aggregates the mapper reports - and computes aggregated CRC32 and bytes written on per partitionID basis Reader - Each CelebornInputStream will report (int shuffleId, int partitionId, int startMapIndex, int endMapIndex, int crc32, long bytes) to driver when it finished reading all data on the stream - On every report - Driver will aggregate the CRC32 and bytesRead for the partitionID - Driver will aggregate mapRange to determine when all sub-paritions have been read for partitionID have been read - It will then compare the aggregated CRC32 and bytes read with the expected CRC32 and bytes written for the partition - There is special handling for skewhandlingwithoutMapRangeSplit scenario as well - In this case, we report the number of sub-partitions and index of the sub-partition instead of startMapIndex and endMapIndex There is separate handling for skew handling with and without map range split As a follow up, I will do another PR that will harden up the checks and perform additional checks to add book keeping that every CelebornInputStream makes the required checks ### Why are the changes needed? https://issues.apache.org/jira/browse/CELEBORN-894 Note: I am putting up this PR even though some tests are failing, since I want to get some early feedback on the code changes. ### Does this PR introduce _any_ user-facing change? Not sure how to answer this. A new client side config is available to enable the checks if required ### How was this patch tested? Unit tests + Integration tests Closes #3261 from gauravkm/gaurav/e2e_checks_v3. Lead-authored-by: Gaurav Mittal <gaurav@stripe.com> Co-authored-by: Gaurav Mittal <gauravkm@gmail.com> Co-authored-by: Fei Wang <cn.feiwang@gmail.com> Signed-off-by: Shuang <lvshuang.xjs@alibaba-inc.com>
This commit is contained in:
parent
7a0eee332a
commit
cde33d953b
@ -197,7 +197,7 @@ public class RemoteShuffleOutputGate {
|
||||
/** Indicates the writing/spilling is finished. */
|
||||
public void finish() throws InterruptedException, IOException {
|
||||
flinkShuffleClient.mapPartitionMapperEnd(
|
||||
shuffleId, mapId, attemptId, numMappers, partitionLocation.getId());
|
||||
shuffleId, mapId, attemptId, numMappers, numSubs, partitionLocation.getId());
|
||||
}
|
||||
|
||||
/** Close the transportation gate. */
|
||||
|
||||
@ -107,7 +107,7 @@ public class RemoteShuffleOutputGateSuiteJ {
|
||||
|
||||
doNothing()
|
||||
.when(remoteShuffleOutputGate.flinkShuffleClient)
|
||||
.mapperEnd(anyInt(), anyInt(), anyInt(), anyInt());
|
||||
.mapperEnd(anyInt(), anyInt(), anyInt(), anyInt(), anyInt());
|
||||
remoteShuffleOutputGate.finish();
|
||||
|
||||
doNothing().when(remoteShuffleOutputGate.flinkShuffleClient).shutdown();
|
||||
|
||||
@ -234,7 +234,12 @@ public class CelebornTierProducerAgent implements TierProducerAgent {
|
||||
try {
|
||||
if (hasRegisteredShuffle && partitionLocation != null) {
|
||||
flinkShuffleClient.mapPartitionMapperEnd(
|
||||
shuffleId, mapId, attemptId, numPartitions, partitionLocation.getId());
|
||||
shuffleId,
|
||||
mapId,
|
||||
attemptId,
|
||||
numPartitions,
|
||||
numSubPartitions,
|
||||
partitionLocation.getId());
|
||||
}
|
||||
} catch (Exception e) {
|
||||
Utils.rethrowAsRuntimeException(e);
|
||||
|
||||
@ -234,7 +234,12 @@ public class CelebornTierProducerAgent implements TierProducerAgent {
|
||||
try {
|
||||
if (hasRegisteredShuffle && partitionLocation != null) {
|
||||
flinkShuffleClient.mapPartitionMapperEnd(
|
||||
shuffleId, mapId, attemptId, numPartitions, partitionLocation.getId());
|
||||
shuffleId,
|
||||
mapId,
|
||||
attemptId,
|
||||
numPartitions,
|
||||
numSubPartitions,
|
||||
partitionLocation.getId());
|
||||
}
|
||||
} catch (Exception e) {
|
||||
Utils.rethrowAsRuntimeException(e);
|
||||
|
||||
@ -316,7 +316,7 @@ public class CelebornSortBasedPusher<K, V> extends OutputStream {
|
||||
mapId,
|
||||
attempt,
|
||||
numMappers);
|
||||
shuffleClient.mapperEnd(0, mapId, attempt, numMappers);
|
||||
shuffleClient.mapperEnd(0, mapId, attempt, numMappers, numReducers);
|
||||
} catch (IOException e) {
|
||||
exception.compareAndSet(null, e);
|
||||
}
|
||||
|
||||
@ -369,7 +369,7 @@ public class HashBasedShuffleWriter<K, V, C> extends ShuffleWriter<K, V> {
|
||||
sendOffsets = null;
|
||||
|
||||
long waitStartTime = System.nanoTime();
|
||||
shuffleClient.mapperEnd(shuffleId, mapId, encodedAttemptId, numMappers);
|
||||
shuffleClient.mapperEnd(shuffleId, mapId, encodedAttemptId, numMappers, numPartitions);
|
||||
writeMetrics.incWriteTime(System.nanoTime() - waitStartTime);
|
||||
|
||||
BlockManagerId bmId = SparkEnv.get().blockManager().shuffleServerId();
|
||||
|
||||
@ -316,7 +316,7 @@ public class SortBasedShuffleWriter<K, V, C> extends ShuffleWriter<K, V> {
|
||||
updateMapStatus();
|
||||
|
||||
long waitStartTime = System.nanoTime();
|
||||
shuffleClient.mapperEnd(shuffleId, mapId, encodedAttemptId, numMappers);
|
||||
shuffleClient.mapperEnd(shuffleId, mapId, encodedAttemptId, numMappers, numPartitions);
|
||||
writeMetrics.incWriteTime(System.nanoTime() - waitStartTime);
|
||||
}
|
||||
|
||||
|
||||
@ -378,7 +378,7 @@ public class HashBasedShuffleWriter<K, V, C> extends ShuffleWriter<K, V> {
|
||||
updateRecordsWrittenMetrics();
|
||||
|
||||
long waitStartTime = System.nanoTime();
|
||||
shuffleClient.mapperEnd(shuffleId, mapId, encodedAttemptId, numMappers);
|
||||
shuffleClient.mapperEnd(shuffleId, mapId, encodedAttemptId, numMappers, numPartitions);
|
||||
writeMetrics.incWriteTime(System.nanoTime() - waitStartTime);
|
||||
|
||||
BlockManagerId bmId = SparkEnv.get().blockManager().shuffleServerId();
|
||||
|
||||
@ -381,7 +381,7 @@ public class SortBasedShuffleWriter<K, V, C> extends ShuffleWriter<K, V> {
|
||||
writeMetrics.incRecordsWritten(tmpRecordsWritten);
|
||||
|
||||
long waitStartTime = System.nanoTime();
|
||||
shuffleClient.mapperEnd(shuffleId, mapId, encodedAttemptId, numMappers);
|
||||
shuffleClient.mapperEnd(shuffleId, mapId, encodedAttemptId, numMappers, numPartitions);
|
||||
writeMetrics.incWriteTime(System.nanoTime() - waitStartTime);
|
||||
}
|
||||
|
||||
|
||||
@ -120,7 +120,7 @@ public class CelebornTezWriter {
|
||||
try {
|
||||
dataPusher.waitOnTermination();
|
||||
shuffleClient.pushMergedData(shuffleId, mapId, attemptNumber);
|
||||
shuffleClient.mapperEnd(shuffleId, mapId, attemptNumber, numMappers);
|
||||
shuffleClient.mapperEnd(shuffleId, mapId, attemptNumber, numMappers, numPartitions);
|
||||
} catch (InterruptedException e) {
|
||||
throw new IOInterruptedException(e);
|
||||
}
|
||||
|
||||
@ -115,11 +115,17 @@ public class DummyShuffleClient extends ShuffleClient {
|
||||
public void pushMergedData(int shuffleId, int mapId, int attemptId) {}
|
||||
|
||||
@Override
|
||||
public void mapperEnd(int shuffleId, int mapId, int attemptId, int numMappers) {}
|
||||
public void mapperEnd(
|
||||
int shuffleId, int mapId, int attemptId, int numMappers, int numPartitions) {}
|
||||
|
||||
@Override
|
||||
public void readReducerPartitionEnd(
|
||||
int shuffleId, int partitionId, int startMapIndex, int endMapIndex, int crc32, long bytes)
|
||||
throws IOException {}
|
||||
|
||||
@Override
|
||||
public void mapPartitionMapperEnd(
|
||||
int shuffleId, int mapId, int attemptId, int numMappers, int partitionId)
|
||||
int shuffleId, int mapId, int attemptId, int numMappers, int numPartitions, int partitionId)
|
||||
throws IOException {}
|
||||
|
||||
@Override
|
||||
|
||||
@ -203,13 +203,19 @@ public abstract class ShuffleClient {
|
||||
|
||||
public abstract void pushMergedData(int shuffleId, int mapId, int attemptId) throws IOException;
|
||||
|
||||
// Report partition locations written by the completed map task of ReducePartition Shuffle Type
|
||||
public abstract void mapperEnd(int shuffleId, int mapId, int attemptId, int numMappers)
|
||||
// Report partition locations written by the completed map task of ReducePartition Shuffle Type.
|
||||
public abstract void mapperEnd(
|
||||
int shuffleId, int mapId, int attemptId, int numMappers, int numPartitions)
|
||||
throws IOException;
|
||||
|
||||
// Report partition locations written by the completed map task of MapPartition Shuffle Type
|
||||
public abstract void readReducerPartitionEnd(
|
||||
int shuffleId, int partitionId, int startMapIndex, int endMapIndex, int crc32, long bytes)
|
||||
throws IOException;
|
||||
|
||||
// Report partition locations written by the completed map task of MapPartition Shuffle Type.
|
||||
public abstract void mapPartitionMapperEnd(
|
||||
int shuffleId, int mapId, int attemptId, int numMappers, int partitionId) throws IOException;
|
||||
int shuffleId, int mapId, int attemptId, int numMappers, int numPartitions, int partitionId)
|
||||
throws IOException;
|
||||
|
||||
// Cleanup states of the map task
|
||||
public abstract void cleanup(int shuffleId, int mapId, int attemptId);
|
||||
|
||||
@ -50,10 +50,7 @@ import org.apache.celeborn.common.identity.UserIdentifier;
|
||||
import org.apache.celeborn.common.metrics.source.Role;
|
||||
import org.apache.celeborn.common.network.TransportContext;
|
||||
import org.apache.celeborn.common.network.buffer.NettyManagedBuffer;
|
||||
import org.apache.celeborn.common.network.client.RpcResponseCallback;
|
||||
import org.apache.celeborn.common.network.client.TransportClient;
|
||||
import org.apache.celeborn.common.network.client.TransportClientBootstrap;
|
||||
import org.apache.celeborn.common.network.client.TransportClientFactory;
|
||||
import org.apache.celeborn.common.network.client.*;
|
||||
import org.apache.celeborn.common.network.protocol.*;
|
||||
import org.apache.celeborn.common.network.protocol.SerdeVersion;
|
||||
import org.apache.celeborn.common.network.sasl.SaslClientBootstrap;
|
||||
@ -122,6 +119,8 @@ public class ShuffleClientImpl extends ShuffleClient {
|
||||
|
||||
private final boolean pushExcludeWorkerOnFailureEnabled;
|
||||
private final boolean shuffleCompressionEnabled;
|
||||
private final boolean shuffleIntegrityCheckEnabled;
|
||||
|
||||
private final Set<String> pushExcludedWorkers = ConcurrentHashMap.newKeySet();
|
||||
private final ConcurrentHashMap<String, Long> fetchExcludedWorkers =
|
||||
JavaUtils.newConcurrentHashMap();
|
||||
@ -203,6 +202,7 @@ public class ShuffleClientImpl extends ShuffleClient {
|
||||
shuffleCompressionEnabled = !conf.shuffleCompressionCodec().equals(CompressionCodec.NONE);
|
||||
pushReplicateEnabled = conf.clientPushReplicateEnabled();
|
||||
fetchExcludeWorkerOnFailureEnabled = conf.clientFetchExcludeWorkerOnFailureEnabled();
|
||||
shuffleIntegrityCheckEnabled = conf.clientShuffleIntegrityCheckEnabled();
|
||||
if (conf.clientPushReplicateEnabled()) {
|
||||
pushDataTimeout = conf.pushDataTimeoutMs() * 2;
|
||||
} else {
|
||||
@ -648,6 +648,35 @@ public class ShuffleClientImpl extends ShuffleClient {
|
||||
});
|
||||
}
|
||||
|
||||
@Override
|
||||
public void readReducerPartitionEnd(
|
||||
int shuffleId,
|
||||
int partitionId,
|
||||
int startMapIndex,
|
||||
int endMapIndex,
|
||||
int crc32,
|
||||
long bytesWritten)
|
||||
throws IOException {
|
||||
PbReadReducerPartitionEnd pbReadReducerPartitionEnd =
|
||||
PbReadReducerPartitionEnd.newBuilder()
|
||||
.setShuffleId(shuffleId)
|
||||
.setPartitionId(partitionId)
|
||||
.setStartMaxIndex(startMapIndex)
|
||||
.setEndMapIndex(endMapIndex)
|
||||
.setCrc32(crc32)
|
||||
.setBytesWritten(bytesWritten)
|
||||
.build();
|
||||
|
||||
PbReadReducerPartitionEndResponse pbReducerPartitionEndResponse =
|
||||
lifecycleManagerRef.askSync(
|
||||
pbReadReducerPartitionEnd,
|
||||
conf.clientRpcRegisterShuffleAskTimeout(),
|
||||
ClassTag$.MODULE$.apply(PbReadReducerPartitionEndResponse.class));
|
||||
if (pbReducerPartitionEndResponse.getStatus() != StatusCode.SUCCESS.getValue()) {
|
||||
throw new CelebornIOException(pbReducerPartitionEndResponse.getErrorMsg());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean reportShuffleFetchFailure(int appShuffleId, int shuffleId, long taskId) {
|
||||
PbReportShuffleFetchFailure pbReportShuffleFetchFailure =
|
||||
@ -1012,6 +1041,12 @@ public class ShuffleClientImpl extends ShuffleClient {
|
||||
// increment batchId
|
||||
final int nextBatchId = pushState.nextBatchId();
|
||||
|
||||
// Track commit metadata if shuffle compression and integrity check are enabled and this request
|
||||
// is not for pushing metadata itself.
|
||||
if (shuffleIntegrityCheckEnabled) {
|
||||
pushState.addDataWithOffsetAndLength(partitionId, data, offset, length);
|
||||
}
|
||||
|
||||
if (shuffleCompressionEnabled && !skipCompress) {
|
||||
// compress data
|
||||
final Compressor compressor = compressorThreadLocal.get();
|
||||
@ -1725,19 +1760,25 @@ public class ShuffleClientImpl extends ShuffleClient {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void mapperEnd(int shuffleId, int mapId, int attemptId, int numMappers)
|
||||
public void mapperEnd(int shuffleId, int mapId, int attemptId, int numMappers, int numPartitions)
|
||||
throws IOException {
|
||||
mapEndInternal(shuffleId, mapId, attemptId, numMappers, -1);
|
||||
mapEndInternal(shuffleId, mapId, attemptId, numMappers, numPartitions, -1);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void mapPartitionMapperEnd(
|
||||
int shuffleId, int mapId, int attemptId, int numMappers, int partitionId) throws IOException {
|
||||
mapEndInternal(shuffleId, mapId, attemptId, numMappers, partitionId);
|
||||
int shuffleId, int mapId, int attemptId, int numMappers, int numPartitions, int partitionId)
|
||||
throws IOException {
|
||||
mapEndInternal(shuffleId, mapId, attemptId, numMappers, numPartitions, partitionId);
|
||||
}
|
||||
|
||||
private void mapEndInternal(
|
||||
int shuffleId, int mapId, int attemptId, int numMappers, Integer partitionId)
|
||||
int shuffleId,
|
||||
int mapId,
|
||||
int attemptId,
|
||||
int numMappers,
|
||||
int numPartitions,
|
||||
Integer partitionId)
|
||||
throws IOException {
|
||||
final String mapKey = Utils.makeMapKey(shuffleId, mapId, attemptId);
|
||||
PushState pushState = getPushState(mapKey);
|
||||
@ -1745,6 +1786,12 @@ public class ShuffleClientImpl extends ShuffleClient {
|
||||
try {
|
||||
limitZeroInFlight(mapKey, pushState);
|
||||
|
||||
// send CRC32 and num bytes per partition if e2e checks are enabled
|
||||
int[] crc32PerPartition =
|
||||
pushState.getCRC32PerPartition(shuffleIntegrityCheckEnabled, numPartitions);
|
||||
long[] bytesPerPartition =
|
||||
pushState.getBytesWrittenPerPartition(shuffleIntegrityCheckEnabled, numPartitions);
|
||||
|
||||
MapperEndResponse response =
|
||||
lifecycleManagerRef.askSync(
|
||||
new MapperEnd(
|
||||
@ -1753,7 +1800,10 @@ public class ShuffleClientImpl extends ShuffleClient {
|
||||
attemptId,
|
||||
numMappers,
|
||||
partitionId,
|
||||
pushState.getFailedBatches()),
|
||||
pushState.getFailedBatches(),
|
||||
numPartitions,
|
||||
crc32PerPartition,
|
||||
bytesPerPartition),
|
||||
rpcMaxRetries,
|
||||
rpcRetryWait,
|
||||
ClassTag$.MODULE$.apply(MapperEndResponse.class));
|
||||
|
||||
@ -40,6 +40,7 @@ import org.apache.celeborn.client.ShuffleClient;
|
||||
import org.apache.celeborn.client.compress.Decompressor;
|
||||
import org.apache.celeborn.client.read.checkpoint.PartitionReaderCheckpointMetadata;
|
||||
import org.apache.celeborn.common.CelebornConf;
|
||||
import org.apache.celeborn.common.CommitMetadata;
|
||||
import org.apache.celeborn.common.exception.CelebornIOException;
|
||||
import org.apache.celeborn.common.network.client.TransportClient;
|
||||
import org.apache.celeborn.common.network.client.TransportClientFactory;
|
||||
@ -103,7 +104,9 @@ public abstract class CelebornInputStream extends InputStream {
|
||||
exceptionMaker,
|
||||
true,
|
||||
metricsCallback,
|
||||
needDecompress);
|
||||
needDecompress,
|
||||
startMapIndex,
|
||||
endMapIndex);
|
||||
} else {
|
||||
return new CelebornInputStreamImpl(
|
||||
conf,
|
||||
@ -126,7 +129,9 @@ public abstract class CelebornInputStream extends InputStream {
|
||||
exceptionMaker,
|
||||
false,
|
||||
metricsCallback,
|
||||
needDecompress);
|
||||
needDecompress,
|
||||
-1,
|
||||
-1);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -168,6 +173,8 @@ public abstract class CelebornInputStream extends InputStream {
|
||||
private final CelebornConf conf;
|
||||
private final TransportClientFactory clientFactory;
|
||||
private final String shuffleKey;
|
||||
private final int numberOfSubPartitions;
|
||||
private final int currentIndexOfSubPartition;
|
||||
private ArrayList<PartitionLocation> locations;
|
||||
private ArrayList<PbStreamHandler> streamHandlers;
|
||||
private int[] attempts;
|
||||
@ -206,6 +213,7 @@ public abstract class CelebornInputStream extends InputStream {
|
||||
private final String localHostAddress;
|
||||
|
||||
private boolean shouldDecompress;
|
||||
private boolean shuffleIntegrityCheckEnabled;
|
||||
private long fetchExcludedWorkerExpireTimeout;
|
||||
private ConcurrentHashMap<String, Long> fetchExcludedWorkers;
|
||||
|
||||
@ -216,6 +224,8 @@ public abstract class CelebornInputStream extends InputStream {
|
||||
private int partitionId;
|
||||
private ExceptionMaker exceptionMaker;
|
||||
private boolean closed = false;
|
||||
private boolean integrityChecked = false;
|
||||
private final CommitMetadata aggregatedActualCommitMetadata = new CommitMetadata();
|
||||
|
||||
private final boolean readSkewPartitionWithoutMapRange;
|
||||
|
||||
@ -238,7 +248,9 @@ public abstract class CelebornInputStream extends InputStream {
|
||||
ExceptionMaker exceptionMaker,
|
||||
boolean splitSkewPartitionWithoutMapRange,
|
||||
MetricsCallback metricsCallback,
|
||||
boolean needDecompress)
|
||||
boolean needDecompress,
|
||||
int numberOfSubPartitions,
|
||||
int currentIndexOfSubPartition)
|
||||
throws IOException {
|
||||
this(
|
||||
conf,
|
||||
@ -261,7 +273,9 @@ public abstract class CelebornInputStream extends InputStream {
|
||||
exceptionMaker,
|
||||
splitSkewPartitionWithoutMapRange,
|
||||
metricsCallback,
|
||||
needDecompress);
|
||||
needDecompress,
|
||||
numberOfSubPartitions,
|
||||
currentIndexOfSubPartition);
|
||||
}
|
||||
|
||||
CelebornInputStreamImpl(
|
||||
@ -285,7 +299,9 @@ public abstract class CelebornInputStream extends InputStream {
|
||||
ExceptionMaker exceptionMaker,
|
||||
boolean readSkewPartitionWithoutMapRange,
|
||||
MetricsCallback metricsCallback,
|
||||
boolean needDecompress)
|
||||
boolean needDecompress,
|
||||
int numberOfSubPartitions,
|
||||
int currentIndexOfSubPartition)
|
||||
throws IOException {
|
||||
this.conf = conf;
|
||||
this.clientFactory = clientFactory;
|
||||
@ -305,9 +321,12 @@ public abstract class CelebornInputStream extends InputStream {
|
||||
this.localHostAddress = Utils.localHostName(conf);
|
||||
this.shouldDecompress =
|
||||
!conf.shuffleCompressionCodec().equals(CompressionCodec.NONE) && needDecompress;
|
||||
this.shuffleIntegrityCheckEnabled = conf.clientShuffleIntegrityCheckEnabled();
|
||||
this.fetchExcludedWorkerExpireTimeout = conf.clientFetchExcludedWorkerExpireTimeout();
|
||||
this.failedBatches = failedBatchSet;
|
||||
this.readSkewPartitionWithoutMapRange = readSkewPartitionWithoutMapRange;
|
||||
this.numberOfSubPartitions = numberOfSubPartitions;
|
||||
this.currentIndexOfSubPartition = currentIndexOfSubPartition;
|
||||
this.fetchExcludedWorkers = fetchExcludedWorkers;
|
||||
|
||||
if (conf.clientPushReplicateEnabled()) {
|
||||
@ -721,6 +740,36 @@ public abstract class CelebornInputStream extends InputStream {
|
||||
}
|
||||
}
|
||||
|
||||
void validateIntegrity() throws IOException {
|
||||
if (integrityChecked || !shuffleIntegrityCheckEnabled) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (readSkewPartitionWithoutMapRange) {
|
||||
shuffleClient.readReducerPartitionEnd(
|
||||
shuffleId,
|
||||
partitionId,
|
||||
numberOfSubPartitions,
|
||||
currentIndexOfSubPartition,
|
||||
aggregatedActualCommitMetadata.getChecksum(),
|
||||
aggregatedActualCommitMetadata.getBytes());
|
||||
} else {
|
||||
shuffleClient.readReducerPartitionEnd(
|
||||
shuffleId,
|
||||
partitionId,
|
||||
startMapIndex,
|
||||
endMapIndex,
|
||||
aggregatedActualCommitMetadata.getChecksum(),
|
||||
aggregatedActualCommitMetadata.getBytes());
|
||||
}
|
||||
logger.info(
|
||||
"reducerPartitionEnd successful for shuffleId{}, partitionId{}. actual CommitMetadata: {}",
|
||||
shuffleId,
|
||||
partitionId,
|
||||
aggregatedActualCommitMetadata);
|
||||
integrityChecked = true;
|
||||
}
|
||||
|
||||
private boolean moveToNextChunk() throws IOException {
|
||||
if (currentChunk != null) {
|
||||
currentChunk.release();
|
||||
@ -760,6 +809,7 @@ public abstract class CelebornInputStream extends InputStream {
|
||||
firstChunk = false;
|
||||
}
|
||||
if (currentChunk == null) {
|
||||
validateIntegrity();
|
||||
close();
|
||||
return false;
|
||||
}
|
||||
@ -821,6 +871,9 @@ public abstract class CelebornInputStream extends InputStream {
|
||||
} else {
|
||||
limit = size;
|
||||
}
|
||||
if (shuffleIntegrityCheckEnabled) {
|
||||
aggregatedActualCommitMetadata.addDataWithOffsetAndLength(rawDataBuf, 0, limit);
|
||||
}
|
||||
position = 0;
|
||||
hasData = true;
|
||||
break;
|
||||
@ -835,6 +888,10 @@ public abstract class CelebornInputStream extends InputStream {
|
||||
}
|
||||
}
|
||||
|
||||
if (!hasData) {
|
||||
validateIntegrity();
|
||||
// TODO(gaurav): consider closing the stream
|
||||
}
|
||||
return hasData;
|
||||
} catch (LZ4Exception | ZstdException | IOException e) {
|
||||
logger.error(
|
||||
|
||||
@ -31,7 +31,7 @@ import org.apache.celeborn.client.CommitManager.CommittedPartitionInfo
|
||||
import org.apache.celeborn.client.LifecycleManager.ShuffleFailedWorkers
|
||||
import org.apache.celeborn.client.commit.{CommitFilesParam, CommitHandler, MapPartitionCommitHandler, ReducePartitionCommitHandler}
|
||||
import org.apache.celeborn.client.listener.{WorkersStatus, WorkerStatusListener}
|
||||
import org.apache.celeborn.common.CelebornConf
|
||||
import org.apache.celeborn.common.{CelebornConf, CommitMetadata}
|
||||
import org.apache.celeborn.common.internal.Logging
|
||||
import org.apache.celeborn.common.meta.WorkerInfo
|
||||
import org.apache.celeborn.common.network.protocol.SerdeVersion
|
||||
@ -178,7 +178,8 @@ class CommitManager(appUniqueId: String, val conf: CelebornConf, lifecycleManage
|
||||
def registerShuffle(
|
||||
shuffleId: Int,
|
||||
numMappers: Int,
|
||||
isSegmentGranularityVisible: Boolean): Unit = {
|
||||
isSegmentGranularityVisible: Boolean,
|
||||
numPartitions: Int): Unit = {
|
||||
committedPartitionInfo.put(
|
||||
shuffleId,
|
||||
ShuffleCommittedInfo(
|
||||
@ -198,7 +199,8 @@ class CommitManager(appUniqueId: String, val conf: CelebornConf, lifecycleManage
|
||||
getCommitHandler(shuffleId).registerShuffle(
|
||||
shuffleId,
|
||||
numMappers,
|
||||
isSegmentGranularityVisible);
|
||||
isSegmentGranularityVisible,
|
||||
numPartitions)
|
||||
}
|
||||
|
||||
def isSegmentGranularityVisible(shuffleId: Int): Boolean = {
|
||||
@ -219,8 +221,10 @@ class CommitManager(appUniqueId: String, val conf: CelebornConf, lifecycleManage
|
||||
attemptId: Int,
|
||||
numMappers: Int,
|
||||
partitionId: Int = -1,
|
||||
pushFailedBatches: util.Map[String, LocationPushFailedBatches] = Collections.emptyMap())
|
||||
: (Boolean, Boolean) = {
|
||||
pushFailedBatches: util.Map[String, LocationPushFailedBatches] = Collections.emptyMap(),
|
||||
numPartitions: Int = -1,
|
||||
crc32PerPartition: Array[Int] = new Array[Int](0),
|
||||
bytesWrittenPerPartition: Array[Long] = new Array[Long](0)): (Boolean, Boolean) = {
|
||||
getCommitHandler(shuffleId).finishMapperAttempt(
|
||||
shuffleId,
|
||||
mapId,
|
||||
@ -228,7 +232,10 @@ class CommitManager(appUniqueId: String, val conf: CelebornConf, lifecycleManage
|
||||
numMappers,
|
||||
partitionId,
|
||||
pushFailedBatches,
|
||||
r => lifecycleManager.workerStatusTracker.recordWorkerFailure(r))
|
||||
r => lifecycleManager.workerStatusTracker.recordWorkerFailure(r),
|
||||
numPartitions,
|
||||
crc32PerPartition,
|
||||
bytesWrittenPerPartition)
|
||||
}
|
||||
|
||||
def releasePartitionResource(shuffleId: Int, partitionId: Int): Unit = {
|
||||
@ -350,4 +357,18 @@ class CommitManager(appUniqueId: String, val conf: CelebornConf, lifecycleManage
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def finishPartition(
|
||||
shuffleId: Int,
|
||||
partitionId: Int,
|
||||
startMapIndex: Int,
|
||||
endMapIndex: Int,
|
||||
actualCommitMetadata: CommitMetadata): (Boolean, String) = {
|
||||
getCommitHandler(shuffleId).finishPartition(
|
||||
shuffleId,
|
||||
partitionId,
|
||||
startMapIndex,
|
||||
endMapIndex,
|
||||
actualCommitMetadata)
|
||||
}
|
||||
}
|
||||
|
||||
@ -21,7 +21,7 @@ import java.lang.{Byte => JByte}
|
||||
import java.nio.ByteBuffer
|
||||
import java.security.SecureRandom
|
||||
import java.util
|
||||
import java.util.{function, List => JList}
|
||||
import java.util.{function, Collections, List => JList}
|
||||
import java.util.concurrent._
|
||||
import java.util.concurrent.atomic.{AtomicInteger, LongAdder}
|
||||
import java.util.function.{BiConsumer, BiFunction, Consumer}
|
||||
@ -40,7 +40,7 @@ import org.roaringbitmap.RoaringBitmap
|
||||
|
||||
import org.apache.celeborn.client.LifecycleManager.{ShuffleAllocatedWorkers, ShuffleFailedWorkers}
|
||||
import org.apache.celeborn.client.listener.WorkerStatusListener
|
||||
import org.apache.celeborn.common.CelebornConf
|
||||
import org.apache.celeborn.common.{CelebornConf, CommitMetadata}
|
||||
import org.apache.celeborn.common.CelebornConf.ACTIVE_STORAGE_TYPES
|
||||
import org.apache.celeborn.common.client.MasterClient
|
||||
import org.apache.celeborn.common.identity.{IdentityProvider, UserIdentifier}
|
||||
@ -423,13 +423,31 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
|
||||
oldPartition,
|
||||
isSegmentGranularityVisible = commitManager.isSegmentGranularityVisible(shuffleId))
|
||||
|
||||
case MapperEnd(shuffleId, mapId, attemptId, numMappers, partitionId, pushFailedBatch) =>
|
||||
case MapperEnd(
|
||||
shuffleId,
|
||||
mapId,
|
||||
attemptId,
|
||||
numMappers,
|
||||
partitionId,
|
||||
pushFailedBatch,
|
||||
numPartitions,
|
||||
crc32PerPartition,
|
||||
bytesWrittenPerPartition) =>
|
||||
logTrace(s"Received MapperEnd TaskEnd request, " +
|
||||
s"${Utils.makeMapKey(shuffleId, mapId, attemptId)}")
|
||||
val partitionType = getPartitionType(shuffleId)
|
||||
partitionType match {
|
||||
case PartitionType.REDUCE =>
|
||||
handleMapperEnd(context, shuffleId, mapId, attemptId, numMappers, pushFailedBatch)
|
||||
handleMapperEnd(
|
||||
context,
|
||||
shuffleId,
|
||||
mapId,
|
||||
attemptId,
|
||||
numMappers,
|
||||
pushFailedBatch,
|
||||
numPartitions,
|
||||
crc32PerPartition,
|
||||
bytesWrittenPerPartition)
|
||||
case PartitionType.MAP =>
|
||||
handleMapPartitionEnd(
|
||||
context,
|
||||
@ -442,6 +460,22 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
|
||||
throw new UnsupportedOperationException(s"Not support $partitionType yet")
|
||||
}
|
||||
|
||||
case pb: ReadReducerPartitionEnd =>
|
||||
val partitionType = getPartitionType(pb.shuffleId)
|
||||
partitionType match {
|
||||
case PartitionType.REDUCE =>
|
||||
handleReducerPartitionEnd(
|
||||
context,
|
||||
pb.shuffleId,
|
||||
pb.partitionId,
|
||||
pb.startMapIndex,
|
||||
pb.endMapIndex,
|
||||
pb.crc32,
|
||||
pb.bytesWritten)
|
||||
case _ =>
|
||||
throw new UnsupportedOperationException(s"Not support $partitionType yet")
|
||||
}
|
||||
|
||||
case GetReducerFileGroup(
|
||||
shuffleId: Int,
|
||||
isSegmentGranularityVisible: Boolean,
|
||||
@ -493,6 +527,32 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
|
||||
}
|
||||
}
|
||||
|
||||
private def handleReducerPartitionEnd(
|
||||
context: RpcCallContext,
|
||||
shuffleId: Int,
|
||||
partitionId: Int,
|
||||
startMapIndex: Int,
|
||||
endMapIndex: Int,
|
||||
crc32: Int,
|
||||
bytesWritten: Long): Unit = {
|
||||
val (isValid, errorMessage) = commitManager.finishPartition(
|
||||
shuffleId,
|
||||
partitionId,
|
||||
startMapIndex,
|
||||
endMapIndex,
|
||||
new CommitMetadata(crc32, bytesWritten))
|
||||
var response: PbReadReducerPartitionEndResponse = null
|
||||
if (isValid) {
|
||||
response =
|
||||
PbReadReducerPartitionEndResponse.newBuilder().setStatus(
|
||||
StatusCode.SUCCESS.getValue).build()
|
||||
} else {
|
||||
response = PbReadReducerPartitionEndResponse.newBuilder().setStatus(
|
||||
+StatusCode.READ_REDUCER_PARTITION_END_FAILED.getValue).setErrorMsg(errorMessage).build()
|
||||
}
|
||||
context.reply(response)
|
||||
}
|
||||
|
||||
def setupEndpoints(
|
||||
workers: util.Set[WorkerInfo],
|
||||
shuffleId: Int,
|
||||
@ -771,7 +831,8 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
|
||||
commitManager.registerShuffle(
|
||||
shuffleId,
|
||||
numMappers,
|
||||
isSegmentGranularityVisible)
|
||||
isSegmentGranularityVisible,
|
||||
numPartitions)
|
||||
|
||||
// Fifth, reply the allocated partition location to ShuffleClient.
|
||||
logInfo(s"Handle RegisterShuffle Success for $shuffleId.")
|
||||
@ -840,7 +901,10 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
|
||||
mapId: Int,
|
||||
attemptId: Int,
|
||||
numMappers: Int,
|
||||
pushFailedBatches: util.Map[String, LocationPushFailedBatches]): Unit = {
|
||||
pushFailedBatches: util.Map[String, LocationPushFailedBatches],
|
||||
numPartitions: Int,
|
||||
crc32PerPartition: Array[Int],
|
||||
bytesWrittenPerPartition: Array[Long]): Unit = {
|
||||
|
||||
val (mapperAttemptFinishedSuccess, allMapperFinished) =
|
||||
commitManager.finishMapperAttempt(
|
||||
@ -848,7 +912,10 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
|
||||
mapId,
|
||||
attemptId,
|
||||
numMappers,
|
||||
pushFailedBatches = pushFailedBatches)
|
||||
pushFailedBatches = pushFailedBatches,
|
||||
numPartitions = numPartitions,
|
||||
crc32PerPartition = crc32PerPartition,
|
||||
bytesWrittenPerPartition = bytesWrittenPerPartition)
|
||||
if (mapperAttemptFinishedSuccess && allMapperFinished) {
|
||||
// last mapper finished. call mapper end
|
||||
logInfo(s"Last MapperEnd, call StageEnd with shuffleKey:" +
|
||||
|
||||
@ -31,7 +31,7 @@ import scala.concurrent.duration.Duration
|
||||
import org.apache.celeborn.client.{ShuffleCommittedInfo, WorkerStatusTracker}
|
||||
import org.apache.celeborn.client.CommitManager.CommittedPartitionInfo
|
||||
import org.apache.celeborn.client.LifecycleManager.{ShuffleFailedWorkers, ShuffleFileGroups, ShufflePushFailedBatches}
|
||||
import org.apache.celeborn.common.CelebornConf
|
||||
import org.apache.celeborn.common.{CelebornConf, CommitMetadata}
|
||||
import org.apache.celeborn.common.internal.Logging
|
||||
import org.apache.celeborn.common.meta.{ShufflePartitionLocationInfo, WorkerInfo}
|
||||
import org.apache.celeborn.common.network.protocol.SerdeVersion
|
||||
@ -216,12 +216,16 @@ abstract class CommitHandler(
|
||||
numMappers: Int,
|
||||
partitionId: Int,
|
||||
pushFailedBatches: util.Map[String, LocationPushFailedBatches],
|
||||
recordWorkerFailure: ShuffleFailedWorkers => Unit): (Boolean, Boolean)
|
||||
recordWorkerFailure: ShuffleFailedWorkers => Unit,
|
||||
numPartitions: Int,
|
||||
crc32PerPartition: Array[Int],
|
||||
bytesWrittenPerPartition: Array[Long]): (Boolean, Boolean)
|
||||
|
||||
def registerShuffle(
|
||||
shuffleId: Int,
|
||||
numMappers: Int,
|
||||
isSegmentGranularityVisible: Boolean): Unit = {
|
||||
isSegmentGranularityVisible: Boolean,
|
||||
numPartitions: Int): Unit = {
|
||||
// TODO: if isSegmentGranularityVisible is set to true, it is necessary to handle the pending
|
||||
// get partition request of downstream reduce task here, in scenarios which support
|
||||
// downstream task start early before the upstream task, e.g. flink hybrid shuffle.
|
||||
@ -422,6 +426,16 @@ abstract class CommitHandler(
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Invoked when a reduce partition finishes reading data to perform end to end integrity check validation
|
||||
*/
|
||||
def finishPartition(
|
||||
shuffleId: Int,
|
||||
partitionId: Int,
|
||||
startMapIndex: Int,
|
||||
endMapIndex: Int,
|
||||
actualCommitMetadata: CommitMetadata): (Boolean, String)
|
||||
|
||||
def parallelCommitFiles(
|
||||
shuffleId: Int,
|
||||
allocatedWorkers: util.Map[String, ShufflePartitionLocationInfo],
|
||||
|
||||
@ -0,0 +1,189 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.celeborn.client.commit
|
||||
|
||||
import java.util
|
||||
import java.util.Comparator
|
||||
import java.util.function.BiFunction
|
||||
|
||||
import com.google.common.base.Preconditions.{checkArgument, checkState}
|
||||
|
||||
import org.apache.celeborn.common.CommitMetadata
|
||||
import org.apache.celeborn.common.internal.Logging
|
||||
import org.apache.celeborn.common.util.JavaUtils
|
||||
|
||||
class LegacySkewHandlingPartitionValidator extends AbstractPartitionCompletenessValidator
|
||||
with Logging {
|
||||
private val subRangeToCommitMetadataPerReducer = {
|
||||
JavaUtils.newConcurrentHashMap[Int, java.util.TreeMap[(Int, Int), CommitMetadata]]()
|
||||
}
|
||||
private val partitionToSubPartitionCount = JavaUtils.newConcurrentHashMap[Int, Int]()
|
||||
private val currentCommitMetadataForReducer =
|
||||
JavaUtils.newConcurrentHashMap[Int, CommitMetadata]()
|
||||
private val currentTotalMapIdCountForReducer = JavaUtils.newConcurrentHashMap[Int, Int]()
|
||||
private val comparator: java.util.Comparator[(Int, Int)] = new Comparator[(Int, Int)] {
|
||||
override def compare(o1: (Int, Int), o2: (Int, Int)): Int = {
|
||||
val comparator = Integer.compare(o1._1, o2._1)
|
||||
if (comparator != 0)
|
||||
comparator
|
||||
else
|
||||
Integer.compare(o1._2, o2._2)
|
||||
}
|
||||
}
|
||||
private val mapCountMergeBiFunction = new BiFunction[Int, Int, Int] {
|
||||
override def apply(t: Int, u: Int): Int =
|
||||
Integer.sum(t, u)
|
||||
}
|
||||
private val metadataMergeBiFunction: BiFunction[CommitMetadata, CommitMetadata, CommitMetadata] =
|
||||
new BiFunction[CommitMetadata, CommitMetadata, CommitMetadata] {
|
||||
override def apply(
|
||||
existing: CommitMetadata,
|
||||
incoming: CommitMetadata): CommitMetadata = {
|
||||
if (existing == null) {
|
||||
if (incoming != null) {
|
||||
return new CommitMetadata(incoming.getChecksum, incoming.getBytes)
|
||||
} else {
|
||||
return incoming
|
||||
}
|
||||
}
|
||||
if (incoming == null) {
|
||||
return existing
|
||||
}
|
||||
existing.addCommitData(incoming)
|
||||
existing
|
||||
}
|
||||
}
|
||||
|
||||
private def checkOverlappingRange(
|
||||
treeMap: java.util.TreeMap[(Int, Int), CommitMetadata],
|
||||
rangeKey: (Int, Int)): (Boolean, ((Int, Int), CommitMetadata)) = {
|
||||
val floorEntry: util.Map.Entry[(Int, Int), CommitMetadata] =
|
||||
treeMap.floorEntry(rangeKey)
|
||||
val ceilingEntry: util.Map.Entry[(Int, Int), CommitMetadata] =
|
||||
treeMap.ceilingEntry(rangeKey)
|
||||
|
||||
if (floorEntry != null) {
|
||||
if (rangeKey._1 < floorEntry.getKey._2) {
|
||||
return (true, ((floorEntry.getKey._1, floorEntry.getKey._2), floorEntry.getValue))
|
||||
}
|
||||
}
|
||||
if (ceilingEntry != null) {
|
||||
if (rangeKey._2 > ceilingEntry.getKey._1) {
|
||||
return (true, ((ceilingEntry.getKey._1, ceilingEntry.getKey._2), ceilingEntry.getValue))
|
||||
}
|
||||
}
|
||||
(false, ((0, 0), new CommitMetadata()))
|
||||
}
|
||||
|
||||
override def processSubPartition(
|
||||
partitionId: Int,
|
||||
startMapIndex: Int,
|
||||
endMapIndex: Int,
|
||||
actualCommitMetadata: CommitMetadata,
|
||||
expectedTotalMapperCount: Int): (Boolean, String) = {
|
||||
checkArgument(
|
||||
startMapIndex < endMapIndex,
|
||||
"startMapIndex %s must be less than endMapIndex %s",
|
||||
startMapIndex,
|
||||
endMapIndex)
|
||||
logDebug(
|
||||
s"Validate partition invoked for partitionId $partitionId startMapIndex $startMapIndex endMapIndex $endMapIndex")
|
||||
partitionToSubPartitionCount.put(partitionId, expectedTotalMapperCount)
|
||||
val subRangeToCommitMetadataMap = subRangeToCommitMetadataPerReducer.computeIfAbsent(
|
||||
partitionId,
|
||||
new java.util.function.Function[Int, util.TreeMap[(Int, Int), CommitMetadata]] {
|
||||
override def apply(key: Int): util.TreeMap[(Int, Int), CommitMetadata] =
|
||||
new util.TreeMap[(Int, Int), CommitMetadata](comparator)
|
||||
})
|
||||
val rangeKey = (startMapIndex, endMapIndex)
|
||||
subRangeToCommitMetadataMap.synchronized {
|
||||
val existingMetadata = subRangeToCommitMetadataMap.get(rangeKey)
|
||||
if (existingMetadata == null) {
|
||||
val (isOverlapping, overlappingEntry) =
|
||||
checkOverlappingRange(subRangeToCommitMetadataMap, rangeKey)
|
||||
if (isOverlapping) {
|
||||
val errorMessage = s"Encountered overlapping map range for partitionId: $partitionId " +
|
||||
s" while processing range with startMapIndex: $startMapIndex and endMapIndex: $endMapIndex " +
|
||||
s"existing range map: $subRangeToCommitMetadataMap " +
|
||||
s"overlapped with Entry((startMapIndex, endMapIndex), count): $overlappingEntry"
|
||||
logError(errorMessage)
|
||||
return (false, errorMessage)
|
||||
}
|
||||
|
||||
// Process new range
|
||||
subRangeToCommitMetadataMap.put(rangeKey, actualCommitMetadata)
|
||||
currentCommitMetadataForReducer.merge(
|
||||
partitionId,
|
||||
actualCommitMetadata,
|
||||
metadataMergeBiFunction)
|
||||
currentTotalMapIdCountForReducer.merge(
|
||||
partitionId,
|
||||
endMapIndex - startMapIndex,
|
||||
mapCountMergeBiFunction)
|
||||
} else if (existingMetadata != actualCommitMetadata) {
|
||||
val errorMessage = s"Commit Metadata for partition: $partitionId " +
|
||||
s"not matching for sub-partition with startMapIndex: $startMapIndex endMapIndex: $endMapIndex " +
|
||||
s"previous count: $existingMetadata new count: $actualCommitMetadata"
|
||||
logError(errorMessage)
|
||||
return (false, errorMessage)
|
||||
}
|
||||
|
||||
validateState(partitionId, startMapIndex, endMapIndex)
|
||||
val sumOfMapRanges: Int = currentTotalMapIdCountForReducer.get(partitionId)
|
||||
val currentCommitMetadata: CommitMetadata =
|
||||
currentCommitMetadataForReducer.get(partitionId)
|
||||
|
||||
if (sumOfMapRanges > expectedTotalMapperCount) {
|
||||
val errorMsg = s"AQE Partition $partitionId failed validation check " +
|
||||
s"while processing startMapIndex: $startMapIndex endMapIndex: $endMapIndex " +
|
||||
s"ActualCommitMetadata $currentCommitMetadata > ExpectedTotalMapperCount $expectedTotalMapperCount"
|
||||
logError(errorMsg)
|
||||
return (false, errorMsg)
|
||||
}
|
||||
}
|
||||
(true, "")
|
||||
}
|
||||
|
||||
private def validateState(partitionId: Int, startMapIndex: Int, endMapIndex: Int): Unit = {
|
||||
if (!currentTotalMapIdCountForReducer.containsKey(partitionId)) {
|
||||
checkState(!currentCommitMetadataForReducer.containsKey(
|
||||
partitionId,
|
||||
"mapper total count missing while processing partitionId %s startMapIndex %s endMapIndex %s",
|
||||
partitionId,
|
||||
startMapIndex,
|
||||
endMapIndex))
|
||||
currentTotalMapIdCountForReducer.put(partitionId, 0)
|
||||
currentCommitMetadataForReducer.put(partitionId, new CommitMetadata())
|
||||
}
|
||||
checkState(
|
||||
currentCommitMetadataForReducer.containsKey(partitionId),
|
||||
"mapper written count missing while processing partitionId %s startMapIndex %s endMapIndex %s",
|
||||
partitionId,
|
||||
startMapIndex,
|
||||
endMapIndex)
|
||||
}
|
||||
|
||||
override def currentCommitMetadata(partitionId: Int): CommitMetadata = {
|
||||
currentCommitMetadataForReducer.get(partitionId)
|
||||
}
|
||||
|
||||
override def isPartitionComplete(partitionId: Int): Boolean = {
|
||||
val sumOfMapRanges: Int = currentTotalMapIdCountForReducer.get(partitionId)
|
||||
sumOfMapRanges == partitionToSubPartitionCount.get(partitionId)
|
||||
}
|
||||
}
|
||||
@ -20,15 +20,17 @@ package org.apache.celeborn.client.commit
|
||||
import java.util
|
||||
import java.util.Collections
|
||||
import java.util.concurrent.{ConcurrentHashMap, ThreadPoolExecutor}
|
||||
import java.util.concurrent.atomic.AtomicInteger
|
||||
import java.util.concurrent.atomic.{AtomicInteger, AtomicIntegerArray}
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
import scala.collection.mutable
|
||||
|
||||
import org.roaringbitmap.RoaringBitmap
|
||||
|
||||
import org.apache.celeborn.client.{ShuffleCommittedInfo, WorkerStatusTracker}
|
||||
import org.apache.celeborn.client.CommitManager.CommittedPartitionInfo
|
||||
import org.apache.celeborn.client.LifecycleManager.{ShuffleAllocatedWorkers, ShuffleFailedWorkers}
|
||||
import org.apache.celeborn.common.CelebornConf
|
||||
import org.apache.celeborn.common.{CelebornConf, CommitMetadata}
|
||||
import org.apache.celeborn.common.internal.Logging
|
||||
import org.apache.celeborn.common.meta.{ShufflePartitionLocationInfo, WorkerInfo}
|
||||
import org.apache.celeborn.common.network.protocol.SerdeVersion
|
||||
@ -187,7 +189,10 @@ class MapPartitionCommitHandler(
|
||||
numMappers: Int,
|
||||
partitionId: Int,
|
||||
pushFailedBatches: util.Map[String, LocationPushFailedBatches],
|
||||
recordWorkerFailure: ShuffleFailedWorkers => Unit): (Boolean, Boolean) = {
|
||||
recordWorkerFailure: ShuffleFailedWorkers => Unit,
|
||||
numPartitions: Int,
|
||||
crc32PerPartition: Array[Int],
|
||||
bytesWrittenPerPartition: Array[Long]): (Boolean, Boolean) = {
|
||||
val inProcessingPartitionIds =
|
||||
inProcessMapPartitionEndIds.computeIfAbsent(
|
||||
shuffleId,
|
||||
@ -222,8 +227,9 @@ class MapPartitionCommitHandler(
|
||||
override def registerShuffle(
|
||||
shuffleId: Int,
|
||||
numMappers: Int,
|
||||
isSegmentGranularityVisible: Boolean): Unit = {
|
||||
super.registerShuffle(shuffleId, numMappers, isSegmentGranularityVisible)
|
||||
isSegmentGranularityVisible: Boolean,
|
||||
numPartitions: Int): Unit = {
|
||||
super.registerShuffle(shuffleId, numMappers, isSegmentGranularityVisible, numPartitions)
|
||||
shuffleIsSegmentGranularityVisible.put(shuffleId, isSegmentGranularityVisible)
|
||||
}
|
||||
|
||||
@ -231,6 +237,15 @@ class MapPartitionCommitHandler(
|
||||
shuffleIsSegmentGranularityVisible.get(shuffleId)
|
||||
}
|
||||
|
||||
override def finishPartition(
|
||||
shuffleId: Int,
|
||||
partitionId: Int,
|
||||
startMapIndex: Int,
|
||||
endMapIndex: Int,
|
||||
actualCommitMetadata: CommitMetadata): (Boolean, String) = {
|
||||
throw new UnsupportedOperationException()
|
||||
}
|
||||
|
||||
override def handleGetReducerFileGroup(
|
||||
context: RpcCallContext,
|
||||
shuffleId: Int,
|
||||
|
||||
@ -0,0 +1,86 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.celeborn.client.commit
|
||||
|
||||
import org.apache.celeborn.common.CommitMetadata
|
||||
import org.apache.celeborn.common.internal.Logging
|
||||
|
||||
abstract class AbstractPartitionCompletenessValidator {
|
||||
def processSubPartition(
|
||||
partitionId: Int,
|
||||
startMapIndex: Int,
|
||||
endMapIndex: Int,
|
||||
actualCommitMetadata: CommitMetadata,
|
||||
expectedTotalMapperCount: Int): (Boolean, String)
|
||||
|
||||
def currentCommitMetadata(partitionId: Int): CommitMetadata
|
||||
|
||||
def isPartitionComplete(partitionId: Int): Boolean
|
||||
}
|
||||
|
||||
class PartitionCompletenessValidator extends Logging {
|
||||
|
||||
private val skewHandlingValidator: AbstractPartitionCompletenessValidator =
|
||||
new SkewHandlingWithoutMapRangeValidator
|
||||
private val legacySkewHandlingValidator: AbstractPartitionCompletenessValidator =
|
||||
new LegacySkewHandlingPartitionValidator
|
||||
|
||||
def validateSubPartition(
|
||||
partitionId: Int,
|
||||
startMapIndex: Int,
|
||||
endMapIndex: Int,
|
||||
actualCommitMetadata: CommitMetadata,
|
||||
expectedCommitMetadata: CommitMetadata,
|
||||
expectedTotalMapperCountForParent: Int,
|
||||
skewPartitionHandlingWithoutMapRange: Boolean): (Boolean, String) = {
|
||||
val validator =
|
||||
if (skewPartitionHandlingWithoutMapRange) {
|
||||
skewHandlingValidator
|
||||
} else {
|
||||
legacySkewHandlingValidator
|
||||
}
|
||||
|
||||
val (successfullyProcessed, error) = validator.processSubPartition(
|
||||
partitionId,
|
||||
startMapIndex,
|
||||
endMapIndex,
|
||||
actualCommitMetadata,
|
||||
expectedTotalMapperCountForParent)
|
||||
if (!successfullyProcessed) return (false, error)
|
||||
|
||||
if (!validator.isPartitionComplete(partitionId)) {
|
||||
return (true, "Partition is valid but still waiting for more data")
|
||||
}
|
||||
|
||||
val currentCommitMetadata = validator.currentCommitMetadata(partitionId)
|
||||
if (!CommitMetadata.checkCommitMetadata(expectedCommitMetadata, currentCommitMetadata)) {
|
||||
val errorMsg =
|
||||
s"AQE Partition $partitionId failed validation check" +
|
||||
s"while processing range startMapIndex: $startMapIndex endMapIndex: $endMapIndex" +
|
||||
s"ExpectedCommitMetadata $expectedCommitMetadata, " +
|
||||
s"ActualCommitMetadata $currentCommitMetadata, "
|
||||
logError(errorMsg)
|
||||
return (false, errorMsg)
|
||||
}
|
||||
logInfo(
|
||||
s"AQE Partition $partitionId completed validation check , " +
|
||||
s"expectedCommitMetadata $expectedCommitMetadata, " +
|
||||
s"actualCommitMetadata $currentCommitMetadata")
|
||||
(true, "Partition is complete")
|
||||
}
|
||||
}
|
||||
@ -25,13 +25,14 @@ import java.util.function
|
||||
import scala.collection.JavaConverters._
|
||||
import scala.collection.mutable
|
||||
|
||||
import com.google.common.base.Preconditions.checkState
|
||||
import com.google.common.cache.{Cache, CacheBuilder}
|
||||
import com.google.common.collect.Sets
|
||||
|
||||
import org.apache.celeborn.client.{ClientUtils, LifecycleManager, ShuffleCommittedInfo, WorkerStatusTracker}
|
||||
import org.apache.celeborn.client.CommitManager.CommittedPartitionInfo
|
||||
import org.apache.celeborn.client.LifecycleManager.{ShuffleAllocatedWorkers, ShuffleFailedWorkers}
|
||||
import org.apache.celeborn.common.CelebornConf
|
||||
import org.apache.celeborn.common.{CelebornConf, CommitMetadata}
|
||||
import org.apache.celeborn.common.internal.Logging
|
||||
import org.apache.celeborn.common.meta.ShufflePartitionLocationInfo
|
||||
import org.apache.celeborn.common.network.protocol.SerdeVersion
|
||||
@ -82,6 +83,13 @@ class ReducePartitionCommitHandler(
|
||||
private val rpcCacheConcurrencyLevel = conf.clientRpcCacheConcurrencyLevel
|
||||
private val rpcCacheExpireTime = conf.clientRpcCacheExpireTime
|
||||
|
||||
private val shuffleIntegrityCheckEnabled = conf.clientShuffleIntegrityCheckEnabled
|
||||
// partitionId-shuffleId -> number of mappers that have written to this reducer (partition + shuffle)
|
||||
private val commitMetadataForReducer =
|
||||
JavaUtils.newConcurrentHashMap[Integer, Array[CommitMetadata]]
|
||||
private val skewPartitionCompletenessValidator =
|
||||
JavaUtils.newConcurrentHashMap[Int, PartitionCompletenessValidator]()
|
||||
|
||||
private val getReducerFileGroupResponseBroadcastEnabled = conf.getReducerFileGroupBroadcastEnabled
|
||||
private val getReducerFileGroupResponseBroadcastMiniSize =
|
||||
conf.getReducerFileGroupBroadcastMiniSize
|
||||
@ -153,6 +161,8 @@ class ReducePartitionCommitHandler(
|
||||
stageEndShuffleSet.remove(shuffleId)
|
||||
inProcessStageEndShuffleSet.remove(shuffleId)
|
||||
shuffleMapperAttempts.remove(shuffleId)
|
||||
commitMetadataForReducer.remove(shuffleId)
|
||||
skewPartitionCompletenessValidator.remove(shuffleId)
|
||||
super.removeExpiredShuffle(shuffleId)
|
||||
}
|
||||
|
||||
@ -268,16 +278,20 @@ class ReducePartitionCommitHandler(
|
||||
numMappers: Int,
|
||||
partitionId: Int,
|
||||
pushFailedBatches: util.Map[String, LocationPushFailedBatches],
|
||||
recordWorkerFailure: ShuffleFailedWorkers => Unit): (Boolean, Boolean) = {
|
||||
shuffleMapperAttempts.synchronized {
|
||||
recordWorkerFailure: ShuffleFailedWorkers => Unit,
|
||||
numPartitions: Int,
|
||||
crc32PerPartition: Array[Int],
|
||||
bytesWrittenPerPartition: Array[Long]): (Boolean, Boolean) = {
|
||||
val (mapperAttemptFinishedSuccess, allMapperFinished) = shuffleMapperAttempts.synchronized {
|
||||
if (getMapperAttempts(shuffleId) == null) {
|
||||
logDebug(s"[handleMapperEnd] $shuffleId not registered, create one.")
|
||||
initMapperAttempts(shuffleId, numMappers)
|
||||
initMapperAttempts(shuffleId, numMappers, numPartitions)
|
||||
}
|
||||
|
||||
val attempts = shuffleMapperAttempts.get(shuffleId)
|
||||
if (attempts(mapId) < 0) {
|
||||
attempts(mapId) = attemptId
|
||||
|
||||
if (null != pushFailedBatches && !pushFailedBatches.isEmpty) {
|
||||
val pushFailedBatchesMap = shufflePushFailedBatches.computeIfAbsent(
|
||||
shuffleId,
|
||||
@ -296,24 +310,93 @@ class ReducePartitionCommitHandler(
|
||||
(false, false)
|
||||
}
|
||||
}
|
||||
if (shuffleIntegrityCheckEnabled && mapperAttemptFinishedSuccess) {
|
||||
val commitMetadataArray = commitMetadataForReducer.get(shuffleId)
|
||||
checkState(
|
||||
commitMetadataArray != null,
|
||||
"commitMetadataArray can only be null if shuffleId %s is not registered!",
|
||||
shuffleId)
|
||||
for (i <- 0 until numPartitions) {
|
||||
if (bytesWrittenPerPartition(i) != 0) {
|
||||
commitMetadataArray(i).addCommitData(
|
||||
crc32PerPartition(i),
|
||||
bytesWrittenPerPartition(i))
|
||||
}
|
||||
}
|
||||
}
|
||||
(mapperAttemptFinishedSuccess, allMapperFinished)
|
||||
}
|
||||
|
||||
override def registerShuffle(
|
||||
shuffleId: Int,
|
||||
numMappers: Int,
|
||||
isSegmentGranularityVisible: Boolean): Unit = {
|
||||
super.registerShuffle(shuffleId, numMappers, isSegmentGranularityVisible)
|
||||
isSegmentGranularityVisible: Boolean,
|
||||
numPartitions: Int): Unit = {
|
||||
super.registerShuffle(shuffleId, numMappers, isSegmentGranularityVisible, numPartitions)
|
||||
getReducerFileGroupRequest.put(shuffleId, new util.HashSet[MultiSerdeVersionRpcContext]())
|
||||
initMapperAttempts(shuffleId, numMappers)
|
||||
initMapperAttempts(shuffleId, numMappers, numPartitions)
|
||||
}
|
||||
|
||||
private def initMapperAttempts(shuffleId: Int, numMappers: Int): Unit = {
|
||||
override def finishPartition(
|
||||
shuffleId: Int,
|
||||
partitionId: Int,
|
||||
startMapIndex: Int,
|
||||
endMapIndex: Int,
|
||||
actualCommitMetadata: CommitMetadata): (Boolean, String) = {
|
||||
logDebug(s"finish Partition call: shuffleId: $shuffleId, " +
|
||||
s"partitionId: $partitionId, " +
|
||||
s"startMapIndex: $startMapIndex " +
|
||||
s"endMapIndex: $endMapIndex, " +
|
||||
s"actualCommitMetadata: $actualCommitMetadata")
|
||||
val map = commitMetadataForReducer.get(shuffleId)
|
||||
checkState(
|
||||
map != null,
|
||||
"CommitMetadata map cannot be null for a registered shuffleId: %d",
|
||||
shuffleId)
|
||||
val expectedCommitMetadata = map(partitionId)
|
||||
if (endMapIndex == Integer.MAX_VALUE) {
|
||||
// complete partition available
|
||||
val bool = CommitMetadata.checkCommitMetadata(actualCommitMetadata, expectedCommitMetadata)
|
||||
var message = ""
|
||||
if (!bool) {
|
||||
message =
|
||||
s"CommitMetadata mismatch for shuffleId: $shuffleId partitionId: $partitionId expected: $expectedCommitMetadata actual: $actualCommitMetadata"
|
||||
} else {
|
||||
logInfo(
|
||||
s"CommitMetadata matched for shuffleID : $shuffleId, partitionId: $partitionId expected: $expectedCommitMetadata actual: $actualCommitMetadata")
|
||||
}
|
||||
return (bool, message)
|
||||
}
|
||||
|
||||
val splitSkewPartitionWithoutMapRange =
|
||||
ClientUtils.readSkewPartitionWithoutMapRange(conf, startMapIndex, endMapIndex)
|
||||
|
||||
val validator = skewPartitionCompletenessValidator.computeIfAbsent(
|
||||
shuffleId,
|
||||
new java.util.function.Function[Int, PartitionCompletenessValidator] {
|
||||
override def apply(key: Int): PartitionCompletenessValidator =
|
||||
new PartitionCompletenessValidator()
|
||||
})
|
||||
validator.validateSubPartition(
|
||||
partitionId,
|
||||
startMapIndex,
|
||||
endMapIndex,
|
||||
actualCommitMetadata,
|
||||
expectedCommitMetadata,
|
||||
shuffleMapperAttempts.get(shuffleId).length,
|
||||
splitSkewPartitionWithoutMapRange)
|
||||
}
|
||||
|
||||
private def initMapperAttempts(shuffleId: Int, numMappers: Int, numPartitions: Int): Unit = {
|
||||
shuffleMapperAttempts.synchronized {
|
||||
if (!shuffleMapperAttempts.containsKey(shuffleId)) {
|
||||
val attempts = new Array[Int](numMappers)
|
||||
0 until numMappers foreach (idx => attempts(idx) = -1)
|
||||
shuffleMapperAttempts.put(shuffleId, attempts)
|
||||
}
|
||||
if (shuffleIntegrityCheckEnabled) {
|
||||
commitMetadataForReducer.put(shuffleId, Array.fill(numPartitions)(new CommitMetadata()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -0,0 +1,109 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.celeborn.client.commit
|
||||
|
||||
import java.util
|
||||
|
||||
import com.google.common.base.Preconditions.{checkArgument, checkState}
|
||||
|
||||
import org.apache.celeborn.common.CommitMetadata
|
||||
import org.apache.celeborn.common.util.JavaUtils
|
||||
|
||||
class SkewHandlingWithoutMapRangeValidator extends AbstractPartitionCompletenessValidator {
|
||||
|
||||
private val totalSubPartitionsProcessed =
|
||||
JavaUtils.newConcurrentHashMap[Int, util.HashMap[Int, CommitMetadata]]()
|
||||
private val partitionToSubPartitionCount = JavaUtils.newConcurrentHashMap[Int, Int]()
|
||||
private val currentCommitMetadataForReducer =
|
||||
JavaUtils.newConcurrentHashMap[Int, CommitMetadata]()
|
||||
|
||||
override def processSubPartition(
|
||||
partitionId: Int,
|
||||
startMapIndex: Int,
|
||||
endMapIndex: Int,
|
||||
actualCommitMetadata: CommitMetadata,
|
||||
expectedTotalMapperCount: Int): (Boolean, String) = {
|
||||
checkArgument(
|
||||
endMapIndex >= 0,
|
||||
"index of sub-partition %s must be greater than or equal to 0",
|
||||
endMapIndex)
|
||||
checkArgument(
|
||||
startMapIndex > endMapIndex,
|
||||
"startMapIndex %s must be greater than endMapIndex %s",
|
||||
startMapIndex,
|
||||
endMapIndex)
|
||||
totalSubPartitionsProcessed.synchronized {
|
||||
if (totalSubPartitionsProcessed.containsKey(partitionId)) {
|
||||
val currentSubPartitionCount = partitionToSubPartitionCount.getOrDefault(partitionId, -1)
|
||||
checkState(
|
||||
currentSubPartitionCount == startMapIndex,
|
||||
"total subpartition count mismatch for partitionId %s existing %s new %s",
|
||||
partitionId,
|
||||
currentSubPartitionCount,
|
||||
startMapIndex)
|
||||
} else {
|
||||
totalSubPartitionsProcessed.put(partitionId, new util.HashMap[Int, CommitMetadata]())
|
||||
partitionToSubPartitionCount.put(partitionId, startMapIndex)
|
||||
}
|
||||
val subPartitionsProcessed = totalSubPartitionsProcessed.get(partitionId)
|
||||
if (subPartitionsProcessed.containsKey(endMapIndex)) {
|
||||
// check if previous entry matches
|
||||
val existingCommitMetadata = subPartitionsProcessed.get(endMapIndex)
|
||||
if (existingCommitMetadata != actualCommitMetadata) {
|
||||
return (
|
||||
false,
|
||||
s"Mismatch in metadata for the same chunk range on retry: $endMapIndex existing: $existingCommitMetadata new: $actualCommitMetadata")
|
||||
}
|
||||
}
|
||||
subPartitionsProcessed.put(endMapIndex, actualCommitMetadata)
|
||||
val partitionProcessed = getTotalNumberOfSubPartitionsProcessed(partitionId)
|
||||
checkState(
|
||||
partitionProcessed <= startMapIndex,
|
||||
"Number of sub-partitions processed %s should less than total number of sub-partitions %s",
|
||||
partitionProcessed,
|
||||
startMapIndex)
|
||||
}
|
||||
updateCommitMetadata(partitionId, actualCommitMetadata)
|
||||
(true, "")
|
||||
}
|
||||
|
||||
private def updateCommitMetadata(partitionId: Int, actualCommitMetadata: CommitMetadata): Unit = {
|
||||
val currentCommitMetadata =
|
||||
currentCommitMetadataForReducer.computeIfAbsent(
|
||||
partitionId,
|
||||
new java.util.function.Function[Int, CommitMetadata] {
|
||||
override def apply(partitionId: Int): CommitMetadata = {
|
||||
new CommitMetadata()
|
||||
}
|
||||
})
|
||||
currentCommitMetadata.addCommitData(actualCommitMetadata)
|
||||
}
|
||||
|
||||
private def getTotalNumberOfSubPartitionsProcessed(partitionId: Int) = {
|
||||
totalSubPartitionsProcessed.get(partitionId).size()
|
||||
}
|
||||
|
||||
override def currentCommitMetadata(partitionId: Int): CommitMetadata = {
|
||||
currentCommitMetadataForReducer.get(partitionId)
|
||||
}
|
||||
|
||||
override def isPartitionComplete(partitionId: Int): Boolean = {
|
||||
getTotalNumberOfSubPartitionsProcessed(partitionId) == partitionToSubPartitionCount.get(
|
||||
partitionId)
|
||||
}
|
||||
}
|
||||
@ -18,8 +18,7 @@
|
||||
package org.apache.celeborn.client;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.anyLong;
|
||||
import static org.mockito.ArgumentMatchers.*;
|
||||
import static org.mockito.BDDMockito.willAnswer;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.when;
|
||||
@ -34,6 +33,9 @@ import java.util.concurrent.TimeUnit;
|
||||
import java.util.concurrent.TimeoutException;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
|
||||
import scala.reflect.ClassTag;
|
||||
import scala.reflect.ClassTag$;
|
||||
|
||||
import io.netty.channel.Channel;
|
||||
import io.netty.channel.ChannelFuture;
|
||||
import io.netty.util.concurrent.Future;
|
||||
@ -41,6 +43,7 @@ import io.netty.util.concurrent.GenericFutureListener;
|
||||
import org.apache.commons.lang3.RandomStringUtils;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
import org.mockito.ArgumentCaptor;
|
||||
|
||||
import org.apache.celeborn.client.compress.Compressor;
|
||||
import org.apache.celeborn.common.CelebornConf;
|
||||
@ -51,6 +54,8 @@ import org.apache.celeborn.common.network.client.TransportClientFactory;
|
||||
import org.apache.celeborn.common.network.protocol.SerdeVersion;
|
||||
import org.apache.celeborn.common.protocol.CompressionCodec;
|
||||
import org.apache.celeborn.common.protocol.PartitionLocation;
|
||||
import org.apache.celeborn.common.protocol.PbReadReducerPartitionEnd;
|
||||
import org.apache.celeborn.common.protocol.PbReadReducerPartitionEndResponse;
|
||||
import org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse$;
|
||||
import org.apache.celeborn.common.protocol.message.ControlMessages.RegisterShuffleResponse$;
|
||||
import org.apache.celeborn.common.protocol.message.StatusCode;
|
||||
@ -68,6 +73,7 @@ public class ShuffleClientSuiteJ {
|
||||
private static final int TEST_SHUFFLE_ID = 1;
|
||||
private static final int TEST_ATTEMPT_ID = 0;
|
||||
private static final int TEST_REDUCRE_ID = 0;
|
||||
private static final int TEST_MAP_ID = 0;
|
||||
|
||||
private static final int PRIMARY_RPC_PORT = 1234;
|
||||
private static final int PRIMARY_PUSH_PORT = 1235;
|
||||
@ -621,4 +627,86 @@ public class ShuffleClientSuiteJ {
|
||||
Exception exception = exceptionRef.get();
|
||||
Assert.assertTrue(exception.getCause() instanceof TimeoutException);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSuccessfulReadReducePartitionEnd() throws IOException {
|
||||
CelebornConf conf = new CelebornConf();
|
||||
shuffleClient =
|
||||
new ShuffleClientImpl(TEST_APPLICATION_ID, conf, new UserIdentifier("mock", "mock"));
|
||||
shuffleClient.setupLifecycleManagerRef(endpointRef);
|
||||
|
||||
ClassTag<PbReadReducerPartitionEndResponse> classTag =
|
||||
ClassTag$.MODULE$.apply(PbReadReducerPartitionEndResponse.class);
|
||||
PbReadReducerPartitionEndResponse mockResponse =
|
||||
PbReadReducerPartitionEndResponse.newBuilder()
|
||||
.setStatus(StatusCode.SUCCESS.getValue())
|
||||
.build();
|
||||
when(endpointRef.askSync(any(PbReadReducerPartitionEnd.class), any(), eq(classTag)))
|
||||
.thenReturn(mockResponse);
|
||||
|
||||
shuffleClient.readReducerPartitionEnd(1, 2, 3, 4, 12345, 10000L);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testFailedReadReducerPartitionEnd() throws IOException {
|
||||
CelebornConf conf = new CelebornConf();
|
||||
shuffleClient =
|
||||
new ShuffleClientImpl(TEST_APPLICATION_ID, conf, new UserIdentifier("mock", "mock"));
|
||||
shuffleClient.setupLifecycleManagerRef(endpointRef);
|
||||
ClassTag<PbReadReducerPartitionEndResponse> classTag =
|
||||
ClassTag$.MODULE$.apply(PbReadReducerPartitionEndResponse.class);
|
||||
|
||||
String errorMsg = "Test error message";
|
||||
PbReadReducerPartitionEndResponse mockResponse =
|
||||
PbReadReducerPartitionEndResponse.newBuilder()
|
||||
.setStatus(StatusCode.READ_REDUCER_PARTITION_END_FAILED.getValue())
|
||||
.setErrorMsg(errorMsg)
|
||||
.build();
|
||||
|
||||
when(endpointRef.askSync(any(PbReadReducerPartitionEnd.class), any(), eq(classTag)))
|
||||
.thenReturn(mockResponse);
|
||||
|
||||
try {
|
||||
shuffleClient.readReducerPartitionEnd(1, 2, 3, 4, 12345, 10000L);
|
||||
} catch (CelebornIOException e) {
|
||||
Assert.assertEquals(errorMsg, e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCorrectParametersPassedInRequest() throws IOException {
|
||||
CelebornConf conf = new CelebornConf();
|
||||
shuffleClient =
|
||||
new ShuffleClientImpl(TEST_APPLICATION_ID, conf, new UserIdentifier("mock", "mock"));
|
||||
shuffleClient.setupLifecycleManagerRef(endpointRef);
|
||||
|
||||
int shuffleId = 123;
|
||||
int partitionId = 456;
|
||||
int startMapIndex = 0;
|
||||
int endMapIndex = 10;
|
||||
int crc32 = 98765;
|
||||
long bytesWritten = 54321L;
|
||||
|
||||
PbReadReducerPartitionEndResponse mockResponse =
|
||||
PbReadReducerPartitionEndResponse.newBuilder()
|
||||
.setStatus(StatusCode.SUCCESS.getValue())
|
||||
.build();
|
||||
ArgumentCaptor<PbReadReducerPartitionEnd> requestCaptor =
|
||||
ArgumentCaptor.forClass(PbReadReducerPartitionEnd.class);
|
||||
ClassTag<PbReadReducerPartitionEndResponse> classTag =
|
||||
ClassTag$.MODULE$.apply(PbReadReducerPartitionEndResponse.class);
|
||||
|
||||
when(endpointRef.askSync(requestCaptor.capture(), any(), eq(classTag)))
|
||||
.thenReturn(mockResponse);
|
||||
|
||||
shuffleClient.readReducerPartitionEnd(
|
||||
shuffleId, partitionId, startMapIndex, endMapIndex, crc32, bytesWritten);
|
||||
PbReadReducerPartitionEnd capturedRequest = requestCaptor.getValue();
|
||||
assertEquals(shuffleId, capturedRequest.getShuffleId());
|
||||
assertEquals(partitionId, capturedRequest.getPartitionId());
|
||||
assertEquals(startMapIndex, capturedRequest.getStartMaxIndex());
|
||||
assertEquals(endMapIndex, capturedRequest.getEndMapIndex());
|
||||
assertEquals(crc32, capturedRequest.getCrc32());
|
||||
assertEquals(bytesWritten, capturedRequest.getBytesWritten());
|
||||
}
|
||||
}
|
||||
|
||||
@ -197,21 +197,29 @@ trait WithShuffleClientSuite extends CelebornFunSuite {
|
||||
}
|
||||
|
||||
private def registerAndFinishPartition(shuffleId: Int): Unit = {
|
||||
val numPartitions = 9
|
||||
shuffleClient.registerMapPartitionTask(shuffleId, numMappers, mapId, attemptId, 1)
|
||||
shuffleClient.registerMapPartitionTask(shuffleId, numMappers, mapId + 1, attemptId, 2)
|
||||
shuffleClient.registerMapPartitionTask(shuffleId, numMappers, mapId + 2, attemptId, 3)
|
||||
|
||||
// task number incr to numMappers + 1
|
||||
shuffleClient.registerMapPartitionTask(shuffleId, numMappers, mapId, attemptId + 1, 9)
|
||||
shuffleClient.mapPartitionMapperEnd(shuffleId, mapId, attemptId, numMappers, 1)
|
||||
shuffleClient.mapPartitionMapperEnd(shuffleId, mapId, attemptId, numMappers, numPartitions, 1)
|
||||
// another attempt
|
||||
shuffleClient.mapPartitionMapperEnd(
|
||||
shuffleId,
|
||||
mapId,
|
||||
attemptId + 1,
|
||||
numMappers,
|
||||
numPartitions,
|
||||
9)
|
||||
// another mapper
|
||||
shuffleClient.mapPartitionMapperEnd(shuffleId, mapId + 1, attemptId, numMappers, mapId + 1)
|
||||
shuffleClient.mapPartitionMapperEnd(
|
||||
shuffleId,
|
||||
mapId + 1,
|
||||
attemptId,
|
||||
numMappers,
|
||||
numPartitions,
|
||||
mapId + 1)
|
||||
}
|
||||
}
|
||||
|
||||
@ -0,0 +1,301 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.celeborn.client.commit
|
||||
|
||||
import org.scalatest.matchers.must.Matchers.{be, include}
|
||||
import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper
|
||||
|
||||
import org.apache.celeborn.CelebornFunSuite
|
||||
import org.apache.celeborn.common.CommitMetadata
|
||||
|
||||
class PartitionValidatorTest extends CelebornFunSuite {
|
||||
|
||||
var validator: PartitionCompletenessValidator = _
|
||||
|
||||
var mockCommitMetadata: CommitMetadata = new CommitMetadata()
|
||||
test("AQEPartitionCompletenessValidator should validate a new sub-partition correctly when there are no overlapping ranges") {
|
||||
validator = new PartitionCompletenessValidator
|
||||
val (isValid, message) =
|
||||
validator.validateSubPartition(1, 0, 10, mockCommitMetadata, mockCommitMetadata, 20, false)
|
||||
|
||||
isValid shouldBe (true)
|
||||
message shouldBe ("Partition is valid but still waiting for more data")
|
||||
}
|
||||
|
||||
test("AQEPartitionCompletenessValidator should fail validation for overlapping map ranges") {
|
||||
validator = new PartitionCompletenessValidator
|
||||
validator.validateSubPartition(
|
||||
1,
|
||||
0,
|
||||
10,
|
||||
mockCommitMetadata,
|
||||
mockCommitMetadata,
|
||||
20,
|
||||
false
|
||||
) // First call should add the range
|
||||
val (isValid, message) =
|
||||
validator.validateSubPartition(
|
||||
1,
|
||||
5,
|
||||
15,
|
||||
mockCommitMetadata,
|
||||
mockCommitMetadata,
|
||||
20,
|
||||
false
|
||||
) // This overlaps
|
||||
|
||||
isValid shouldBe (false)
|
||||
message should include("Encountered overlapping map range for partitionId: 1")
|
||||
}
|
||||
|
||||
test(
|
||||
"AQEPartitionCompletenessValidator should fail validation for overlapping map ranges- case 2") {
|
||||
validator = new PartitionCompletenessValidator
|
||||
validator.validateSubPartition(
|
||||
1,
|
||||
0,
|
||||
1,
|
||||
mockCommitMetadata,
|
||||
mockCommitMetadata,
|
||||
20,
|
||||
false
|
||||
) // First call should add the range
|
||||
validator.validateSubPartition(
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
mockCommitMetadata,
|
||||
mockCommitMetadata,
|
||||
20,
|
||||
false
|
||||
) // First call should add the range
|
||||
val (isValid, message) =
|
||||
validator.validateSubPartition(
|
||||
1,
|
||||
1,
|
||||
2,
|
||||
mockCommitMetadata,
|
||||
mockCommitMetadata,
|
||||
20,
|
||||
false
|
||||
) // This overlaps
|
||||
|
||||
isValid shouldBe (true)
|
||||
}
|
||||
|
||||
test("AQEPartitionCompletenessValidator should fail validation for overlapping map ranges - edge cases") {
|
||||
validator = new PartitionCompletenessValidator
|
||||
validator.validateSubPartition(
|
||||
1,
|
||||
0,
|
||||
10,
|
||||
mockCommitMetadata,
|
||||
mockCommitMetadata,
|
||||
20,
|
||||
false
|
||||
) // First call should add the range
|
||||
val (isValid, message) =
|
||||
validator.validateSubPartition(
|
||||
1,
|
||||
0,
|
||||
5,
|
||||
mockCommitMetadata,
|
||||
mockCommitMetadata,
|
||||
20,
|
||||
false
|
||||
) // This overlaps
|
||||
|
||||
isValid shouldBe (false)
|
||||
message should include("Encountered overlapping map range for partitionId: 1")
|
||||
}
|
||||
|
||||
test("AQEPartitionCompletenessValidator should fail validation for one map range subsuming another map range") {
|
||||
validator = new PartitionCompletenessValidator
|
||||
validator.validateSubPartition(
|
||||
1,
|
||||
5,
|
||||
10,
|
||||
mockCommitMetadata,
|
||||
mockCommitMetadata,
|
||||
20,
|
||||
false
|
||||
) // First call should add the range
|
||||
val (isValid, message) =
|
||||
validator.validateSubPartition(
|
||||
1,
|
||||
0,
|
||||
15,
|
||||
mockCommitMetadata,
|
||||
mockCommitMetadata,
|
||||
20,
|
||||
false
|
||||
) // This overlaps
|
||||
|
||||
isValid shouldBe (false)
|
||||
message should include("Encountered overlapping map range for partitionId: 1")
|
||||
}
|
||||
|
||||
test("AQEPartitionCompletenessValidator should fail validation for one map range subsuming another map range - 2") {
|
||||
validator = new PartitionCompletenessValidator
|
||||
validator.validateSubPartition(
|
||||
1,
|
||||
0,
|
||||
15,
|
||||
mockCommitMetadata,
|
||||
mockCommitMetadata,
|
||||
20,
|
||||
false
|
||||
) // First call should add the range
|
||||
val (isValid, message) =
|
||||
validator.validateSubPartition(
|
||||
1,
|
||||
5,
|
||||
10,
|
||||
mockCommitMetadata,
|
||||
mockCommitMetadata,
|
||||
20,
|
||||
false
|
||||
) // This overlaps
|
||||
|
||||
isValid shouldBe (false)
|
||||
message should include("Encountered overlapping map range for partitionId: 1")
|
||||
}
|
||||
|
||||
test(
|
||||
"AQEPartitionCompletenessValidator should fail validation if commit metadata does not match") {
|
||||
val expectedCommitMetadata = new CommitMetadata()
|
||||
val actuaCommitMetadata = new CommitMetadata()
|
||||
validator = new PartitionCompletenessValidator
|
||||
validator.validateSubPartition(
|
||||
1,
|
||||
0,
|
||||
10,
|
||||
actuaCommitMetadata,
|
||||
expectedCommitMetadata,
|
||||
20,
|
||||
false
|
||||
) // Write first partition
|
||||
|
||||
// Second one with a different count
|
||||
val (isValid, message) =
|
||||
validator.validateSubPartition(
|
||||
1,
|
||||
0,
|
||||
10,
|
||||
new CommitMetadata(3, 3),
|
||||
expectedCommitMetadata,
|
||||
20,
|
||||
false)
|
||||
|
||||
isValid should be(false)
|
||||
message should include("Commit Metadata for partition: 1 not matching for sub-partition")
|
||||
}
|
||||
|
||||
test("pass validation if written counts are correct after all updates") {
|
||||
validator = new PartitionCompletenessValidator
|
||||
val expectedCommitMetadata = new CommitMetadata(0, 10)
|
||||
validator.validateSubPartition(
|
||||
1,
|
||||
0,
|
||||
10,
|
||||
new CommitMetadata(0, 5),
|
||||
expectedCommitMetadata,
|
||||
20,
|
||||
false)
|
||||
val (isValid, message) = validator.validateSubPartition(
|
||||
1,
|
||||
10,
|
||||
20,
|
||||
new CommitMetadata(0, 5),
|
||||
expectedCommitMetadata,
|
||||
20,
|
||||
false)
|
||||
|
||||
isValid should be(true)
|
||||
message should be("Partition is complete")
|
||||
}
|
||||
|
||||
test("handle multiple partitions correctly") {
|
||||
validator = new PartitionCompletenessValidator
|
||||
// Testing with multiple partitions to check isolation
|
||||
val expectedCommitMetadataForPartition1 = new CommitMetadata(0, 10)
|
||||
val expectedCommitMetadataForPartition2 = new CommitMetadata(0, 2)
|
||||
validator.validateSubPartition(
|
||||
1,
|
||||
0,
|
||||
10,
|
||||
new CommitMetadata(0, 5),
|
||||
expectedCommitMetadataForPartition1,
|
||||
20,
|
||||
false
|
||||
) // Validate partition 1
|
||||
validator.validateSubPartition(
|
||||
2,
|
||||
0,
|
||||
5,
|
||||
new CommitMetadata(0, 2),
|
||||
expectedCommitMetadataForPartition2,
|
||||
10,
|
||||
false
|
||||
) // Validate partition 2
|
||||
|
||||
val (isValid1, message1) = validator.validateSubPartition(
|
||||
1,
|
||||
0,
|
||||
10,
|
||||
new CommitMetadata(0, 5),
|
||||
expectedCommitMetadataForPartition1,
|
||||
20,
|
||||
false)
|
||||
val (isValid2, message2) = validator.validateSubPartition(
|
||||
2,
|
||||
0,
|
||||
5,
|
||||
new CommitMetadata(0, 2),
|
||||
expectedCommitMetadataForPartition2,
|
||||
10,
|
||||
false)
|
||||
|
||||
isValid1 should be(true)
|
||||
isValid2 should be(true)
|
||||
message1 should be("Partition is valid but still waiting for more data")
|
||||
message2 should be("Partition is valid but still waiting for more data")
|
||||
|
||||
val (isValid3, message3) = validator.validateSubPartition(
|
||||
1,
|
||||
10,
|
||||
20,
|
||||
new CommitMetadata(0, 5),
|
||||
expectedCommitMetadataForPartition1,
|
||||
20,
|
||||
false)
|
||||
val (isValid4, message4) = validator.validateSubPartition(
|
||||
2,
|
||||
5,
|
||||
10,
|
||||
new CommitMetadata(),
|
||||
expectedCommitMetadataForPartition2,
|
||||
10,
|
||||
false)
|
||||
isValid3 should be(true)
|
||||
isValid4 should be(true)
|
||||
message3 should be("Partition is complete")
|
||||
message4 should be("Partition is complete")
|
||||
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,251 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.celeborn.client.commit
|
||||
|
||||
import org.scalatest.matchers.must.Matchers.{be, include}
|
||||
import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper
|
||||
|
||||
import org.apache.celeborn.CelebornFunSuite
|
||||
import org.apache.celeborn.common.CommitMetadata
|
||||
|
||||
class SkewHandlingWithoutMapRangeValidatorTest extends CelebornFunSuite {
|
||||
|
||||
private var validator: SkewHandlingWithoutMapRangeValidator = _
|
||||
|
||||
test("testProcessSubPartitionFirstTime") {
|
||||
// Setup
|
||||
validator = new SkewHandlingWithoutMapRangeValidator
|
||||
val partitionId = 1
|
||||
val startMapIndex = 10
|
||||
val endMapIndex = 0
|
||||
val metadata = new CommitMetadata
|
||||
val expectedMapperCount = 20
|
||||
|
||||
// Execute
|
||||
val (success, message) = validator.processSubPartition(
|
||||
partitionId,
|
||||
startMapIndex,
|
||||
endMapIndex,
|
||||
metadata,
|
||||
expectedMapperCount)
|
||||
|
||||
success shouldBe true
|
||||
message shouldBe ""
|
||||
validator.isPartitionComplete(partitionId) shouldBe false
|
||||
|
||||
// Verify current metadata
|
||||
val currentMetadata = validator.currentCommitMetadata(partitionId)
|
||||
currentMetadata shouldNot be(null)
|
||||
currentMetadata shouldBe metadata
|
||||
}
|
||||
|
||||
test("testProcessMultipleSubPartitions") {
|
||||
// Setup
|
||||
validator = new SkewHandlingWithoutMapRangeValidator
|
||||
val partitionId = 1
|
||||
val startMapIndex = 10
|
||||
val metadata1 = new CommitMetadata(5, 1000)
|
||||
val metadata2 = new CommitMetadata(3, 500)
|
||||
|
||||
// Process first sub-partition
|
||||
validator.processSubPartition(partitionId, startMapIndex, 0, metadata1, startMapIndex)
|
||||
|
||||
// Process second sub-partition
|
||||
val (success, message) = validator.processSubPartition(
|
||||
partitionId,
|
||||
startMapIndex,
|
||||
1,
|
||||
metadata2,
|
||||
startMapIndex)
|
||||
|
||||
// Verify
|
||||
success shouldBe true
|
||||
message shouldBe ""
|
||||
validator.isPartitionComplete(partitionId) shouldBe false
|
||||
|
||||
// Verify current metadata
|
||||
val currentMetadata = validator.currentCommitMetadata(partitionId)
|
||||
currentMetadata shouldNot be(null)
|
||||
metadata1.addCommitData(metadata2)
|
||||
currentMetadata shouldBe metadata1
|
||||
|
||||
}
|
||||
|
||||
test("testPartitionComplete") {
|
||||
// Setup - processing all sub-partitions from 0 to 9 for a total of 10
|
||||
validator = new SkewHandlingWithoutMapRangeValidator
|
||||
val partitionId = 1
|
||||
val startMapIndex = 10
|
||||
|
||||
// Process all sub-partitions
|
||||
for (i <- 0 until startMapIndex) {
|
||||
val metadata = new CommitMetadata()
|
||||
validator.processSubPartition(partitionId, startMapIndex, i, metadata, startMapIndex)
|
||||
}
|
||||
|
||||
validator.isPartitionComplete(partitionId) shouldBe true
|
||||
}
|
||||
|
||||
test("testInvalidStartEndMapIndex") {
|
||||
// Setup - invalid case where startMapIndex <= endMapIndex
|
||||
validator = new SkewHandlingWithoutMapRangeValidator
|
||||
val partitionId = 1
|
||||
val startMapIndex = 5
|
||||
val endMapIndex = 5 // Equal, which should fail
|
||||
val metadata = new CommitMetadata()
|
||||
|
||||
// This should throw IllegalArgumentException
|
||||
intercept[IllegalArgumentException] {
|
||||
validator.processSubPartition(partitionId, startMapIndex, endMapIndex, metadata, 20)
|
||||
}
|
||||
}
|
||||
|
||||
test("testMismatchSubPartitionCount") {
|
||||
// Setup
|
||||
validator = new SkewHandlingWithoutMapRangeValidator
|
||||
val partitionId = 1
|
||||
val metadata = new CommitMetadata()
|
||||
|
||||
// First call with startMapIndex = 10
|
||||
validator.processSubPartition(partitionId, 10, 0, metadata, 10)
|
||||
|
||||
// Second call with different startMapIndex, which should fail
|
||||
intercept[IllegalStateException] {
|
||||
validator.processSubPartition(partitionId, 12, 1, metadata, 12)
|
||||
}
|
||||
}
|
||||
|
||||
test("testDuplicateSubPartitionWithSameMetadata") {
|
||||
// Setup
|
||||
validator = new SkewHandlingWithoutMapRangeValidator
|
||||
val partitionId = 1
|
||||
val startMapIndex = 10
|
||||
val endMapIndex = 5
|
||||
val metadata = new CommitMetadata(3, 100)
|
||||
|
||||
// Process sub-partition first time
|
||||
validator.processSubPartition(partitionId, startMapIndex, endMapIndex, metadata, startMapIndex)
|
||||
|
||||
// Process same sub-partition again with identical metadata
|
||||
val (success, message) = validator.processSubPartition(
|
||||
partitionId,
|
||||
startMapIndex,
|
||||
endMapIndex,
|
||||
metadata,
|
||||
startMapIndex)
|
||||
success shouldBe true
|
||||
message shouldBe ""
|
||||
validator.isPartitionComplete(partitionId) shouldBe false
|
||||
}
|
||||
|
||||
test("testDuplicateSubPartitionWithDifferentMetadata") {
|
||||
// Setup
|
||||
validator = new SkewHandlingWithoutMapRangeValidator
|
||||
val partitionId = 1
|
||||
val startMapIndex = 10
|
||||
val endMapIndex = 5
|
||||
val metadata1 = new CommitMetadata(3, 100)
|
||||
|
||||
val metadata2 = new CommitMetadata(4, 200)
|
||||
|
||||
// Process sub-partition first time
|
||||
validator.processSubPartition(partitionId, startMapIndex, endMapIndex, metadata1, startMapIndex)
|
||||
|
||||
// Process same sub-partition again with different metadata
|
||||
val (success, message) = validator.processSubPartition(
|
||||
partitionId,
|
||||
startMapIndex,
|
||||
endMapIndex,
|
||||
metadata2,
|
||||
startMapIndex)
|
||||
success shouldBe false
|
||||
message should include("Mismatch in metadata")
|
||||
validator.isPartitionComplete(partitionId) shouldBe false
|
||||
}
|
||||
|
||||
test("testProcessTooManySubPartitions") {
|
||||
// Setup
|
||||
validator = new SkewHandlingWithoutMapRangeValidator
|
||||
val partitionId = 1
|
||||
val startMapIndex = 10
|
||||
|
||||
// Process all sub-partitions
|
||||
for (i <- 0 until startMapIndex) {
|
||||
val metadata = new CommitMetadata()
|
||||
validator.processSubPartition(partitionId, startMapIndex, i, metadata, startMapIndex)
|
||||
}
|
||||
|
||||
// Try to process one more sub-partition, which should exceed the total count
|
||||
val extraMetadata = new CommitMetadata()
|
||||
// This should throw IllegalStateException
|
||||
intercept[IllegalArgumentException] {
|
||||
validator.processSubPartition(
|
||||
partitionId,
|
||||
startMapIndex,
|
||||
startMapIndex,
|
||||
extraMetadata,
|
||||
startMapIndex)
|
||||
}
|
||||
}
|
||||
|
||||
test("multiple complete partitions") {
|
||||
// Setup - we'll use 3 different partitions
|
||||
validator = new SkewHandlingWithoutMapRangeValidator
|
||||
val partition1 = 1
|
||||
val partition2 = 2
|
||||
val partition3 = 3
|
||||
|
||||
// Different sizes for each partition
|
||||
val subPartitions1 = 5
|
||||
val subPartitions2 = 8
|
||||
val subPartitions3 = 3
|
||||
|
||||
// Process all sub-partitions for partition1
|
||||
val expectedMetadata1 = new CommitMetadata()
|
||||
for (i <- 0 until subPartitions1) {
|
||||
val metadata = new CommitMetadata(i, 100 + i)
|
||||
validator.processSubPartition(partition1, subPartitions1, i, metadata, subPartitions1)
|
||||
expectedMetadata1.addCommitData(metadata)
|
||||
}
|
||||
|
||||
// Process all sub-partitions for partition2
|
||||
val expectedMetadata2 = new CommitMetadata()
|
||||
for (i <- 0 until subPartitions2) {
|
||||
val metadata = new CommitMetadata(i + 1, 200 + i)
|
||||
validator.processSubPartition(partition2, subPartitions2, i, metadata, subPartitions2)
|
||||
expectedMetadata2.addCommitData(metadata)
|
||||
}
|
||||
|
||||
// Process only some sub-partitions for partition3
|
||||
val expectedMetadata3 = new CommitMetadata()
|
||||
for (i <- 0 until subPartitions3 - 1) { // Deliberately leave one out
|
||||
val metadata = new CommitMetadata(i + 2, 300 + i)
|
||||
validator.processSubPartition(partition3, subPartitions3, i, metadata, subPartitions3)
|
||||
expectedMetadata3.addCommitData(metadata)
|
||||
}
|
||||
|
||||
// Verify completion status for each partition
|
||||
validator.isPartitionComplete(partition1) shouldBe true
|
||||
validator.isPartitionComplete(partition2) shouldBe true
|
||||
validator.isPartitionComplete(partition3) shouldBe false
|
||||
|
||||
validator.currentCommitMetadata(partition1) shouldBe expectedMetadata1
|
||||
validator.currentCommitMetadata(partition2) shouldBe expectedMetadata2
|
||||
validator.currentCommitMetadata(partition3) shouldBe expectedMetadata3
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,91 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.celeborn.common;
|
||||
|
||||
import java.util.Objects;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import java.util.zip.CRC32;
|
||||
|
||||
public class CelebornCRC32 {
|
||||
|
||||
private final AtomicInteger current;
|
||||
|
||||
CelebornCRC32(int i) {
|
||||
this.current = new AtomicInteger(i);
|
||||
}
|
||||
|
||||
CelebornCRC32() {
|
||||
this(0);
|
||||
}
|
||||
|
||||
static int compute(byte[] bytes) {
|
||||
CRC32 hashFunction = new CRC32();
|
||||
hashFunction.update(bytes);
|
||||
return (int) hashFunction.getValue();
|
||||
}
|
||||
|
||||
static int compute(byte[] bytes, int offset, int length) {
|
||||
CRC32 hashFunction = new CRC32();
|
||||
hashFunction.update(bytes, offset, length);
|
||||
return (int) hashFunction.getValue();
|
||||
}
|
||||
|
||||
static int combine(int first, int second) {
|
||||
first =
|
||||
(((byte) second + (byte) first) & 0xFF)
|
||||
| ((((byte) (second >> 8) + (byte) (first >> 8)) & 0xFF) << 8)
|
||||
| ((((byte) (second >> 16) + (byte) (first >> 16)) & 0xFF) << 16)
|
||||
| (((byte) (second >> 24) + (byte) (first >> 24)) << 24);
|
||||
return first;
|
||||
}
|
||||
|
||||
public void addChecksum(int checksum) {
|
||||
while (true) {
|
||||
int val = current.get();
|
||||
int newVal = combine(checksum, val);
|
||||
if (current.compareAndSet(val, newVal)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void addData(byte[] bytes, int offset, int length) {
|
||||
addChecksum(compute(bytes, offset, length));
|
||||
}
|
||||
|
||||
int get() {
|
||||
return current.get();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "CelebornCRC32{" + "current=" + current + '}';
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
CelebornCRC32 that = (CelebornCRC32) o;
|
||||
return current.get() == that.current.get();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hashCode(current);
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,82 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.celeborn.common;
|
||||
|
||||
import java.util.Objects;
|
||||
import java.util.concurrent.atomic.AtomicLong;
|
||||
|
||||
public class CommitMetadata {
|
||||
|
||||
private final AtomicLong bytes;
|
||||
private final CelebornCRC32 crc;
|
||||
|
||||
public CommitMetadata() {
|
||||
this.bytes = new AtomicLong();
|
||||
this.crc = new CelebornCRC32();
|
||||
}
|
||||
|
||||
public CommitMetadata(int checksum, long numBytes) {
|
||||
this.bytes = new AtomicLong(numBytes);
|
||||
this.crc = new CelebornCRC32(checksum);
|
||||
}
|
||||
|
||||
public void addDataWithOffsetAndLength(byte[] rawDataBuf, int offset, int length) {
|
||||
this.bytes.addAndGet(length);
|
||||
this.crc.addData(rawDataBuf, offset, length);
|
||||
}
|
||||
|
||||
public void addCommitData(CommitMetadata commitMetadata) {
|
||||
addCommitData(commitMetadata.getChecksum(), commitMetadata.getBytes());
|
||||
}
|
||||
|
||||
public void addCommitData(int checksum, long numBytes) {
|
||||
this.bytes.addAndGet(numBytes);
|
||||
this.crc.addChecksum(checksum);
|
||||
}
|
||||
|
||||
public int getChecksum() {
|
||||
return crc.get();
|
||||
}
|
||||
|
||||
public long getBytes() {
|
||||
return bytes.get();
|
||||
}
|
||||
|
||||
public static boolean checkCommitMetadata(CommitMetadata expected, CommitMetadata actual) {
|
||||
boolean bytesMatch = expected.getBytes() == actual.getBytes();
|
||||
boolean checksumsMatch = expected.getChecksum() == actual.getChecksum();
|
||||
return bytesMatch && checksumsMatch;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "CommitMetadata{" + "bytes=" + bytes.get() + ", crc=" + crc + '}';
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
CommitMetadata that = (CommitMetadata) o;
|
||||
return bytes.get() == that.bytes.get() && Objects.equals(crc, that.crc);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(bytes, crc);
|
||||
}
|
||||
}
|
||||
@ -92,7 +92,8 @@ public enum StatusCode {
|
||||
SEGMENT_START_FAIL_REPLICA(52),
|
||||
SEGMENT_START_FAIL_PRIMARY(53),
|
||||
NO_SPLIT(54),
|
||||
WORKER_UNRESPONSIVE(55);
|
||||
WORKER_UNRESPONSIVE(55),
|
||||
READ_REDUCER_PARTITION_END_FAILED(56);
|
||||
|
||||
private final byte value;
|
||||
|
||||
|
||||
@ -25,6 +25,7 @@ import java.util.concurrent.atomic.AtomicReference;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
|
||||
import org.apache.celeborn.common.CelebornConf;
|
||||
import org.apache.celeborn.common.CommitMetadata;
|
||||
import org.apache.celeborn.common.protocol.PartitionLocation;
|
||||
import org.apache.celeborn.common.util.JavaUtils;
|
||||
|
||||
@ -33,6 +34,9 @@ public class PushState {
|
||||
private final int pushBufferMaxSize;
|
||||
public AtomicReference<IOException> exception = new AtomicReference<>();
|
||||
private final InFlightRequestTracker inFlightRequestTracker;
|
||||
// partition id -> CommitMetadata
|
||||
private final ConcurrentHashMap<Integer, CommitMetadata> commitMetadataMap =
|
||||
new ConcurrentHashMap<>();
|
||||
|
||||
private final Map<String, LocationPushFailedBatches> failedBatchMap;
|
||||
|
||||
@ -102,4 +106,34 @@ public class PushState {
|
||||
public Map<String, LocationPushFailedBatches> getFailedBatches() {
|
||||
return this.failedBatchMap;
|
||||
}
|
||||
|
||||
public int[] getCRC32PerPartition(boolean shuffleIntegrityCheckEnabled, int numPartitions) {
|
||||
if (!shuffleIntegrityCheckEnabled) {
|
||||
return new int[0];
|
||||
}
|
||||
|
||||
int[] crc32PerPartition = new int[numPartitions];
|
||||
for (Map.Entry<Integer, CommitMetadata> entry : commitMetadataMap.entrySet()) {
|
||||
crc32PerPartition[entry.getKey()] = entry.getValue().getChecksum();
|
||||
}
|
||||
return crc32PerPartition;
|
||||
}
|
||||
|
||||
public long[] getBytesWrittenPerPartition(
|
||||
boolean shuffleIntegrityCheckEnabled, int numPartitions) {
|
||||
if (!shuffleIntegrityCheckEnabled) {
|
||||
return new long[0];
|
||||
}
|
||||
long[] bytesWrittenPerPartition = new long[numPartitions];
|
||||
for (Map.Entry<Integer, CommitMetadata> entry : commitMetadataMap.entrySet()) {
|
||||
bytesWrittenPerPartition[entry.getKey()] = entry.getValue().getBytes();
|
||||
}
|
||||
return bytesWrittenPerPartition;
|
||||
}
|
||||
|
||||
public void addDataWithOffsetAndLength(int partitionId, byte[] data, int offset, int length) {
|
||||
CommitMetadata commitMetadata =
|
||||
commitMetadataMap.computeIfAbsent(partitionId, id -> new CommitMetadata());
|
||||
commitMetadata.addDataWithOffsetAndLength(data, offset, length);
|
||||
}
|
||||
}
|
||||
|
||||
@ -114,6 +114,8 @@ enum MessageType {
|
||||
PUSH_MERGED_DATA_SPLIT_PARTITION_INFO = 91;
|
||||
GET_STAGE_END = 92;
|
||||
GET_STAGE_END_RESPONSE = 93;
|
||||
READ_REDUCER_PARTITION_END = 94;
|
||||
READ_REDUCER_PARTITION_END_RESPONSE = 95;
|
||||
}
|
||||
|
||||
enum StreamType {
|
||||
@ -360,6 +362,9 @@ message PbMapperEnd {
|
||||
int32 numMappers = 4;
|
||||
int32 partitionId = 5;
|
||||
map<string, PbLocationPushFailedBatches> pushFailureBatches= 6;
|
||||
int32 numPartitions = 7;
|
||||
repeated int32 crc32PerPartition = 8;
|
||||
repeated int64 bytesWrittenPerPartition = 9;
|
||||
}
|
||||
|
||||
message PbLocationPushFailedBatches {
|
||||
@ -920,3 +925,17 @@ message PbGetStageEndResponse {
|
||||
message PbChunkOffsets {
|
||||
repeated int64 chunkOffset = 1;
|
||||
}
|
||||
|
||||
message PbReadReducerPartitionEnd {
|
||||
int32 shuffleId = 1;
|
||||
int32 partitionId = 2;
|
||||
int32 startMaxIndex = 3;
|
||||
int32 endMapIndex = 4;
|
||||
int32 crc32 = 5;
|
||||
int64 bytesWritten = 6;
|
||||
}
|
||||
|
||||
message PbReadReducerPartitionEndResponse {
|
||||
int32 status = 1;
|
||||
string errorMsg = 2;
|
||||
}
|
||||
|
||||
@ -944,6 +944,8 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se
|
||||
get(CLIENT_EXCLUDE_PEER_WORKER_ON_FAILURE_ENABLED)
|
||||
def clientMrMaxPushData: Long = get(CLIENT_MR_PUSH_DATA_MAX)
|
||||
def clientApplicationUUIDSuffixEnabled: Boolean = get(CLIENT_APPLICATION_UUID_SUFFIX_ENABLED)
|
||||
def clientShuffleIntegrityCheckEnabled: Boolean =
|
||||
get(CLIENT_SHUFFLE_INTEGRITY_CHECK_ENABLED)
|
||||
|
||||
def appUniqueIdWithUUIDSuffix(appId: String): String = {
|
||||
if (clientApplicationUUIDSuffixEnabled) {
|
||||
@ -5358,6 +5360,14 @@ object CelebornConf extends Logging {
|
||||
.bytesConf(ByteUnit.BYTE)
|
||||
.createWithDefaultString("512k")
|
||||
|
||||
val CLIENT_SHUFFLE_INTEGRITY_CHECK_ENABLED: ConfigEntry[Boolean] =
|
||||
buildConf("celeborn.client.shuffle.integrityCheck.enabled")
|
||||
.categories("client")
|
||||
.version("0.6.1")
|
||||
.doc("When `true`, enables end-to-end integrity checks for Spark workloads.")
|
||||
.booleanConf
|
||||
.createWithDefault(false)
|
||||
|
||||
val SPARK_SHUFFLE_WRITER_MODE: ConfigEntry[String] =
|
||||
buildConf("celeborn.client.spark.shuffle.writer")
|
||||
.withAlternative("celeborn.shuffle.writer")
|
||||
|
||||
@ -19,9 +19,11 @@ package org.apache.celeborn.common.protocol.message
|
||||
|
||||
import java.util
|
||||
import java.util.{Collections, UUID}
|
||||
import java.util.concurrent.atomic.AtomicIntegerArray
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import com.google.common.base.Preconditions.checkState
|
||||
import com.google.protobuf.ByteString
|
||||
import org.roaringbitmap.RoaringBitmap
|
||||
|
||||
@ -274,11 +276,25 @@ object ControlMessages extends Logging {
|
||||
attemptId: Int,
|
||||
numMappers: Int,
|
||||
partitionId: Int,
|
||||
failedBatchSet: util.Map[String, LocationPushFailedBatches])
|
||||
failedBatchSet: util.Map[String, LocationPushFailedBatches],
|
||||
numPartitions: Int,
|
||||
crc32PerPartition: Array[Int],
|
||||
bytesWrittenPerPartition: Array[Long])
|
||||
extends MasterMessage
|
||||
|
||||
case class ReadReducerPartitionEnd(
|
||||
shuffleId: Int,
|
||||
partitionId: Int,
|
||||
startMapIndex: Int,
|
||||
endMapIndex: Int,
|
||||
crc32: Int,
|
||||
bytesWritten: Long)
|
||||
extends MasterMessage
|
||||
|
||||
case class MapperEndResponse(status: StatusCode) extends MasterMessage
|
||||
|
||||
case class ReadReducerPartitionEndResponse(status: StatusCode) extends MasterMessage
|
||||
|
||||
case class GetReducerFileGroup(
|
||||
shuffleId: Int,
|
||||
isSegmentGranularityVisible: Boolean,
|
||||
@ -606,6 +622,12 @@ object ControlMessages extends Logging {
|
||||
case pb: PbReviseLostShufflesResponse =>
|
||||
new TransportMessage(MessageType.REVISE_LOST_SHUFFLES_RESPONSE, pb.toByteArray)
|
||||
|
||||
case pb: PbReadReducerPartitionEnd =>
|
||||
new TransportMessage(MessageType.READ_REDUCER_PARTITION_END, pb.toByteArray)
|
||||
|
||||
case pb: PbReadReducerPartitionEndResponse =>
|
||||
new TransportMessage(MessageType.READ_REDUCER_PARTITION_END_RESPONSE, pb.toByteArray)
|
||||
|
||||
case pb: PbReportBarrierStageAttemptFailure =>
|
||||
new TransportMessage(MessageType.REPORT_BARRIER_STAGE_ATTEMPT_FAILURE, pb.toByteArray)
|
||||
|
||||
@ -739,7 +761,16 @@ object ControlMessages extends Logging {
|
||||
case pb: PbChangeLocationResponse =>
|
||||
new TransportMessage(MessageType.CHANGE_LOCATION_RESPONSE, pb.toByteArray)
|
||||
|
||||
case MapperEnd(shuffleId, mapId, attemptId, numMappers, partitionId, pushFailedBatch) =>
|
||||
case MapperEnd(
|
||||
shuffleId,
|
||||
mapId,
|
||||
attemptId,
|
||||
numMappers,
|
||||
partitionId,
|
||||
pushFailedBatch,
|
||||
numPartitions,
|
||||
crc32PerPartition,
|
||||
bytesWrittenPerPartition) =>
|
||||
val pushFailedMap = pushFailedBatch.asScala.map { case (k, v) =>
|
||||
val resultValue = PbSerDeUtils.toPbLocationPushFailedBatches(v)
|
||||
(k, resultValue)
|
||||
@ -751,6 +782,10 @@ object ControlMessages extends Logging {
|
||||
.setNumMappers(numMappers)
|
||||
.setPartitionId(partitionId)
|
||||
.putAllPushFailureBatches(pushFailedMap)
|
||||
.setNumPartitions(numPartitions)
|
||||
.addAllCrc32PerPartition(crc32PerPartition.map(Integer.valueOf).toSeq.asJava)
|
||||
.addAllBytesWrittenPerPartition(bytesWrittenPerPartition.map(
|
||||
java.lang.Long.valueOf).toSeq.asJava)
|
||||
.build().toByteArray
|
||||
new TransportMessage(MessageType.MAPPER_END, payload)
|
||||
|
||||
@ -1176,6 +1211,15 @@ object ControlMessages extends Logging {
|
||||
|
||||
case MAPPER_END_VALUE =>
|
||||
val pbMapperEnd = PbMapperEnd.parseFrom(message.getPayload)
|
||||
val partitionCount = pbMapperEnd.getCrc32PerPartitionCount
|
||||
checkState(partitionCount == pbMapperEnd.getBytesWrittenPerPartitionCount)
|
||||
val crc32Array = new Array[Int](partitionCount)
|
||||
val bytesWrittenPerPartitionArray =
|
||||
new Array[Long](pbMapperEnd.getBytesWrittenPerPartitionCount)
|
||||
for (i <- 0 until partitionCount) {
|
||||
crc32Array(i) = pbMapperEnd.getCrc32PerPartition(i)
|
||||
bytesWrittenPerPartitionArray(i) = pbMapperEnd.getBytesWrittenPerPartition(i)
|
||||
}
|
||||
MapperEnd(
|
||||
pbMapperEnd.getShuffleId,
|
||||
pbMapperEnd.getMapId,
|
||||
@ -1185,7 +1229,23 @@ object ControlMessages extends Logging {
|
||||
pbMapperEnd.getPushFailureBatchesMap.asScala.map {
|
||||
case (partitionId, pushFailedBatchSet) =>
|
||||
(partitionId, PbSerDeUtils.fromPbLocationPushFailedBatches(pushFailedBatchSet))
|
||||
}.toMap.asJava)
|
||||
}.toMap.asJava,
|
||||
pbMapperEnd.getNumPartitions,
|
||||
crc32Array,
|
||||
bytesWrittenPerPartitionArray)
|
||||
|
||||
case READ_REDUCER_PARTITION_END_VALUE =>
|
||||
val pbReadReducerPartitionEnd = PbReadReducerPartitionEnd.parseFrom(message.getPayload)
|
||||
ReadReducerPartitionEnd(
|
||||
pbReadReducerPartitionEnd.getShuffleId,
|
||||
pbReadReducerPartitionEnd.getPartitionId,
|
||||
pbReadReducerPartitionEnd.getStartMaxIndex,
|
||||
pbReadReducerPartitionEnd.getEndMapIndex,
|
||||
pbReadReducerPartitionEnd.getCrc32,
|
||||
pbReadReducerPartitionEnd.getBytesWritten)
|
||||
|
||||
case READ_REDUCER_PARTITION_END_RESPONSE_VALUE =>
|
||||
PbReadReducerPartitionEndResponse.parseFrom(message.getPayload)
|
||||
|
||||
case MAPPER_END_RESPONSE_VALUE =>
|
||||
val pbMapperEndResponse = PbMapperEndResponse.parseFrom(message.getPayload)
|
||||
|
||||
@ -0,0 +1,71 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.celeborn.common;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertNotEquals;
|
||||
|
||||
import org.junit.Test;
|
||||
|
||||
// Test data is generated from https://crccalc.com/ using "CRC-32/ISO-HDLC".
|
||||
public class CelebornCRC32Test {
|
||||
|
||||
@Test
|
||||
public void testCompute() {
|
||||
byte[] data = "test".getBytes();
|
||||
int checksum = CelebornCRC32.compute(data);
|
||||
assertEquals(3632233996L, checksum & 0xFFFFFFFFL);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testComputeWithOffset() {
|
||||
byte[] data = "testdata".getBytes();
|
||||
int checksum = CelebornCRC32.compute(data, 4, 4);
|
||||
assertEquals(2918445923L, checksum & 0xFFFFFFFFL);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCombine() {
|
||||
int first = 123456789;
|
||||
int second = 987654321;
|
||||
int combined = CelebornCRC32.combine(first, second);
|
||||
assertNotEquals(first, combined);
|
||||
assertNotEquals(second, combined);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testAddChecksum() {
|
||||
CelebornCRC32 crc = new CelebornCRC32();
|
||||
crc.addChecksum(123456789);
|
||||
assertEquals(123456789, crc.get());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testAddDataWithOffset() {
|
||||
CelebornCRC32 crc = new CelebornCRC32();
|
||||
byte[] data = "testdata".getBytes();
|
||||
crc.addData(data, 4, 4);
|
||||
assertEquals(2918445923L, crc.get() & 0xFFFFFFFFL);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testToString() {
|
||||
CelebornCRC32 crc = new CelebornCRC32(123456789);
|
||||
assertEquals("CelebornCRC32{current=123456789}", crc.toString());
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,72 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.celeborn.common;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
|
||||
public class CommitMetadataTest {
|
||||
|
||||
@Test
|
||||
public void testAddDataWithOffsetAndLength() {
|
||||
CommitMetadata metadata = new CommitMetadata();
|
||||
byte[] data = "testdata".getBytes();
|
||||
metadata.addDataWithOffsetAndLength(data, 4, 4);
|
||||
assertEquals(4, metadata.getBytes());
|
||||
assertEquals(CelebornCRC32.compute(data, 4, 4), metadata.getChecksum());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testAddCommitData() {
|
||||
CommitMetadata metadata1 = new CommitMetadata();
|
||||
byte[] data1 = "test".getBytes();
|
||||
metadata1.addDataWithOffsetAndLength(data1, 0, data1.length);
|
||||
|
||||
CommitMetadata metadata2 = new CommitMetadata();
|
||||
byte[] data2 = "data".getBytes();
|
||||
metadata2.addDataWithOffsetAndLength(data2, 0, data2.length);
|
||||
|
||||
metadata1.addCommitData(metadata2);
|
||||
|
||||
// Verify that the metadata1 now contains the combined data and checksum of data1 and data2
|
||||
assertEquals(data1.length + data2.length, metadata1.getBytes());
|
||||
assertEquals(
|
||||
CelebornCRC32.combine(CelebornCRC32.compute(data1), CelebornCRC32.compute(data2)),
|
||||
metadata1.getChecksum());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCheckCommitMetadata() {
|
||||
CommitMetadata expected = new CommitMetadata(CelebornCRC32.compute("testdata".getBytes()), 8);
|
||||
CommitMetadata actualMatching =
|
||||
new CommitMetadata(CelebornCRC32.compute("testdata".getBytes()), 8);
|
||||
CommitMetadata actualNonMatchingBytesOnly =
|
||||
new CommitMetadata(CelebornCRC32.compute("testdata".getBytes()), 12);
|
||||
CommitMetadata actualNonMatchingChecksumOnly =
|
||||
new CommitMetadata(CelebornCRC32.compute("foo".getBytes()), 8);
|
||||
CommitMetadata actualNonMatching =
|
||||
new CommitMetadata(CelebornCRC32.compute("bar".getBytes()), 16);
|
||||
|
||||
Assert.assertTrue(CommitMetadata.checkCommitMetadata(expected, actualMatching));
|
||||
Assert.assertFalse(CommitMetadata.checkCommitMetadata(expected, actualNonMatchingBytesOnly));
|
||||
Assert.assertFalse(CommitMetadata.checkCommitMetadata(expected, actualNonMatchingChecksumOnly));
|
||||
Assert.assertFalse(CommitMetadata.checkCommitMetadata(expected, actualNonMatching));
|
||||
}
|
||||
}
|
||||
@ -20,6 +20,9 @@ package org.apache.celeborn.common.util
|
||||
import java.util
|
||||
import java.util.Collections
|
||||
|
||||
import org.scalatest.matchers.must.Matchers.contain
|
||||
import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper
|
||||
|
||||
import org.apache.celeborn.CelebornFunSuite
|
||||
import org.apache.celeborn.common.CelebornConf
|
||||
import org.apache.celeborn.common.client.{MasterEndpointResolver, StaticMasterEndpointResolver}
|
||||
@ -145,10 +148,19 @@ class UtilsSuite extends CelebornFunSuite {
|
||||
}
|
||||
|
||||
test("MapperEnd class convert with pb") {
|
||||
val mapperEnd = MapperEnd(1, 1, 1, 2, 1, Collections.emptyMap())
|
||||
val mapperEnd =
|
||||
MapperEnd(1, 1, 1, 2, 1, Collections.emptyMap(), 1, Array.emptyIntArray, Array.emptyLongArray)
|
||||
val mapperEndTrans =
|
||||
Utils.fromTransportMessage(Utils.toTransportMessage(mapperEnd)).asInstanceOf[MapperEnd]
|
||||
assert(mapperEnd == mapperEndTrans)
|
||||
assert(mapperEnd.shuffleId == mapperEndTrans.shuffleId)
|
||||
assert(mapperEnd.mapId == mapperEndTrans.mapId)
|
||||
assert(mapperEnd.attemptId == mapperEndTrans.attemptId)
|
||||
assert(mapperEnd.numMappers == mapperEndTrans.numMappers)
|
||||
assert(mapperEnd.partitionId == mapperEndTrans.partitionId)
|
||||
assert(mapperEnd.failedBatchSet == mapperEndTrans.failedBatchSet)
|
||||
assert(mapperEnd.numPartitions == mapperEndTrans.numPartitions)
|
||||
mapperEnd.crc32PerPartition.array should contain theSameElementsInOrderAs mapperEndTrans.crc32PerPartition
|
||||
mapperEnd.bytesWrittenPerPartition.array should contain theSameElementsInOrderAs mapperEndTrans.bytesWrittenPerPartition
|
||||
}
|
||||
|
||||
test("validate HDFS compatible fs path") {
|
||||
|
||||
@ -105,6 +105,7 @@ license: |
|
||||
| celeborn.client.shuffle.dynamicResourceEnabled | false | false | When enabled, the ChangePartitionManager will obtain candidate workers from the availableWorkers pool during heartbeats when worker resource change. | 0.6.0 | |
|
||||
| celeborn.client.shuffle.dynamicResourceFactor | 0.5 | false | The ChangePartitionManager will check whether (unavailable workers / shuffle allocated workers) is more than the factor before obtaining candidate workers from the requestSlots RPC response when `celeborn.client.shuffle.dynamicResourceEnabled` set true | 0.6.0 | |
|
||||
| celeborn.client.shuffle.expired.checkInterval | 60s | false | Interval for client to check expired shuffles. | 0.3.0 | celeborn.shuffle.expired.checkInterval |
|
||||
| celeborn.client.shuffle.integrityCheck.enabled | false | false | When `true`, enables end-to-end integrity checks for Spark workloads. | 0.6.1 | |
|
||||
| celeborn.client.shuffle.manager.port | 0 | false | Port used by the LifecycleManager on the Driver. | 0.3.0 | celeborn.shuffle.manager.port |
|
||||
| celeborn.client.shuffle.partition.type | REDUCE | false | Type of shuffle's partition. | 0.3.0 | celeborn.shuffle.partition.type |
|
||||
| celeborn.client.shuffle.partitionSplit.mode | SOFT | false | soft: the shuffle file size might be larger than split threshold. hard: the shuffle file size will be limited to split threshold. | 0.3.0 | celeborn.shuffle.partitionSplit.mode |
|
||||
|
||||
@ -70,8 +70,9 @@ class LifecycleManagerCommitFilesSuite extends WithShuffleClientSuite with MiniC
|
||||
res.workerResource,
|
||||
updateEpoch = false)
|
||||
|
||||
lifecycleManager.commitManager.registerShuffle(shuffleId, 1, false)
|
||||
0 until 10 foreach { partitionId =>
|
||||
val numPartitions = 10
|
||||
lifecycleManager.commitManager.registerShuffle(shuffleId, 1, false, numPartitions)
|
||||
0 until numPartitions foreach { partitionId =>
|
||||
lifecycleManager.commitManager.finishMapperAttempt(shuffleId, 0, 0, 1, partitionId)
|
||||
}
|
||||
|
||||
@ -126,8 +127,9 @@ class LifecycleManagerCommitFilesSuite extends WithShuffleClientSuite with MiniC
|
||||
res.workerResource,
|
||||
updateEpoch = false)
|
||||
|
||||
lifecycleManager.commitManager.registerShuffle(shuffleId, 1, false)
|
||||
0 until 10 foreach { partitionId =>
|
||||
val numPartitions = 10
|
||||
lifecycleManager.commitManager.registerShuffle(shuffleId, 1, false, numPartitions)
|
||||
0 until numPartitions foreach { partitionId =>
|
||||
lifecycleManager.commitManager.finishMapperAttempt(shuffleId, 0, 0, 1, partitionId)
|
||||
}
|
||||
|
||||
@ -196,8 +198,9 @@ class LifecycleManagerCommitFilesSuite extends WithShuffleClientSuite with MiniC
|
||||
res.workerResource,
|
||||
updateEpoch = false)
|
||||
|
||||
lifecycleManager.commitManager.registerShuffle(shuffleId, 1, false)
|
||||
0 until 1000 foreach { partitionId =>
|
||||
val numPartitions = 1000
|
||||
lifecycleManager.commitManager.registerShuffle(shuffleId, 1, false, numPartitions)
|
||||
0 until numPartitions foreach { partitionId =>
|
||||
lifecycleManager.commitManager.finishMapperAttempt(shuffleId, 0, 0, 1, partitionId)
|
||||
}
|
||||
|
||||
@ -256,13 +259,14 @@ class LifecycleManagerCommitFilesSuite extends WithShuffleClientSuite with MiniC
|
||||
res.workerResource,
|
||||
updateEpoch = false)
|
||||
|
||||
lifecycleManager.commitManager.registerShuffle(shuffleId, 1, false)
|
||||
val numPartitions = 3
|
||||
lifecycleManager.commitManager.registerShuffle(shuffleId, 1, false, numPartitions)
|
||||
|
||||
val buffer = "hello world".getBytes(StandardCharsets.UTF_8)
|
||||
|
||||
var bufferLength = -1
|
||||
|
||||
0 until 3 foreach { partitionId =>
|
||||
0 until numPartitions foreach { partitionId =>
|
||||
bufferLength =
|
||||
shuffleClient.pushData(shuffleId, 0, 0, partitionId, buffer, 0, buffer.length, 1, 3)
|
||||
lifecycleManager.commitManager.finishMapperAttempt(shuffleId, 0, 0, 1, partitionId)
|
||||
|
||||
@ -157,7 +157,7 @@ class LifecycleManagerReserveSlotsSuite extends AnyFunSuite
|
||||
|
||||
// push merged data, we expect that partition(0) will be split, while partition(1) will not be split
|
||||
shuffleClient1.pushMergedData(SHUFFLE_ID, MAP_ID, ATTEMPT_ID)
|
||||
shuffleClient1.mapperEnd(SHUFFLE_ID, MAP_ID, ATTEMPT_ID, MAP_NUM)
|
||||
shuffleClient1.mapperEnd(SHUFFLE_ID, MAP_ID, ATTEMPT_ID, MAP_NUM, PARTITION_NUM)
|
||||
// partition(1) will not be split
|
||||
assert(partitionLocationMap1.get(partitions(1)).getEpoch == 0)
|
||||
|
||||
|
||||
@ -0,0 +1,158 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.celeborn.tests.spark
|
||||
|
||||
import org.apache.spark._
|
||||
import org.apache.spark.shuffle.celeborn.{SparkUtils, TestCelebornShuffleManager}
|
||||
import org.apache.spark.sql.SparkSession
|
||||
import org.scalatest.BeforeAndAfterEach
|
||||
import org.scalatest.funsuite.AnyFunSuite
|
||||
import org.slf4j.{Logger, LoggerFactory}
|
||||
|
||||
import org.apache.celeborn.client.ShuffleClient
|
||||
import org.apache.celeborn.common.CelebornConf
|
||||
import org.apache.celeborn.common.protocol.{CompressionCodec, ShuffleMode}
|
||||
|
||||
class CelebornIntegrityCheckSuite extends AnyFunSuite
|
||||
with SparkTestBase
|
||||
with BeforeAndAfterEach {
|
||||
|
||||
private val logger = LoggerFactory.getLogger(classOf[CelebornIntegrityCheckSuite])
|
||||
|
||||
override def beforeEach(): Unit = {
|
||||
ShuffleClient.reset()
|
||||
}
|
||||
|
||||
override def afterEach(): Unit = {
|
||||
System.gc()
|
||||
}
|
||||
|
||||
test("celeborn spark integration test - corrupted data, no integrity check - app succeeds but data is different") {
|
||||
if (Spark3OrNewer) {
|
||||
val sparkConf = new SparkConf().setAppName("rss-demo").setMaster("local[2,3]")
|
||||
val sparkSession = SparkSession.builder()
|
||||
.config(updateSparkConf(sparkConf, ShuffleMode.HASH))
|
||||
.config("spark.sql.shuffle.partitions", 2)
|
||||
.config(
|
||||
"spark.shuffle.manager",
|
||||
"org.apache.spark.shuffle.celeborn.TestCelebornShuffleManager")
|
||||
.config(
|
||||
s"spark.${CelebornConf.SHUFFLE_COMPRESSION_CODEC.key}",
|
||||
CompressionCodec.NONE.toString)
|
||||
.getOrCreate()
|
||||
|
||||
// Introduce Data Corruption in single bit in 1 partition location file
|
||||
val celebornConf = SparkUtils.fromSparkConf(sparkSession.sparkContext.getConf)
|
||||
val hook = new ShuffleReaderGetHookForCorruptedData(celebornConf, workerDirs)
|
||||
TestCelebornShuffleManager.registerReaderGetHook(hook)
|
||||
|
||||
val value = Range(1, 10000).mkString(",")
|
||||
val tuples = sparkSession.sparkContext.parallelize(1 to 1000, 2)
|
||||
.map { i => (i, value) }.groupByKey(16).collect()
|
||||
|
||||
assert(tuples.length == 1000)
|
||||
|
||||
try {
|
||||
for (elem <- tuples) {
|
||||
assert(elem._2.mkString(",").equals(value))
|
||||
}
|
||||
} catch {
|
||||
case e: Throwable =>
|
||||
e.getMessage.contains("elem._2.mkString(\",\").equals(value) was false")
|
||||
} finally {
|
||||
sparkSession.stop()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("celeborn spark integration test - corrupted data, integrity checks enabled, no stage rerun - app fails") {
|
||||
if (Spark3OrNewer) {
|
||||
val sparkConf = new SparkConf().setAppName("rss-demo").setMaster("local[2,3]")
|
||||
val sparkSession = SparkSession.builder()
|
||||
.config(updateSparkConf(sparkConf, ShuffleMode.HASH))
|
||||
.config("spark.sql.shuffle.partitions", 2)
|
||||
.config(
|
||||
s"spark.${CelebornConf.SHUFFLE_COMPRESSION_CODEC.key}",
|
||||
CompressionCodec.NONE.toString)
|
||||
.config(s"spark.${CelebornConf.CLIENT_STAGE_RERUN_ENABLED.key}", "false")
|
||||
.config("spark.celeborn.client.shuffle.integrityCheck.enabled", "true")
|
||||
.config(
|
||||
"spark.shuffle.manager",
|
||||
"org.apache.spark.shuffle.celeborn.TestCelebornShuffleManager")
|
||||
.getOrCreate()
|
||||
|
||||
// Introduce Data Corruption in single bit in 1 partition location file
|
||||
val celebornConf = SparkUtils.fromSparkConf(sparkSession.sparkContext.getConf)
|
||||
val hook = new ShuffleReaderGetHookForCorruptedData(celebornConf, workerDirs)
|
||||
TestCelebornShuffleManager.registerReaderGetHook(hook)
|
||||
|
||||
try {
|
||||
val value = Range(1, 10000).mkString(",")
|
||||
val tuples = sparkSession.sparkContext.parallelize(1 to 1000, 2)
|
||||
.map { i => (i, value) }.groupByKey(16).collect()
|
||||
fail("App should abort prior to this step and throw an exception")
|
||||
} catch {
|
||||
// verify that the app fails
|
||||
case e: Throwable => {
|
||||
logger.error("Expected exception, logging the full exception", e)
|
||||
assert(e.getMessage.contains("Job aborted"))
|
||||
assert(e.getMessage.contains("CommitMetadata mismatch"))
|
||||
}
|
||||
} finally {
|
||||
sparkSession.stop()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("celeborn spark integration test - corrupted data, integrity checks enabled, stage rerun enabled - app succeeds") {
|
||||
if (Spark3OrNewer) {
|
||||
val sparkConf = new SparkConf().setAppName("rss-demo").setMaster("local[2,3]")
|
||||
val sparkSession = SparkSession.builder()
|
||||
.config(updateSparkConf(sparkConf, ShuffleMode.HASH))
|
||||
.config("spark.sql.shuffle.partitions", 2)
|
||||
.config(
|
||||
s"spark.${CelebornConf.SHUFFLE_COMPRESSION_CODEC.key}",
|
||||
CompressionCodec.NONE.toString)
|
||||
.config(s"spark.${CelebornConf.CLIENT_STAGE_RERUN_ENABLED.key}", "true")
|
||||
.config("spark.celeborn.client.shuffle.integrityCheck.enabled", "true")
|
||||
.config(
|
||||
"spark.shuffle.manager",
|
||||
"org.apache.spark.shuffle.celeborn.TestCelebornShuffleManager")
|
||||
.getOrCreate()
|
||||
|
||||
// Introduce Data Corruption in single bit in 1 partition location file
|
||||
val celebornConf = SparkUtils.fromSparkConf(sparkSession.sparkContext.getConf)
|
||||
val hook = new ShuffleReaderGetHookForCorruptedData(celebornConf, workerDirs)
|
||||
TestCelebornShuffleManager.registerReaderGetHook(hook)
|
||||
|
||||
try {
|
||||
val value = Range(1, 10000).mkString(",")
|
||||
val tuples = sparkSession.sparkContext.parallelize(1 to 1000, 2)
|
||||
.map { i => (i, value) }.groupByKey(16).collect()
|
||||
|
||||
// verify result
|
||||
assert(tuples.length == 1000)
|
||||
for (elem <- tuples) {
|
||||
assert(elem._2.mkString(",").equals(value))
|
||||
}
|
||||
} finally {
|
||||
sparkSession.stop()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,252 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.celeborn.tests.spark
|
||||
|
||||
import java.io.{File, RandomAccessFile}
|
||||
import java.util.Random
|
||||
import java.util.concurrent.atomic.AtomicBoolean
|
||||
|
||||
import scala.util.control.Breaks.{break, breakable}
|
||||
|
||||
import org.apache.spark.TaskContext
|
||||
import org.apache.spark.shuffle.ShuffleHandle
|
||||
import org.apache.spark.shuffle.celeborn.{CelebornShuffleHandle, ShuffleManagerHook, SparkCommonUtils, SparkUtils}
|
||||
import org.slf4j.LoggerFactory
|
||||
|
||||
import org.apache.celeborn.client.ShuffleClient
|
||||
import org.apache.celeborn.common.CelebornConf
|
||||
import org.apache.celeborn.common.unsafe.Platform
|
||||
|
||||
class ShuffleReaderGetHookForCorruptedData(
|
||||
conf: CelebornConf,
|
||||
workerDirs: Seq[String],
|
||||
shuffleIdToBeModified: Seq[Int] = Seq(),
|
||||
triggerStageId: Option[Int] = None)
|
||||
extends ShuffleManagerHook {
|
||||
|
||||
private val logger = LoggerFactory.getLogger(classOf[ShuffleReaderGetHookForCorruptedData])
|
||||
var executed: AtomicBoolean = new AtomicBoolean(false)
|
||||
val lock = new Object
|
||||
var corruptedCount = 0
|
||||
val random = new Random()
|
||||
|
||||
private def modifyDataFileWithSingleBitFlip(appUniqueId: String, celebornShuffleId: Int): Unit = {
|
||||
if (corruptedCount > 0) {
|
||||
return
|
||||
}
|
||||
|
||||
val minFileSize = 16 // Minimum file size to be considered for corruption
|
||||
|
||||
// Find all potential data files across all worker directories
|
||||
val dataFiles =
|
||||
workerDirs.flatMap(dir => {
|
||||
val shuffleDir =
|
||||
new File(s"$dir/celeborn-worker/shuffle_data/$appUniqueId/$celebornShuffleId")
|
||||
if (shuffleDir.exists()) {
|
||||
shuffleDir.listFiles().filter(_.isFile).filter(_.length() >= minFileSize)
|
||||
} else {
|
||||
Array.empty[File]
|
||||
}
|
||||
})
|
||||
|
||||
if (dataFiles.isEmpty) {
|
||||
logger.error(
|
||||
s"No suitable data files found for appUniqueId=$appUniqueId, shuffleId=$celebornShuffleId")
|
||||
return
|
||||
}
|
||||
|
||||
// Sort files by size (descending) to prioritize larger files that are more likely to have data
|
||||
val sortedFiles = dataFiles.sortBy(-_.length())
|
||||
logger.info(s"Found ${sortedFiles.length} data files for corruption testing.")
|
||||
|
||||
// Try to corrupt files one by one until successful
|
||||
breakable {
|
||||
for (file <- sortedFiles) {
|
||||
if (tryCorruptFile(file)) {
|
||||
corruptedCount = 1
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
// If we couldn't corrupt any file through the normal process, use the fallback method
|
||||
if (corruptedCount == 0 && sortedFiles.nonEmpty) {
|
||||
logger.warn(
|
||||
"Could not find a valid data section in any file. Using safer fallback corruption method.")
|
||||
val fileToCorrupt = sortedFiles.head // Take the largest file for fallback
|
||||
if (fallbackCorruption(fileToCorrupt)) {
|
||||
corruptedCount = 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Try to corrupt a specific file by finding and corrupting a valid data section.
|
||||
* @return true if corruption was successful, false otherwise
|
||||
*/
|
||||
private def tryCorruptFile(file: File): Boolean = {
|
||||
|
||||
val random = new Random()
|
||||
logger.info(s"Attempting to corrupt file: ${file.getPath()}, size: ${file.length()} bytes")
|
||||
val raf = new RandomAccessFile(file, "rw")
|
||||
try {
|
||||
val fileSize = raf.length()
|
||||
|
||||
// Find a valid data section in the file
|
||||
var position: Long = 0
|
||||
val headerSize = 16
|
||||
|
||||
breakable {
|
||||
while (position + headerSize <= fileSize) {
|
||||
raf.seek(position)
|
||||
val sizeBuf = new Array[Byte](headerSize)
|
||||
raf.readFully(sizeBuf)
|
||||
|
||||
// Parse header
|
||||
val mapId: Int = Platform.getInt(sizeBuf, Platform.BYTE_ARRAY_OFFSET)
|
||||
val attemptId: Int = Platform.getInt(sizeBuf, Platform.BYTE_ARRAY_OFFSET + 4)
|
||||
val batchId: Int = Platform.getInt(sizeBuf, Platform.BYTE_ARRAY_OFFSET + 8)
|
||||
val dataSize: Int = Platform.getInt(sizeBuf, Platform.BYTE_ARRAY_OFFSET + 12)
|
||||
|
||||
logger.info(s"Found header at position $position: mapId=$mapId, attemptId=$attemptId, " +
|
||||
s"batchId=$batchId, dataSize=$dataSize")
|
||||
|
||||
// Validate dataSize - should be positive and within file bounds
|
||||
if (dataSize > 0 && position + headerSize + dataSize <= fileSize) {
|
||||
val dataStartPosition = position + headerSize
|
||||
|
||||
// Choose a random position within the data section
|
||||
val offsetInData = random.nextInt(dataSize)
|
||||
val corruptPosition = dataStartPosition + offsetInData
|
||||
|
||||
// Read the byte at the position we want to corrupt
|
||||
raf.seek(corruptPosition)
|
||||
val originalByte = raf.readByte()
|
||||
|
||||
// Flip one bit in the byte (avoid the sign bit for numeric values)
|
||||
val bitToFlip = 1 << random.nextInt(7) // bits 0-6 only, avoiding the sign bit
|
||||
val corruptedByte = (originalByte ^ bitToFlip).toByte
|
||||
|
||||
// Write back the corrupted byte
|
||||
raf.seek(corruptPosition)
|
||||
raf.writeByte(corruptedByte)
|
||||
|
||||
logger.info(s"Successfully corrupted byte in file ${file.getName()} at position " +
|
||||
s"$corruptPosition: ${originalByte & 0xFF} -> ${corruptedByte & 0xFF} (flipped bit $bitToFlip)")
|
||||
|
||||
return true // Corruption successful
|
||||
}
|
||||
|
||||
// Skip to next record: header size + data size (even if data size is 0)
|
||||
position += headerSize + Math.max(0, dataSize)
|
||||
}
|
||||
}
|
||||
|
||||
false // No valid data section found
|
||||
} catch {
|
||||
case e: Exception =>
|
||||
logger.error(s"Error while attempting to corrupt file ${file.getPath()}", e)
|
||||
false
|
||||
} finally {
|
||||
raf.close()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Fallback method for corruption when we can't find a valid data section.
|
||||
* This method simply corrupts a byte in the middle of the file.
|
||||
*/
|
||||
private def fallbackCorruption(file: File): Boolean = {
|
||||
val raf = new RandomAccessFile(file, "rw")
|
||||
try {
|
||||
val fileSize = raf.length()
|
||||
|
||||
// Skip first 64 bytes to avoid headers, and corrupt somewhere in the middle of the file
|
||||
val safeStartPosition = Math.min(64, fileSize / 4)
|
||||
val safeEndPosition = Math.max(safeStartPosition + 1, fileSize - 16)
|
||||
|
||||
// Ensure we have a valid range
|
||||
if (safeEndPosition <= safeStartPosition) {
|
||||
logger.error(s"File ${file.getName()} too small for safe corruption: $fileSize bytes")
|
||||
return false
|
||||
}
|
||||
|
||||
val corruptPosition = safeStartPosition +
|
||||
random.nextInt((safeEndPosition - safeStartPosition).toInt)
|
||||
|
||||
raf.seek(corruptPosition)
|
||||
val originalByte = raf.readByte()
|
||||
|
||||
val bitToFlip = 1 << random.nextInt(7)
|
||||
val corruptedByte = (originalByte ^ bitToFlip).toByte
|
||||
|
||||
raf.seek(corruptPosition)
|
||||
raf.writeByte(corruptedByte)
|
||||
|
||||
logger.info(s"Used fallback corruption approach on file ${file.getName()}. " +
|
||||
s"Corrupted byte at position $corruptPosition: ${originalByte & 0xFF} -> ${corruptedByte & 0xFF}")
|
||||
|
||||
true
|
||||
} catch {
|
||||
case e: Exception =>
|
||||
logger.error(s"Error during fallback corruption of file ${file.getPath()}", e)
|
||||
false
|
||||
} finally {
|
||||
raf.close()
|
||||
}
|
||||
}
|
||||
|
||||
override def exec(
|
||||
handle: ShuffleHandle,
|
||||
startPartition: Int,
|
||||
endPartition: Int,
|
||||
context: TaskContext): Unit = {
|
||||
if (executed.get()) {
|
||||
return
|
||||
}
|
||||
lock.synchronized {
|
||||
handle match {
|
||||
case h: CelebornShuffleHandle[_, _, _] => {
|
||||
val appUniqueId = h.appUniqueId
|
||||
val shuffleClient = ShuffleClient.get(
|
||||
h.appUniqueId,
|
||||
h.lifecycleManagerHost,
|
||||
h.lifecycleManagerPort,
|
||||
conf,
|
||||
h.userIdentifier,
|
||||
h.extension)
|
||||
val celebornShuffleId = SparkUtils.celebornShuffleId(shuffleClient, h, context, false)
|
||||
val appShuffleIdentifier =
|
||||
SparkCommonUtils.encodeAppShuffleIdentifier(handle.shuffleId, context)
|
||||
val Array(_, stageId, _) = appShuffleIdentifier.split('-')
|
||||
if (triggerStageId.isEmpty || triggerStageId.get == stageId.toInt) {
|
||||
if (shuffleIdToBeModified.isEmpty) {
|
||||
modifyDataFileWithSingleBitFlip(appUniqueId, celebornShuffleId)
|
||||
} else {
|
||||
shuffleIdToBeModified.foreach { shuffleId =>
|
||||
modifyDataFileWithSingleBitFlip(appUniqueId, shuffleId)
|
||||
}
|
||||
}
|
||||
executed.set(true)
|
||||
}
|
||||
}
|
||||
case x => throw new RuntimeException(s"unexpected, only support RssShuffleHandle here," +
|
||||
s" but get ${x.getClass.getCanonicalName}")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -49,70 +49,74 @@ class SkewJoinSuite extends AnyFunSuite
|
||||
|
||||
CompressionCodec.values.foreach { codec =>
|
||||
Seq(false, true).foreach { enabled =>
|
||||
test(
|
||||
s"celeborn spark integration test - skew join - with $codec - with client skew $enabled") {
|
||||
val sparkConf = new SparkConf().setAppName("celeborn-demo")
|
||||
.setMaster("local[2]")
|
||||
.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true")
|
||||
.set("spark.sql.adaptive.skewJoin.enabled", "true")
|
||||
.set("spark.sql.adaptive.skewJoin.skewedPartitionThresholdInBytes", "16MB")
|
||||
.set("spark.sql.adaptive.advisoryPartitionSizeInBytes", "12MB")
|
||||
.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
|
||||
.set("spark.sql.adaptive.autoBroadcastJoinThreshold", "-1")
|
||||
.set(SQLConf.PARQUET_COMPRESSION.key, "gzip")
|
||||
.set(s"spark.${CelebornConf.SHUFFLE_COMPRESSION_CODEC.key}", codec.name)
|
||||
.set(s"spark.${CelebornConf.SHUFFLE_RANGE_READ_FILTER_ENABLED.key}", "true")
|
||||
.set(
|
||||
s"spark.${CelebornConf.CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ_ENABLED.key}",
|
||||
s"$enabled")
|
||||
Seq(false, true).foreach { integrityChecksEnabled =>
|
||||
test(
|
||||
s"celeborn spark integration test - skew join - with $codec - with client skew $enabled - with integrity checks $integrityChecksEnabled") {
|
||||
val sparkConf = new SparkConf().setAppName("celeborn-demo")
|
||||
.setMaster("local[2]")
|
||||
.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true")
|
||||
.set("spark.sql.adaptive.skewJoin.enabled", "true")
|
||||
.set("spark.sql.adaptive.skewJoin.skewedPartitionThresholdInBytes", "16MB")
|
||||
.set("spark.sql.adaptive.advisoryPartitionSizeInBytes", "12MB")
|
||||
.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
|
||||
.set("spark.sql.adaptive.autoBroadcastJoinThreshold", "-1")
|
||||
.set(SQLConf.PARQUET_COMPRESSION.key, "gzip")
|
||||
.set(s"spark.${CelebornConf.SHUFFLE_COMPRESSION_CODEC.key}", codec.name)
|
||||
.set(s"spark.${CelebornConf.SHUFFLE_RANGE_READ_FILTER_ENABLED.key}", "true")
|
||||
.set(
|
||||
s"spark.${CelebornConf.CLIENT_ADAPTIVE_OPTIMIZE_SKEWED_PARTITION_READ_ENABLED.key}",
|
||||
s"$enabled")
|
||||
.set(
|
||||
s"spark.${CelebornConf.CLIENT_SHUFFLE_INTEGRITY_CHECK_ENABLED.key}",
|
||||
s"$integrityChecksEnabled")
|
||||
|
||||
enableCeleborn(sparkConf)
|
||||
enableCeleborn(sparkConf)
|
||||
|
||||
val sparkSession = SparkSession.builder().config(sparkConf).getOrCreate()
|
||||
if (sparkSession.version.startsWith("3")) {
|
||||
import sparkSession.implicits._
|
||||
val df = sparkSession.sparkContext.parallelize(1 to 120000, 8)
|
||||
.map(i => {
|
||||
val random = new Random()
|
||||
val oriKey = random.nextInt(64)
|
||||
val key = if (oriKey < 32) 1 else oriKey
|
||||
val fas = random.nextInt(1200000)
|
||||
val fa = Range(fas, fas + 100).mkString(",")
|
||||
val fbs = random.nextInt(1200000)
|
||||
val fb = Range(fbs, fbs + 100).mkString(",")
|
||||
val fcs = random.nextInt(1200000)
|
||||
val fc = Range(fcs, fcs + 100).mkString(",")
|
||||
val fds = random.nextInt(1200000)
|
||||
val fd = Range(fds, fds + 100).mkString(",")
|
||||
val sparkSession = SparkSession.builder().config(sparkConf).getOrCreate()
|
||||
if (sparkSession.version.startsWith("3")) {
|
||||
import sparkSession.implicits._
|
||||
val df = sparkSession.sparkContext.parallelize(1 to 120000, 8)
|
||||
.map(i => {
|
||||
val random = new Random()
|
||||
val oriKey = random.nextInt(64)
|
||||
val key = if (oriKey < 32) 1 else oriKey
|
||||
val fas = random.nextInt(1200000)
|
||||
val fa = Range(fas, fas + 100).mkString(",")
|
||||
val fbs = random.nextInt(1200000)
|
||||
val fb = Range(fbs, fbs + 100).mkString(",")
|
||||
val fcs = random.nextInt(1200000)
|
||||
val fc = Range(fcs, fcs + 100).mkString(",")
|
||||
val fds = random.nextInt(1200000)
|
||||
val fd = Range(fds, fds + 100).mkString(",")
|
||||
|
||||
(key, fa, fb, fc, fd)
|
||||
})
|
||||
.toDF("fa", "f1", "f2", "f3", "f4")
|
||||
df.createOrReplaceTempView("view1")
|
||||
val df2 = sparkSession.sparkContext.parallelize(1 to 8, 8)
|
||||
.map(i => {
|
||||
val random = new Random()
|
||||
val oriKey = random.nextInt(64)
|
||||
val key = if (oriKey < 32) 1 else oriKey
|
||||
val fas = random.nextInt(1200000)
|
||||
val fa = Range(fas, fas + 100).mkString(",")
|
||||
val fbs = random.nextInt(1200000)
|
||||
val fb = Range(fbs, fbs + 100).mkString(",")
|
||||
val fcs = random.nextInt(1200000)
|
||||
val fc = Range(fcs, fcs + 100).mkString(",")
|
||||
val fds = random.nextInt(1200000)
|
||||
val fd = Range(fds, fds + 100).mkString(",")
|
||||
(key, fa, fb, fc, fd)
|
||||
})
|
||||
.toDF("fb", "f6", "f7", "f8", "f9")
|
||||
df2.createOrReplaceTempView("view2")
|
||||
sparkSession.sql("drop table if exists fres")
|
||||
sparkSession.sql("create table fres using parquet as select * from view1 a inner join view2 b on a.fa=b.fb where a.fa=1 ")
|
||||
sparkSession.sql("drop table fres")
|
||||
sparkSession.stop()
|
||||
(key, fa, fb, fc, fd)
|
||||
})
|
||||
.toDF("fa", "f1", "f2", "f3", "f4")
|
||||
df.createOrReplaceTempView("view1")
|
||||
val df2 = sparkSession.sparkContext.parallelize(1 to 8, 8)
|
||||
.map(i => {
|
||||
val random = new Random()
|
||||
val oriKey = random.nextInt(64)
|
||||
val key = if (oriKey < 32) 1 else oriKey
|
||||
val fas = random.nextInt(1200000)
|
||||
val fa = Range(fas, fas + 100).mkString(",")
|
||||
val fbs = random.nextInt(1200000)
|
||||
val fb = Range(fbs, fbs + 100).mkString(",")
|
||||
val fcs = random.nextInt(1200000)
|
||||
val fc = Range(fcs, fcs + 100).mkString(",")
|
||||
val fds = random.nextInt(1200000)
|
||||
val fd = Range(fds, fds + 100).mkString(",")
|
||||
(key, fa, fb, fc, fd)
|
||||
})
|
||||
.toDF("fb", "f6", "f7", "f8", "f9")
|
||||
df2.createOrReplaceTempView("view2")
|
||||
sparkSession.sql("drop table if exists fres")
|
||||
sparkSession.sql("create table fres using parquet as select * from view1 a inner join view2 b on a.fa=b.fb where a.fa=1 ")
|
||||
sparkSession.sql("drop table fres")
|
||||
sparkSession.stop()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@ -109,7 +109,7 @@ trait JavaReadCppWriteTestBase extends AnyFunSuite
|
||||
}
|
||||
}
|
||||
shuffleClient.pushMergedData(shuffleId, mapId, attemptId)
|
||||
shuffleClient.mapperEnd(shuffleId, mapId, attemptId, numMappers)
|
||||
shuffleClient.mapperEnd(shuffleId, mapId, attemptId, numMappers, numPartitions)
|
||||
}
|
||||
|
||||
// Launch cpp reader to read data, calculate result and write to specific result file.
|
||||
|
||||
@ -116,7 +116,7 @@ class LocalReadByChunkOffsetsTest extends AnyFunSuite
|
||||
shuffleClient.pushMergedData(1, 0, 0)
|
||||
Thread.sleep(1000)
|
||||
|
||||
shuffleClient.mapperEnd(1, 0, 0, 1)
|
||||
shuffleClient.mapperEnd(1, 0, 0, 1, 0)
|
||||
|
||||
val metricsCallback = new MetricsCallback {
|
||||
override def incBytesRead(bytesWritten: Long): Unit = {}
|
||||
|
||||
@ -153,7 +153,7 @@ class PushMergedDataSplitSuite extends AnyFunSuite
|
||||
|
||||
// push merged data, we expect that partition(0) will be split, while partition(1) will not be split
|
||||
shuffleClient.pushMergedData(SHUFFLE_ID, MAP_ID, ATTEMPT_ID)
|
||||
shuffleClient.mapperEnd(SHUFFLE_ID, MAP_ID, ATTEMPT_ID, MAP_NUM)
|
||||
shuffleClient.mapperEnd(SHUFFLE_ID, MAP_ID, ATTEMPT_ID, MAP_NUM, PARTITION_NUM)
|
||||
assert(
|
||||
partitionLocationMap.get(partitions(1)).getEpoch == 0
|
||||
) // means partition(1) will not be split
|
||||
|
||||
@ -98,7 +98,7 @@ trait ReadWriteTestBase extends AnyFunSuite
|
||||
shuffleClient.pushMergedData(1, 0, 0)
|
||||
Thread.sleep(1000)
|
||||
|
||||
shuffleClient.mapperEnd(1, 0, 0, 1)
|
||||
shuffleClient.mapperEnd(1, 0, 0, 1, 1)
|
||||
|
||||
val metricsCallback = new MetricsCallback {
|
||||
override def incBytesRead(bytesWritten: Long): Unit = {}
|
||||
|
||||
@ -92,7 +92,7 @@ class ReadWriteTestWithFailures extends AnyFunSuite
|
||||
shuffleClient.pushMergedData(1, 0, 0)
|
||||
Thread.sleep(1000)
|
||||
|
||||
shuffleClient.mapperEnd(1, 0, 0, 1)
|
||||
shuffleClient.mapperEnd(1, 0, 0, 1, 1)
|
||||
|
||||
var duplicateBytesRead = new AtomicLong(0)
|
||||
val metricsCallback = new MetricsCallback {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user