[CELEBORN-1921] Broadcast large GetReducerFileGroupResponse to prevent Spark driver network exhausted
### What changes were proposed in this pull request? For spark celeborn application, if the GetReducerFileGroupResponse is larger than the threshold, Spark driver would broadcast the GetReducerFileGroupResponse to the executors, it prevents the driver from being the bottleneck in sending out multiple copies of the GetReducerFileGroupResponse (one per executor). ### Why are the changes needed? To prevent the driver from being the bottleneck in sending out multiple copies of the GetReducerFileGroupResponse (one per executor). ### Does this PR introduce _any_ user-facing change? No, the feature is not enabled by defaults. ### How was this patch tested? UT. Cluster testing with `spark.celeborn.client.spark.shuffle.getReducerFileGroup.broadcast.enabled=true`. The broadcast response size should be always about 1kb.   Application succeed.  Closes #3158 from turboFei/broadcast_rgf. Authored-by: Wang, Fei <fwang12@ebay.com> Signed-off-by: Wang, Fei <fwang12@ebay.com>
This commit is contained in:
parent
1e30f159b9
commit
5e12b7d607
1
LICENSE
1
LICENSE
@ -223,6 +223,7 @@ Apache Spark
|
||||
./common/src/test/java/org/apache/celeborn/common/network/protocol/EncryptedMessageWithHeaderSuiteJ.java
|
||||
./common/src/test/java/org/apache/celeborn/common/network/ssl/ReloadingX509TrustManagerSuiteJ.java
|
||||
./common/src/test/java/org/apache/celeborn/common/network/ssl/SslSampleConfigs.java
|
||||
./common/src/main/scala/org/apache/celeborn/common/util/KeyLock.scala
|
||||
./worker/src/main/java/org/apache/celeborn/service/deploy/worker/shuffledb/DB.java
|
||||
./worker/src/main/java/org/apache/celeborn/service/deploy/worker/shuffledb/DBIterator.java
|
||||
./worker/src/main/java/org/apache/celeborn/service/deploy/worker/shuffledb/LevelDB.java
|
||||
|
||||
@ -73,6 +73,7 @@
|
||||
<include>io.netty:*</include>
|
||||
<include>org.apache.commons:commons-lang3</include>
|
||||
<include>org.roaringbitmap:RoaringBitmap</include>
|
||||
<include>commons-io:commons-io</include>
|
||||
</includes>
|
||||
</artifactSet>
|
||||
<filters>
|
||||
|
||||
@ -109,6 +109,15 @@ public class SparkShuffleManager implements ShuffleManager {
|
||||
lifecycleManager.registerShuffleTrackerCallback(
|
||||
shuffleId -> mapOutputTracker.unregisterAllMapOutput(shuffleId));
|
||||
}
|
||||
|
||||
if (celebornConf.getReducerFileGroupBroadcastEnabled()) {
|
||||
lifecycleManager.registerBroadcastGetReducerFileGroupResponseCallback(
|
||||
(shuffleId, getReducerFileGroupResponse) ->
|
||||
SparkUtils.serializeGetReducerFileGroupResponse(
|
||||
shuffleId, getReducerFileGroupResponse));
|
||||
lifecycleManager.registerInvalidatedBroadcastCallback(
|
||||
shuffleId -> SparkUtils.invalidateSerializedGetReducerFileGroupResponse(shuffleId));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -17,7 +17,10 @@
|
||||
|
||||
package org.apache.spark.shuffle.celeborn;
|
||||
|
||||
import java.io.ByteArrayInputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.ObjectInputStream;
|
||||
import java.io.ObjectOutputStream;
|
||||
import java.lang.reflect.Field;
|
||||
import java.lang.reflect.Method;
|
||||
import java.util.HashSet;
|
||||
@ -25,6 +28,7 @@ import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import java.util.concurrent.atomic.LongAdder;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@ -37,7 +41,12 @@ import org.apache.spark.BarrierTaskContext;
|
||||
import org.apache.spark.SparkConf;
|
||||
import org.apache.spark.SparkContext;
|
||||
import org.apache.spark.SparkContext$;
|
||||
import org.apache.spark.SparkEnv;
|
||||
import org.apache.spark.SparkEnv$;
|
||||
import org.apache.spark.TaskContext;
|
||||
import org.apache.spark.broadcast.Broadcast;
|
||||
import org.apache.spark.io.CompressionCodec;
|
||||
import org.apache.spark.io.CompressionCodec$;
|
||||
import org.apache.spark.scheduler.DAGScheduler;
|
||||
import org.apache.spark.scheduler.MapStatus;
|
||||
import org.apache.spark.scheduler.MapStatus$;
|
||||
@ -54,7 +63,10 @@ import org.slf4j.LoggerFactory;
|
||||
|
||||
import org.apache.celeborn.client.ShuffleClient;
|
||||
import org.apache.celeborn.common.CelebornConf;
|
||||
import org.apache.celeborn.common.network.protocol.TransportMessage;
|
||||
import org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse;
|
||||
import org.apache.celeborn.common.util.JavaUtils;
|
||||
import org.apache.celeborn.common.util.KeyLock;
|
||||
import org.apache.celeborn.common.util.Utils;
|
||||
import org.apache.celeborn.reflect.DynFields;
|
||||
|
||||
@ -346,4 +358,121 @@ public class SparkUtils {
|
||||
sparkContext.addSparkListener(listener);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A [[KeyLock]] whose key is a shuffle id to ensure there is only one thread accessing the
|
||||
* broadcast belonging to the shuffle id at a time.
|
||||
*/
|
||||
private static final KeyLock<Integer> shuffleBroadcastLock = new KeyLock<>();
|
||||
|
||||
@VisibleForTesting
|
||||
public static AtomicInteger getReducerFileGroupResponseBroadcastNum = new AtomicInteger();
|
||||
|
||||
@VisibleForTesting
|
||||
public static Map<Integer, Tuple2<Broadcast<TransportMessage>, byte[]>>
|
||||
getReducerFileGroupResponseBroadcasts = JavaUtils.newConcurrentHashMap();
|
||||
|
||||
public static byte[] serializeGetReducerFileGroupResponse(
|
||||
Integer shuffleId, GetReducerFileGroupResponse response) {
|
||||
SparkContext sparkContext = SparkContext$.MODULE$.getActive().getOrElse(null);
|
||||
if (sparkContext == null) {
|
||||
logger.error("Can not get active SparkContext.");
|
||||
return null;
|
||||
}
|
||||
|
||||
return shuffleBroadcastLock.withLock(
|
||||
shuffleId,
|
||||
() -> {
|
||||
Tuple2<Broadcast<TransportMessage>, byte[]> cachedSerializeGetReducerFileGroupResponse =
|
||||
getReducerFileGroupResponseBroadcasts.get(shuffleId);
|
||||
if (cachedSerializeGetReducerFileGroupResponse != null) {
|
||||
return cachedSerializeGetReducerFileGroupResponse._2;
|
||||
}
|
||||
|
||||
try {
|
||||
logger.info("Broadcasting GetReducerFileGroupResponse for shuffle: {}", shuffleId);
|
||||
TransportMessage transportMessage =
|
||||
(TransportMessage) Utils.toTransportMessage(response);
|
||||
Broadcast<TransportMessage> broadcast =
|
||||
sparkContext.broadcast(
|
||||
transportMessage,
|
||||
scala.reflect.ClassManifestFactory.fromClass(TransportMessage.class));
|
||||
|
||||
CompressionCodec codec = CompressionCodec$.MODULE$.createCodec(sparkContext.conf());
|
||||
// Using `org.apache.commons.io.output.ByteArrayOutputStream` instead of the standard
|
||||
// one
|
||||
// This implementation doesn't reallocate the whole memory block but allocates
|
||||
// additional buffers. This way no buffers need to be garbage collected and
|
||||
// the contents don't have to be copied to the new buffer.
|
||||
org.apache.commons.io.output.ByteArrayOutputStream out =
|
||||
new org.apache.commons.io.output.ByteArrayOutputStream();
|
||||
try (ObjectOutputStream oos =
|
||||
new ObjectOutputStream(codec.compressedOutputStream(out))) {
|
||||
oos.writeObject(broadcast);
|
||||
}
|
||||
byte[] _serializeResult = out.toByteArray();
|
||||
getReducerFileGroupResponseBroadcasts.put(
|
||||
shuffleId, Tuple2.apply(broadcast, _serializeResult));
|
||||
getReducerFileGroupResponseBroadcastNum.incrementAndGet();
|
||||
return _serializeResult;
|
||||
} catch (Throwable e) {
|
||||
logger.error(
|
||||
"Failed to serialize GetReducerFileGroupResponse for shuffle: {}", shuffleId, e);
|
||||
return null;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
public static GetReducerFileGroupResponse deserializeGetReducerFileGroupResponse(
|
||||
Integer shuffleId, byte[] bytes) {
|
||||
SparkEnv sparkEnv = SparkEnv$.MODULE$.get();
|
||||
if (sparkEnv == null) {
|
||||
logger.error("Can not get SparkEnv.");
|
||||
return null;
|
||||
}
|
||||
|
||||
return shuffleBroadcastLock.withLock(
|
||||
shuffleId,
|
||||
() -> {
|
||||
GetReducerFileGroupResponse response = null;
|
||||
logger.info(
|
||||
"Deserializing GetReducerFileGroupResponse broadcast for shuffle: {}", shuffleId);
|
||||
|
||||
try {
|
||||
CompressionCodec codec = CompressionCodec$.MODULE$.createCodec(sparkEnv.conf());
|
||||
try (ObjectInputStream objIn =
|
||||
new ObjectInputStream(
|
||||
codec.compressedInputStream(new ByteArrayInputStream(bytes)))) {
|
||||
Broadcast<TransportMessage> broadcast =
|
||||
(Broadcast<TransportMessage>) objIn.readObject();
|
||||
response =
|
||||
(GetReducerFileGroupResponse) Utils.fromTransportMessage(broadcast.value());
|
||||
}
|
||||
} catch (Throwable e) {
|
||||
logger.error(
|
||||
"Failed to deserialize GetReducerFileGroupResponse for shuffle: " + shuffleId, e);
|
||||
}
|
||||
return response;
|
||||
});
|
||||
}
|
||||
|
||||
public static void invalidateSerializedGetReducerFileGroupResponse(Integer shuffleId) {
|
||||
shuffleBroadcastLock.withLock(
|
||||
shuffleId,
|
||||
() -> {
|
||||
try {
|
||||
Tuple2<Broadcast<TransportMessage>, byte[]> cachedSerializeGetReducerFileGroupResponse =
|
||||
getReducerFileGroupResponseBroadcasts.remove(shuffleId);
|
||||
if (cachedSerializeGetReducerFileGroupResponse != null) {
|
||||
cachedSerializeGetReducerFileGroupResponse._1().destroy();
|
||||
}
|
||||
} catch (Throwable e) {
|
||||
logger.error(
|
||||
"Failed to invalidate serialized GetReducerFileGroupResponse for shuffle: "
|
||||
+ shuffleId,
|
||||
e);
|
||||
}
|
||||
return null;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@ -20,6 +20,7 @@ package org.apache.spark.shuffle.celeborn
|
||||
import java.io.IOException
|
||||
import java.util.concurrent.{ThreadPoolExecutor, TimeUnit}
|
||||
import java.util.concurrent.atomic.AtomicReference
|
||||
import java.util.function.BiFunction
|
||||
|
||||
import org.apache.spark.{Aggregator, InterruptibleIterator, TaskContext}
|
||||
import org.apache.spark.internal.Logging
|
||||
@ -33,6 +34,7 @@ import org.apache.celeborn.client.read.CelebornInputStream
|
||||
import org.apache.celeborn.client.read.MetricsCallback
|
||||
import org.apache.celeborn.common.CelebornConf
|
||||
import org.apache.celeborn.common.exception.{CelebornIOException, PartitionUnRetryAbleException}
|
||||
import org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse
|
||||
import org.apache.celeborn.common.util.{JavaUtils, ThreadUtils}
|
||||
|
||||
class CelebornShuffleReader[K, C](
|
||||
@ -254,4 +256,13 @@ class CelebornShuffleReader[K, C](
|
||||
|
||||
object CelebornShuffleReader {
|
||||
var streamCreatorPool: ThreadPoolExecutor = null
|
||||
// Register the deserializer for GetReducerFileGroupResponse broadcast
|
||||
ShuffleClient.registerDeserializeReducerFileGroupResponseFunction(new BiFunction[
|
||||
Integer,
|
||||
Array[Byte],
|
||||
GetReducerFileGroupResponse] {
|
||||
override def apply(shuffleId: Integer, broadcast: Array[Byte]): GetReducerFileGroupResponse = {
|
||||
SparkUtils.deserializeGetReducerFileGroupResponse(shuffleId, broadcast)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@ -77,6 +77,7 @@
|
||||
<include>io.netty:*</include>
|
||||
<include>org.apache.commons:commons-lang3</include>
|
||||
<include>org.roaringbitmap:RoaringBitmap</include>
|
||||
<include>commons-io:commons-io</include>
|
||||
</includes>
|
||||
</artifactSet>
|
||||
<filters>
|
||||
|
||||
@ -156,6 +156,15 @@ public class SparkShuffleManager implements ShuffleManager {
|
||||
SparkUtils::isCelebornSkewShuffleOrChildShuffle);
|
||||
}
|
||||
}
|
||||
|
||||
if (celebornConf.getReducerFileGroupBroadcastEnabled()) {
|
||||
lifecycleManager.registerBroadcastGetReducerFileGroupResponseCallback(
|
||||
(shuffleId, getReducerFileGroupResponse) ->
|
||||
SparkUtils.serializeGetReducerFileGroupResponse(
|
||||
shuffleId, getReducerFileGroupResponse));
|
||||
lifecycleManager.registerInvalidatedBroadcastCallback(
|
||||
shuffleId -> SparkUtils.invalidateSerializedGetReducerFileGroupResponse(shuffleId));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -17,11 +17,15 @@
|
||||
|
||||
package org.apache.spark.shuffle.celeborn;
|
||||
|
||||
import java.io.ByteArrayInputStream;
|
||||
import java.io.ObjectInputStream;
|
||||
import java.io.ObjectOutputStream;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import java.util.concurrent.atomic.LongAdder;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@ -35,7 +39,12 @@ import org.apache.spark.MapOutputTrackerMaster;
|
||||
import org.apache.spark.SparkConf;
|
||||
import org.apache.spark.SparkContext;
|
||||
import org.apache.spark.SparkContext$;
|
||||
import org.apache.spark.SparkEnv;
|
||||
import org.apache.spark.SparkEnv$;
|
||||
import org.apache.spark.TaskContext;
|
||||
import org.apache.spark.broadcast.Broadcast;
|
||||
import org.apache.spark.io.CompressionCodec;
|
||||
import org.apache.spark.io.CompressionCodec$;
|
||||
import org.apache.spark.scheduler.DAGScheduler;
|
||||
import org.apache.spark.scheduler.MapStatus;
|
||||
import org.apache.spark.scheduler.MapStatus$;
|
||||
@ -57,7 +66,11 @@ import org.slf4j.LoggerFactory;
|
||||
|
||||
import org.apache.celeborn.client.ShuffleClient;
|
||||
import org.apache.celeborn.common.CelebornConf;
|
||||
import org.apache.celeborn.common.network.protocol.TransportMessage;
|
||||
import org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse;
|
||||
import org.apache.celeborn.common.util.JavaUtils;
|
||||
import org.apache.celeborn.common.util.KeyLock;
|
||||
import org.apache.celeborn.common.util.Utils;
|
||||
import org.apache.celeborn.reflect.DynConstructors;
|
||||
import org.apache.celeborn.reflect.DynFields;
|
||||
import org.apache.celeborn.reflect.DynMethods;
|
||||
@ -476,4 +489,121 @@ public class SparkUtils {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A [[KeyLock]] whose key is a shuffle id to ensure there is only one thread accessing the
|
||||
* broadcast belonging to the shuffle id at a time.
|
||||
*/
|
||||
private static final KeyLock<Integer> shuffleBroadcastLock = new KeyLock<>();
|
||||
|
||||
@VisibleForTesting
|
||||
public static AtomicInteger getReducerFileGroupResponseBroadcastNum = new AtomicInteger();
|
||||
|
||||
@VisibleForTesting
|
||||
public static Map<Integer, Tuple2<Broadcast<TransportMessage>, byte[]>>
|
||||
getReducerFileGroupResponseBroadcasts = JavaUtils.newConcurrentHashMap();
|
||||
|
||||
public static byte[] serializeGetReducerFileGroupResponse(
|
||||
Integer shuffleId, GetReducerFileGroupResponse response) {
|
||||
SparkContext sparkContext = SparkContext$.MODULE$.getActive().getOrElse(null);
|
||||
if (sparkContext == null) {
|
||||
LOG.error("Can not get active SparkContext.");
|
||||
return null;
|
||||
}
|
||||
|
||||
return shuffleBroadcastLock.withLock(
|
||||
shuffleId,
|
||||
() -> {
|
||||
Tuple2<Broadcast<TransportMessage>, byte[]> cachedSerializeGetReducerFileGroupResponse =
|
||||
getReducerFileGroupResponseBroadcasts.get(shuffleId);
|
||||
if (cachedSerializeGetReducerFileGroupResponse != null) {
|
||||
return cachedSerializeGetReducerFileGroupResponse._2;
|
||||
}
|
||||
|
||||
try {
|
||||
LOG.info("Broadcasting GetReducerFileGroupResponse for shuffle: {}", shuffleId);
|
||||
TransportMessage transportMessage =
|
||||
(TransportMessage) Utils.toTransportMessage(response);
|
||||
Broadcast<TransportMessage> broadcast =
|
||||
sparkContext.broadcast(
|
||||
transportMessage,
|
||||
scala.reflect.ClassManifestFactory.fromClass(TransportMessage.class));
|
||||
|
||||
CompressionCodec codec = CompressionCodec$.MODULE$.createCodec(sparkContext.conf());
|
||||
// Using `org.apache.commons.io.output.ByteArrayOutputStream` instead of the standard
|
||||
// one
|
||||
// This implementation doesn't reallocate the whole memory block but allocates
|
||||
// additional buffers. This way no buffers need to be garbage collected and
|
||||
// the contents don't have to be copied to the new buffer.
|
||||
org.apache.commons.io.output.ByteArrayOutputStream out =
|
||||
new org.apache.commons.io.output.ByteArrayOutputStream();
|
||||
try (ObjectOutputStream oos =
|
||||
new ObjectOutputStream(codec.compressedOutputStream(out))) {
|
||||
oos.writeObject(broadcast);
|
||||
}
|
||||
byte[] _serializeResult = out.toByteArray();
|
||||
getReducerFileGroupResponseBroadcasts.put(
|
||||
shuffleId, Tuple2.apply(broadcast, _serializeResult));
|
||||
getReducerFileGroupResponseBroadcastNum.incrementAndGet();
|
||||
return _serializeResult;
|
||||
} catch (Throwable e) {
|
||||
LOG.error(
|
||||
"Failed to serialize GetReducerFileGroupResponse for shuffle: {}", shuffleId, e);
|
||||
return null;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
public static GetReducerFileGroupResponse deserializeGetReducerFileGroupResponse(
|
||||
Integer shuffleId, byte[] bytes) {
|
||||
SparkEnv sparkEnv = SparkEnv$.MODULE$.get();
|
||||
if (sparkEnv == null) {
|
||||
LOG.error("Can not get SparkEnv.");
|
||||
return null;
|
||||
}
|
||||
|
||||
return shuffleBroadcastLock.withLock(
|
||||
shuffleId,
|
||||
() -> {
|
||||
GetReducerFileGroupResponse response = null;
|
||||
LOG.info(
|
||||
"Deserializing GetReducerFileGroupResponse broadcast for shuffle: {}", shuffleId);
|
||||
|
||||
try {
|
||||
CompressionCodec codec = CompressionCodec$.MODULE$.createCodec(sparkEnv.conf());
|
||||
try (ObjectInputStream objIn =
|
||||
new ObjectInputStream(
|
||||
codec.compressedInputStream(new ByteArrayInputStream(bytes)))) {
|
||||
Broadcast<TransportMessage> broadcast =
|
||||
(Broadcast<TransportMessage>) objIn.readObject();
|
||||
response =
|
||||
(GetReducerFileGroupResponse) Utils.fromTransportMessage(broadcast.value());
|
||||
}
|
||||
} catch (Throwable e) {
|
||||
LOG.error(
|
||||
"Failed to deserialize GetReducerFileGroupResponse for shuffle: " + shuffleId, e);
|
||||
}
|
||||
return response;
|
||||
});
|
||||
}
|
||||
|
||||
public static void invalidateSerializedGetReducerFileGroupResponse(Integer shuffleId) {
|
||||
shuffleBroadcastLock.withLock(
|
||||
shuffleId,
|
||||
() -> {
|
||||
try {
|
||||
Tuple2<Broadcast<TransportMessage>, byte[]> cachedSerializeGetReducerFileGroupResponse =
|
||||
getReducerFileGroupResponseBroadcasts.remove(shuffleId);
|
||||
if (cachedSerializeGetReducerFileGroupResponse != null) {
|
||||
cachedSerializeGetReducerFileGroupResponse._1().destroy();
|
||||
}
|
||||
} catch (Throwable e) {
|
||||
LOG.error(
|
||||
"Failed to invalidate serialized GetReducerFileGroupResponse for shuffle: "
|
||||
+ shuffleId,
|
||||
e);
|
||||
}
|
||||
return null;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@ -21,6 +21,7 @@ import java.io.IOException
|
||||
import java.util.{ArrayList => JArrayList, HashMap => JHashMap, Map => JMap, Set => JSet}
|
||||
import java.util.concurrent.{ConcurrentHashMap, ThreadPoolExecutor, TimeoutException, TimeUnit}
|
||||
import java.util.concurrent.atomic.AtomicReference
|
||||
import java.util.function.BiFunction
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
@ -43,7 +44,8 @@ import org.apache.celeborn.common.exception.{CelebornIOException, PartitionUnRet
|
||||
import org.apache.celeborn.common.network.client.TransportClient
|
||||
import org.apache.celeborn.common.network.protocol.TransportMessage
|
||||
import org.apache.celeborn.common.protocol._
|
||||
import org.apache.celeborn.common.protocol.message.StatusCode
|
||||
import org.apache.celeborn.common.protocol.message.{ControlMessages, StatusCode}
|
||||
import org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse
|
||||
import org.apache.celeborn.common.util.{JavaUtils, ThreadUtils, Utils}
|
||||
|
||||
class CelebornShuffleReader[K, C](
|
||||
@ -465,4 +467,13 @@ class CelebornShuffleReader[K, C](
|
||||
|
||||
object CelebornShuffleReader {
|
||||
var streamCreatorPool: ThreadPoolExecutor = null
|
||||
// Register the deserializer for GetReducerFileGroupResponse broadcast
|
||||
ShuffleClient.registerDeserializeReducerFileGroupResponseFunction(new BiFunction[
|
||||
Integer,
|
||||
Array[Byte],
|
||||
GetReducerFileGroupResponse] {
|
||||
override def apply(shuffleId: Integer, broadcast: Array[Byte]): GetReducerFileGroupResponse = {
|
||||
SparkUtils.deserializeGetReducerFileGroupResponse(shuffleId, broadcast)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@ -20,9 +20,11 @@ package org.apache.celeborn.client;
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.atomic.LongAdder;
|
||||
import java.util.function.BiFunction;
|
||||
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.apache.hadoop.fs.FileSystem;
|
||||
@ -38,6 +40,7 @@ import org.apache.celeborn.common.network.client.TransportClientFactory;
|
||||
import org.apache.celeborn.common.protocol.PartitionLocation;
|
||||
import org.apache.celeborn.common.protocol.PbStreamHandler;
|
||||
import org.apache.celeborn.common.protocol.StorageInfo;
|
||||
import org.apache.celeborn.common.protocol.message.ControlMessages;
|
||||
import org.apache.celeborn.common.rpc.RpcEndpointRef;
|
||||
import org.apache.celeborn.common.util.CelebornHadoopUtils;
|
||||
import org.apache.celeborn.common.util.ExceptionMaker;
|
||||
@ -56,6 +59,10 @@ public abstract class ShuffleClient {
|
||||
private static LongAdder totalReadCounter = new LongAdder();
|
||||
private static LongAdder localShuffleReadCounter = new LongAdder();
|
||||
|
||||
private static volatile Optional<
|
||||
BiFunction<Integer, byte[], ControlMessages.GetReducerFileGroupResponse>>
|
||||
deserializeReducerFileGroupResponseFunction = Optional.empty();
|
||||
|
||||
// for testing
|
||||
public static void reset() {
|
||||
_instance = null;
|
||||
@ -297,4 +304,21 @@ public abstract class ShuffleClient {
|
||||
public abstract TransportClientFactory getDataClientFactory();
|
||||
|
||||
public abstract void excludeFailedFetchLocation(String hostAndFetchPort, Exception e);
|
||||
|
||||
public static void registerDeserializeReducerFileGroupResponseFunction(
|
||||
BiFunction<Integer, byte[], ControlMessages.GetReducerFileGroupResponse> function) {
|
||||
if (!deserializeReducerFileGroupResponseFunction.isPresent()) {
|
||||
deserializeReducerFileGroupResponseFunction = Optional.ofNullable(function);
|
||||
}
|
||||
}
|
||||
|
||||
public static ControlMessages.GetReducerFileGroupResponse deserializeReducerFileGroupResponse(
|
||||
int shuffleId, byte[] bytes) {
|
||||
if (!deserializeReducerFileGroupResponseFunction.isPresent()) {
|
||||
// Should never happen
|
||||
logger.warn("DeserializeReducerFileGroupResponseFunction is not registered.");
|
||||
return null;
|
||||
}
|
||||
return deserializeReducerFileGroupResponseFunction.get().apply(shuffleId, bytes);
|
||||
}
|
||||
}
|
||||
|
||||
@ -1814,6 +1814,14 @@ public class ShuffleClientImpl extends ShuffleClient {
|
||||
ClassTag$.MODULE$.apply(GetReducerFileGroupResponse.class));
|
||||
switch (response.status()) {
|
||||
case SUCCESS:
|
||||
if (response.broadcast() != null && response.broadcast().length > 0) {
|
||||
response =
|
||||
ShuffleClient.deserializeReducerFileGroupResponse(shuffleId, response.broadcast());
|
||||
if (response == null) {
|
||||
throw new CelebornIOException(
|
||||
"Failed to get GetReducerFileGroupResponse broadcast for shuffle: " + shuffleId);
|
||||
}
|
||||
}
|
||||
logger.info(
|
||||
"Shuffle {} request reducer file group success using {} ms, result partition size {}.",
|
||||
shuffleId,
|
||||
|
||||
@ -292,7 +292,8 @@ class CommitManager(appUniqueId: String, val conf: CelebornConf, lifecycleManage
|
||||
lifecycleManager.shuffleAllocatedWorkers,
|
||||
committedPartitionInfo,
|
||||
lifecycleManager.workerStatusTracker,
|
||||
lifecycleManager.rpcSharedThreadPool)
|
||||
lifecycleManager.rpcSharedThreadPool,
|
||||
lifecycleManager)
|
||||
case PartitionType.MAP => new MapPartitionCommitHandler(
|
||||
appUniqueId,
|
||||
conf,
|
||||
|
||||
@ -1677,6 +1677,7 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
|
||||
} else {
|
||||
batchRemoveShuffleIds += shuffleId
|
||||
}
|
||||
invalidatedBroadcastGetReducerFileGroupResponse(shuffleId)
|
||||
}
|
||||
}
|
||||
if (batchRemoveShuffleIds.nonEmpty) {
|
||||
@ -1848,6 +1849,22 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
|
||||
cancelShuffleCallback = Some(callback)
|
||||
}
|
||||
|
||||
@volatile private var broadcastGetReducerFileGroupResponseCallback
|
||||
: Option[java.util.function.BiFunction[Integer, GetReducerFileGroupResponse, Array[Byte]]] =
|
||||
None
|
||||
def registerBroadcastGetReducerFileGroupResponseCallback(call: java.util.function.BiFunction[
|
||||
Integer,
|
||||
GetReducerFileGroupResponse,
|
||||
Array[Byte]]): Unit = {
|
||||
broadcastGetReducerFileGroupResponseCallback = Some(call)
|
||||
}
|
||||
|
||||
@volatile private var invalidatedBroadcastCallback: Option[Consumer[Integer]] =
|
||||
None
|
||||
def registerInvalidatedBroadcastCallback(call: Consumer[Integer]): Unit = {
|
||||
invalidatedBroadcastCallback = Some(call)
|
||||
}
|
||||
|
||||
def invalidateLatestMaxLocsCache(shuffleId: Int): Unit = {
|
||||
registerShuffleResponseRpcCache.invalidate(shuffleId)
|
||||
}
|
||||
@ -1889,4 +1906,19 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
|
||||
case _ =>
|
||||
}
|
||||
|
||||
def broadcastGetReducerFileGroupResponse(
|
||||
shuffleId: Int,
|
||||
response: GetReducerFileGroupResponse): Option[Array[Byte]] = {
|
||||
broadcastGetReducerFileGroupResponseCallback match {
|
||||
case Some(c) => Option(c.apply(shuffleId, response))
|
||||
case _ => None
|
||||
}
|
||||
}
|
||||
|
||||
private def invalidatedBroadcastGetReducerFileGroupResponse(shuffleId: Int): Unit = {
|
||||
invalidatedBroadcastCallback match {
|
||||
case Some(c) => c.accept(shuffleId)
|
||||
case _ =>
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -28,7 +28,7 @@ import scala.collection.mutable
|
||||
import com.google.common.cache.{Cache, CacheBuilder}
|
||||
import com.google.common.collect.Sets
|
||||
|
||||
import org.apache.celeborn.client.{ClientUtils, ShuffleCommittedInfo, WorkerStatusTracker}
|
||||
import org.apache.celeborn.client.{ClientUtils, LifecycleManager, ShuffleCommittedInfo, WorkerStatusTracker}
|
||||
import org.apache.celeborn.client.CommitManager.CommittedPartitionInfo
|
||||
import org.apache.celeborn.client.LifecycleManager.{ShuffleAllocatedWorkers, ShuffleFailedWorkers}
|
||||
import org.apache.celeborn.common.CelebornConf
|
||||
@ -55,7 +55,8 @@ class ReducePartitionCommitHandler(
|
||||
shuffleAllocatedWorkers: ShuffleAllocatedWorkers,
|
||||
committedPartitionInfo: CommittedPartitionInfo,
|
||||
workerStatusTracker: WorkerStatusTracker,
|
||||
sharedRpcPool: ThreadPoolExecutor)
|
||||
sharedRpcPool: ThreadPoolExecutor,
|
||||
lifecycleManager: LifecycleManager)
|
||||
extends CommitHandler(
|
||||
appUniqueId,
|
||||
conf,
|
||||
@ -78,6 +79,10 @@ class ReducePartitionCommitHandler(
|
||||
private val rpcCacheConcurrencyLevel = conf.clientRpcCacheConcurrencyLevel
|
||||
private val rpcCacheExpireTime = conf.clientRpcCacheExpireTime
|
||||
|
||||
private val getReducerFileGroupResponseBroadcastEnabled = conf.getReducerFileGroupBroadcastEnabled
|
||||
private val getReducerFileGroupResponseBroadcastMiniSize =
|
||||
conf.getReducerFileGroupBroadcastMiniSize
|
||||
|
||||
// noinspection UnstableApiUsage
|
||||
private val getReducerFileGroupRpcCache: Cache[Int, ByteBuffer] = CacheBuilder.newBuilder()
|
||||
.concurrencyLevel(rpcCacheConcurrencyLevel)
|
||||
@ -320,10 +325,17 @@ class ReducePartitionCommitHandler(
|
||||
} else {
|
||||
// LocalNettyRpcCallContext is for the UTs
|
||||
if (context.isInstanceOf[LocalNettyRpcCallContext]) {
|
||||
context.reply(GetReducerFileGroupResponse(
|
||||
var response = GetReducerFileGroupResponse(
|
||||
StatusCode.SUCCESS,
|
||||
reducerFileGroupsMap.getOrDefault(shuffleId, JavaUtils.newConcurrentHashMap()),
|
||||
getMapperAttempts(shuffleId)))
|
||||
getMapperAttempts(shuffleId))
|
||||
|
||||
// only check whether broadcast enabled for the UTs
|
||||
if (getReducerFileGroupResponseBroadcastEnabled) {
|
||||
response = broadcastGetReducerFileGroup(shuffleId, response)
|
||||
}
|
||||
|
||||
context.reply(response)
|
||||
} else {
|
||||
val cachedMsg = getReducerFileGroupRpcCache.get(
|
||||
shuffleId,
|
||||
@ -337,7 +349,27 @@ class ReducePartitionCommitHandler(
|
||||
shufflePushFailedBatches.getOrDefault(
|
||||
shuffleId,
|
||||
new util.HashMap[String, util.Set[PushFailedBatch]]()))
|
||||
context.asInstanceOf[RemoteNettyRpcCallContext].nettyEnv.serialize(returnedMsg)
|
||||
|
||||
val serializedMsg =
|
||||
context.asInstanceOf[RemoteNettyRpcCallContext].nettyEnv.serialize(returnedMsg)
|
||||
|
||||
if (getReducerFileGroupResponseBroadcastEnabled &&
|
||||
serializedMsg.capacity() >= getReducerFileGroupResponseBroadcastMiniSize) {
|
||||
val broadcastMsg = broadcastGetReducerFileGroup(shuffleId, returnedMsg)
|
||||
if (broadcastMsg != returnedMsg) {
|
||||
val serializedBroadcastMsg =
|
||||
context.asInstanceOf[RemoteNettyRpcCallContext].nettyEnv.serialize(broadcastMsg)
|
||||
logInfo(s"Shuffle $shuffleId GetReducerFileGroupResponse size" +
|
||||
s" ${serializedMsg.capacity()} reached the broadcast threshold" +
|
||||
s" $getReducerFileGroupResponseBroadcastMiniSize," +
|
||||
s" the broadcast response size is ${serializedBroadcastMsg.capacity()}.")
|
||||
serializedBroadcastMsg
|
||||
} else {
|
||||
serializedMsg
|
||||
}
|
||||
} else {
|
||||
serializedMsg
|
||||
}
|
||||
}
|
||||
})
|
||||
context.asInstanceOf[RemoteNettyRpcCallContext].callback.onSuccess(cachedMsg)
|
||||
@ -345,6 +377,16 @@ class ReducePartitionCommitHandler(
|
||||
}
|
||||
}
|
||||
|
||||
private def broadcastGetReducerFileGroup(
|
||||
shuffleId: Int,
|
||||
response: GetReducerFileGroupResponse): GetReducerFileGroupResponse = {
|
||||
lifecycleManager.broadcastGetReducerFileGroupResponse(shuffleId, response) match {
|
||||
case Some(broadcastBytes) if broadcastBytes.nonEmpty =>
|
||||
GetReducerFileGroupResponse(response.status, broadcast = broadcastBytes)
|
||||
case _ => response
|
||||
}
|
||||
}
|
||||
|
||||
override def handleGetReducerFileGroup(context: RpcCallContext, shuffleId: Int): Unit = {
|
||||
// Quick return for ended stage, avoid occupy sync lock.
|
||||
if (isStageEnd(shuffleId)) {
|
||||
|
||||
@ -427,7 +427,8 @@ public class ShuffleClientSuiteJ {
|
||||
locations,
|
||||
new int[0],
|
||||
Collections.emptySet(),
|
||||
Collections.emptyMap());
|
||||
Collections.emptyMap(),
|
||||
new byte[0]);
|
||||
});
|
||||
|
||||
when(endpointRef.askSync(any(), any(), any(Integer.class), any(Long.class), any()))
|
||||
@ -439,7 +440,8 @@ public class ShuffleClientSuiteJ {
|
||||
locations,
|
||||
new int[0],
|
||||
Collections.emptySet(),
|
||||
Collections.emptyMap());
|
||||
Collections.emptyMap(),
|
||||
new byte[0]);
|
||||
});
|
||||
|
||||
shuffleClient =
|
||||
@ -482,7 +484,8 @@ public class ShuffleClientSuiteJ {
|
||||
locations,
|
||||
new int[0],
|
||||
Collections.emptySet(),
|
||||
Collections.emptyMap());
|
||||
Collections.emptyMap(),
|
||||
new byte[0]);
|
||||
});
|
||||
|
||||
when(endpointRef.askSync(any(), any(), any(Integer.class), any(Long.class), any()))
|
||||
@ -493,7 +496,8 @@ public class ShuffleClientSuiteJ {
|
||||
locations,
|
||||
new int[0],
|
||||
Collections.emptySet(),
|
||||
Collections.emptyMap());
|
||||
Collections.emptyMap(),
|
||||
new byte[0]);
|
||||
});
|
||||
|
||||
shuffleClient =
|
||||
@ -514,7 +518,8 @@ public class ShuffleClientSuiteJ {
|
||||
locations,
|
||||
new int[0],
|
||||
Collections.emptySet(),
|
||||
Collections.emptyMap());
|
||||
Collections.emptyMap(),
|
||||
new byte[0]);
|
||||
});
|
||||
|
||||
when(endpointRef.askSync(any(), any(), any(Integer.class), any(Long.class), any()))
|
||||
@ -525,7 +530,8 @@ public class ShuffleClientSuiteJ {
|
||||
locations,
|
||||
new int[0],
|
||||
Collections.emptySet(),
|
||||
Collections.emptyMap());
|
||||
Collections.emptyMap(),
|
||||
new byte[0]);
|
||||
});
|
||||
|
||||
shuffleClient =
|
||||
@ -546,7 +552,8 @@ public class ShuffleClientSuiteJ {
|
||||
locations,
|
||||
new int[0],
|
||||
Collections.emptySet(),
|
||||
Collections.emptyMap());
|
||||
Collections.emptyMap(),
|
||||
new byte[0]);
|
||||
});
|
||||
|
||||
when(endpointRef.askSync(any(), any(), any(Integer.class), any(Long.class), any()))
|
||||
@ -557,7 +564,8 @@ public class ShuffleClientSuiteJ {
|
||||
locations,
|
||||
new int[0],
|
||||
Collections.emptySet(),
|
||||
Collections.emptyMap());
|
||||
Collections.emptyMap(),
|
||||
new byte[0]);
|
||||
});
|
||||
|
||||
shuffleClient =
|
||||
|
||||
@ -390,6 +390,8 @@ message PbGetReducerFileGroupResponse {
|
||||
repeated int32 partitionIds = 4;
|
||||
|
||||
map<string, PbPushFailedBatchSet> pushFailedBatches = 5;
|
||||
|
||||
bytes broadcast = 6;
|
||||
}
|
||||
|
||||
message PbGetShuffleId {
|
||||
|
||||
@ -1057,6 +1057,10 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se
|
||||
get(CLIENT_PUSH_DYNAMIC_WRITE_MODE_ENABLED)
|
||||
def dynamicWriteModePartitionNumThreshold =
|
||||
get(CLIENT_PUSH_DYNAMIC_WRITE_MODE_PARTITION_NUM_THRESHOLD)
|
||||
def getReducerFileGroupBroadcastEnabled =
|
||||
get(CLIENT_SHUFFLE_GET_REDUCER_FILE_GROUP_BROADCAST_ENABLED)
|
||||
def getReducerFileGroupBroadcastMiniSize =
|
||||
get(CLIENT_SHUFFLE_GET_REDUCER_FILE_GROUP_BROADCAST_MINI_SIZE)
|
||||
def shufflePartitionType: PartitionType = PartitionType.valueOf(get(SHUFFLE_PARTITION_TYPE))
|
||||
def shuffleRangeReadFilterEnabled: Boolean = get(SHUFFLE_RANGE_READ_FILTER_ENABLED)
|
||||
def shuffleForceFallbackEnabled: Boolean = get(SPARK_SHUFFLE_FORCE_FALLBACK_ENABLED)
|
||||
@ -5213,6 +5217,27 @@ object CelebornConf extends Logging {
|
||||
.intConf
|
||||
.createWithDefault(2000)
|
||||
|
||||
val CLIENT_SHUFFLE_GET_REDUCER_FILE_GROUP_BROADCAST_ENABLED =
|
||||
buildConf("celeborn.client.spark.shuffle.getReducerFileGroup.broadcast.enabled")
|
||||
.categories("client")
|
||||
.doc(
|
||||
"Whether to leverage Spark broadcast mechanism to send the GetReducerFileGroupResponse. " +
|
||||
"If the response size is large and Spark executor number is large, the Spark driver network " +
|
||||
"may be exhausted because each executor will pull the response from the driver. With broadcasting " +
|
||||
"GetReducerFileGroupResponse, it prevents the driver from being the bottleneck in sending out multiple " +
|
||||
"copies of the GetReducerFileGroupResponse (one per executor).")
|
||||
.version("0.6.0")
|
||||
.booleanConf
|
||||
.createWithDefault(false)
|
||||
|
||||
val CLIENT_SHUFFLE_GET_REDUCER_FILE_GROUP_BROADCAST_MINI_SIZE =
|
||||
buildConf("celeborn.client.spark.shuffle.getReducerFileGroup.broadcast.miniSize")
|
||||
.categories("client")
|
||||
.doc("The size at which we use Broadcast to send the GetReducerFileGroupResponse to the executors.")
|
||||
.version("0.6.0")
|
||||
.bytesConf(ByteUnit.BYTE)
|
||||
.createWithDefaultString("512k")
|
||||
|
||||
val SPARK_SHUFFLE_WRITER_MODE: ConfigEntry[String] =
|
||||
buildConf("celeborn.client.spark.shuffle.writer")
|
||||
.withAlternative("celeborn.shuffle.writer")
|
||||
|
||||
@ -22,6 +22,7 @@ import java.util.{Collections, UUID}
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import com.google.protobuf.ByteString
|
||||
import org.roaringbitmap.RoaringBitmap
|
||||
|
||||
import org.apache.celeborn.common.identity.UserIdentifier
|
||||
@ -285,10 +286,11 @@ object ControlMessages extends Logging {
|
||||
// Path can't be serialized
|
||||
case class GetReducerFileGroupResponse(
|
||||
status: StatusCode,
|
||||
fileGroup: util.Map[Integer, util.Set[PartitionLocation]],
|
||||
attempts: Array[Int],
|
||||
fileGroup: util.Map[Integer, util.Set[PartitionLocation]] = Collections.emptyMap(),
|
||||
attempts: Array[Int] = Array.emptyIntArray,
|
||||
partitionIds: util.Set[Integer] = Collections.emptySet[Integer](),
|
||||
pushFailedBatches: util.Map[String, util.Set[PushFailedBatch]] = Collections.emptyMap())
|
||||
pushFailedBatches: util.Map[String, util.Set[PushFailedBatch]] = Collections.emptyMap(),
|
||||
broadcast: Array[Byte] = Array.emptyByteArray)
|
||||
extends MasterMessage
|
||||
|
||||
object WorkerExclude {
|
||||
@ -752,7 +754,13 @@ object ControlMessages extends Logging {
|
||||
.build().toByteArray
|
||||
new TransportMessage(MessageType.GET_REDUCER_FILE_GROUP, payload)
|
||||
|
||||
case GetReducerFileGroupResponse(status, fileGroup, attempts, partitionIds, failedBatches) =>
|
||||
case GetReducerFileGroupResponse(
|
||||
status,
|
||||
fileGroup,
|
||||
attempts,
|
||||
partitionIds,
|
||||
failedBatches,
|
||||
broadcast) =>
|
||||
val builder = PbGetReducerFileGroupResponse
|
||||
.newBuilder()
|
||||
.setStatus(status.getValue)
|
||||
@ -770,6 +778,7 @@ object ControlMessages extends Logging {
|
||||
case (uniqueId, pushFailedBatchSet) =>
|
||||
(uniqueId, PbSerDeUtils.toPbPushFailedBatchSet(pushFailedBatchSet))
|
||||
}.asJava)
|
||||
builder.setBroadcast(ByteString.copyFrom(broadcast))
|
||||
val payload = builder.build().toByteArray
|
||||
new TransportMessage(MessageType.GET_REDUCER_FILE_GROUP_RESPONSE, payload)
|
||||
|
||||
@ -1198,12 +1207,14 @@ object ControlMessages extends Logging {
|
||||
case (uniqueId, pushFailedBatchSet) =>
|
||||
(uniqueId, PbSerDeUtils.fromPbPushFailedBatchSet(pushFailedBatchSet))
|
||||
}.toMap.asJava
|
||||
val broadcast = pbGetReducerFileGroupResponse.getBroadcast.toByteArray
|
||||
GetReducerFileGroupResponse(
|
||||
Utils.toStatusCode(pbGetReducerFileGroupResponse.getStatus),
|
||||
fileGroup,
|
||||
attempts,
|
||||
partitionIds,
|
||||
pushFailedBatches)
|
||||
pushFailedBatches,
|
||||
broadcast)
|
||||
|
||||
case GET_SHUFFLE_ID_VALUE =>
|
||||
message.getParsedPayload()
|
||||
|
||||
@ -0,0 +1,70 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.celeborn.common.util
|
||||
|
||||
import java.util.concurrent.{Callable, ConcurrentHashMap}
|
||||
|
||||
/**
|
||||
* This class is copied from Apache Spark.
|
||||
* A special locking mechanism to provide locking with a given key. By providing the same key
|
||||
* (identity is tested using the `equals` method), we ensure there is only one `func` running at
|
||||
* the same time.
|
||||
*
|
||||
* @tparam K the type of key to identify a lock. This type must implement `equals` and `hashCode`
|
||||
* correctly as it will be the key type of an internal Map.
|
||||
*/
|
||||
class KeyLock[K] {
|
||||
|
||||
private val lockMap = new ConcurrentHashMap[K, AnyRef]()
|
||||
|
||||
private def acquireLock(key: K): Unit = {
|
||||
while (true) {
|
||||
val lock = lockMap.putIfAbsent(key, new Object)
|
||||
if (lock == null) return
|
||||
lock.synchronized {
|
||||
while (lockMap.get(key) eq lock) {
|
||||
lock.wait()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private def releaseLock(key: K): Unit = {
|
||||
val lock = lockMap.remove(key)
|
||||
lock.synchronized {
|
||||
lock.notifyAll()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Run `func` under a lock identified by the given key. Multiple calls with the same key
|
||||
* (identity is tested using the `equals` method) will be locked properly to ensure there is only
|
||||
* one `func` running at the same time.
|
||||
*/
|
||||
def withLock[T](key: K)(func: Callable[T]): T = {
|
||||
if (key == null) {
|
||||
throw new NullPointerException("key must not be null")
|
||||
}
|
||||
acquireLock(key)
|
||||
try {
|
||||
func.call()
|
||||
} finally {
|
||||
releaseLock(key)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -122,6 +122,8 @@ license: |
|
||||
| celeborn.client.spark.shuffle.fallback.numPartitionsThreshold | 2147483647 | false | Celeborn will only accept shuffle of partition number lower than this configuration value. This configuration only takes effect when `celeborn.client.spark.shuffle.fallback.policy` is `AUTO`. | 0.5.0 | celeborn.shuffle.forceFallback.numPartitionsThreshold,celeborn.client.spark.shuffle.forceFallback.numPartitionsThreshold |
|
||||
| celeborn.client.spark.shuffle.fallback.policy | AUTO | false | Celeborn supports the following kind of fallback policies. 1. ALWAYS: always use spark built-in shuffle implementation; 2. AUTO: prefer to use celeborn shuffle implementation, and fallback to use spark built-in shuffle implementation based on certain factors, e.g. availability of enough workers and quota, shuffle partition number; 3. NEVER: always use celeborn shuffle implementation, and fail fast when it it is concluded that fallback is required based on factors above. | 0.5.0 | |
|
||||
| celeborn.client.spark.shuffle.forceFallback.enabled | false | false | Always use spark built-in shuffle implementation. This configuration is deprecated, consider configuring `celeborn.client.spark.shuffle.fallback.policy` instead. | 0.3.0 | celeborn.shuffle.forceFallback.enabled |
|
||||
| celeborn.client.spark.shuffle.getReducerFileGroup.broadcast.enabled | false | false | Whether to leverage Spark broadcast mechanism to send the GetReducerFileGroupResponse. If the response size is large and Spark executor number is large, the Spark driver network may be exhausted because each executor will pull the response from the driver. With broadcasting GetReducerFileGroupResponse, it prevents the driver from being the bottleneck in sending out multiple copies of the GetReducerFileGroupResponse (one per executor). | 0.6.0 | |
|
||||
| celeborn.client.spark.shuffle.getReducerFileGroup.broadcast.miniSize | 512k | false | The size at which we use Broadcast to send the GetReducerFileGroupResponse to the executors. | 0.6.0 | |
|
||||
| celeborn.client.spark.shuffle.writer | HASH | false | Celeborn supports the following kind of shuffle writers. 1. hash: hash-based shuffle writer works fine when shuffle partition count is normal; 2. sort: sort-based shuffle writer works fine when memory pressure is high or shuffle partition count is huge. This configuration only takes effect when celeborn.client.spark.push.dynamicWriteMode.enabled is false. | 0.3.0 | celeborn.shuffle.writer |
|
||||
| celeborn.client.spark.stageRerun.enabled | true | false | Whether to enable stage rerun. If true, client throws FetchFailedException instead of CelebornIOException. | 0.4.0 | celeborn.client.spark.fetch.throwsFetchFailure |
|
||||
| celeborn.identity.provider | org.apache.celeborn.common.identity.DefaultIdentityProvider | false | IdentityProvider class name. Default class is `org.apache.celeborn.common.identity.DefaultIdentityProvider`. Optional values: org.apache.celeborn.common.identity.HadoopBasedIdentityProvider user name will be obtained by UserGroupInformation.getUserName; org.apache.celeborn.common.identity.DefaultIdentityProvider user name and tenant id are default values or user-specific values. | 0.6.0 | celeborn.quota.identity.provider |
|
||||
|
||||
@ -18,11 +18,13 @@
|
||||
package org.apache.celeborn.tests.spark
|
||||
|
||||
import org.apache.spark.SparkConf
|
||||
import org.apache.spark.shuffle.celeborn.SparkUtils
|
||||
import org.apache.spark.sql.SparkSession
|
||||
import org.scalatest.BeforeAndAfterEach
|
||||
import org.scalatest.funsuite.AnyFunSuite
|
||||
|
||||
import org.apache.celeborn.client.ShuffleClient
|
||||
import org.apache.celeborn.common.CelebornConf
|
||||
import org.apache.celeborn.common.protocol.ShuffleMode
|
||||
|
||||
class CelebornHashSuite extends AnyFunSuite
|
||||
@ -64,4 +66,43 @@ class CelebornHashSuite extends AnyFunSuite
|
||||
|
||||
celebornSparkSession.stop()
|
||||
}
|
||||
|
||||
test("celeborn spark integration test - GetReducerFileGroupResponse broadcast") {
|
||||
SparkUtils.getReducerFileGroupResponseBroadcasts.clear()
|
||||
SparkUtils.getReducerFileGroupResponseBroadcastNum.set(0)
|
||||
val sparkConf = new SparkConf().setAppName("celeborn-demo").setMaster("local[2]")
|
||||
.set(
|
||||
s"spark.${CelebornConf.CLIENT_SHUFFLE_GET_REDUCER_FILE_GROUP_BROADCAST_ENABLED.key}",
|
||||
"true")
|
||||
.set(
|
||||
s"spark.${CelebornConf.CLIENT_SHUFFLE_GET_REDUCER_FILE_GROUP_BROADCAST_MINI_SIZE.key}",
|
||||
"0")
|
||||
val sparkSession = SparkSession.builder().config(sparkConf).getOrCreate()
|
||||
val combineResult = combine(sparkSession)
|
||||
val groupbyResult = groupBy(sparkSession)
|
||||
val repartitionResult = repartition(sparkSession)
|
||||
val sqlResult = runsql(sparkSession)
|
||||
|
||||
Thread.sleep(3000L)
|
||||
sparkSession.stop()
|
||||
|
||||
val celebornSparkSession = SparkSession.builder()
|
||||
.config(updateSparkConf(sparkConf, ShuffleMode.HASH))
|
||||
.getOrCreate()
|
||||
val celebornCombineResult = combine(celebornSparkSession)
|
||||
val celebornGroupbyResult = groupBy(celebornSparkSession)
|
||||
val celebornRepartitionResult = repartition(celebornSparkSession)
|
||||
val celebornSqlResult = runsql(celebornSparkSession)
|
||||
|
||||
assert(combineResult.equals(celebornCombineResult))
|
||||
assert(groupbyResult.equals(celebornGroupbyResult))
|
||||
assert(repartitionResult.equals(celebornRepartitionResult))
|
||||
assert(combineResult.equals(celebornCombineResult))
|
||||
assert(sqlResult.equals(celebornSqlResult))
|
||||
assert(SparkUtils.getReducerFileGroupResponseBroadcastNum.get() > 0)
|
||||
|
||||
celebornSparkSession.stop()
|
||||
SparkUtils.getReducerFileGroupResponseBroadcasts.clear()
|
||||
SparkUtils.getReducerFileGroupResponseBroadcastNum.set(0)
|
||||
}
|
||||
}
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
package org.apache.celeborn.tests.spark
|
||||
|
||||
import org.apache.spark.SparkConf
|
||||
import org.apache.spark.shuffle.celeborn.SparkUtils
|
||||
import org.apache.spark.sql.SparkSession
|
||||
import org.scalatest.BeforeAndAfterEach
|
||||
import org.scalatest.funsuite.AnyFunSuite
|
||||
@ -66,4 +67,45 @@ class CelebornSortSuite extends AnyFunSuite
|
||||
|
||||
celebornSparkSession.stop()
|
||||
}
|
||||
|
||||
test("celeborn spark integration test - GetReducerFileGroupResponse broadcast") {
|
||||
SparkUtils.getReducerFileGroupResponseBroadcasts.clear()
|
||||
SparkUtils.getReducerFileGroupResponseBroadcastNum.set(0)
|
||||
val sparkConf = new SparkConf().setAppName("celeborn-demo").setMaster("local[2]")
|
||||
.set(s"spark.${CelebornConf.CLIENT_PUSH_SORT_RANDOMIZE_PARTITION_ENABLED.key}", "false")
|
||||
.set(
|
||||
s"spark.${CelebornConf.CLIENT_SHUFFLE_GET_REDUCER_FILE_GROUP_BROADCAST_ENABLED.key}",
|
||||
"true")
|
||||
.set(
|
||||
s"spark.${CelebornConf.CLIENT_SHUFFLE_GET_REDUCER_FILE_GROUP_BROADCAST_MINI_SIZE.key}",
|
||||
"0")
|
||||
|
||||
val sparkSession = SparkSession.builder().config(sparkConf).getOrCreate()
|
||||
val combineResult = combine(sparkSession)
|
||||
val groupbyResult = groupBy(sparkSession)
|
||||
val repartitionResult = repartition(sparkSession)
|
||||
val sqlResult = runsql(sparkSession)
|
||||
|
||||
Thread.sleep(3000L)
|
||||
sparkSession.stop()
|
||||
|
||||
val celebornSparkSession = SparkSession.builder()
|
||||
.config(updateSparkConf(sparkConf, ShuffleMode.SORT))
|
||||
.getOrCreate()
|
||||
val celebornCombineResult = combine(celebornSparkSession)
|
||||
val celebornGroupbyResult = groupBy(celebornSparkSession)
|
||||
val celebornRepartitionResult = repartition(celebornSparkSession)
|
||||
val celebornSqlResult = runsql(celebornSparkSession)
|
||||
|
||||
assert(combineResult.equals(celebornCombineResult))
|
||||
assert(groupbyResult.equals(celebornGroupbyResult))
|
||||
assert(repartitionResult.equals(celebornRepartitionResult))
|
||||
assert(combineResult.equals(celebornCombineResult))
|
||||
assert(sqlResult.equals(celebornSqlResult))
|
||||
assert(SparkUtils.getReducerFileGroupResponseBroadcastNum.get() > 0)
|
||||
|
||||
celebornSparkSession.stop()
|
||||
SparkUtils.getReducerFileGroupResponseBroadcasts.clear()
|
||||
SparkUtils.getReducerFileGroupResponseBroadcastNum.set(0)
|
||||
}
|
||||
}
|
||||
|
||||
@ -29,7 +29,9 @@ import org.scalatest.funsuite.AnyFunSuite
|
||||
import org.scalatest.time.SpanSugar.convertIntToGrainOfTime
|
||||
|
||||
import org.apache.celeborn.client.ShuffleClient
|
||||
import org.apache.celeborn.common.protocol.ShuffleMode
|
||||
import org.apache.celeborn.common.protocol.{PartitionLocation, ShuffleMode}
|
||||
import org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse
|
||||
import org.apache.celeborn.common.protocol.message.StatusCode
|
||||
import org.apache.celeborn.tests.spark.SparkTestBase
|
||||
|
||||
class SparkUtilsSuite extends AnyFunSuite
|
||||
@ -157,4 +159,61 @@ class SparkUtilsSuite extends AnyFunSuite
|
||||
sparkSession.stop()
|
||||
}
|
||||
}
|
||||
|
||||
test("serialize/deserialize GetReducerFileGroupResponse with broadcast") {
|
||||
val sparkConf = new SparkConf().setAppName("rss-demo").setMaster("local[2,3]")
|
||||
val sparkSession = SparkSession.builder()
|
||||
.config(updateSparkConf(sparkConf, ShuffleMode.HASH))
|
||||
.config("spark.sql.shuffle.partitions", 2)
|
||||
.config("spark.celeborn.shuffle.forceFallback.partition.enabled", false)
|
||||
.config("spark.celeborn.client.spark.fetch.throwsFetchFailure", "true")
|
||||
.config(
|
||||
"spark.shuffle.manager",
|
||||
"org.apache.spark.shuffle.celeborn.TestCelebornShuffleManager")
|
||||
.getOrCreate()
|
||||
|
||||
try {
|
||||
val shuffleId = Integer.MAX_VALUE
|
||||
val getReducerFileGroupResponse = GetReducerFileGroupResponse(
|
||||
StatusCode.SUCCESS,
|
||||
Map(Integer.valueOf(shuffleId) -> Set(new PartitionLocation(
|
||||
0,
|
||||
1,
|
||||
"localhost",
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
4,
|
||||
PartitionLocation.Mode.REPLICA)).asJava).asJava,
|
||||
Array(1),
|
||||
Set(Integer.valueOf(shuffleId)).asJava)
|
||||
|
||||
val serializedBytes =
|
||||
SparkUtils.serializeGetReducerFileGroupResponse(shuffleId, getReducerFileGroupResponse)
|
||||
assert(serializedBytes != null && serializedBytes.length > 0)
|
||||
val broadcast = SparkUtils.getReducerFileGroupResponseBroadcasts.values().asScala.head._1
|
||||
assert(broadcast.isValid)
|
||||
|
||||
val deserializedGetReducerFileGroupResponse =
|
||||
SparkUtils.deserializeGetReducerFileGroupResponse(shuffleId, serializedBytes)
|
||||
assert(deserializedGetReducerFileGroupResponse.status == getReducerFileGroupResponse.status)
|
||||
assert(
|
||||
deserializedGetReducerFileGroupResponse.fileGroup == getReducerFileGroupResponse.fileGroup)
|
||||
assert(java.util.Arrays.equals(
|
||||
deserializedGetReducerFileGroupResponse.attempts,
|
||||
getReducerFileGroupResponse.attempts))
|
||||
assert(deserializedGetReducerFileGroupResponse.partitionIds == getReducerFileGroupResponse.partitionIds)
|
||||
assert(
|
||||
deserializedGetReducerFileGroupResponse.pushFailedBatches == getReducerFileGroupResponse.pushFailedBatches)
|
||||
|
||||
assert(!SparkUtils.getReducerFileGroupResponseBroadcasts.isEmpty)
|
||||
SparkUtils.invalidateSerializedGetReducerFileGroupResponse(shuffleId)
|
||||
assert(SparkUtils.getReducerFileGroupResponseBroadcasts.isEmpty)
|
||||
assert(!broadcast.isValid)
|
||||
} finally {
|
||||
sparkSession.stop()
|
||||
SparkUtils.getReducerFileGroupResponseBroadcasts.clear()
|
||||
SparkUtils.getReducerFileGroupResponseBroadcastNum.set(0)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user