[CELEBORN-297] don't cache file groups for map partition shuffle type (#1237)

This commit is contained in:
Shuang 2023-02-17 11:28:47 +08:00 committed by GitHub
parent 1dcfdb0c8f
commit b7ef9cf216
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 63 additions and 48 deletions

View File

@ -17,16 +17,13 @@
package org.apache.celeborn.client.commit
import java.nio.ByteBuffer
import java.util
import java.util.concurrent.{Callable, ConcurrentHashMap, TimeUnit}
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.{AtomicLong, LongAdder}
import scala.collection.JavaConverters._
import scala.collection.mutable
import com.google.common.cache.{Cache, CacheBuilder}
import org.apache.celeborn.client.CommitManager.CommittedPartitionInfo
import org.apache.celeborn.client.LifecycleManager.{ShuffleAllocatedWorkers, ShuffleFailedWorkers, ShuffleFileGroups}
import org.apache.celeborn.client.ShuffleCommittedInfo
@ -34,10 +31,9 @@ import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.internal.Logging
import org.apache.celeborn.common.meta.{ShufflePartitionLocationInfo, WorkerInfo}
import org.apache.celeborn.common.protocol.{PartitionLocation, PartitionType}
import org.apache.celeborn.common.protocol.message.ControlMessages.{CommitFiles, CommitFilesResponse, GetReducerFileGroupResponse}
import org.apache.celeborn.common.protocol.message.ControlMessages.{CommitFiles, CommitFilesResponse}
import org.apache.celeborn.common.protocol.message.StatusCode
import org.apache.celeborn.common.rpc.{RpcCallContext, RpcEndpointRef}
import org.apache.celeborn.common.rpc.netty.{LocalNettyRpcCallContext, RemoteNettyRpcCallContext}
import org.apache.celeborn.common.util.{CollectionUtils, ThreadUtils, Utils}
// Can Remove this if celeborn don't support scala211 in future
import org.apache.celeborn.common.util.FunctionConverter._
@ -55,20 +51,11 @@ abstract class CommitHandler(
private val pushReplicateEnabled = conf.pushReplicateEnabled
private val testRetryCommitFiles = conf.testRetryCommitFiles
private val rpcCacheSize = conf.rpcCacheSize
private val rpcCacheConcurrencyLevel = conf.rpcCacheConcurrencyLevel
private val rpcCacheExpireTime = conf.rpcCacheExpireTime
private val commitEpoch = new AtomicLong()
private val totalWritten = new LongAdder
private val fileCount = new LongAdder
private val reducerFileGroupsMap = new ShuffleFileGroups
// noinspection UnstableApiUsage
private val getReducerFileGroupRpcCache: Cache[Int, ByteBuffer] = CacheBuilder.newBuilder()
.concurrencyLevel(rpcCacheConcurrencyLevel)
.expireAfterWrite(rpcCacheExpireTime, TimeUnit.MILLISECONDS)
.maximumSize(rpcCacheSize)
.build().asInstanceOf[Cache[Int, ByteBuffer]]
protected val reducerFileGroupsMap = new ShuffleFileGroups
def getPartitionType(): PartitionType
@ -151,36 +138,12 @@ abstract class CommitHandler(
shuffleId: Int,
recordWorkerFailure: ShuffleFailedWorkers => Unit): Boolean
def handleGetReducerFileGroup(context: RpcCallContext, shuffleId: Int): Unit = {
if (isStageDataLost(shuffleId)) {
context.reply(
GetReducerFileGroupResponse(
StatusCode.SHUFFLE_DATA_LOST,
new ConcurrentHashMap(),
Array.empty))
} else {
if (context.isInstanceOf[LocalNettyRpcCallContext]) {
// This branch is for the UTs
context.reply(GetReducerFileGroupResponse(
StatusCode.SUCCESS,
reducerFileGroupsMap.getOrDefault(shuffleId, new ConcurrentHashMap()),
getMapperAttempts(shuffleId)))
} else {
val cachedMsg = getReducerFileGroupRpcCache.get(
shuffleId,
new Callable[ByteBuffer]() {
override def call(): ByteBuffer = {
val returnedMsg = GetReducerFileGroupResponse(
StatusCode.SUCCESS,
reducerFileGroupsMap.getOrDefault(shuffleId, new ConcurrentHashMap()),
getMapperAttempts(shuffleId))
context.asInstanceOf[RemoteNettyRpcCallContext].nettyEnv.serialize(returnedMsg)
}
})
context.asInstanceOf[RemoteNettyRpcCallContext].callback.onSuccess(cachedMsg)
}
}
}
/**
* Only Reduce partition mode supports cache all file groups for reducer. Map partition doesn't guarantee that all
* partitions are complete by the time the method is called, as downstream tasks may start early before all tasks
* are completed.So map partition may need refresh reducer file group if needed.
*/
def handleGetReducerFileGroup(context: RpcCallContext, shuffleId: Int): Unit
def removeExpiredShuffle(shuffleId: Int): Unit = {
reducerFileGroupsMap.remove(shuffleId)

View File

@ -32,6 +32,9 @@ import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.internal.Logging
import org.apache.celeborn.common.meta.{ShufflePartitionLocationInfo, WorkerInfo}
import org.apache.celeborn.common.protocol.{PartitionLocation, PartitionType}
import org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse
import org.apache.celeborn.common.protocol.message.StatusCode
import org.apache.celeborn.common.rpc.RpcCallContext
// Can Remove this if celeborn don't support scala211 in future
import org.apache.celeborn.common.util.FunctionConverter._
import org.apache.celeborn.common.util.Utils
@ -211,4 +214,11 @@ class MapPartitionCommitHandler(
inProcessingPartitionIds.remove(partitionId)
(dataCommitSuccess, false)
}
override def handleGetReducerFileGroup(context: RpcCallContext, shuffleId: Int): Unit = {
context.reply(GetReducerFileGroupResponse(
StatusCode.SUCCESS,
reducerFileGroupsMap.getOrDefault(shuffleId, new ConcurrentHashMap()),
getMapperAttempts(shuffleId)))
}
}

View File

@ -17,12 +17,15 @@
package org.apache.celeborn.client.commit
import java.nio.ByteBuffer
import java.util
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.{Callable, ConcurrentHashMap, TimeUnit}
import scala.collection.JavaConverters._
import scala.collection.mutable
import com.google.common.cache.{Cache, CacheBuilder}
import org.apache.celeborn.client.CommitManager.CommittedPartitionInfo
import org.apache.celeborn.client.LifecycleManager.{ShuffleAllocatedWorkers, ShuffleFailedWorkers}
import org.apache.celeborn.client.ShuffleCommittedInfo
@ -33,6 +36,7 @@ import org.apache.celeborn.common.protocol.{PartitionLocation, PartitionType}
import org.apache.celeborn.common.protocol.message.ControlMessages.GetReducerFileGroupResponse
import org.apache.celeborn.common.protocol.message.StatusCode
import org.apache.celeborn.common.rpc.RpcCallContext
import org.apache.celeborn.common.rpc.netty.{LocalNettyRpcCallContext, RemoteNettyRpcCallContext}
/**
* This commit handler is for ReducePartition ShuffleType, which means that a Reduce Partition contains all data
@ -55,6 +59,17 @@ class ReducePartitionCommitHandler(
private val shuffleMapperAttempts = new ConcurrentHashMap[Int, Array[Int]]()
private val stageEndTimeout = conf.pushStageEndTimeout
private val rpcCacheSize = conf.rpcCacheSize
private val rpcCacheConcurrencyLevel = conf.rpcCacheConcurrencyLevel
private val rpcCacheExpireTime = conf.rpcCacheExpireTime
// noinspection UnstableApiUsage
private val getReducerFileGroupRpcCache: Cache[Int, ByteBuffer] = CacheBuilder.newBuilder()
.concurrencyLevel(rpcCacheConcurrencyLevel)
.expireAfterWrite(rpcCacheExpireTime, TimeUnit.MILLISECONDS)
.maximumSize(rpcCacheSize)
.build().asInstanceOf[Cache[Int, ByteBuffer]]
override def getPartitionType(): PartitionType = {
PartitionType.REDUCE
}
@ -244,7 +259,34 @@ class ReducePartitionCommitHandler(
} else {
logDebug("[handleGetReducerFileGroup] Wait for handleStageEnd complete cost" +
s" ${cost}ms")
super.handleGetReducerFileGroup(context, shuffleId)
if (isStageDataLost(shuffleId)) {
context.reply(
GetReducerFileGroupResponse(
StatusCode.SHUFFLE_DATA_LOST,
new ConcurrentHashMap(),
Array.empty))
} else {
// LocalNettyRpcCallContext is for the UTs
if (context.isInstanceOf[LocalNettyRpcCallContext]) {
context.reply(GetReducerFileGroupResponse(
StatusCode.SUCCESS,
reducerFileGroupsMap.getOrDefault(shuffleId, new ConcurrentHashMap()),
getMapperAttempts(shuffleId)))
} else {
val cachedMsg = getReducerFileGroupRpcCache.get(
shuffleId,
new Callable[ByteBuffer]() {
override def call(): ByteBuffer = {
val returnedMsg = GetReducerFileGroupResponse(
StatusCode.SUCCESS,
reducerFileGroupsMap.getOrDefault(shuffleId, new ConcurrentHashMap()),
getMapperAttempts(shuffleId))
context.asInstanceOf[RemoteNettyRpcCallContext].nettyEnv.serialize(returnedMsg)
}
})
context.asInstanceOf[RemoteNettyRpcCallContext].callback.onSuccess(cachedMsg)
}
}
}
}