diff --git a/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala b/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala index 16d781069..f885de1ad 100644 --- a/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala +++ b/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala @@ -17,16 +17,13 @@ package org.apache.celeborn.client.commit -import java.nio.ByteBuffer import java.util -import java.util.concurrent.{Callable, ConcurrentHashMap, TimeUnit} +import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.atomic.{AtomicLong, LongAdder} import scala.collection.JavaConverters._ import scala.collection.mutable -import com.google.common.cache.{Cache, CacheBuilder} - import org.apache.celeborn.client.CommitManager.CommittedPartitionInfo import org.apache.celeborn.client.LifecycleManager.{ShuffleAllocatedWorkers, ShuffleFailedWorkers, ShuffleFileGroups} import org.apache.celeborn.client.ShuffleCommittedInfo @@ -34,10 +31,9 @@ import org.apache.celeborn.common.CelebornConf import org.apache.celeborn.common.internal.Logging import org.apache.celeborn.common.meta.{ShufflePartitionLocationInfo, WorkerInfo} import org.apache.celeborn.common.protocol.{PartitionLocation, PartitionType} -import org.apache.celeborn.common.protocol.message.ControlMessages.{CommitFiles, CommitFilesResponse, GetReducerFileGroupResponse} +import org.apache.celeborn.common.protocol.message.ControlMessages.{CommitFiles, CommitFilesResponse} import org.apache.celeborn.common.protocol.message.StatusCode import org.apache.celeborn.common.rpc.{RpcCallContext, RpcEndpointRef} -import org.apache.celeborn.common.rpc.netty.{LocalNettyRpcCallContext, RemoteNettyRpcCallContext} import org.apache.celeborn.common.util.{CollectionUtils, ThreadUtils, Utils} // Can Remove this if celeborn don't support scala211 in future import org.apache.celeborn.common.util.FunctionConverter._ @@ -55,20 +51,11 @@ abstract class CommitHandler( private val pushReplicateEnabled = conf.pushReplicateEnabled private val testRetryCommitFiles = conf.testRetryCommitFiles - private val rpcCacheSize = conf.rpcCacheSize - private val rpcCacheConcurrencyLevel = conf.rpcCacheConcurrencyLevel - private val rpcCacheExpireTime = conf.rpcCacheExpireTime private val commitEpoch = new AtomicLong() private val totalWritten = new LongAdder private val fileCount = new LongAdder - private val reducerFileGroupsMap = new ShuffleFileGroups - // noinspection UnstableApiUsage - private val getReducerFileGroupRpcCache: Cache[Int, ByteBuffer] = CacheBuilder.newBuilder() - .concurrencyLevel(rpcCacheConcurrencyLevel) - .expireAfterWrite(rpcCacheExpireTime, TimeUnit.MILLISECONDS) - .maximumSize(rpcCacheSize) - .build().asInstanceOf[Cache[Int, ByteBuffer]] + protected val reducerFileGroupsMap = new ShuffleFileGroups def getPartitionType(): PartitionType @@ -151,36 +138,12 @@ abstract class CommitHandler( shuffleId: Int, recordWorkerFailure: ShuffleFailedWorkers => Unit): Boolean - def handleGetReducerFileGroup(context: RpcCallContext, shuffleId: Int): Unit = { - if (isStageDataLost(shuffleId)) { - context.reply( - GetReducerFileGroupResponse( - StatusCode.SHUFFLE_DATA_LOST, - new ConcurrentHashMap(), - Array.empty)) - } else { - if (context.isInstanceOf[LocalNettyRpcCallContext]) { - // This branch is for the UTs - context.reply(GetReducerFileGroupResponse( - StatusCode.SUCCESS, - reducerFileGroupsMap.getOrDefault(shuffleId, new ConcurrentHashMap()), - getMapperAttempts(shuffleId))) - } else { - val cachedMsg = getReducerFileGroupRpcCache.get( - shuffleId, - new Callable[ByteBuffer]() { - override def call(): ByteBuffer = { - val returnedMsg = GetReducerFileGroupResponse( - StatusCode.SUCCESS, - reducerFileGroupsMap.getOrDefault(shuffleId, new ConcurrentHashMap()), - getMapperAttempts(shuffleId)) - context.asInstanceOf[RemoteNettyRpcCallContext].nettyEnv.serialize(returnedMsg) - } - }) - context.asInstanceOf[RemoteNettyRpcCallContext].callback.onSuccess(cachedMsg) - } - } - } + /** + * Only Reduce partition mode supports cache all file groups for reducer. Map partition doesn't guarantee that all + * partitions are complete by the time the method is called, as downstream tasks may start early before all tasks + * are completed.So map partition may need refresh reducer file group if needed. + */ + def handleGetReducerFileGroup(context: RpcCallContext, shuffleId: Int): Unit def removeExpiredShuffle(shuffleId: Int): Unit = { reducerFileGroupsMap.remove(shuffleId) diff --git a/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala b/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala index 34e6608a8..1a4f1e128 100644 --- a/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala +++ b/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala @@ -32,6 +32,9 @@ import org.apache.celeborn.common.CelebornConf import org.apache.celeborn.common.internal.Logging import org.apache.celeborn.common.meta.{ShufflePartitionLocationInfo, WorkerInfo} import org.apache.celeborn.common.protocol.{PartitionLocation, PartitionType} +import org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse +import org.apache.celeborn.common.protocol.message.StatusCode +import org.apache.celeborn.common.rpc.RpcCallContext // Can Remove this if celeborn don't support scala211 in future import org.apache.celeborn.common.util.FunctionConverter._ import org.apache.celeborn.common.util.Utils @@ -211,4 +214,11 @@ class MapPartitionCommitHandler( inProcessingPartitionIds.remove(partitionId) (dataCommitSuccess, false) } + + override def handleGetReducerFileGroup(context: RpcCallContext, shuffleId: Int): Unit = { + context.reply(GetReducerFileGroupResponse( + StatusCode.SUCCESS, + reducerFileGroupsMap.getOrDefault(shuffleId, new ConcurrentHashMap()), + getMapperAttempts(shuffleId))) + } } diff --git a/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala index 9d6423d7e..12475a4cd 100644 --- a/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala +++ b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala @@ -17,12 +17,15 @@ package org.apache.celeborn.client.commit +import java.nio.ByteBuffer import java.util -import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.{Callable, ConcurrentHashMap, TimeUnit} import scala.collection.JavaConverters._ import scala.collection.mutable +import com.google.common.cache.{Cache, CacheBuilder} + import org.apache.celeborn.client.CommitManager.CommittedPartitionInfo import org.apache.celeborn.client.LifecycleManager.{ShuffleAllocatedWorkers, ShuffleFailedWorkers} import org.apache.celeborn.client.ShuffleCommittedInfo @@ -33,6 +36,7 @@ import org.apache.celeborn.common.protocol.{PartitionLocation, PartitionType} import org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse import org.apache.celeborn.common.protocol.message.StatusCode import org.apache.celeborn.common.rpc.RpcCallContext +import org.apache.celeborn.common.rpc.netty.{LocalNettyRpcCallContext, RemoteNettyRpcCallContext} /** * This commit handler is for ReducePartition ShuffleType, which means that a Reduce Partition contains all data @@ -55,6 +59,17 @@ class ReducePartitionCommitHandler( private val shuffleMapperAttempts = new ConcurrentHashMap[Int, Array[Int]]() private val stageEndTimeout = conf.pushStageEndTimeout + private val rpcCacheSize = conf.rpcCacheSize + private val rpcCacheConcurrencyLevel = conf.rpcCacheConcurrencyLevel + private val rpcCacheExpireTime = conf.rpcCacheExpireTime + + // noinspection UnstableApiUsage + private val getReducerFileGroupRpcCache: Cache[Int, ByteBuffer] = CacheBuilder.newBuilder() + .concurrencyLevel(rpcCacheConcurrencyLevel) + .expireAfterWrite(rpcCacheExpireTime, TimeUnit.MILLISECONDS) + .maximumSize(rpcCacheSize) + .build().asInstanceOf[Cache[Int, ByteBuffer]] + override def getPartitionType(): PartitionType = { PartitionType.REDUCE } @@ -244,7 +259,34 @@ class ReducePartitionCommitHandler( } else { logDebug("[handleGetReducerFileGroup] Wait for handleStageEnd complete cost" + s" ${cost}ms") - super.handleGetReducerFileGroup(context, shuffleId) + if (isStageDataLost(shuffleId)) { + context.reply( + GetReducerFileGroupResponse( + StatusCode.SHUFFLE_DATA_LOST, + new ConcurrentHashMap(), + Array.empty)) + } else { + // LocalNettyRpcCallContext is for the UTs + if (context.isInstanceOf[LocalNettyRpcCallContext]) { + context.reply(GetReducerFileGroupResponse( + StatusCode.SUCCESS, + reducerFileGroupsMap.getOrDefault(shuffleId, new ConcurrentHashMap()), + getMapperAttempts(shuffleId))) + } else { + val cachedMsg = getReducerFileGroupRpcCache.get( + shuffleId, + new Callable[ByteBuffer]() { + override def call(): ByteBuffer = { + val returnedMsg = GetReducerFileGroupResponse( + StatusCode.SUCCESS, + reducerFileGroupsMap.getOrDefault(shuffleId, new ConcurrentHashMap()), + getMapperAttempts(shuffleId)) + context.asInstanceOf[RemoteNettyRpcCallContext].nettyEnv.serialize(returnedMsg) + } + }) + context.asInstanceOf[RemoteNettyRpcCallContext].callback.onSuccess(cachedMsg) + } + } } }