[CELEBORN-8] [ISSUE-952][FEATURE] support register shuffle task in map partition mode (#973)
This commit is contained in:
parent
a6e89f3b63
commit
fb6d1de108
@ -20,6 +20,7 @@ package org.apache.celeborn.client;
|
||||
import java.io.IOException;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.*;
|
||||
import java.util.concurrent.Callable;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
@ -27,6 +28,7 @@ import java.util.concurrent.TimeUnit;
|
||||
import scala.reflect.ClassTag;
|
||||
import scala.reflect.ClassTag$;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import io.netty.buffer.CompositeByteBuf;
|
||||
import io.netty.buffer.Unpooled;
|
||||
import io.netty.channel.ChannelFuture;
|
||||
@ -56,6 +58,7 @@ import org.apache.celeborn.common.rpc.RpcAddress;
|
||||
import org.apache.celeborn.common.rpc.RpcEndpointRef;
|
||||
import org.apache.celeborn.common.rpc.RpcEnv;
|
||||
import org.apache.celeborn.common.unsafe.Platform;
|
||||
import org.apache.celeborn.common.util.PackedPartitionId;
|
||||
import org.apache.celeborn.common.util.PbSerDeUtils;
|
||||
import org.apache.celeborn.common.util.ThreadUtils;
|
||||
import org.apache.celeborn.common.util.Utils;
|
||||
@ -257,13 +260,58 @@ public class ShuffleClientImpl extends ShuffleClient {
|
||||
|
||||
private ConcurrentHashMap<Integer, PartitionLocation> registerShuffle(
|
||||
String appId, int shuffleId, int numMappers, int numPartitions) {
|
||||
return registerShuffleInternal(
|
||||
shuffleId,
|
||||
numMappers,
|
||||
numMappers,
|
||||
() ->
|
||||
driverRssMetaService.askSync(
|
||||
RegisterShuffle$.MODULE$.apply(appId, shuffleId, numMappers, numPartitions),
|
||||
ClassTag$.MODULE$.apply(PbRegisterShuffleResponse.class)));
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
public PartitionLocation registerMapPartitionTask(
|
||||
String appId, int shuffleId, int numMappers, int mapId, int attemptId) {
|
||||
int partitionId = PackedPartitionId.packedPartitionId(mapId, attemptId);
|
||||
logger.info(
|
||||
"register mapPartitionTask, mapId: {}, attemptId: {}, partitionId: {}",
|
||||
mapId,
|
||||
attemptId,
|
||||
partitionId);
|
||||
if (attemptId == 0) {
|
||||
return registerMapPartitionTaskWithFirstAttempt(
|
||||
appId, shuffleId, numMappers, mapId, attemptId, partitionId);
|
||||
}
|
||||
|
||||
// TODO
|
||||
throw new UnsupportedOperationException("can not register shuffle task with attempt beyond 0");
|
||||
}
|
||||
|
||||
private PartitionLocation registerMapPartitionTaskWithFirstAttempt(
|
||||
String appId, int shuffleId, int numMappers, int mapId, int attemptId, int partitionId) {
|
||||
ConcurrentHashMap<Integer, PartitionLocation> partitionLocationMap =
|
||||
registerShuffleInternal(
|
||||
shuffleId,
|
||||
numMappers,
|
||||
numMappers,
|
||||
() ->
|
||||
driverRssMetaService.askSync(
|
||||
RegisterMapPartitionTask$.MODULE$.apply(
|
||||
appId, shuffleId, numMappers, mapId, attemptId, partitionId),
|
||||
ClassTag$.MODULE$.apply(PbRegisterShuffleResponse.class)));
|
||||
return partitionLocationMap.get(partitionId);
|
||||
}
|
||||
|
||||
private ConcurrentHashMap<Integer, PartitionLocation> registerShuffleInternal(
|
||||
int shuffleId,
|
||||
int numMappers,
|
||||
int numPartitions,
|
||||
Callable<PbRegisterShuffleResponse> callable) {
|
||||
int numRetries = registerShuffleMaxRetries;
|
||||
while (numRetries > 0) {
|
||||
try {
|
||||
PbRegisterShuffleResponse response =
|
||||
driverRssMetaService.askSync(
|
||||
RegisterShuffle$.MODULE$.apply(appId, shuffleId, numMappers, numPartitions),
|
||||
ClassTag$.MODULE$.apply(PbRegisterShuffleResponse.class));
|
||||
PbRegisterShuffleResponse response = callable.call();
|
||||
StatusCode respStatus = Utils.toStatusCode(response.getStatus());
|
||||
if (StatusCode.SUCCESS.equals(respStatus)) {
|
||||
ConcurrentHashMap<Integer, PartitionLocation> result = new ConcurrentHashMap<>();
|
||||
|
||||
@ -27,6 +27,7 @@ import scala.collection.JavaConverters._
|
||||
import scala.collection.mutable
|
||||
import scala.util.Random
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting
|
||||
import com.google.common.cache.{Cache, CacheBuilder}
|
||||
import org.roaringbitmap.RoaringBitmap
|
||||
|
||||
@ -53,7 +54,8 @@ class LifecycleManager(appId: String, val conf: CelebornConf) extends RpcEndpoin
|
||||
private val pushReplicateEnabled = conf.pushReplicateEnabled
|
||||
private val partitionSplitThreshold = conf.partitionSplitThreshold
|
||||
private val partitionSplitMode = conf.partitionSplitMode
|
||||
private val partitionType = conf.shufflePartitionType
|
||||
// shuffle id -> partition type
|
||||
private val shufflePartitionType = new ConcurrentHashMap[Int, PartitionType]()
|
||||
private val rangeReadFilter = conf.shuffleRangeReadFilterEnabled
|
||||
private val unregisterShuffleTime = new ConcurrentHashMap[Int, Long]()
|
||||
private val stageEndTimeout = conf.pushStageEndTimeout
|
||||
@ -83,7 +85,8 @@ class LifecycleManager(appId: String, val conf: CelebornConf) extends RpcEndpoin
|
||||
.maximumSize(rpcCacheSize)
|
||||
.build().asInstanceOf[Cache[Int, ByteBuffer]]
|
||||
|
||||
private def workerSnapshots(shuffleId: Int): util.Map[WorkerInfo, PartitionLocationInfo] =
|
||||
@VisibleForTesting
|
||||
def workerSnapshots(shuffleId: Int): util.Map[WorkerInfo, PartitionLocationInfo] =
|
||||
shuffleAllocatedWorkers.get(shuffleId)
|
||||
|
||||
val newMapFunc: function.Function[Int, ConcurrentHashMap[Int, PartitionLocation]] =
|
||||
@ -293,6 +296,10 @@ class LifecycleManager(appId: String, val conf: CelebornConf) extends RpcEndpoin
|
||||
rpcEnv.address.port
|
||||
}
|
||||
|
||||
def getPartitionType(shuffleId: Int): PartitionType = {
|
||||
shufflePartitionType.getOrDefault(shuffleId, conf.shufflePartitionType)
|
||||
}
|
||||
|
||||
override def receive: PartialFunction[Any, Unit] = {
|
||||
case RemoveExpiredShuffle =>
|
||||
removeExpiredShuffle()
|
||||
@ -319,6 +326,22 @@ class LifecycleManager(appId: String, val conf: CelebornConf) extends RpcEndpoin
|
||||
s"$applicationId, $shuffleId, $numMappers, $numPartitions.")
|
||||
handleRegisterShuffle(context, applicationId, shuffleId, numMappers, numPartitions)
|
||||
|
||||
case pb: PbRegisterMapPartitionTask =>
|
||||
val applicationId = pb.getApplicationId
|
||||
val shuffleId = pb.getShuffleId
|
||||
val numMappers = pb.getNumMappers
|
||||
val mapId = pb.getMapId
|
||||
val attemptId = pb.getAttemptId
|
||||
val partitionId = pb.getPartitionId
|
||||
logDebug(s"Received Register map partition task request, " +
|
||||
s"$applicationId, $shuffleId, $numMappers, $mapId, $attemptId, $partitionId.")
|
||||
handleRegisterMapPartitionTask(
|
||||
context,
|
||||
applicationId,
|
||||
shuffleId,
|
||||
numMappers,
|
||||
partitionId)
|
||||
|
||||
case pb: PbRevive =>
|
||||
val applicationId = pb.getApplicationId
|
||||
val shuffleId = pb.getShuffleId
|
||||
@ -379,6 +402,33 @@ class LifecycleManager(appId: String, val conf: CelebornConf) extends RpcEndpoin
|
||||
shuffleId: Int,
|
||||
numMappers: Int,
|
||||
numReducers: Int): Unit = {
|
||||
handleOfferAndReserveSlots(context, applicationId, shuffleId, numMappers, numReducers)
|
||||
}
|
||||
|
||||
private def handleRegisterMapPartitionTask(
|
||||
context: RpcCallContext,
|
||||
applicationId: String,
|
||||
shuffleId: Int,
|
||||
numMappers: Int,
|
||||
partitionId: Int): Unit = {
|
||||
shufflePartitionType.putIfAbsent(shuffleId, PartitionType.MAP)
|
||||
handleOfferAndReserveSlots(
|
||||
context,
|
||||
applicationId,
|
||||
shuffleId,
|
||||
numMappers,
|
||||
numMappers,
|
||||
partitionId)
|
||||
}
|
||||
|
||||
private def handleOfferAndReserveSlots(
|
||||
context: RpcCallContext,
|
||||
applicationId: String,
|
||||
shuffleId: Int,
|
||||
numMappers: Int,
|
||||
numReducers: Int,
|
||||
partitionId: Int = -1): Unit = {
|
||||
val partitionType = getPartitionType(shuffleId)
|
||||
registeringShuffleRequest.synchronized {
|
||||
if (registeringShuffleRequest.containsKey(shuffleId)) {
|
||||
// If same request already exists in the registering request list for the same shuffle,
|
||||
@ -394,7 +444,8 @@ class LifecycleManager(appId: String, val conf: CelebornConf) extends RpcEndpoin
|
||||
.values()
|
||||
.asScala
|
||||
.flatMap(_.getAllMasterLocationsWithMinEpoch(shuffleId.toString).asScala)
|
||||
.filter(_.getEpoch == 0)
|
||||
.filter(p =>
|
||||
(partitionType == PartitionType.REDUCE && p.getEpoch == 0) || (partitionType == PartitionType.MAP && p.getId == partitionId))
|
||||
.toArray
|
||||
context.reply(RegisterShuffleResponse(StatusCode.SUCCESS, initialLocs))
|
||||
return
|
||||
@ -513,7 +564,9 @@ class LifecycleManager(appId: String, val conf: CelebornConf) extends RpcEndpoin
|
||||
|
||||
// Fifth, reply the allocated partition location to ShuffleClient.
|
||||
logInfo(s"Handle RegisterShuffle Success for $shuffleId.")
|
||||
val allMasterPartitionLocations = slots.asScala.flatMap(_._2._1.asScala).toArray
|
||||
val allMasterPartitionLocations = slots.asScala.flatMap(_._2._1.asScala).filter(p =>
|
||||
partitionType ==
|
||||
PartitionType.REDUCE || (partitionType == PartitionType.MAP && p.getId == partitionId)).toArray
|
||||
reply(RegisterShuffleResponse(StatusCode.SUCCESS, allMasterPartitionLocations))
|
||||
}
|
||||
}
|
||||
@ -1123,7 +1176,7 @@ class LifecycleManager(appId: String, val conf: CelebornConf) extends RpcEndpoin
|
||||
slaveLocations,
|
||||
partitionSplitThreshold,
|
||||
partitionSplitMode,
|
||||
partitionType,
|
||||
getPartitionType(shuffleId),
|
||||
rangeReadFilter,
|
||||
userIdentifier))
|
||||
if (res.status.equals(StatusCode.SUCCESS)) {
|
||||
|
||||
@ -22,6 +22,7 @@ import java.io.Serializable;
|
||||
import org.roaringbitmap.RoaringBitmap;
|
||||
|
||||
import org.apache.celeborn.common.meta.WorkerInfo;
|
||||
import org.apache.celeborn.common.util.PackedPartitionId;
|
||||
|
||||
public class PartitionLocation implements Serializable {
|
||||
public enum Mode {
|
||||
@ -277,9 +278,13 @@ public class PartitionLocation implements Serializable {
|
||||
peerAddr = peer.hostAndPorts();
|
||||
}
|
||||
return "PartitionLocation["
|
||||
+ "\n id-epoch:"
|
||||
+ "\n id(rawId-attemptId)-epoch:"
|
||||
+ id
|
||||
+ "("
|
||||
+ getRawId()
|
||||
+ "-"
|
||||
+ getAttemptId()
|
||||
+ ")-"
|
||||
+ epoch
|
||||
+ "\n host-rpcPort-pushPort-fetchPort-replicatePort:"
|
||||
+ host
|
||||
@ -313,4 +318,12 @@ public class PartitionLocation implements Serializable {
|
||||
public void setMapIdBitMap(RoaringBitmap mapIdBitMap) {
|
||||
this.mapIdBitMap = mapIdBitMap;
|
||||
}
|
||||
|
||||
public int getRawId() {
|
||||
return PackedPartitionId.getRawPartitionId(id);
|
||||
}
|
||||
|
||||
public int getAttemptId() {
|
||||
return PackedPartitionId.getAttemptId(id);
|
||||
}
|
||||
}
|
||||
|
||||
@ -0,0 +1,58 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.celeborn.common.util;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
|
||||
/**
|
||||
* Pack for encode/decode id of partition Location for id of partitionLocation attemptId
|
||||
* raw_partitionId <br>
|
||||
* (upper 8 bits = attemptId) (lower 24 bits = raw id) <br>
|
||||
* (0000 0000) (0000 0000 0000 0000 0000 0000)<br>
|
||||
*
|
||||
* @see org.apache.celeborn.common.protocol.PartitionLocation#id
|
||||
*/
|
||||
public class PackedPartitionId {
|
||||
|
||||
/** The maximum partition identifier that can be encoded. Note that partition ids start from 0. */
|
||||
static final int MAXIMUM_PARTITION_ID = (1 << 24) - 1; // 16777215
|
||||
|
||||
/** The maximum partition attempt id that can be encoded. Note that attempt ids start from 0. */
|
||||
static final int MAXIMUM_ATTEMPT_ID = (1 << 8) - 1; // 255
|
||||
|
||||
static final int MASK_INT_LOWER_24_BITS = (int) (1L << 24) - 1;
|
||||
|
||||
public static int packedPartitionId(int partitionRawId, int attemptId) {
|
||||
Preconditions.checkArgument(
|
||||
partitionRawId <= MAXIMUM_PARTITION_ID,
|
||||
"packedPartitionId called with invalid partitionRawId: " + partitionRawId);
|
||||
Preconditions.checkArgument(
|
||||
attemptId <= MAXIMUM_ATTEMPT_ID,
|
||||
"packedPartitionId called with invalid attemptId: " + attemptId);
|
||||
|
||||
return (attemptId << 24) | partitionRawId;
|
||||
}
|
||||
|
||||
public static int getRawPartitionId(int packedPartitionId) {
|
||||
return packedPartitionId & MASK_INT_LOWER_24_BITS;
|
||||
}
|
||||
|
||||
public static int getAttemptId(int packedPartitionId) {
|
||||
return packedPartitionId >>> 24;
|
||||
}
|
||||
}
|
||||
@ -155,6 +155,15 @@ message PbRegisterShuffle {
|
||||
int32 numPartitions = 4;
|
||||
}
|
||||
|
||||
message PbRegisterMapPartitionTask {
|
||||
string applicationId = 1;
|
||||
int32 shuffleId = 2;
|
||||
int32 numMappers = 3;
|
||||
int32 mapId = 4;
|
||||
int32 attemptId = 5;
|
||||
int32 partitionId = 6;
|
||||
}
|
||||
|
||||
message PbRegisterShuffleResponse {
|
||||
int32 status = 1;
|
||||
repeated PbPartitionLocation partitionLocations = 2;
|
||||
|
||||
@ -130,6 +130,24 @@ object ControlMessages extends Logging {
|
||||
.build()
|
||||
}
|
||||
|
||||
object RegisterMapPartitionTask {
|
||||
def apply(
|
||||
appId: String,
|
||||
shuffleId: Int,
|
||||
numMappers: Int,
|
||||
mapId: Int,
|
||||
attemptId: Int,
|
||||
partitionId: Int): PbRegisterMapPartitionTask =
|
||||
PbRegisterMapPartitionTask.newBuilder()
|
||||
.setApplicationId(appId)
|
||||
.setShuffleId(shuffleId)
|
||||
.setNumMappers(numMappers)
|
||||
.setMapId(mapId)
|
||||
.setAttemptId(attemptId)
|
||||
.setPartitionId(partitionId)
|
||||
.build()
|
||||
}
|
||||
|
||||
object RegisterShuffleResponse {
|
||||
def apply(
|
||||
status: StatusCode,
|
||||
|
||||
@ -20,6 +20,8 @@ package org.apache.celeborn.common.protocol;
|
||||
import org.junit.Test;
|
||||
import org.roaringbitmap.RoaringBitmap;
|
||||
|
||||
import org.apache.celeborn.common.util.PackedPartitionId;
|
||||
|
||||
public class PartitionLocationSuiteJ {
|
||||
|
||||
private final int partitionId = 0;
|
||||
@ -183,9 +185,13 @@ public class PartitionLocationSuiteJ {
|
||||
bitmap.add(1);
|
||||
bitmap.add(2);
|
||||
bitmap.add(3);
|
||||
|
||||
int attemptId = 10;
|
||||
int rawPartitionId = 1000;
|
||||
int newPartitionId = PackedPartitionId.packedPartitionId(rawPartitionId, attemptId);
|
||||
PartitionLocation location3 =
|
||||
new PartitionLocation(
|
||||
partitionId,
|
||||
newPartitionId,
|
||||
epoch,
|
||||
host,
|
||||
rpcPort,
|
||||
@ -199,7 +205,7 @@ public class PartitionLocationSuiteJ {
|
||||
|
||||
String exp1 =
|
||||
"PartitionLocation[\n"
|
||||
+ " id-epoch:0-0\n"
|
||||
+ " id(rawId-attemptId)-epoch:0(0-0)-0\n"
|
||||
+ " host-rpcPort-pushPort-fetchPort-replicatePort:localhost-3-1-2-4\n"
|
||||
+ " mode:MASTER\n"
|
||||
+ " peer:(empty)\n"
|
||||
@ -207,7 +213,7 @@ public class PartitionLocationSuiteJ {
|
||||
+ " mapIdBitMap:{}]";
|
||||
String exp2 =
|
||||
"PartitionLocation[\n"
|
||||
+ " id-epoch:0-0\n"
|
||||
+ " id(rawId-attemptId)-epoch:0(0-0)-0\n"
|
||||
+ " host-rpcPort-pushPort-fetchPort-replicatePort:localhost-3-1-2-4\n"
|
||||
+ " mode:MASTER\n"
|
||||
+ " peer:(host-rpcPort-pushPort-fetchPort-replicatePort:localhost-3-1-2-4)\n"
|
||||
@ -215,7 +221,7 @@ public class PartitionLocationSuiteJ {
|
||||
+ " mapIdBitMap:{}]";
|
||||
String exp3 =
|
||||
"PartitionLocation[\n"
|
||||
+ " id-epoch:0-0\n"
|
||||
+ " id(rawId-attemptId)-epoch:167773160(1000-10)-0\n"
|
||||
+ " host-rpcPort-pushPort-fetchPort-replicatePort:localhost-3-1-2-4\n"
|
||||
+ " mode:MASTER\n"
|
||||
+ " peer:(host-rpcPort-pushPort-fetchPort-replicatePort:localhost-3-1-2-4)\n"
|
||||
|
||||
@ -0,0 +1,53 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.celeborn.common.util;
|
||||
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
|
||||
public class PackedPartitionIdSuiteJ {
|
||||
|
||||
@Test
|
||||
public void testNormalPackedPartitionId() {
|
||||
assertTest(0, 0);
|
||||
assertTest(555, 1);
|
||||
assertTest(888, 1);
|
||||
assertTest(10001, 100);
|
||||
|
||||
// testUseMaxPartitionId or MaxAttemptId
|
||||
assertTest(PackedPartitionId.MAXIMUM_PARTITION_ID, 11);
|
||||
assertTest(100, PackedPartitionId.MAXIMUM_ATTEMPT_ID);
|
||||
assertTest(PackedPartitionId.MAXIMUM_PARTITION_ID, PackedPartitionId.MAXIMUM_ATTEMPT_ID);
|
||||
}
|
||||
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
public void testAttemptIdGreaterThanMaximumAttemptId() {
|
||||
PackedPartitionId.packedPartitionId(0, PackedPartitionId.MAXIMUM_ATTEMPT_ID + 1);
|
||||
}
|
||||
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
public void testPartitionIdGreaterThanMaximumPartitionId() {
|
||||
PackedPartitionId.packedPartitionId(PackedPartitionId.MAXIMUM_PARTITION_ID + 1, 1);
|
||||
}
|
||||
|
||||
private void assertTest(int partitionRawId, int attemptId) {
|
||||
int packedPartitionId = PackedPartitionId.packedPartitionId(partitionRawId, attemptId);
|
||||
Assert.assertTrue(partitionRawId == PackedPartitionId.getRawPartitionId(packedPartitionId));
|
||||
Assert.assertTrue(attemptId == PackedPartitionId.getAttemptId(packedPartitionId));
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,87 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.celeborn.tests.client
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
import scala.language.implicitConversions
|
||||
|
||||
import org.junit.Assert
|
||||
import org.scalatest.BeforeAndAfterAll
|
||||
import org.scalatest.funsuite.AnyFunSuite
|
||||
|
||||
import org.apache.celeborn.client.{LifecycleManager, ShuffleClient, ShuffleClientImpl}
|
||||
import org.apache.celeborn.common.CelebornConf
|
||||
import org.apache.celeborn.common.identity.UserIdentifier
|
||||
import org.apache.celeborn.common.util.PackedPartitionId
|
||||
import org.apache.celeborn.service.deploy.MiniClusterFeature
|
||||
|
||||
class ShuffleClientSuite extends AnyFunSuite with MiniClusterFeature
|
||||
with BeforeAndAfterAll {
|
||||
val masterPort = 19097
|
||||
val APP = "app-1"
|
||||
var shuffleClient: ShuffleClientImpl = _
|
||||
var lifecycleManager: LifecycleManager = _
|
||||
|
||||
override def beforeAll(): Unit = {
|
||||
val masterConf = Map(
|
||||
"celeborn.master.host" -> "localhost",
|
||||
"celeborn.master.port" -> masterPort.toString)
|
||||
val workerConf = Map(
|
||||
"celeborn.master.endpoints" -> s"localhost:$masterPort")
|
||||
setUpMiniCluster(masterConf, workerConf)
|
||||
|
||||
val clientConf = new CelebornConf()
|
||||
.set("celeborn.master.endpoints", s"localhost:$masterPort")
|
||||
.set("celeborn.push.replicate.enabled", "true")
|
||||
.set("celeborn.push.buffer.size", "256K")
|
||||
lifecycleManager = new LifecycleManager(APP, clientConf)
|
||||
shuffleClient = new ShuffleClientImpl(clientConf, UserIdentifier("mock", "mock"))
|
||||
shuffleClient.setupMetaServiceRef(lifecycleManager.self)
|
||||
}
|
||||
|
||||
test(s"test register map partition task with first attemptId") {
|
||||
val shuffleId = 1
|
||||
val numMappers = 8
|
||||
val mapId = 1
|
||||
val attemptId = 0
|
||||
var location =
|
||||
shuffleClient.registerMapPartitionTask(APP, shuffleId, numMappers, mapId, attemptId)
|
||||
Assert.assertEquals(location.getId, PackedPartitionId.packedPartitionId(mapId, attemptId))
|
||||
|
||||
// retry register
|
||||
location = shuffleClient.registerMapPartitionTask(APP, shuffleId, numMappers, mapId, attemptId)
|
||||
Assert.assertEquals(location.getId, PackedPartitionId.packedPartitionId(mapId, attemptId))
|
||||
|
||||
// another mapId
|
||||
location =
|
||||
shuffleClient.registerMapPartitionTask(APP, shuffleId, numMappers, mapId + 1, attemptId)
|
||||
Assert.assertEquals(location.getId, PackedPartitionId.packedPartitionId(mapId + 1, attemptId))
|
||||
|
||||
// offer and reserve all slots
|
||||
val partitionLocationInfos = lifecycleManager.workerSnapshots(shuffleId).values().asScala
|
||||
val count =
|
||||
partitionLocationInfos.map(r => r.getAllMasterLocations(shuffleId.toString).size()).sum
|
||||
Assert.assertEquals(count, numMappers)
|
||||
}
|
||||
|
||||
override def afterAll(): Unit = {
|
||||
// TODO refactor MiniCluster later
|
||||
println("test done")
|
||||
sys.exit(0)
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user