diff --git a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala index 73b0dc7d7..a718a2b8b 100644 --- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala +++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala @@ -145,6 +145,7 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends locations: util.List[PartitionLocation]): Unit = { val map = latestPartitionLocation.computeIfAbsent(shuffleId, newMapFunc) locations.asScala.foreach(location => map.put(location.getId, location)) + invalidateLatestMaxLocsCache(shuffleId) } case class RegisterCallContext(context: RpcCallContext, partitionId: Int = -1) { @@ -547,12 +548,12 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends shuffleId, rpcContext, partitionId, - getInitialLocs(shuffleId, p => p.getId == partitionId)) + getLatestLocs(shuffleId, p => p.getId == partitionId)) case PartitionType.REDUCE => if (rpcContext.isInstanceOf[LocalNettyRpcCallContext]) { context.reply(RegisterShuffleResponse( StatusCode.SUCCESS, - getInitialLocs(shuffleId, p => p.getEpoch == 0))) + getLatestLocs(shuffleId, _ => true))) } else { val cachedMsg = registerShuffleResponseRpcCache.get( shuffleId, @@ -561,7 +562,7 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends rpcContext.asInstanceOf[RemoteNettyRpcCallContext].nettyEnv.serialize( RegisterShuffleResponse( StatusCode.SUCCESS, - getInitialLocs(shuffleId, p => p.getEpoch == 0))) + getLatestLocs(shuffleId, _ => true))) } }) rpcContext.asInstanceOf[RemoteNettyRpcCallContext].callback.onSuccess(cachedMsg) @@ -580,13 +581,23 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends } } - def getInitialLocs( + def getLatestLocs( shuffleId: Int, partitionLocationFilter: PartitionLocation => Boolean): Array[PartitionLocation] = { workerSnapshots(shuffleId) .values() .asScala - .flatMap(_.getAllPrimaryLocationsWithMinEpoch()) + .flatMap( + _.getAllPrimaryLocationsWithMaxEpoch() + ) // get the partition with latest epoch of each worker + .foldLeft(Map.empty[Int, PartitionLocation]) { (partitionLocationMap, partitionLocation) => + partitionLocationMap.get(partitionLocation.getId) match { + case Some(existing) if existing.getEpoch >= partitionLocation.getEpoch => + partitionLocationMap + case _ => partitionLocationMap + (partitionLocation.getId -> partitionLocation) + } + } // get the partition with latest epoch of all the partitions + .values .filter(partitionLocationFilter) .toArray } @@ -1824,6 +1835,10 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends cancelShuffleCallback = Some(callback) } + def invalidateLatestMaxLocsCache(shuffleId: Int): Unit = { + registerShuffleResponseRpcCache.invalidate(shuffleId) + } + // Initialize at the end of LifecycleManager construction. initialize() diff --git a/common/src/main/scala/org/apache/celeborn/common/meta/ShufflePartitionLocationInfo.scala b/common/src/main/scala/org/apache/celeborn/common/meta/ShufflePartitionLocationInfo.scala index a19f182eb..bfb6c7b39 100644 --- a/common/src/main/scala/org/apache/celeborn/common/meta/ShufflePartitionLocationInfo.scala +++ b/common/src/main/scala/org/apache/celeborn/common/meta/ShufflePartitionLocationInfo.scala @@ -89,22 +89,22 @@ class ShufflePartitionLocationInfo(val workerInfo: WorkerInfo) { } } - def getAllPrimaryLocationsWithMinEpoch(): ArrayBuffer[PartitionLocation] = { + def getAllPrimaryLocationsWithMaxEpoch(): ArrayBuffer[PartitionLocation] = { val set = new ArrayBuffer[PartitionLocation](primaryPartitionLocations.size()) val locationsIterator = primaryPartitionLocations.values().iterator() while (locationsIterator.hasNext) { val locationIterator = locationsIterator.next().iterator() - var min: PartitionLocation = null + var max: PartitionLocation = null if (locationIterator.hasNext) { - min = locationIterator.next(); + max = locationIterator.next(); } while (locationIterator.hasNext) { val next = locationIterator.next() - if (min.getEpoch > next.getEpoch) { - min = next; + if (max.getEpoch < next.getEpoch) { + max = next; } } - set += min; + set += max; } set } diff --git a/common/src/test/scala/org/apache/celeborn/common/meta/ShufflePartitionLocationInfoSuite.scala b/common/src/test/scala/org/apache/celeborn/common/meta/ShufflePartitionLocationInfoSuite.scala index 6b7851109..03c3cd491 100644 --- a/common/src/test/scala/org/apache/celeborn/common/meta/ShufflePartitionLocationInfoSuite.scala +++ b/common/src/test/scala/org/apache/celeborn/common/meta/ShufflePartitionLocationInfoSuite.scala @@ -60,9 +60,9 @@ class ShufflePartitionLocationInfoSuite extends CelebornFunSuite { assertEquals(shufflePartitionLocationInfo.getReplicaPartitions(Some(0)).size(), 1) assertEquals(shufflePartitionLocationInfo.getReplicaPartitions(Some(1)).size(), 1) - // test get min epoch - val locations = shufflePartitionLocationInfo.getAllPrimaryLocationsWithMinEpoch() - assertTrue(locations.contains(partitionLocation00) && locations.contains(partitionLocation11)) + // test get max epoch + val locations = shufflePartitionLocationInfo.getAllPrimaryLocationsWithMaxEpoch() + assertTrue(locations.contains(partitionLocation02) && locations.contains(partitionLocation12)) // test remove assertEquals(shufflePartitionLocationInfo.removePrimaryPartitions(0).size(), 3) diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerReserveSlotsSuite.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerReserveSlotsSuite.scala new file mode 100644 index 000000000..ee9317a5e --- /dev/null +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerReserveSlotsSuite.scala @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.tests.client + +import java.nio.charset.StandardCharsets + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +import org.apache.commons.lang3.RandomStringUtils +import org.scalatest.BeforeAndAfterAll +import org.scalatest.funsuite.AnyFunSuite + +import org.apache.celeborn.client.{LifecycleManager, ShuffleClientImpl} +import org.apache.celeborn.common.CelebornConf +import org.apache.celeborn.common.identity.UserIdentifier +import org.apache.celeborn.common.internal.Logging +import org.apache.celeborn.common.meta.WorkerInfo +import org.apache.celeborn.common.protocol.PartitionLocation +import org.apache.celeborn.service.deploy.MiniClusterFeature + +class LifecycleManagerReserveSlotsSuite extends AnyFunSuite + with Logging with MiniClusterFeature with BeforeAndAfterAll { + + var masterEndpoint = "" + override def beforeAll(): Unit = { + val conf = Map("celeborn.worker.flusher.buffer.size" -> "0") + + logInfo("test initialized , setup Celeborn mini cluster") + val (master, _) = setupMiniClusterWithRandomPorts(conf, conf, 2) + masterEndpoint = master.conf.get(CelebornConf.MASTER_ENDPOINTS.key) + } + + override def afterAll(): Unit = { + logInfo("all test complete , stop Celeborn mini cluster") + super.shutdownMiniCluster() + } + + test("LifecycleManager Respond to RegisterShuffle with max epoch PartitionLocation") { + val SHUFFLE_ID = 0 + val MAP_ID = 0 + val ATTEMPT_ID = 0 + val MAP_NUM = 1 + val PARTITION_NUM = 3 + val APP = s"app-${System.currentTimeMillis()}" + + val clientConf = new CelebornConf() + .set(CelebornConf.MASTER_ENDPOINTS.key, masterEndpoint) + .set(CelebornConf.CLIENT_PUSH_MAX_REVIVE_TIMES.key, "5") + .set(CelebornConf.SHUFFLE_PARTITION_SPLIT_THRESHOLD.key, "1K") + val lifecycleManager = new LifecycleManager(APP, clientConf) + val shuffleClient1 = new ShuffleClientImpl(APP, clientConf, UserIdentifier("mock", "mock")) + shuffleClient1.setupLifecycleManagerRef(lifecycleManager.self) + + // ping and reserveSlots + val DATA0 = RandomStringUtils.secure().next(10).getBytes(StandardCharsets.UTF_8) + shuffleClient1.pushData( + SHUFFLE_ID, + MAP_ID, + ATTEMPT_ID, + 0, + DATA0, + 0, + DATA0.length, + MAP_NUM, + PARTITION_NUM) + + // find the worker that has at least 2 partitions + val partitionLocationMap1 = + shuffleClient1.getPartitionLocation(SHUFFLE_ID, MAP_NUM, PARTITION_NUM) + + val worker2PartitionIds = mutable.Map.empty[WorkerInfo, ArrayBuffer[Int]] + for (partitionId <- 0 until PARTITION_NUM) { + val partitionLocation = partitionLocationMap1.get(partitionId) + assert(partitionLocation.getEpoch == 0) + worker2PartitionIds + .getOrElseUpdate(partitionLocation.getWorker, ArrayBuffer.empty) + .append(partitionId) + } + val partitions = worker2PartitionIds.values.filter(_.size >= 2).head + assert(partitions.length >= 2) + + // prepare merged data + val PARTITION0_DATA = RandomStringUtils.secure().next(1024).getBytes(StandardCharsets.UTF_8) + shuffleClient1.mergeData( + SHUFFLE_ID, + MAP_ID, + ATTEMPT_ID, + partitions(0), + PARTITION0_DATA, + 0, + PARTITION0_DATA.length, + MAP_NUM, + PARTITION_NUM) + + val PARTITION1_DATA = RandomStringUtils.secure().next(1024).getBytes(StandardCharsets.UTF_8) + shuffleClient1.mergeData( + SHUFFLE_ID, + MAP_ID, + ATTEMPT_ID, + partitions(1), + PARTITION1_DATA, + 0, + PARTITION1_DATA.length, + MAP_NUM, + PARTITION_NUM) + + // pushData until partition(0) is split + val GIANT_DATA = + RandomStringUtils.secure().next(1024 * 100).getBytes(StandardCharsets.UTF_8) + for (_ <- 0 until 5) { + shuffleClient1.pushData( + SHUFFLE_ID, + MAP_ID, + ATTEMPT_ID, + partitions(0), + GIANT_DATA, + 0, + GIANT_DATA.length, + MAP_NUM, + PARTITION_NUM) + } + + for (_ <- 0 until 5) { + val TRIGGER_DATA = RandomStringUtils.secure().next(1024).getBytes(StandardCharsets.UTF_8) + shuffleClient1.pushData( + SHUFFLE_ID, + MAP_ID, + ATTEMPT_ID, + partitions(0), + TRIGGER_DATA, + 0, + TRIGGER_DATA.length, + MAP_NUM, + PARTITION_NUM) + Thread.sleep(5 * 1000) // wait for flush + } + + assert( + partitionLocationMap1.get(partitions(0)).getEpoch > 0 + ) // means partition(0) will be split + + // push merged data, we expect that partition(0) will be split, while partition(1) will not be split + shuffleClient1.pushMergedData(SHUFFLE_ID, MAP_ID, ATTEMPT_ID) + shuffleClient1.mapperEnd(SHUFFLE_ID, MAP_ID, ATTEMPT_ID, MAP_NUM) + // partition(1) will not be split + assert(partitionLocationMap1.get(partitions(1)).getEpoch == 0) + + val shuffleClient2 = new ShuffleClientImpl(APP, clientConf, UserIdentifier("mock", "mock")) + shuffleClient2.setupLifecycleManagerRef(lifecycleManager.self) + val partitionLocationMap2 = + shuffleClient2.getPartitionLocation(SHUFFLE_ID, MAP_NUM, PARTITION_NUM) + + // lifecycleManager response with the latest epoch(epoch of partition(0) is larger than 0 caused by split) + assert(partitionLocationMap2.get(partitions(0)).getEpoch > 0) + // epoch of partition(1) is 0 without split + assert(partitionLocationMap2.get(partitions(1)).getEpoch == 0) + } +}