[CELEBORN-8] [ISSUE-952][FEATURE] support register shuffle task in map partition mode (#973)

This commit is contained in:
Shuang 2022-11-16 21:46:19 +08:00 committed by GitHub
parent a6e89f3b63
commit fb6d1de108
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 359 additions and 14 deletions

View File

@ -20,6 +20,7 @@ package org.apache.celeborn.client;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.*;
import java.util.concurrent.Callable;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeUnit;
@ -27,6 +28,7 @@ import java.util.concurrent.TimeUnit;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;
import com.google.common.annotations.VisibleForTesting;
import io.netty.buffer.CompositeByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelFuture;
@ -56,6 +58,7 @@ import org.apache.celeborn.common.rpc.RpcAddress;
import org.apache.celeborn.common.rpc.RpcEndpointRef;
import org.apache.celeborn.common.rpc.RpcEnv;
import org.apache.celeborn.common.unsafe.Platform;
import org.apache.celeborn.common.util.PackedPartitionId;
import org.apache.celeborn.common.util.PbSerDeUtils;
import org.apache.celeborn.common.util.ThreadUtils;
import org.apache.celeborn.common.util.Utils;
@ -257,13 +260,58 @@ public class ShuffleClientImpl extends ShuffleClient {
private ConcurrentHashMap<Integer, PartitionLocation> registerShuffle(
String appId, int shuffleId, int numMappers, int numPartitions) {
return registerShuffleInternal(
shuffleId,
numMappers,
numMappers,
() ->
driverRssMetaService.askSync(
RegisterShuffle$.MODULE$.apply(appId, shuffleId, numMappers, numPartitions),
ClassTag$.MODULE$.apply(PbRegisterShuffleResponse.class)));
}
@VisibleForTesting
public PartitionLocation registerMapPartitionTask(
String appId, int shuffleId, int numMappers, int mapId, int attemptId) {
int partitionId = PackedPartitionId.packedPartitionId(mapId, attemptId);
logger.info(
"register mapPartitionTask, mapId: {}, attemptId: {}, partitionId: {}",
mapId,
attemptId,
partitionId);
if (attemptId == 0) {
return registerMapPartitionTaskWithFirstAttempt(
appId, shuffleId, numMappers, mapId, attemptId, partitionId);
}
// TODO
throw new UnsupportedOperationException("can not register shuffle task with attempt beyond 0");
}
private PartitionLocation registerMapPartitionTaskWithFirstAttempt(
String appId, int shuffleId, int numMappers, int mapId, int attemptId, int partitionId) {
ConcurrentHashMap<Integer, PartitionLocation> partitionLocationMap =
registerShuffleInternal(
shuffleId,
numMappers,
numMappers,
() ->
driverRssMetaService.askSync(
RegisterMapPartitionTask$.MODULE$.apply(
appId, shuffleId, numMappers, mapId, attemptId, partitionId),
ClassTag$.MODULE$.apply(PbRegisterShuffleResponse.class)));
return partitionLocationMap.get(partitionId);
}
private ConcurrentHashMap<Integer, PartitionLocation> registerShuffleInternal(
int shuffleId,
int numMappers,
int numPartitions,
Callable<PbRegisterShuffleResponse> callable) {
int numRetries = registerShuffleMaxRetries;
while (numRetries > 0) {
try {
PbRegisterShuffleResponse response =
driverRssMetaService.askSync(
RegisterShuffle$.MODULE$.apply(appId, shuffleId, numMappers, numPartitions),
ClassTag$.MODULE$.apply(PbRegisterShuffleResponse.class));
PbRegisterShuffleResponse response = callable.call();
StatusCode respStatus = Utils.toStatusCode(response.getStatus());
if (StatusCode.SUCCESS.equals(respStatus)) {
ConcurrentHashMap<Integer, PartitionLocation> result = new ConcurrentHashMap<>();

View File

@ -27,6 +27,7 @@ import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.util.Random
import com.google.common.annotations.VisibleForTesting
import com.google.common.cache.{Cache, CacheBuilder}
import org.roaringbitmap.RoaringBitmap
@ -53,7 +54,8 @@ class LifecycleManager(appId: String, val conf: CelebornConf) extends RpcEndpoin
private val pushReplicateEnabled = conf.pushReplicateEnabled
private val partitionSplitThreshold = conf.partitionSplitThreshold
private val partitionSplitMode = conf.partitionSplitMode
private val partitionType = conf.shufflePartitionType
// shuffle id -> partition type
private val shufflePartitionType = new ConcurrentHashMap[Int, PartitionType]()
private val rangeReadFilter = conf.shuffleRangeReadFilterEnabled
private val unregisterShuffleTime = new ConcurrentHashMap[Int, Long]()
private val stageEndTimeout = conf.pushStageEndTimeout
@ -83,7 +85,8 @@ class LifecycleManager(appId: String, val conf: CelebornConf) extends RpcEndpoin
.maximumSize(rpcCacheSize)
.build().asInstanceOf[Cache[Int, ByteBuffer]]
private def workerSnapshots(shuffleId: Int): util.Map[WorkerInfo, PartitionLocationInfo] =
@VisibleForTesting
def workerSnapshots(shuffleId: Int): util.Map[WorkerInfo, PartitionLocationInfo] =
shuffleAllocatedWorkers.get(shuffleId)
val newMapFunc: function.Function[Int, ConcurrentHashMap[Int, PartitionLocation]] =
@ -293,6 +296,10 @@ class LifecycleManager(appId: String, val conf: CelebornConf) extends RpcEndpoin
rpcEnv.address.port
}
def getPartitionType(shuffleId: Int): PartitionType = {
shufflePartitionType.getOrDefault(shuffleId, conf.shufflePartitionType)
}
override def receive: PartialFunction[Any, Unit] = {
case RemoveExpiredShuffle =>
removeExpiredShuffle()
@ -319,6 +326,22 @@ class LifecycleManager(appId: String, val conf: CelebornConf) extends RpcEndpoin
s"$applicationId, $shuffleId, $numMappers, $numPartitions.")
handleRegisterShuffle(context, applicationId, shuffleId, numMappers, numPartitions)
case pb: PbRegisterMapPartitionTask =>
val applicationId = pb.getApplicationId
val shuffleId = pb.getShuffleId
val numMappers = pb.getNumMappers
val mapId = pb.getMapId
val attemptId = pb.getAttemptId
val partitionId = pb.getPartitionId
logDebug(s"Received Register map partition task request, " +
s"$applicationId, $shuffleId, $numMappers, $mapId, $attemptId, $partitionId.")
handleRegisterMapPartitionTask(
context,
applicationId,
shuffleId,
numMappers,
partitionId)
case pb: PbRevive =>
val applicationId = pb.getApplicationId
val shuffleId = pb.getShuffleId
@ -379,6 +402,33 @@ class LifecycleManager(appId: String, val conf: CelebornConf) extends RpcEndpoin
shuffleId: Int,
numMappers: Int,
numReducers: Int): Unit = {
handleOfferAndReserveSlots(context, applicationId, shuffleId, numMappers, numReducers)
}
private def handleRegisterMapPartitionTask(
context: RpcCallContext,
applicationId: String,
shuffleId: Int,
numMappers: Int,
partitionId: Int): Unit = {
shufflePartitionType.putIfAbsent(shuffleId, PartitionType.MAP)
handleOfferAndReserveSlots(
context,
applicationId,
shuffleId,
numMappers,
numMappers,
partitionId)
}
private def handleOfferAndReserveSlots(
context: RpcCallContext,
applicationId: String,
shuffleId: Int,
numMappers: Int,
numReducers: Int,
partitionId: Int = -1): Unit = {
val partitionType = getPartitionType(shuffleId)
registeringShuffleRequest.synchronized {
if (registeringShuffleRequest.containsKey(shuffleId)) {
// If same request already exists in the registering request list for the same shuffle,
@ -394,7 +444,8 @@ class LifecycleManager(appId: String, val conf: CelebornConf) extends RpcEndpoin
.values()
.asScala
.flatMap(_.getAllMasterLocationsWithMinEpoch(shuffleId.toString).asScala)
.filter(_.getEpoch == 0)
.filter(p =>
(partitionType == PartitionType.REDUCE && p.getEpoch == 0) || (partitionType == PartitionType.MAP && p.getId == partitionId))
.toArray
context.reply(RegisterShuffleResponse(StatusCode.SUCCESS, initialLocs))
return
@ -513,7 +564,9 @@ class LifecycleManager(appId: String, val conf: CelebornConf) extends RpcEndpoin
// Fifth, reply the allocated partition location to ShuffleClient.
logInfo(s"Handle RegisterShuffle Success for $shuffleId.")
val allMasterPartitionLocations = slots.asScala.flatMap(_._2._1.asScala).toArray
val allMasterPartitionLocations = slots.asScala.flatMap(_._2._1.asScala).filter(p =>
partitionType ==
PartitionType.REDUCE || (partitionType == PartitionType.MAP && p.getId == partitionId)).toArray
reply(RegisterShuffleResponse(StatusCode.SUCCESS, allMasterPartitionLocations))
}
}
@ -1123,7 +1176,7 @@ class LifecycleManager(appId: String, val conf: CelebornConf) extends RpcEndpoin
slaveLocations,
partitionSplitThreshold,
partitionSplitMode,
partitionType,
getPartitionType(shuffleId),
rangeReadFilter,
userIdentifier))
if (res.status.equals(StatusCode.SUCCESS)) {

View File

@ -22,6 +22,7 @@ import java.io.Serializable;
import org.roaringbitmap.RoaringBitmap;
import org.apache.celeborn.common.meta.WorkerInfo;
import org.apache.celeborn.common.util.PackedPartitionId;
public class PartitionLocation implements Serializable {
public enum Mode {
@ -277,9 +278,13 @@ public class PartitionLocation implements Serializable {
peerAddr = peer.hostAndPorts();
}
return "PartitionLocation["
+ "\n id-epoch:"
+ "\n id(rawId-attemptId)-epoch:"
+ id
+ "("
+ getRawId()
+ "-"
+ getAttemptId()
+ ")-"
+ epoch
+ "\n host-rpcPort-pushPort-fetchPort-replicatePort:"
+ host
@ -313,4 +318,12 @@ public class PartitionLocation implements Serializable {
public void setMapIdBitMap(RoaringBitmap mapIdBitMap) {
this.mapIdBitMap = mapIdBitMap;
}
public int getRawId() {
return PackedPartitionId.getRawPartitionId(id);
}
public int getAttemptId() {
return PackedPartitionId.getAttemptId(id);
}
}

View File

@ -0,0 +1,58 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.celeborn.common.util;
import com.google.common.base.Preconditions;
/**
* Pack for encode/decode id of partition Location for id of partitionLocation attemptId
* raw_partitionId <br>
* (upper 8 bits = attemptId) (lower 24 bits = raw id) <br>
* (0000 0000) (0000 0000 0000 0000 0000 0000)<br>
*
* @see org.apache.celeborn.common.protocol.PartitionLocation#id
*/
public class PackedPartitionId {
/** The maximum partition identifier that can be encoded. Note that partition ids start from 0. */
static final int MAXIMUM_PARTITION_ID = (1 << 24) - 1; // 16777215
/** The maximum partition attempt id that can be encoded. Note that attempt ids start from 0. */
static final int MAXIMUM_ATTEMPT_ID = (1 << 8) - 1; // 255
static final int MASK_INT_LOWER_24_BITS = (int) (1L << 24) - 1;
public static int packedPartitionId(int partitionRawId, int attemptId) {
Preconditions.checkArgument(
partitionRawId <= MAXIMUM_PARTITION_ID,
"packedPartitionId called with invalid partitionRawId: " + partitionRawId);
Preconditions.checkArgument(
attemptId <= MAXIMUM_ATTEMPT_ID,
"packedPartitionId called with invalid attemptId: " + attemptId);
return (attemptId << 24) | partitionRawId;
}
public static int getRawPartitionId(int packedPartitionId) {
return packedPartitionId & MASK_INT_LOWER_24_BITS;
}
public static int getAttemptId(int packedPartitionId) {
return packedPartitionId >>> 24;
}
}

View File

@ -155,6 +155,15 @@ message PbRegisterShuffle {
int32 numPartitions = 4;
}
message PbRegisterMapPartitionTask {
string applicationId = 1;
int32 shuffleId = 2;
int32 numMappers = 3;
int32 mapId = 4;
int32 attemptId = 5;
int32 partitionId = 6;
}
message PbRegisterShuffleResponse {
int32 status = 1;
repeated PbPartitionLocation partitionLocations = 2;

View File

@ -130,6 +130,24 @@ object ControlMessages extends Logging {
.build()
}
object RegisterMapPartitionTask {
def apply(
appId: String,
shuffleId: Int,
numMappers: Int,
mapId: Int,
attemptId: Int,
partitionId: Int): PbRegisterMapPartitionTask =
PbRegisterMapPartitionTask.newBuilder()
.setApplicationId(appId)
.setShuffleId(shuffleId)
.setNumMappers(numMappers)
.setMapId(mapId)
.setAttemptId(attemptId)
.setPartitionId(partitionId)
.build()
}
object RegisterShuffleResponse {
def apply(
status: StatusCode,

View File

@ -20,6 +20,8 @@ package org.apache.celeborn.common.protocol;
import org.junit.Test;
import org.roaringbitmap.RoaringBitmap;
import org.apache.celeborn.common.util.PackedPartitionId;
public class PartitionLocationSuiteJ {
private final int partitionId = 0;
@ -183,9 +185,13 @@ public class PartitionLocationSuiteJ {
bitmap.add(1);
bitmap.add(2);
bitmap.add(3);
int attemptId = 10;
int rawPartitionId = 1000;
int newPartitionId = PackedPartitionId.packedPartitionId(rawPartitionId, attemptId);
PartitionLocation location3 =
new PartitionLocation(
partitionId,
newPartitionId,
epoch,
host,
rpcPort,
@ -199,7 +205,7 @@ public class PartitionLocationSuiteJ {
String exp1 =
"PartitionLocation[\n"
+ " id-epoch:0-0\n"
+ " id(rawId-attemptId)-epoch:0(0-0)-0\n"
+ " host-rpcPort-pushPort-fetchPort-replicatePort:localhost-3-1-2-4\n"
+ " mode:MASTER\n"
+ " peer:(empty)\n"
@ -207,7 +213,7 @@ public class PartitionLocationSuiteJ {
+ " mapIdBitMap:{}]";
String exp2 =
"PartitionLocation[\n"
+ " id-epoch:0-0\n"
+ " id(rawId-attemptId)-epoch:0(0-0)-0\n"
+ " host-rpcPort-pushPort-fetchPort-replicatePort:localhost-3-1-2-4\n"
+ " mode:MASTER\n"
+ " peer:(host-rpcPort-pushPort-fetchPort-replicatePort:localhost-3-1-2-4)\n"
@ -215,7 +221,7 @@ public class PartitionLocationSuiteJ {
+ " mapIdBitMap:{}]";
String exp3 =
"PartitionLocation[\n"
+ " id-epoch:0-0\n"
+ " id(rawId-attemptId)-epoch:167773160(1000-10)-0\n"
+ " host-rpcPort-pushPort-fetchPort-replicatePort:localhost-3-1-2-4\n"
+ " mode:MASTER\n"
+ " peer:(host-rpcPort-pushPort-fetchPort-replicatePort:localhost-3-1-2-4)\n"

View File

@ -0,0 +1,53 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.celeborn.common.util;
import org.junit.Assert;
import org.junit.Test;
public class PackedPartitionIdSuiteJ {
@Test
public void testNormalPackedPartitionId() {
assertTest(0, 0);
assertTest(555, 1);
assertTest(888, 1);
assertTest(10001, 100);
// testUseMaxPartitionId or MaxAttemptId
assertTest(PackedPartitionId.MAXIMUM_PARTITION_ID, 11);
assertTest(100, PackedPartitionId.MAXIMUM_ATTEMPT_ID);
assertTest(PackedPartitionId.MAXIMUM_PARTITION_ID, PackedPartitionId.MAXIMUM_ATTEMPT_ID);
}
@Test(expected = IllegalArgumentException.class)
public void testAttemptIdGreaterThanMaximumAttemptId() {
PackedPartitionId.packedPartitionId(0, PackedPartitionId.MAXIMUM_ATTEMPT_ID + 1);
}
@Test(expected = IllegalArgumentException.class)
public void testPartitionIdGreaterThanMaximumPartitionId() {
PackedPartitionId.packedPartitionId(PackedPartitionId.MAXIMUM_PARTITION_ID + 1, 1);
}
private void assertTest(int partitionRawId, int attemptId) {
int packedPartitionId = PackedPartitionId.packedPartitionId(partitionRawId, attemptId);
Assert.assertTrue(partitionRawId == PackedPartitionId.getRawPartitionId(packedPartitionId));
Assert.assertTrue(attemptId == PackedPartitionId.getAttemptId(packedPartitionId));
}
}

View File

@ -0,0 +1,87 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.celeborn.tests.client
import scala.collection.JavaConverters._
import scala.language.implicitConversions
import org.junit.Assert
import org.scalatest.BeforeAndAfterAll
import org.scalatest.funsuite.AnyFunSuite
import org.apache.celeborn.client.{LifecycleManager, ShuffleClient, ShuffleClientImpl}
import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.identity.UserIdentifier
import org.apache.celeborn.common.util.PackedPartitionId
import org.apache.celeborn.service.deploy.MiniClusterFeature
class ShuffleClientSuite extends AnyFunSuite with MiniClusterFeature
with BeforeAndAfterAll {
val masterPort = 19097
val APP = "app-1"
var shuffleClient: ShuffleClientImpl = _
var lifecycleManager: LifecycleManager = _
override def beforeAll(): Unit = {
val masterConf = Map(
"celeborn.master.host" -> "localhost",
"celeborn.master.port" -> masterPort.toString)
val workerConf = Map(
"celeborn.master.endpoints" -> s"localhost:$masterPort")
setUpMiniCluster(masterConf, workerConf)
val clientConf = new CelebornConf()
.set("celeborn.master.endpoints", s"localhost:$masterPort")
.set("celeborn.push.replicate.enabled", "true")
.set("celeborn.push.buffer.size", "256K")
lifecycleManager = new LifecycleManager(APP, clientConf)
shuffleClient = new ShuffleClientImpl(clientConf, UserIdentifier("mock", "mock"))
shuffleClient.setupMetaServiceRef(lifecycleManager.self)
}
test(s"test register map partition task with first attemptId") {
val shuffleId = 1
val numMappers = 8
val mapId = 1
val attemptId = 0
var location =
shuffleClient.registerMapPartitionTask(APP, shuffleId, numMappers, mapId, attemptId)
Assert.assertEquals(location.getId, PackedPartitionId.packedPartitionId(mapId, attemptId))
// retry register
location = shuffleClient.registerMapPartitionTask(APP, shuffleId, numMappers, mapId, attemptId)
Assert.assertEquals(location.getId, PackedPartitionId.packedPartitionId(mapId, attemptId))
// another mapId
location =
shuffleClient.registerMapPartitionTask(APP, shuffleId, numMappers, mapId + 1, attemptId)
Assert.assertEquals(location.getId, PackedPartitionId.packedPartitionId(mapId + 1, attemptId))
// offer and reserve all slots
val partitionLocationInfos = lifecycleManager.workerSnapshots(shuffleId).values().asScala
val count =
partitionLocationInfos.map(r => r.getAllMasterLocations(shuffleId.toString).size()).sum
Assert.assertEquals(count, numMappers)
}
override def afterAll(): Unit = {
// TODO refactor MiniCluster later
println("test done")
sys.exit(0)
}
}