[CELEBORN-1921] Broadcast large GetReducerFileGroupResponse to prevent Spark driver network exhausted

### What changes were proposed in this pull request?

For spark celeborn application, if the GetReducerFileGroupResponse is larger than the threshold, Spark driver would broadcast the GetReducerFileGroupResponse to the executors, it prevents the driver from being the bottleneck in sending out multiple copies of the GetReducerFileGroupResponse (one per executor).

### Why are the changes needed?
To prevent the driver from being the bottleneck in sending out multiple copies of the GetReducerFileGroupResponse (one per executor).

### Does this PR introduce _any_ user-facing change?
No, the feature is not enabled by defaults.

### How was this patch tested?

UT.

Cluster testing with `spark.celeborn.client.spark.shuffle.getReducerFileGroup.broadcast.enabled=true`.

The broadcast response size should be always about 1kb.
![image](https://github.com/user-attachments/assets/d5d1b751-762d-43c8-8a84-0674630a5638)
![image](https://github.com/user-attachments/assets/4841a29e-5d11-4932-9fa5-f6e78b7bc521)
Application succeed.
![image](https://github.com/user-attachments/assets/9b570f70-1433-4457-90ae-b8292e5476ba)

Closes #3158 from turboFei/broadcast_rgf.

Authored-by: Wang, Fei <fwang12@ebay.com>
Signed-off-by: Wang, Fei <fwang12@ebay.com>
This commit is contained in:
Wang, Fei 2025-04-01 08:29:21 -07:00
parent 1e30f159b9
commit 5e12b7d607
23 changed files with 690 additions and 21 deletions

View File

@ -223,6 +223,7 @@ Apache Spark
./common/src/test/java/org/apache/celeborn/common/network/protocol/EncryptedMessageWithHeaderSuiteJ.java
./common/src/test/java/org/apache/celeborn/common/network/ssl/ReloadingX509TrustManagerSuiteJ.java
./common/src/test/java/org/apache/celeborn/common/network/ssl/SslSampleConfigs.java
./common/src/main/scala/org/apache/celeborn/common/util/KeyLock.scala
./worker/src/main/java/org/apache/celeborn/service/deploy/worker/shuffledb/DB.java
./worker/src/main/java/org/apache/celeborn/service/deploy/worker/shuffledb/DBIterator.java
./worker/src/main/java/org/apache/celeborn/service/deploy/worker/shuffledb/LevelDB.java

View File

@ -73,6 +73,7 @@
<include>io.netty:*</include>
<include>org.apache.commons:commons-lang3</include>
<include>org.roaringbitmap:RoaringBitmap</include>
<include>commons-io:commons-io</include>
</includes>
</artifactSet>
<filters>

View File

@ -109,6 +109,15 @@ public class SparkShuffleManager implements ShuffleManager {
lifecycleManager.registerShuffleTrackerCallback(
shuffleId -> mapOutputTracker.unregisterAllMapOutput(shuffleId));
}
if (celebornConf.getReducerFileGroupBroadcastEnabled()) {
lifecycleManager.registerBroadcastGetReducerFileGroupResponseCallback(
(shuffleId, getReducerFileGroupResponse) ->
SparkUtils.serializeGetReducerFileGroupResponse(
shuffleId, getReducerFileGroupResponse));
lifecycleManager.registerInvalidatedBroadcastCallback(
shuffleId -> SparkUtils.invalidateSerializedGetReducerFileGroupResponse(shuffleId));
}
}
}
}

View File

@ -17,7 +17,10 @@
package org.apache.spark.shuffle.celeborn;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.HashSet;
@ -25,6 +28,7 @@ import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.LongAdder;
import java.util.stream.Collectors;
@ -37,7 +41,12 @@ import org.apache.spark.BarrierTaskContext;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
import org.apache.spark.SparkContext$;
import org.apache.spark.SparkEnv;
import org.apache.spark.SparkEnv$;
import org.apache.spark.TaskContext;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.io.CompressionCodec;
import org.apache.spark.io.CompressionCodec$;
import org.apache.spark.scheduler.DAGScheduler;
import org.apache.spark.scheduler.MapStatus;
import org.apache.spark.scheduler.MapStatus$;
@ -54,7 +63,10 @@ import org.slf4j.LoggerFactory;
import org.apache.celeborn.client.ShuffleClient;
import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.network.protocol.TransportMessage;
import org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse;
import org.apache.celeborn.common.util.JavaUtils;
import org.apache.celeborn.common.util.KeyLock;
import org.apache.celeborn.common.util.Utils;
import org.apache.celeborn.reflect.DynFields;
@ -346,4 +358,121 @@ public class SparkUtils {
sparkContext.addSparkListener(listener);
}
}
/**
* A [[KeyLock]] whose key is a shuffle id to ensure there is only one thread accessing the
* broadcast belonging to the shuffle id at a time.
*/
private static final KeyLock<Integer> shuffleBroadcastLock = new KeyLock<>();
@VisibleForTesting
public static AtomicInteger getReducerFileGroupResponseBroadcastNum = new AtomicInteger();
@VisibleForTesting
public static Map<Integer, Tuple2<Broadcast<TransportMessage>, byte[]>>
getReducerFileGroupResponseBroadcasts = JavaUtils.newConcurrentHashMap();
public static byte[] serializeGetReducerFileGroupResponse(
Integer shuffleId, GetReducerFileGroupResponse response) {
SparkContext sparkContext = SparkContext$.MODULE$.getActive().getOrElse(null);
if (sparkContext == null) {
logger.error("Can not get active SparkContext.");
return null;
}
return shuffleBroadcastLock.withLock(
shuffleId,
() -> {
Tuple2<Broadcast<TransportMessage>, byte[]> cachedSerializeGetReducerFileGroupResponse =
getReducerFileGroupResponseBroadcasts.get(shuffleId);
if (cachedSerializeGetReducerFileGroupResponse != null) {
return cachedSerializeGetReducerFileGroupResponse._2;
}
try {
logger.info("Broadcasting GetReducerFileGroupResponse for shuffle: {}", shuffleId);
TransportMessage transportMessage =
(TransportMessage) Utils.toTransportMessage(response);
Broadcast<TransportMessage> broadcast =
sparkContext.broadcast(
transportMessage,
scala.reflect.ClassManifestFactory.fromClass(TransportMessage.class));
CompressionCodec codec = CompressionCodec$.MODULE$.createCodec(sparkContext.conf());
// Using `org.apache.commons.io.output.ByteArrayOutputStream` instead of the standard
// one
// This implementation doesn't reallocate the whole memory block but allocates
// additional buffers. This way no buffers need to be garbage collected and
// the contents don't have to be copied to the new buffer.
org.apache.commons.io.output.ByteArrayOutputStream out =
new org.apache.commons.io.output.ByteArrayOutputStream();
try (ObjectOutputStream oos =
new ObjectOutputStream(codec.compressedOutputStream(out))) {
oos.writeObject(broadcast);
}
byte[] _serializeResult = out.toByteArray();
getReducerFileGroupResponseBroadcasts.put(
shuffleId, Tuple2.apply(broadcast, _serializeResult));
getReducerFileGroupResponseBroadcastNum.incrementAndGet();
return _serializeResult;
} catch (Throwable e) {
logger.error(
"Failed to serialize GetReducerFileGroupResponse for shuffle: {}", shuffleId, e);
return null;
}
});
}
public static GetReducerFileGroupResponse deserializeGetReducerFileGroupResponse(
Integer shuffleId, byte[] bytes) {
SparkEnv sparkEnv = SparkEnv$.MODULE$.get();
if (sparkEnv == null) {
logger.error("Can not get SparkEnv.");
return null;
}
return shuffleBroadcastLock.withLock(
shuffleId,
() -> {
GetReducerFileGroupResponse response = null;
logger.info(
"Deserializing GetReducerFileGroupResponse broadcast for shuffle: {}", shuffleId);
try {
CompressionCodec codec = CompressionCodec$.MODULE$.createCodec(sparkEnv.conf());
try (ObjectInputStream objIn =
new ObjectInputStream(
codec.compressedInputStream(new ByteArrayInputStream(bytes)))) {
Broadcast<TransportMessage> broadcast =
(Broadcast<TransportMessage>) objIn.readObject();
response =
(GetReducerFileGroupResponse) Utils.fromTransportMessage(broadcast.value());
}
} catch (Throwable e) {
logger.error(
"Failed to deserialize GetReducerFileGroupResponse for shuffle: " + shuffleId, e);
}
return response;
});
}
public static void invalidateSerializedGetReducerFileGroupResponse(Integer shuffleId) {
shuffleBroadcastLock.withLock(
shuffleId,
() -> {
try {
Tuple2<Broadcast<TransportMessage>, byte[]> cachedSerializeGetReducerFileGroupResponse =
getReducerFileGroupResponseBroadcasts.remove(shuffleId);
if (cachedSerializeGetReducerFileGroupResponse != null) {
cachedSerializeGetReducerFileGroupResponse._1().destroy();
}
} catch (Throwable e) {
logger.error(
"Failed to invalidate serialized GetReducerFileGroupResponse for shuffle: "
+ shuffleId,
e);
}
return null;
});
}
}

View File

@ -20,6 +20,7 @@ package org.apache.spark.shuffle.celeborn
import java.io.IOException
import java.util.concurrent.{ThreadPoolExecutor, TimeUnit}
import java.util.concurrent.atomic.AtomicReference
import java.util.function.BiFunction
import org.apache.spark.{Aggregator, InterruptibleIterator, TaskContext}
import org.apache.spark.internal.Logging
@ -33,6 +34,7 @@ import org.apache.celeborn.client.read.CelebornInputStream
import org.apache.celeborn.client.read.MetricsCallback
import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.exception.{CelebornIOException, PartitionUnRetryAbleException}
import org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse
import org.apache.celeborn.common.util.{JavaUtils, ThreadUtils}
class CelebornShuffleReader[K, C](
@ -254,4 +256,13 @@ class CelebornShuffleReader[K, C](
object CelebornShuffleReader {
var streamCreatorPool: ThreadPoolExecutor = null
// Register the deserializer for GetReducerFileGroupResponse broadcast
ShuffleClient.registerDeserializeReducerFileGroupResponseFunction(new BiFunction[
Integer,
Array[Byte],
GetReducerFileGroupResponse] {
override def apply(shuffleId: Integer, broadcast: Array[Byte]): GetReducerFileGroupResponse = {
SparkUtils.deserializeGetReducerFileGroupResponse(shuffleId, broadcast)
}
})
}

View File

@ -77,6 +77,7 @@
<include>io.netty:*</include>
<include>org.apache.commons:commons-lang3</include>
<include>org.roaringbitmap:RoaringBitmap</include>
<include>commons-io:commons-io</include>
</includes>
</artifactSet>
<filters>

View File

@ -156,6 +156,15 @@ public class SparkShuffleManager implements ShuffleManager {
SparkUtils::isCelebornSkewShuffleOrChildShuffle);
}
}
if (celebornConf.getReducerFileGroupBroadcastEnabled()) {
lifecycleManager.registerBroadcastGetReducerFileGroupResponseCallback(
(shuffleId, getReducerFileGroupResponse) ->
SparkUtils.serializeGetReducerFileGroupResponse(
shuffleId, getReducerFileGroupResponse));
lifecycleManager.registerInvalidatedBroadcastCallback(
shuffleId -> SparkUtils.invalidateSerializedGetReducerFileGroupResponse(shuffleId));
}
}
}
}

View File

@ -17,11 +17,15 @@
package org.apache.spark.shuffle.celeborn;
import java.io.ByteArrayInputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.LongAdder;
import java.util.stream.Collectors;
@ -35,7 +39,12 @@ import org.apache.spark.MapOutputTrackerMaster;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
import org.apache.spark.SparkContext$;
import org.apache.spark.SparkEnv;
import org.apache.spark.SparkEnv$;
import org.apache.spark.TaskContext;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.io.CompressionCodec;
import org.apache.spark.io.CompressionCodec$;
import org.apache.spark.scheduler.DAGScheduler;
import org.apache.spark.scheduler.MapStatus;
import org.apache.spark.scheduler.MapStatus$;
@ -57,7 +66,11 @@ import org.slf4j.LoggerFactory;
import org.apache.celeborn.client.ShuffleClient;
import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.network.protocol.TransportMessage;
import org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse;
import org.apache.celeborn.common.util.JavaUtils;
import org.apache.celeborn.common.util.KeyLock;
import org.apache.celeborn.common.util.Utils;
import org.apache.celeborn.reflect.DynConstructors;
import org.apache.celeborn.reflect.DynFields;
import org.apache.celeborn.reflect.DynMethods;
@ -476,4 +489,121 @@ public class SparkUtils {
return false;
}
}
/**
* A [[KeyLock]] whose key is a shuffle id to ensure there is only one thread accessing the
* broadcast belonging to the shuffle id at a time.
*/
private static final KeyLock<Integer> shuffleBroadcastLock = new KeyLock<>();
@VisibleForTesting
public static AtomicInteger getReducerFileGroupResponseBroadcastNum = new AtomicInteger();
@VisibleForTesting
public static Map<Integer, Tuple2<Broadcast<TransportMessage>, byte[]>>
getReducerFileGroupResponseBroadcasts = JavaUtils.newConcurrentHashMap();
public static byte[] serializeGetReducerFileGroupResponse(
Integer shuffleId, GetReducerFileGroupResponse response) {
SparkContext sparkContext = SparkContext$.MODULE$.getActive().getOrElse(null);
if (sparkContext == null) {
LOG.error("Can not get active SparkContext.");
return null;
}
return shuffleBroadcastLock.withLock(
shuffleId,
() -> {
Tuple2<Broadcast<TransportMessage>, byte[]> cachedSerializeGetReducerFileGroupResponse =
getReducerFileGroupResponseBroadcasts.get(shuffleId);
if (cachedSerializeGetReducerFileGroupResponse != null) {
return cachedSerializeGetReducerFileGroupResponse._2;
}
try {
LOG.info("Broadcasting GetReducerFileGroupResponse for shuffle: {}", shuffleId);
TransportMessage transportMessage =
(TransportMessage) Utils.toTransportMessage(response);
Broadcast<TransportMessage> broadcast =
sparkContext.broadcast(
transportMessage,
scala.reflect.ClassManifestFactory.fromClass(TransportMessage.class));
CompressionCodec codec = CompressionCodec$.MODULE$.createCodec(sparkContext.conf());
// Using `org.apache.commons.io.output.ByteArrayOutputStream` instead of the standard
// one
// This implementation doesn't reallocate the whole memory block but allocates
// additional buffers. This way no buffers need to be garbage collected and
// the contents don't have to be copied to the new buffer.
org.apache.commons.io.output.ByteArrayOutputStream out =
new org.apache.commons.io.output.ByteArrayOutputStream();
try (ObjectOutputStream oos =
new ObjectOutputStream(codec.compressedOutputStream(out))) {
oos.writeObject(broadcast);
}
byte[] _serializeResult = out.toByteArray();
getReducerFileGroupResponseBroadcasts.put(
shuffleId, Tuple2.apply(broadcast, _serializeResult));
getReducerFileGroupResponseBroadcastNum.incrementAndGet();
return _serializeResult;
} catch (Throwable e) {
LOG.error(
"Failed to serialize GetReducerFileGroupResponse for shuffle: {}", shuffleId, e);
return null;
}
});
}
public static GetReducerFileGroupResponse deserializeGetReducerFileGroupResponse(
Integer shuffleId, byte[] bytes) {
SparkEnv sparkEnv = SparkEnv$.MODULE$.get();
if (sparkEnv == null) {
LOG.error("Can not get SparkEnv.");
return null;
}
return shuffleBroadcastLock.withLock(
shuffleId,
() -> {
GetReducerFileGroupResponse response = null;
LOG.info(
"Deserializing GetReducerFileGroupResponse broadcast for shuffle: {}", shuffleId);
try {
CompressionCodec codec = CompressionCodec$.MODULE$.createCodec(sparkEnv.conf());
try (ObjectInputStream objIn =
new ObjectInputStream(
codec.compressedInputStream(new ByteArrayInputStream(bytes)))) {
Broadcast<TransportMessage> broadcast =
(Broadcast<TransportMessage>) objIn.readObject();
response =
(GetReducerFileGroupResponse) Utils.fromTransportMessage(broadcast.value());
}
} catch (Throwable e) {
LOG.error(
"Failed to deserialize GetReducerFileGroupResponse for shuffle: " + shuffleId, e);
}
return response;
});
}
public static void invalidateSerializedGetReducerFileGroupResponse(Integer shuffleId) {
shuffleBroadcastLock.withLock(
shuffleId,
() -> {
try {
Tuple2<Broadcast<TransportMessage>, byte[]> cachedSerializeGetReducerFileGroupResponse =
getReducerFileGroupResponseBroadcasts.remove(shuffleId);
if (cachedSerializeGetReducerFileGroupResponse != null) {
cachedSerializeGetReducerFileGroupResponse._1().destroy();
}
} catch (Throwable e) {
LOG.error(
"Failed to invalidate serialized GetReducerFileGroupResponse for shuffle: "
+ shuffleId,
e);
}
return null;
});
}
}

View File

@ -21,6 +21,7 @@ import java.io.IOException
import java.util.{ArrayList => JArrayList, HashMap => JHashMap, Map => JMap, Set => JSet}
import java.util.concurrent.{ConcurrentHashMap, ThreadPoolExecutor, TimeoutException, TimeUnit}
import java.util.concurrent.atomic.AtomicReference
import java.util.function.BiFunction
import scala.collection.JavaConverters._
@ -43,7 +44,8 @@ import org.apache.celeborn.common.exception.{CelebornIOException, PartitionUnRet
import org.apache.celeborn.common.network.client.TransportClient
import org.apache.celeborn.common.network.protocol.TransportMessage
import org.apache.celeborn.common.protocol._
import org.apache.celeborn.common.protocol.message.StatusCode
import org.apache.celeborn.common.protocol.message.{ControlMessages, StatusCode}
import org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse
import org.apache.celeborn.common.util.{JavaUtils, ThreadUtils, Utils}
class CelebornShuffleReader[K, C](
@ -465,4 +467,13 @@ class CelebornShuffleReader[K, C](
object CelebornShuffleReader {
var streamCreatorPool: ThreadPoolExecutor = null
// Register the deserializer for GetReducerFileGroupResponse broadcast
ShuffleClient.registerDeserializeReducerFileGroupResponseFunction(new BiFunction[
Integer,
Array[Byte],
GetReducerFileGroupResponse] {
override def apply(shuffleId: Integer, broadcast: Array[Byte]): GetReducerFileGroupResponse = {
SparkUtils.deserializeGetReducerFileGroupResponse(shuffleId, broadcast)
}
})
}

View File

@ -20,9 +20,11 @@ package org.apache.celeborn.client;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.LongAdder;
import java.util.function.BiFunction;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.hadoop.fs.FileSystem;
@ -38,6 +40,7 @@ import org.apache.celeborn.common.network.client.TransportClientFactory;
import org.apache.celeborn.common.protocol.PartitionLocation;
import org.apache.celeborn.common.protocol.PbStreamHandler;
import org.apache.celeborn.common.protocol.StorageInfo;
import org.apache.celeborn.common.protocol.message.ControlMessages;
import org.apache.celeborn.common.rpc.RpcEndpointRef;
import org.apache.celeborn.common.util.CelebornHadoopUtils;
import org.apache.celeborn.common.util.ExceptionMaker;
@ -56,6 +59,10 @@ public abstract class ShuffleClient {
private static LongAdder totalReadCounter = new LongAdder();
private static LongAdder localShuffleReadCounter = new LongAdder();
private static volatile Optional<
BiFunction<Integer, byte[], ControlMessages.GetReducerFileGroupResponse>>
deserializeReducerFileGroupResponseFunction = Optional.empty();
// for testing
public static void reset() {
_instance = null;
@ -297,4 +304,21 @@ public abstract class ShuffleClient {
public abstract TransportClientFactory getDataClientFactory();
public abstract void excludeFailedFetchLocation(String hostAndFetchPort, Exception e);
public static void registerDeserializeReducerFileGroupResponseFunction(
BiFunction<Integer, byte[], ControlMessages.GetReducerFileGroupResponse> function) {
if (!deserializeReducerFileGroupResponseFunction.isPresent()) {
deserializeReducerFileGroupResponseFunction = Optional.ofNullable(function);
}
}
public static ControlMessages.GetReducerFileGroupResponse deserializeReducerFileGroupResponse(
int shuffleId, byte[] bytes) {
if (!deserializeReducerFileGroupResponseFunction.isPresent()) {
// Should never happen
logger.warn("DeserializeReducerFileGroupResponseFunction is not registered.");
return null;
}
return deserializeReducerFileGroupResponseFunction.get().apply(shuffleId, bytes);
}
}

View File

@ -1814,6 +1814,14 @@ public class ShuffleClientImpl extends ShuffleClient {
ClassTag$.MODULE$.apply(GetReducerFileGroupResponse.class));
switch (response.status()) {
case SUCCESS:
if (response.broadcast() != null && response.broadcast().length > 0) {
response =
ShuffleClient.deserializeReducerFileGroupResponse(shuffleId, response.broadcast());
if (response == null) {
throw new CelebornIOException(
"Failed to get GetReducerFileGroupResponse broadcast for shuffle: " + shuffleId);
}
}
logger.info(
"Shuffle {} request reducer file group success using {} ms, result partition size {}.",
shuffleId,

View File

@ -292,7 +292,8 @@ class CommitManager(appUniqueId: String, val conf: CelebornConf, lifecycleManage
lifecycleManager.shuffleAllocatedWorkers,
committedPartitionInfo,
lifecycleManager.workerStatusTracker,
lifecycleManager.rpcSharedThreadPool)
lifecycleManager.rpcSharedThreadPool,
lifecycleManager)
case PartitionType.MAP => new MapPartitionCommitHandler(
appUniqueId,
conf,

View File

@ -1677,6 +1677,7 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
} else {
batchRemoveShuffleIds += shuffleId
}
invalidatedBroadcastGetReducerFileGroupResponse(shuffleId)
}
}
if (batchRemoveShuffleIds.nonEmpty) {
@ -1848,6 +1849,22 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
cancelShuffleCallback = Some(callback)
}
@volatile private var broadcastGetReducerFileGroupResponseCallback
: Option[java.util.function.BiFunction[Integer, GetReducerFileGroupResponse, Array[Byte]]] =
None
def registerBroadcastGetReducerFileGroupResponseCallback(call: java.util.function.BiFunction[
Integer,
GetReducerFileGroupResponse,
Array[Byte]]): Unit = {
broadcastGetReducerFileGroupResponseCallback = Some(call)
}
@volatile private var invalidatedBroadcastCallback: Option[Consumer[Integer]] =
None
def registerInvalidatedBroadcastCallback(call: Consumer[Integer]): Unit = {
invalidatedBroadcastCallback = Some(call)
}
def invalidateLatestMaxLocsCache(shuffleId: Int): Unit = {
registerShuffleResponseRpcCache.invalidate(shuffleId)
}
@ -1889,4 +1906,19 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
case _ =>
}
def broadcastGetReducerFileGroupResponse(
shuffleId: Int,
response: GetReducerFileGroupResponse): Option[Array[Byte]] = {
broadcastGetReducerFileGroupResponseCallback match {
case Some(c) => Option(c.apply(shuffleId, response))
case _ => None
}
}
private def invalidatedBroadcastGetReducerFileGroupResponse(shuffleId: Int): Unit = {
invalidatedBroadcastCallback match {
case Some(c) => c.accept(shuffleId)
case _ =>
}
}
}

View File

@ -28,7 +28,7 @@ import scala.collection.mutable
import com.google.common.cache.{Cache, CacheBuilder}
import com.google.common.collect.Sets
import org.apache.celeborn.client.{ClientUtils, ShuffleCommittedInfo, WorkerStatusTracker}
import org.apache.celeborn.client.{ClientUtils, LifecycleManager, ShuffleCommittedInfo, WorkerStatusTracker}
import org.apache.celeborn.client.CommitManager.CommittedPartitionInfo
import org.apache.celeborn.client.LifecycleManager.{ShuffleAllocatedWorkers, ShuffleFailedWorkers}
import org.apache.celeborn.common.CelebornConf
@ -55,7 +55,8 @@ class ReducePartitionCommitHandler(
shuffleAllocatedWorkers: ShuffleAllocatedWorkers,
committedPartitionInfo: CommittedPartitionInfo,
workerStatusTracker: WorkerStatusTracker,
sharedRpcPool: ThreadPoolExecutor)
sharedRpcPool: ThreadPoolExecutor,
lifecycleManager: LifecycleManager)
extends CommitHandler(
appUniqueId,
conf,
@ -78,6 +79,10 @@ class ReducePartitionCommitHandler(
private val rpcCacheConcurrencyLevel = conf.clientRpcCacheConcurrencyLevel
private val rpcCacheExpireTime = conf.clientRpcCacheExpireTime
private val getReducerFileGroupResponseBroadcastEnabled = conf.getReducerFileGroupBroadcastEnabled
private val getReducerFileGroupResponseBroadcastMiniSize =
conf.getReducerFileGroupBroadcastMiniSize
// noinspection UnstableApiUsage
private val getReducerFileGroupRpcCache: Cache[Int, ByteBuffer] = CacheBuilder.newBuilder()
.concurrencyLevel(rpcCacheConcurrencyLevel)
@ -320,10 +325,17 @@ class ReducePartitionCommitHandler(
} else {
// LocalNettyRpcCallContext is for the UTs
if (context.isInstanceOf[LocalNettyRpcCallContext]) {
context.reply(GetReducerFileGroupResponse(
var response = GetReducerFileGroupResponse(
StatusCode.SUCCESS,
reducerFileGroupsMap.getOrDefault(shuffleId, JavaUtils.newConcurrentHashMap()),
getMapperAttempts(shuffleId)))
getMapperAttempts(shuffleId))
// only check whether broadcast enabled for the UTs
if (getReducerFileGroupResponseBroadcastEnabled) {
response = broadcastGetReducerFileGroup(shuffleId, response)
}
context.reply(response)
} else {
val cachedMsg = getReducerFileGroupRpcCache.get(
shuffleId,
@ -337,7 +349,27 @@ class ReducePartitionCommitHandler(
shufflePushFailedBatches.getOrDefault(
shuffleId,
new util.HashMap[String, util.Set[PushFailedBatch]]()))
context.asInstanceOf[RemoteNettyRpcCallContext].nettyEnv.serialize(returnedMsg)
val serializedMsg =
context.asInstanceOf[RemoteNettyRpcCallContext].nettyEnv.serialize(returnedMsg)
if (getReducerFileGroupResponseBroadcastEnabled &&
serializedMsg.capacity() >= getReducerFileGroupResponseBroadcastMiniSize) {
val broadcastMsg = broadcastGetReducerFileGroup(shuffleId, returnedMsg)
if (broadcastMsg != returnedMsg) {
val serializedBroadcastMsg =
context.asInstanceOf[RemoteNettyRpcCallContext].nettyEnv.serialize(broadcastMsg)
logInfo(s"Shuffle $shuffleId GetReducerFileGroupResponse size" +
s" ${serializedMsg.capacity()} reached the broadcast threshold" +
s" $getReducerFileGroupResponseBroadcastMiniSize," +
s" the broadcast response size is ${serializedBroadcastMsg.capacity()}.")
serializedBroadcastMsg
} else {
serializedMsg
}
} else {
serializedMsg
}
}
})
context.asInstanceOf[RemoteNettyRpcCallContext].callback.onSuccess(cachedMsg)
@ -345,6 +377,16 @@ class ReducePartitionCommitHandler(
}
}
private def broadcastGetReducerFileGroup(
shuffleId: Int,
response: GetReducerFileGroupResponse): GetReducerFileGroupResponse = {
lifecycleManager.broadcastGetReducerFileGroupResponse(shuffleId, response) match {
case Some(broadcastBytes) if broadcastBytes.nonEmpty =>
GetReducerFileGroupResponse(response.status, broadcast = broadcastBytes)
case _ => response
}
}
override def handleGetReducerFileGroup(context: RpcCallContext, shuffleId: Int): Unit = {
// Quick return for ended stage, avoid occupy sync lock.
if (isStageEnd(shuffleId)) {

View File

@ -427,7 +427,8 @@ public class ShuffleClientSuiteJ {
locations,
new int[0],
Collections.emptySet(),
Collections.emptyMap());
Collections.emptyMap(),
new byte[0]);
});
when(endpointRef.askSync(any(), any(), any(Integer.class), any(Long.class), any()))
@ -439,7 +440,8 @@ public class ShuffleClientSuiteJ {
locations,
new int[0],
Collections.emptySet(),
Collections.emptyMap());
Collections.emptyMap(),
new byte[0]);
});
shuffleClient =
@ -482,7 +484,8 @@ public class ShuffleClientSuiteJ {
locations,
new int[0],
Collections.emptySet(),
Collections.emptyMap());
Collections.emptyMap(),
new byte[0]);
});
when(endpointRef.askSync(any(), any(), any(Integer.class), any(Long.class), any()))
@ -493,7 +496,8 @@ public class ShuffleClientSuiteJ {
locations,
new int[0],
Collections.emptySet(),
Collections.emptyMap());
Collections.emptyMap(),
new byte[0]);
});
shuffleClient =
@ -514,7 +518,8 @@ public class ShuffleClientSuiteJ {
locations,
new int[0],
Collections.emptySet(),
Collections.emptyMap());
Collections.emptyMap(),
new byte[0]);
});
when(endpointRef.askSync(any(), any(), any(Integer.class), any(Long.class), any()))
@ -525,7 +530,8 @@ public class ShuffleClientSuiteJ {
locations,
new int[0],
Collections.emptySet(),
Collections.emptyMap());
Collections.emptyMap(),
new byte[0]);
});
shuffleClient =
@ -546,7 +552,8 @@ public class ShuffleClientSuiteJ {
locations,
new int[0],
Collections.emptySet(),
Collections.emptyMap());
Collections.emptyMap(),
new byte[0]);
});
when(endpointRef.askSync(any(), any(), any(Integer.class), any(Long.class), any()))
@ -557,7 +564,8 @@ public class ShuffleClientSuiteJ {
locations,
new int[0],
Collections.emptySet(),
Collections.emptyMap());
Collections.emptyMap(),
new byte[0]);
});
shuffleClient =

View File

@ -390,6 +390,8 @@ message PbGetReducerFileGroupResponse {
repeated int32 partitionIds = 4;
map<string, PbPushFailedBatchSet> pushFailedBatches = 5;
bytes broadcast = 6;
}
message PbGetShuffleId {

View File

@ -1057,6 +1057,10 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se
get(CLIENT_PUSH_DYNAMIC_WRITE_MODE_ENABLED)
def dynamicWriteModePartitionNumThreshold =
get(CLIENT_PUSH_DYNAMIC_WRITE_MODE_PARTITION_NUM_THRESHOLD)
def getReducerFileGroupBroadcastEnabled =
get(CLIENT_SHUFFLE_GET_REDUCER_FILE_GROUP_BROADCAST_ENABLED)
def getReducerFileGroupBroadcastMiniSize =
get(CLIENT_SHUFFLE_GET_REDUCER_FILE_GROUP_BROADCAST_MINI_SIZE)
def shufflePartitionType: PartitionType = PartitionType.valueOf(get(SHUFFLE_PARTITION_TYPE))
def shuffleRangeReadFilterEnabled: Boolean = get(SHUFFLE_RANGE_READ_FILTER_ENABLED)
def shuffleForceFallbackEnabled: Boolean = get(SPARK_SHUFFLE_FORCE_FALLBACK_ENABLED)
@ -5213,6 +5217,27 @@ object CelebornConf extends Logging {
.intConf
.createWithDefault(2000)
val CLIENT_SHUFFLE_GET_REDUCER_FILE_GROUP_BROADCAST_ENABLED =
buildConf("celeborn.client.spark.shuffle.getReducerFileGroup.broadcast.enabled")
.categories("client")
.doc(
"Whether to leverage Spark broadcast mechanism to send the GetReducerFileGroupResponse. " +
"If the response size is large and Spark executor number is large, the Spark driver network " +
"may be exhausted because each executor will pull the response from the driver. With broadcasting " +
"GetReducerFileGroupResponse, it prevents the driver from being the bottleneck in sending out multiple " +
"copies of the GetReducerFileGroupResponse (one per executor).")
.version("0.6.0")
.booleanConf
.createWithDefault(false)
val CLIENT_SHUFFLE_GET_REDUCER_FILE_GROUP_BROADCAST_MINI_SIZE =
buildConf("celeborn.client.spark.shuffle.getReducerFileGroup.broadcast.miniSize")
.categories("client")
.doc("The size at which we use Broadcast to send the GetReducerFileGroupResponse to the executors.")
.version("0.6.0")
.bytesConf(ByteUnit.BYTE)
.createWithDefaultString("512k")
val SPARK_SHUFFLE_WRITER_MODE: ConfigEntry[String] =
buildConf("celeborn.client.spark.shuffle.writer")
.withAlternative("celeborn.shuffle.writer")

View File

@ -22,6 +22,7 @@ import java.util.{Collections, UUID}
import scala.collection.JavaConverters._
import com.google.protobuf.ByteString
import org.roaringbitmap.RoaringBitmap
import org.apache.celeborn.common.identity.UserIdentifier
@ -285,10 +286,11 @@ object ControlMessages extends Logging {
// Path can't be serialized
case class GetReducerFileGroupResponse(
status: StatusCode,
fileGroup: util.Map[Integer, util.Set[PartitionLocation]],
attempts: Array[Int],
fileGroup: util.Map[Integer, util.Set[PartitionLocation]] = Collections.emptyMap(),
attempts: Array[Int] = Array.emptyIntArray,
partitionIds: util.Set[Integer] = Collections.emptySet[Integer](),
pushFailedBatches: util.Map[String, util.Set[PushFailedBatch]] = Collections.emptyMap())
pushFailedBatches: util.Map[String, util.Set[PushFailedBatch]] = Collections.emptyMap(),
broadcast: Array[Byte] = Array.emptyByteArray)
extends MasterMessage
object WorkerExclude {
@ -752,7 +754,13 @@ object ControlMessages extends Logging {
.build().toByteArray
new TransportMessage(MessageType.GET_REDUCER_FILE_GROUP, payload)
case GetReducerFileGroupResponse(status, fileGroup, attempts, partitionIds, failedBatches) =>
case GetReducerFileGroupResponse(
status,
fileGroup,
attempts,
partitionIds,
failedBatches,
broadcast) =>
val builder = PbGetReducerFileGroupResponse
.newBuilder()
.setStatus(status.getValue)
@ -770,6 +778,7 @@ object ControlMessages extends Logging {
case (uniqueId, pushFailedBatchSet) =>
(uniqueId, PbSerDeUtils.toPbPushFailedBatchSet(pushFailedBatchSet))
}.asJava)
builder.setBroadcast(ByteString.copyFrom(broadcast))
val payload = builder.build().toByteArray
new TransportMessage(MessageType.GET_REDUCER_FILE_GROUP_RESPONSE, payload)
@ -1198,12 +1207,14 @@ object ControlMessages extends Logging {
case (uniqueId, pushFailedBatchSet) =>
(uniqueId, PbSerDeUtils.fromPbPushFailedBatchSet(pushFailedBatchSet))
}.toMap.asJava
val broadcast = pbGetReducerFileGroupResponse.getBroadcast.toByteArray
GetReducerFileGroupResponse(
Utils.toStatusCode(pbGetReducerFileGroupResponse.getStatus),
fileGroup,
attempts,
partitionIds,
pushFailedBatches)
pushFailedBatches,
broadcast)
case GET_SHUFFLE_ID_VALUE =>
message.getParsedPayload()

View File

@ -0,0 +1,70 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.celeborn.common.util
import java.util.concurrent.{Callable, ConcurrentHashMap}
/**
* This class is copied from Apache Spark.
* A special locking mechanism to provide locking with a given key. By providing the same key
* (identity is tested using the `equals` method), we ensure there is only one `func` running at
* the same time.
*
* @tparam K the type of key to identify a lock. This type must implement `equals` and `hashCode`
* correctly as it will be the key type of an internal Map.
*/
class KeyLock[K] {
private val lockMap = new ConcurrentHashMap[K, AnyRef]()
private def acquireLock(key: K): Unit = {
while (true) {
val lock = lockMap.putIfAbsent(key, new Object)
if (lock == null) return
lock.synchronized {
while (lockMap.get(key) eq lock) {
lock.wait()
}
}
}
}
private def releaseLock(key: K): Unit = {
val lock = lockMap.remove(key)
lock.synchronized {
lock.notifyAll()
}
}
/**
* Run `func` under a lock identified by the given key. Multiple calls with the same key
* (identity is tested using the `equals` method) will be locked properly to ensure there is only
* one `func` running at the same time.
*/
def withLock[T](key: K)(func: Callable[T]): T = {
if (key == null) {
throw new NullPointerException("key must not be null")
}
acquireLock(key)
try {
func.call()
} finally {
releaseLock(key)
}
}
}

View File

@ -122,6 +122,8 @@ license: |
| celeborn.client.spark.shuffle.fallback.numPartitionsThreshold | 2147483647 | false | Celeborn will only accept shuffle of partition number lower than this configuration value. This configuration only takes effect when `celeborn.client.spark.shuffle.fallback.policy` is `AUTO`. | 0.5.0 | celeborn.shuffle.forceFallback.numPartitionsThreshold,celeborn.client.spark.shuffle.forceFallback.numPartitionsThreshold |
| celeborn.client.spark.shuffle.fallback.policy | AUTO | false | Celeborn supports the following kind of fallback policies. 1. ALWAYS: always use spark built-in shuffle implementation; 2. AUTO: prefer to use celeborn shuffle implementation, and fallback to use spark built-in shuffle implementation based on certain factors, e.g. availability of enough workers and quota, shuffle partition number; 3. NEVER: always use celeborn shuffle implementation, and fail fast when it it is concluded that fallback is required based on factors above. | 0.5.0 | |
| celeborn.client.spark.shuffle.forceFallback.enabled | false | false | Always use spark built-in shuffle implementation. This configuration is deprecated, consider configuring `celeborn.client.spark.shuffle.fallback.policy` instead. | 0.3.0 | celeborn.shuffle.forceFallback.enabled |
| celeborn.client.spark.shuffle.getReducerFileGroup.broadcast.enabled | false | false | Whether to leverage Spark broadcast mechanism to send the GetReducerFileGroupResponse. If the response size is large and Spark executor number is large, the Spark driver network may be exhausted because each executor will pull the response from the driver. With broadcasting GetReducerFileGroupResponse, it prevents the driver from being the bottleneck in sending out multiple copies of the GetReducerFileGroupResponse (one per executor). | 0.6.0 | |
| celeborn.client.spark.shuffle.getReducerFileGroup.broadcast.miniSize | 512k | false | The size at which we use Broadcast to send the GetReducerFileGroupResponse to the executors. | 0.6.0 | |
| celeborn.client.spark.shuffle.writer | HASH | false | Celeborn supports the following kind of shuffle writers. 1. hash: hash-based shuffle writer works fine when shuffle partition count is normal; 2. sort: sort-based shuffle writer works fine when memory pressure is high or shuffle partition count is huge. This configuration only takes effect when celeborn.client.spark.push.dynamicWriteMode.enabled is false. | 0.3.0 | celeborn.shuffle.writer |
| celeborn.client.spark.stageRerun.enabled | true | false | Whether to enable stage rerun. If true, client throws FetchFailedException instead of CelebornIOException. | 0.4.0 | celeborn.client.spark.fetch.throwsFetchFailure |
| celeborn.identity.provider | org.apache.celeborn.common.identity.DefaultIdentityProvider | false | IdentityProvider class name. Default class is `org.apache.celeborn.common.identity.DefaultIdentityProvider`. Optional values: org.apache.celeborn.common.identity.HadoopBasedIdentityProvider user name will be obtained by UserGroupInformation.getUserName; org.apache.celeborn.common.identity.DefaultIdentityProvider user name and tenant id are default values or user-specific values. | 0.6.0 | celeborn.quota.identity.provider |

View File

@ -18,11 +18,13 @@
package org.apache.celeborn.tests.spark
import org.apache.spark.SparkConf
import org.apache.spark.shuffle.celeborn.SparkUtils
import org.apache.spark.sql.SparkSession
import org.scalatest.BeforeAndAfterEach
import org.scalatest.funsuite.AnyFunSuite
import org.apache.celeborn.client.ShuffleClient
import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.protocol.ShuffleMode
class CelebornHashSuite extends AnyFunSuite
@ -64,4 +66,43 @@ class CelebornHashSuite extends AnyFunSuite
celebornSparkSession.stop()
}
test("celeborn spark integration test - GetReducerFileGroupResponse broadcast") {
SparkUtils.getReducerFileGroupResponseBroadcasts.clear()
SparkUtils.getReducerFileGroupResponseBroadcastNum.set(0)
val sparkConf = new SparkConf().setAppName("celeborn-demo").setMaster("local[2]")
.set(
s"spark.${CelebornConf.CLIENT_SHUFFLE_GET_REDUCER_FILE_GROUP_BROADCAST_ENABLED.key}",
"true")
.set(
s"spark.${CelebornConf.CLIENT_SHUFFLE_GET_REDUCER_FILE_GROUP_BROADCAST_MINI_SIZE.key}",
"0")
val sparkSession = SparkSession.builder().config(sparkConf).getOrCreate()
val combineResult = combine(sparkSession)
val groupbyResult = groupBy(sparkSession)
val repartitionResult = repartition(sparkSession)
val sqlResult = runsql(sparkSession)
Thread.sleep(3000L)
sparkSession.stop()
val celebornSparkSession = SparkSession.builder()
.config(updateSparkConf(sparkConf, ShuffleMode.HASH))
.getOrCreate()
val celebornCombineResult = combine(celebornSparkSession)
val celebornGroupbyResult = groupBy(celebornSparkSession)
val celebornRepartitionResult = repartition(celebornSparkSession)
val celebornSqlResult = runsql(celebornSparkSession)
assert(combineResult.equals(celebornCombineResult))
assert(groupbyResult.equals(celebornGroupbyResult))
assert(repartitionResult.equals(celebornRepartitionResult))
assert(combineResult.equals(celebornCombineResult))
assert(sqlResult.equals(celebornSqlResult))
assert(SparkUtils.getReducerFileGroupResponseBroadcastNum.get() > 0)
celebornSparkSession.stop()
SparkUtils.getReducerFileGroupResponseBroadcasts.clear()
SparkUtils.getReducerFileGroupResponseBroadcastNum.set(0)
}
}

View File

@ -18,6 +18,7 @@
package org.apache.celeborn.tests.spark
import org.apache.spark.SparkConf
import org.apache.spark.shuffle.celeborn.SparkUtils
import org.apache.spark.sql.SparkSession
import org.scalatest.BeforeAndAfterEach
import org.scalatest.funsuite.AnyFunSuite
@ -66,4 +67,45 @@ class CelebornSortSuite extends AnyFunSuite
celebornSparkSession.stop()
}
test("celeborn spark integration test - GetReducerFileGroupResponse broadcast") {
SparkUtils.getReducerFileGroupResponseBroadcasts.clear()
SparkUtils.getReducerFileGroupResponseBroadcastNum.set(0)
val sparkConf = new SparkConf().setAppName("celeborn-demo").setMaster("local[2]")
.set(s"spark.${CelebornConf.CLIENT_PUSH_SORT_RANDOMIZE_PARTITION_ENABLED.key}", "false")
.set(
s"spark.${CelebornConf.CLIENT_SHUFFLE_GET_REDUCER_FILE_GROUP_BROADCAST_ENABLED.key}",
"true")
.set(
s"spark.${CelebornConf.CLIENT_SHUFFLE_GET_REDUCER_FILE_GROUP_BROADCAST_MINI_SIZE.key}",
"0")
val sparkSession = SparkSession.builder().config(sparkConf).getOrCreate()
val combineResult = combine(sparkSession)
val groupbyResult = groupBy(sparkSession)
val repartitionResult = repartition(sparkSession)
val sqlResult = runsql(sparkSession)
Thread.sleep(3000L)
sparkSession.stop()
val celebornSparkSession = SparkSession.builder()
.config(updateSparkConf(sparkConf, ShuffleMode.SORT))
.getOrCreate()
val celebornCombineResult = combine(celebornSparkSession)
val celebornGroupbyResult = groupBy(celebornSparkSession)
val celebornRepartitionResult = repartition(celebornSparkSession)
val celebornSqlResult = runsql(celebornSparkSession)
assert(combineResult.equals(celebornCombineResult))
assert(groupbyResult.equals(celebornGroupbyResult))
assert(repartitionResult.equals(celebornRepartitionResult))
assert(combineResult.equals(celebornCombineResult))
assert(sqlResult.equals(celebornSqlResult))
assert(SparkUtils.getReducerFileGroupResponseBroadcastNum.get() > 0)
celebornSparkSession.stop()
SparkUtils.getReducerFileGroupResponseBroadcasts.clear()
SparkUtils.getReducerFileGroupResponseBroadcastNum.set(0)
}
}

View File

@ -29,7 +29,9 @@ import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.time.SpanSugar.convertIntToGrainOfTime
import org.apache.celeborn.client.ShuffleClient
import org.apache.celeborn.common.protocol.ShuffleMode
import org.apache.celeborn.common.protocol.{PartitionLocation, ShuffleMode}
import org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse
import org.apache.celeborn.common.protocol.message.StatusCode
import org.apache.celeborn.tests.spark.SparkTestBase
class SparkUtilsSuite extends AnyFunSuite
@ -157,4 +159,61 @@ class SparkUtilsSuite extends AnyFunSuite
sparkSession.stop()
}
}
test("serialize/deserialize GetReducerFileGroupResponse with broadcast") {
val sparkConf = new SparkConf().setAppName("rss-demo").setMaster("local[2,3]")
val sparkSession = SparkSession.builder()
.config(updateSparkConf(sparkConf, ShuffleMode.HASH))
.config("spark.sql.shuffle.partitions", 2)
.config("spark.celeborn.shuffle.forceFallback.partition.enabled", false)
.config("spark.celeborn.client.spark.fetch.throwsFetchFailure", "true")
.config(
"spark.shuffle.manager",
"org.apache.spark.shuffle.celeborn.TestCelebornShuffleManager")
.getOrCreate()
try {
val shuffleId = Integer.MAX_VALUE
val getReducerFileGroupResponse = GetReducerFileGroupResponse(
StatusCode.SUCCESS,
Map(Integer.valueOf(shuffleId) -> Set(new PartitionLocation(
0,
1,
"localhost",
1,
2,
3,
4,
PartitionLocation.Mode.REPLICA)).asJava).asJava,
Array(1),
Set(Integer.valueOf(shuffleId)).asJava)
val serializedBytes =
SparkUtils.serializeGetReducerFileGroupResponse(shuffleId, getReducerFileGroupResponse)
assert(serializedBytes != null && serializedBytes.length > 0)
val broadcast = SparkUtils.getReducerFileGroupResponseBroadcasts.values().asScala.head._1
assert(broadcast.isValid)
val deserializedGetReducerFileGroupResponse =
SparkUtils.deserializeGetReducerFileGroupResponse(shuffleId, serializedBytes)
assert(deserializedGetReducerFileGroupResponse.status == getReducerFileGroupResponse.status)
assert(
deserializedGetReducerFileGroupResponse.fileGroup == getReducerFileGroupResponse.fileGroup)
assert(java.util.Arrays.equals(
deserializedGetReducerFileGroupResponse.attempts,
getReducerFileGroupResponse.attempts))
assert(deserializedGetReducerFileGroupResponse.partitionIds == getReducerFileGroupResponse.partitionIds)
assert(
deserializedGetReducerFileGroupResponse.pushFailedBatches == getReducerFileGroupResponse.pushFailedBatches)
assert(!SparkUtils.getReducerFileGroupResponseBroadcasts.isEmpty)
SparkUtils.invalidateSerializedGetReducerFileGroupResponse(shuffleId)
assert(SparkUtils.getReducerFileGroupResponseBroadcasts.isEmpty)
assert(!broadcast.isValid)
} finally {
sparkSession.stop()
SparkUtils.getReducerFileGroupResponseBroadcasts.clear()
SparkUtils.getReducerFileGroupResponseBroadcastNum.set(0)
}
}
}