[ISSUE-739][REFACTOR] Use object wrap pb message method (#740)

This commit is contained in:
AngersZhuuuu 2022-10-09 11:53:48 +08:00 committed by GitHub
parent e221dbc117
commit f2a234f870
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 173 additions and 159 deletions

View File

@ -51,7 +51,6 @@ import org.apache.celeborn.common.network.protocol.PushMergedData;
import org.apache.celeborn.common.network.server.BaseMessageHandler;
import org.apache.celeborn.common.network.util.TransportConf;
import org.apache.celeborn.common.protocol.*;
import org.apache.celeborn.common.protocol.message.ControlMessages;
import org.apache.celeborn.common.protocol.message.ControlMessages.*;
import org.apache.celeborn.common.protocol.message.StatusCode;
import org.apache.celeborn.common.rpc.RpcAddress;
@ -262,7 +261,7 @@ public class ShuffleClientImpl extends ShuffleClient {
try {
PbRegisterShuffleResponse response =
driverRssMetaService.askSync(
ControlMessages.pbRegisterShuffle(appId, shuffleId, numMappers, numPartitions),
RegisterShuffle$.MODULE$.apply(appId, shuffleId, numMappers, numPartitions),
ClassTag$.MODULE$.apply(PbRegisterShuffleResponse.class));
StatusCode respStatus = Utils.toStatusCode(response.getStatus());
if (StatusCode.SUCCESS.equals(respStatus)) {
@ -400,7 +399,7 @@ public class ShuffleClientImpl extends ShuffleClient {
try {
PbChangeLocationResponse response =
driverRssMetaService.askSync(
ControlMessages.pbRevive(
Revive$.MODULE$.apply(
applicationId,
shuffleId,
mapId,
@ -718,8 +717,7 @@ public class ShuffleClientImpl extends ShuffleClient {
ShuffleClientHelper.sendShuffleSplitAsync(
driverRssMetaService,
ControlMessages.pbPartitionSplit(
applicationId, shuffleId, partitionId, loc.getEpoch(), loc),
PartitionSplit$.MODULE$.apply(applicationId, shuffleId, partitionId, loc.getEpoch(), loc),
partitionSplitPool,
splittingSet,
partitionId,
@ -982,7 +980,7 @@ public class ShuffleClientImpl extends ShuffleClient {
if (isDriver) {
try {
driverRssMetaService.send(
ControlMessages.pbUnregisterShuffle(
UnregisterShuffle$.MODULE$.apply(
applicationId, shuffleId, RssHARetryClient.genRequestId()));
} catch (Exception e) {
// If some exceptions need to be ignored, they shouldn't be logged as error-level,

View File

@ -312,7 +312,7 @@ class LifecycleManager(appId: String, val conf: RssConf) extends RpcEndpoint wit
.flatMap(_.getAllMasterLocationsWithMinEpoch(shuffleId.toString).asScala)
.filter(_.getEpoch == 0)
.toArray
context.reply(ControlMessages.pbRegisterShuffleResponse(StatusCode.SUCCESS, initialLocs))
context.reply(RegisterShuffleResponse(StatusCode.SUCCESS, initialLocs))
return
}
logInfo(s"New shuffle request, shuffleId $shuffleId, partitionType: $partitionType " +
@ -345,11 +345,11 @@ class LifecycleManager(appId: String, val conf: RssConf) extends RpcEndpoint wit
res.status match {
case StatusCode.FAILED =>
logError(s"OfferSlots RPC request failed for $shuffleId!")
reply(ControlMessages.pbRegisterShuffleResponse(StatusCode.FAILED, Array.empty))
reply(RegisterShuffleResponse(StatusCode.FAILED, Array.empty))
return
case StatusCode.SLOT_NOT_AVAILABLE =>
logError(s"OfferSlots for $shuffleId failed!")
reply(ControlMessages.pbRegisterShuffleResponse(StatusCode.SLOT_NOT_AVAILABLE, Array.empty))
reply(RegisterShuffleResponse(StatusCode.SLOT_NOT_AVAILABLE, Array.empty))
return
case StatusCode.SUCCESS =>
logInfo(s"OfferSlots for ${Utils.makeShuffleKey(applicationId, shuffleId)} Success!")
@ -388,7 +388,7 @@ class LifecycleManager(appId: String, val conf: RssConf) extends RpcEndpoint wit
// If reserve slots failed, clear allocated resources, reply ReserveSlotFailed and return.
if (!reserveSlotsSuccess) {
logError(s"reserve buffer for $shuffleId failed, reply to all.")
reply(ControlMessages.pbRegisterShuffleResponse(StatusCode.RESERVE_SLOTS_FAILED, Array.empty))
reply(RegisterShuffleResponse(StatusCode.RESERVE_SLOTS_FAILED, Array.empty))
// tell Master to release slots
requestReleaseSlots(
rssHARetryClient,
@ -423,9 +423,7 @@ class LifecycleManager(appId: String, val conf: RssConf) extends RpcEndpoint wit
// Fifth, reply the allocated partition location to ShuffleClient.
logInfo(s"Handle RegisterShuffle Success for $shuffleId.")
val allMasterPartitionLocations = slots.asScala.flatMap(_._2._1.asScala).toArray
reply(ControlMessages.pbRegisterShuffleResponse(
StatusCode.SUCCESS,
allMasterPartitionLocations))
reply(RegisterShuffleResponse(StatusCode.SUCCESS, allMasterPartitionLocations))
}
}
@ -461,7 +459,7 @@ class LifecycleManager(appId: String, val conf: RssConf) extends RpcEndpoint wit
// If shuffle not registered, reply ShuffleNotRegistered and return
if (!registeredShuffle.contains(shuffleId)) {
logError(s"[handleRevive] shuffle $shuffleId not registered!")
context.reply(pbChangeLocationResponse(StatusCode.SHUFFLE_NOT_REGISTERED, None))
context.reply(ChangeLocationResponse(StatusCode.SHUFFLE_NOT_REGISTERED, None))
return
}
@ -470,7 +468,7 @@ class LifecycleManager(appId: String, val conf: RssConf) extends RpcEndpoint wit
&& shuffleMapperAttempts.get(shuffleId)(mapId) != -1) {
logWarning(s"[handleRevive] Mapper ended, mapId $mapId, current attemptId $attemptId, " +
s"ended attemptId ${shuffleMapperAttempts.get(shuffleId)(mapId)}, shuffleId $shuffleId.")
context.reply(pbChangeLocationResponse(StatusCode.MAP_ENDED, None))
context.reply(ChangeLocationResponse(StatusCode.MAP_ENDED, None))
return
}
@ -514,7 +512,7 @@ class LifecycleManager(appId: String, val conf: RssConf) extends RpcEndpoint wit
// If new slot for the partition has been allocated, reply and return.
// Else register and allocate for it.
getLatestPartition(shuffleId, partitionId, oldEpoch).foreach { latestLoc =>
context.reply(pbChangeLocationResponse(StatusCode.SUCCESS, Some(latestLoc)))
context.reply(ChangeLocationResponse(StatusCode.SUCCESS, Some(latestLoc)))
logDebug(s"New partition found, old partition $partitionId-$oldEpoch return it." +
s" shuffleId: $shuffleId $latestLoc")
return
@ -538,7 +536,7 @@ class LifecycleManager(appId: String, val conf: RssConf) extends RpcEndpoint wit
val candidates = workersNotBlacklisted(shuffleId)
if (candidates.size < 1 || (ShouldReplicate && candidates.size < 2)) {
logError("[Update partition] failed for not enough candidates for revive.")
reply(pbChangeLocationResponse(StatusCode.SLOT_NOT_AVAILABLE, None))
reply(ChangeLocationResponse(StatusCode.SLOT_NOT_AVAILABLE, None))
return null
}
@ -551,7 +549,7 @@ class LifecycleManager(appId: String, val conf: RssConf) extends RpcEndpoint wit
if (!reserveSlotsWithRetry(applicationId, shuffleId, candidates, newlyAllocatedLocation)) {
logError(s"[Update partition] failed for $shuffleId.")
reply(pbChangeLocationResponse(StatusCode.RESERVE_SLOTS_FAILED, None))
reply(ChangeLocationResponse(StatusCode.RESERVE_SLOTS_FAILED, None))
return
}
@ -572,7 +570,7 @@ class LifecycleManager(appId: String, val conf: RssConf) extends RpcEndpoint wit
slavePartitions.asScala.headOption.map(_.getPeer)
}
reply(pbChangeLocationResponse(StatusCode.SUCCESS, newMasterLocation))
reply(ChangeLocationResponse(StatusCode.SUCCESS, newMasterLocation))
logDebug(s"Renew $shuffleId $partitionId" +
"$oldEpoch->${newMasterLocation.getEpoch} partition success.")
}
@ -1249,7 +1247,7 @@ class LifecycleManager(appId: String, val conf: RssConf) extends RpcEndpoint wit
requestUnregisterShuffle(
rssHARetryClient,
ControlMessages.pbUnregisterShuffle(appId, shuffleId, RssHARetryClient.genRequestId()))
UnregisterShuffle(appId, shuffleId, RssHARetryClient.genRequestId()))
}
}
}
@ -1359,7 +1357,7 @@ class LifecycleManager(appId: String, val conf: RssConf) extends RpcEndpoint wit
} catch {
case e: Exception =>
logError(s"AskSync UnregisterShuffle for ${message.getShuffleId} failed.", e)
pbUnregisterShuffleResponse(StatusCode.FAILED)
UnregisterShuffleResponse(StatusCode.FAILED)
}
}

View File

@ -45,7 +45,7 @@ import org.apache.celeborn.common.network.client.TransportClient;
import org.apache.celeborn.common.network.client.TransportClientFactory;
import org.apache.celeborn.common.protocol.PartitionLocation;
import org.apache.celeborn.common.protocol.PbRegisterShuffleResponse;
import org.apache.celeborn.common.protocol.message.ControlMessages;
import org.apache.celeborn.common.protocol.message.ControlMessages.*;
import org.apache.celeborn.common.protocol.message.StatusCode;
import org.apache.celeborn.common.rpc.RpcEndpointRef;
@ -194,15 +194,15 @@ public class ShuffleClientSuiteJ {
conf.set("rss.client.compression.codec", codec.name());
conf.set("rss.pushdata.retry.thread.num", "1");
conf.set("rss.push.data.buffer.size", "1K");
shuffleClient = new ShuffleClientImpl(conf, new ControlMessages.UserIdentifier("mock", "mock"));
shuffleClient = new ShuffleClientImpl(conf, new UserIdentifier("mock", "mock"));
masterLocation.setPeer(slaveLocation);
when(endpointRef.askSync(
ControlMessages.pbRegisterShuffle(TEST_APPLICATION_ID, TEST_SHUFFLE_ID, 1, 1),
RegisterShuffle$.MODULE$.apply(TEST_APPLICATION_ID, TEST_SHUFFLE_ID, 1, 1),
ClassTag$.MODULE$.apply(PbRegisterShuffleResponse.class)))
.thenAnswer(
t ->
ControlMessages.pbRegisterShuffleResponse(
RegisterShuffleResponse$.MODULE$.apply(
StatusCode.SUCCESS, new PartitionLocation[] {masterLocation}));
shuffleClient.setupMetaServiceRef(endpointRef);

View File

@ -57,8 +57,8 @@ object ControlMessages extends Logging {
/**
* ==========================================
* handled by master
* ==========================================
* handled by master
* ==========================================
*/
val pbCheckForWorkerTimeout: PbCheckForWorkerTimeout =
PbCheckForWorkerTimeout.newBuilder().build()
@ -74,32 +74,34 @@ object ControlMessages extends Logging {
*/
case object OneWayMessageResponse extends Message
def pbRegisterWorker(
host: String,
rpcPort: Int,
pushPort: Int,
fetchPort: Int,
replicatePort: Int,
disks: Map[String, DiskInfo],
userResourceConsumption: Map[UserIdentifier, ResourceConsumption],
requestId: String): PbRegisterWorker = {
val pbDisks = disks.values.map(PbSerDeUtils.toPbDiskInfo).asJava
val pbUserResourceConsumption = userResourceConsumption
.map { case (userIdentifier, resourceConsumption) =>
(userIdentifier.toString, resourceConsumption)
}
.mapValues(PbSerDeUtils.toPbResourceConsumption)
.asJava
PbRegisterWorker.newBuilder()
.setHost(host)
.setRpcPort(rpcPort)
.setPushPort(pushPort)
.setFetchPort(fetchPort)
.setReplicatePort(replicatePort)
.addAllDisks(pbDisks)
.putAllUserResourceConsumption(pbUserResourceConsumption)
.setRequestId(requestId)
.build()
object RegisterWorker {
def apply(
host: String,
rpcPort: Int,
pushPort: Int,
fetchPort: Int,
replicatePort: Int,
disks: Map[String, DiskInfo],
userResourceConsumption: Map[UserIdentifier, ResourceConsumption],
requestId: String): PbRegisterWorker = {
val pbDisks = disks.values.map(PbSerDeUtils.toPbDiskInfo).asJava
val pbUserResourceConsumption = userResourceConsumption
.map { case (userIdentifier, resourceConsumption) =>
(userIdentifier.toString, resourceConsumption)
}
.mapValues(PbSerDeUtils.toPbResourceConsumption)
.asJava
PbRegisterWorker.newBuilder()
.setHost(host)
.setRpcPort(rpcPort)
.setPushPort(pushPort)
.setFetchPort(fetchPort)
.setReplicatePort(replicatePort)
.addAllDisks(pbDisks)
.putAllUserResourceConsumption(pbUserResourceConsumption)
.setRequestId(requestId)
.build()
}
}
case class HeartbeatFromWorker(
@ -117,26 +119,30 @@ object ControlMessages extends Logging {
expiredShuffleKeys: util.HashSet[String],
registered: Boolean) extends MasterMessage
def pbRegisterShuffle(
appId: String,
shuffleId: Int,
numMappers: Int,
numPartitions: Int): PbRegisterShuffle =
PbRegisterShuffle.newBuilder()
.setApplicationId(appId)
.setShuffleId(shuffleId)
.setNumMapppers(numMappers)
.setNumPartitions(numPartitions)
.build()
object RegisterShuffle {
def apply(
appId: String,
shuffleId: Int,
numMappers: Int,
numPartitions: Int): PbRegisterShuffle =
PbRegisterShuffle.newBuilder()
.setApplicationId(appId)
.setShuffleId(shuffleId)
.setNumMapppers(numMappers)
.setNumPartitions(numPartitions)
.build()
}
def pbRegisterShuffleResponse(
status: StatusCode,
partitionLocations: Array[PartitionLocation]): PbRegisterShuffleResponse =
PbRegisterShuffleResponse.newBuilder()
.setStatus(status.getValue)
.addAllPartitionLocations(
partitionLocations.map(PartitionLocation.toPbPartitionLocation).toSeq.asJava)
.build()
object RegisterShuffleResponse {
def apply(
status: StatusCode,
partitionLocations: Array[PartitionLocation]): PbRegisterShuffleResponse =
PbRegisterShuffleResponse.newBuilder()
.setStatus(status.getValue)
.addAllPartitionLocations(
partitionLocations.map(PartitionLocation.toPbPartitionLocation).toSeq.asJava)
.build()
}
case class RequestSlots(
applicationId: String,
@ -164,49 +170,55 @@ object ControlMessages extends Logging {
workerResource: WorkerResource)
extends MasterMessage
def pbRevive(
appId: String,
shuffleId: Int,
mapId: Int,
attemptId: Int,
partitionId: Int,
epoch: Int,
oldPartition: PartitionLocation,
cause: StatusCode): PbRevive =
PbRevive.newBuilder()
.setApplicationId(appId)
.setShuffleId(shuffleId)
.setMapId(mapId)
.setAttemptId(attemptId)
.setPartitionId(partitionId)
.setEpoch(epoch)
.setOldPartition(PartitionLocation.toPbPartitionLocation(oldPartition))
.setStatus(cause.getValue)
.build()
object Revive {
def apply(
appId: String,
shuffleId: Int,
mapId: Int,
attemptId: Int,
partitionId: Int,
epoch: Int,
oldPartition: PartitionLocation,
cause: StatusCode): PbRevive =
PbRevive.newBuilder()
.setApplicationId(appId)
.setShuffleId(shuffleId)
.setMapId(mapId)
.setAttemptId(attemptId)
.setPartitionId(partitionId)
.setEpoch(epoch)
.setOldPartition(PartitionLocation.toPbPartitionLocation(oldPartition))
.setStatus(cause.getValue)
.build()
}
def pbPartitionSplit(
appId: String,
shuffleId: Int,
partitionId: Int,
epoch: Int,
oldPartition: PartitionLocation): PbPartitionSplit =
PbPartitionSplit.newBuilder()
.setApplicationId(appId)
.setShuffleId(shuffleId)
.setPartitionId(partitionId)
.setEpoch(epoch)
.setOldPartition(PartitionLocation.toPbPartitionLocation(oldPartition))
.build()
object PartitionSplit {
def apply(
appId: String,
shuffleId: Int,
partitionId: Int,
epoch: Int,
oldPartition: PartitionLocation): PbPartitionSplit =
PbPartitionSplit.newBuilder()
.setApplicationId(appId)
.setShuffleId(shuffleId)
.setPartitionId(partitionId)
.setEpoch(epoch)
.setOldPartition(PartitionLocation.toPbPartitionLocation(oldPartition))
.build()
}
def pbChangeLocationResponse(
status: StatusCode,
partitionLocationOpt: Option[PartitionLocation]): PbChangeLocationResponse = {
val builder = PbChangeLocationResponse.newBuilder()
builder.setStatus(status.getValue)
partitionLocationOpt.foreach { partitionLocation =>
builder.setLocation(PartitionLocation.toPbPartitionLocation(partitionLocation))
object ChangeLocationResponse {
def apply(
status: StatusCode,
partitionLocationOpt: Option[PartitionLocation]): PbChangeLocationResponse = {
val builder = PbChangeLocationResponse.newBuilder()
builder.setStatus(status.getValue)
partitionLocationOpt.foreach { partitionLocation =>
builder.setLocation(PartitionLocation.toPbPartitionLocation(partitionLocation))
}
builder.build()
}
builder.build()
}
case class MapperEnd(
@ -229,45 +241,53 @@ object ControlMessages extends Logging {
attempts: Array[Int])
extends MasterMessage
def pbWorkerLost(
host: String,
rpcPort: Int,
pushPort: Int,
fetchPort: Int,
replicatePort: Int,
requestId: String): PbWorkerLost = PbWorkerLost.newBuilder()
.setHost(host)
.setRpcPort(rpcPort)
.setPushPort(pushPort)
.setFetchPort(fetchPort)
.setReplicatePort(replicatePort)
.setRequestId(requestId)
.build()
def pbWorkerLostResponse(success: Boolean): PbWorkerLostResponse =
PbWorkerLostResponse.newBuilder()
.setSuccess(success)
object WorkerLost {
def apply(
host: String,
rpcPort: Int,
pushPort: Int,
fetchPort: Int,
replicatePort: Int,
requestId: String): PbWorkerLost = PbWorkerLost.newBuilder()
.setHost(host)
.setRpcPort(rpcPort)
.setPushPort(pushPort)
.setFetchPort(fetchPort)
.setReplicatePort(replicatePort)
.setRequestId(requestId)
.build()
}
object WorkerLostResponse {
def apply(success: Boolean): PbWorkerLostResponse =
PbWorkerLostResponse.newBuilder()
.setSuccess(success)
.build()
}
case class StageEnd(applicationId: String, shuffleId: Int) extends MasterMessage
case class StageEndResponse(status: StatusCode)
extends MasterMessage
def pbUnregisterShuffle(
appId: String,
shuffleId: Int,
requestId: String): PbUnregisterShuffle =
PbUnregisterShuffle.newBuilder()
.setAppId(appId)
.setShuffleId(shuffleId)
.setRequestId(requestId)
.build()
object UnregisterShuffle {
def apply(
appId: String,
shuffleId: Int,
requestId: String): PbUnregisterShuffle =
PbUnregisterShuffle.newBuilder()
.setAppId(appId)
.setShuffleId(shuffleId)
.setRequestId(requestId)
.build()
}
def pbUnregisterShuffleResponse(status: StatusCode): PbUnregisterShuffleResponse =
PbUnregisterShuffleResponse.newBuilder()
.setStatus(status.getValue)
.build()
object UnregisterShuffleResponse {
def apply(status: StatusCode): PbUnregisterShuffleResponse =
PbUnregisterShuffleResponse.newBuilder()
.setStatus(status.getValue)
.build()
}
case class ApplicationLost(
appId: String,
@ -301,11 +321,13 @@ object ControlMessages extends Logging {
* handled by worker
* ==========================================
*/
def pbRegisterWorkerResponse(success: Boolean, message: String): PbRegisterWorkerResponse =
PbRegisterWorkerResponse.newBuilder()
.setSuccess(success)
.setMessage(message)
.build()
object RegisterWorkerResponse {
def apply(success: Boolean, message: String): PbRegisterWorkerResponse =
PbRegisterWorkerResponse.newBuilder()
.setSuccess(success)
.setMessage(message)
.build()
}
case class ReregisterWorkerResponse(success: Boolean) extends WorkerMessage

View File

@ -316,7 +316,7 @@ private[celeborn] class Master(
&& !statusSystem.workerLostEvents.contains(worker)) {
logWarning(s"Worker ${worker.readableAddress()} timeout! Trigger WorkerLost event.")
// trigger WorkerLost event
self.send(ControlMessages.pbWorkerLost(
self.send(WorkerLost(
worker.host,
worker.rpcPort,
worker.pushPort,
@ -415,7 +415,7 @@ private[celeborn] class Master(
statusSystem.handleWorkerLost(host, rpcPort, pushPort, fetchPort, replicatePort, requestId)
if (context != null) {
context.reply(ControlMessages.pbWorkerLostResponse(true))
context.reply(WorkerLostResponse(true))
}
}
@ -444,9 +444,7 @@ private[celeborn] class Master(
disks,
userResourceConsumption,
requestId)
context.reply(ControlMessages.pbRegisterWorkerResponse(
true,
"Worker in snapshot, re-register."))
context.reply(RegisterWorkerResponse(true, "Worker in snapshot, re-register."))
} else if (statusSystem.workerLostEvents.contains(workerToRegister)) {
logWarning(s"Receive RegisterWorker while worker $workerToRegister " +
s"in workerLostEvents.")
@ -460,9 +458,7 @@ private[celeborn] class Master(
disks,
userResourceConsumption,
requestId)
context.reply(ControlMessages.pbRegisterWorkerResponse(
true,
"Worker in workerLostEvents, re-register."))
context.reply(RegisterWorkerResponse(true, "Worker in workerLostEvents, re-register."))
} else {
statusSystem.handleRegisterWorker(
host,
@ -474,7 +470,7 @@ private[celeborn] class Master(
userResourceConsumption,
requestId)
logInfo(s"Registered worker $workerToRegister.")
context.reply(ControlMessages.pbRegisterWorkerResponse(true, ""))
context.reply(RegisterWorkerResponse(true, ""))
}
}
@ -564,7 +560,7 @@ private[celeborn] class Master(
val shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId)
statusSystem.handleUnRegisterShuffle(shuffleKey, requestId)
logInfo(s"Unregister shuffle $shuffleKey")
context.reply(ControlMessages.pbUnregisterShuffleResponse(StatusCode.SUCCESS))
context.reply(UnregisterShuffleResponse(StatusCode.SUCCESS))
}
def handleGetBlacklist(context: RpcCallContext, msg: GetBlacklist): Unit = {

View File

@ -334,7 +334,7 @@ private[celeborn] class Worker(
val resp =
try {
rssHARetryClient.askSync[PbRegisterWorkerResponse](
ControlMessages.pbRegisterWorker(
RegisterWorker(
host,
rpcPort,
pushPort,