From b5fab4260453b384c12bb520570622aa3c9844e0 Mon Sep 17 00:00:00 2001 From: zhengtao Date: Fri, 14 Mar 2025 16:08:05 +0800 Subject: [PATCH] [CELEBORN-1822] Respond to RegisterShuffle with max epoch PartitionLocation to avoid revive ### What changes were proposed in this pull request? LifecycleManager respond to RegisterShuffle with max epoch PartitionLocation. ### Why are the changes needed? Newly spun up executors in a Spark job will still get the partitionLocations with the minEpoch of the celeborn lost worker. These executors will fail to connect to the lost worker and then call into revive to get the latest PartitionLocation for a given partitionId in `ChangePartitionManager.getLatestPartition()`. Return with max epoch can reduce such revive requests. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? UT. Closes #3135 from zaynt4606/clb1822. Authored-by: zhengtao Signed-off-by: Shuang --- .../celeborn/client/LifecycleManager.scala | 25 ++- .../meta/ShufflePartitionLocationInfo.scala | 12 +- .../ShufflePartitionLocationInfoSuite.scala | 6 +- .../LifecycleManagerReserveSlotsSuite.scala | 174 ++++++++++++++++++ 4 files changed, 203 insertions(+), 14 deletions(-) create mode 100644 tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerReserveSlotsSuite.scala diff --git a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala index 73b0dc7d7..a718a2b8b 100644 --- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala +++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala @@ -145,6 +145,7 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends locations: util.List[PartitionLocation]): Unit = { val map = latestPartitionLocation.computeIfAbsent(shuffleId, newMapFunc) locations.asScala.foreach(location => map.put(location.getId, location)) + invalidateLatestMaxLocsCache(shuffleId) } case class RegisterCallContext(context: RpcCallContext, partitionId: Int = -1) { @@ -547,12 +548,12 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends shuffleId, rpcContext, partitionId, - getInitialLocs(shuffleId, p => p.getId == partitionId)) + getLatestLocs(shuffleId, p => p.getId == partitionId)) case PartitionType.REDUCE => if (rpcContext.isInstanceOf[LocalNettyRpcCallContext]) { context.reply(RegisterShuffleResponse( StatusCode.SUCCESS, - getInitialLocs(shuffleId, p => p.getEpoch == 0))) + getLatestLocs(shuffleId, _ => true))) } else { val cachedMsg = registerShuffleResponseRpcCache.get( shuffleId, @@ -561,7 +562,7 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends rpcContext.asInstanceOf[RemoteNettyRpcCallContext].nettyEnv.serialize( RegisterShuffleResponse( StatusCode.SUCCESS, - getInitialLocs(shuffleId, p => p.getEpoch == 0))) + getLatestLocs(shuffleId, _ => true))) } }) rpcContext.asInstanceOf[RemoteNettyRpcCallContext].callback.onSuccess(cachedMsg) @@ -580,13 +581,23 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends } } - def getInitialLocs( + def getLatestLocs( shuffleId: Int, partitionLocationFilter: PartitionLocation => Boolean): Array[PartitionLocation] = { workerSnapshots(shuffleId) .values() .asScala - .flatMap(_.getAllPrimaryLocationsWithMinEpoch()) + .flatMap( + _.getAllPrimaryLocationsWithMaxEpoch() + ) // get the partition with latest epoch of each worker + .foldLeft(Map.empty[Int, PartitionLocation]) { (partitionLocationMap, partitionLocation) => + partitionLocationMap.get(partitionLocation.getId) match { + case Some(existing) if existing.getEpoch >= partitionLocation.getEpoch => + partitionLocationMap + case _ => partitionLocationMap + (partitionLocation.getId -> partitionLocation) + } + } // get the partition with latest epoch of all the partitions + .values .filter(partitionLocationFilter) .toArray } @@ -1824,6 +1835,10 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends cancelShuffleCallback = Some(callback) } + def invalidateLatestMaxLocsCache(shuffleId: Int): Unit = { + registerShuffleResponseRpcCache.invalidate(shuffleId) + } + // Initialize at the end of LifecycleManager construction. initialize() diff --git a/common/src/main/scala/org/apache/celeborn/common/meta/ShufflePartitionLocationInfo.scala b/common/src/main/scala/org/apache/celeborn/common/meta/ShufflePartitionLocationInfo.scala index a19f182eb..bfb6c7b39 100644 --- a/common/src/main/scala/org/apache/celeborn/common/meta/ShufflePartitionLocationInfo.scala +++ b/common/src/main/scala/org/apache/celeborn/common/meta/ShufflePartitionLocationInfo.scala @@ -89,22 +89,22 @@ class ShufflePartitionLocationInfo(val workerInfo: WorkerInfo) { } } - def getAllPrimaryLocationsWithMinEpoch(): ArrayBuffer[PartitionLocation] = { + def getAllPrimaryLocationsWithMaxEpoch(): ArrayBuffer[PartitionLocation] = { val set = new ArrayBuffer[PartitionLocation](primaryPartitionLocations.size()) val locationsIterator = primaryPartitionLocations.values().iterator() while (locationsIterator.hasNext) { val locationIterator = locationsIterator.next().iterator() - var min: PartitionLocation = null + var max: PartitionLocation = null if (locationIterator.hasNext) { - min = locationIterator.next(); + max = locationIterator.next(); } while (locationIterator.hasNext) { val next = locationIterator.next() - if (min.getEpoch > next.getEpoch) { - min = next; + if (max.getEpoch < next.getEpoch) { + max = next; } } - set += min; + set += max; } set } diff --git a/common/src/test/scala/org/apache/celeborn/common/meta/ShufflePartitionLocationInfoSuite.scala b/common/src/test/scala/org/apache/celeborn/common/meta/ShufflePartitionLocationInfoSuite.scala index 6b7851109..03c3cd491 100644 --- a/common/src/test/scala/org/apache/celeborn/common/meta/ShufflePartitionLocationInfoSuite.scala +++ b/common/src/test/scala/org/apache/celeborn/common/meta/ShufflePartitionLocationInfoSuite.scala @@ -60,9 +60,9 @@ class ShufflePartitionLocationInfoSuite extends CelebornFunSuite { assertEquals(shufflePartitionLocationInfo.getReplicaPartitions(Some(0)).size(), 1) assertEquals(shufflePartitionLocationInfo.getReplicaPartitions(Some(1)).size(), 1) - // test get min epoch - val locations = shufflePartitionLocationInfo.getAllPrimaryLocationsWithMinEpoch() - assertTrue(locations.contains(partitionLocation00) && locations.contains(partitionLocation11)) + // test get max epoch + val locations = shufflePartitionLocationInfo.getAllPrimaryLocationsWithMaxEpoch() + assertTrue(locations.contains(partitionLocation02) && locations.contains(partitionLocation12)) // test remove assertEquals(shufflePartitionLocationInfo.removePrimaryPartitions(0).size(), 3) diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerReserveSlotsSuite.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerReserveSlotsSuite.scala new file mode 100644 index 000000000..ee9317a5e --- /dev/null +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerReserveSlotsSuite.scala @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.tests.client + +import java.nio.charset.StandardCharsets + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +import org.apache.commons.lang3.RandomStringUtils +import org.scalatest.BeforeAndAfterAll +import org.scalatest.funsuite.AnyFunSuite + +import org.apache.celeborn.client.{LifecycleManager, ShuffleClientImpl} +import org.apache.celeborn.common.CelebornConf +import org.apache.celeborn.common.identity.UserIdentifier +import org.apache.celeborn.common.internal.Logging +import org.apache.celeborn.common.meta.WorkerInfo +import org.apache.celeborn.common.protocol.PartitionLocation +import org.apache.celeborn.service.deploy.MiniClusterFeature + +class LifecycleManagerReserveSlotsSuite extends AnyFunSuite + with Logging with MiniClusterFeature with BeforeAndAfterAll { + + var masterEndpoint = "" + override def beforeAll(): Unit = { + val conf = Map("celeborn.worker.flusher.buffer.size" -> "0") + + logInfo("test initialized , setup Celeborn mini cluster") + val (master, _) = setupMiniClusterWithRandomPorts(conf, conf, 2) + masterEndpoint = master.conf.get(CelebornConf.MASTER_ENDPOINTS.key) + } + + override def afterAll(): Unit = { + logInfo("all test complete , stop Celeborn mini cluster") + super.shutdownMiniCluster() + } + + test("LifecycleManager Respond to RegisterShuffle with max epoch PartitionLocation") { + val SHUFFLE_ID = 0 + val MAP_ID = 0 + val ATTEMPT_ID = 0 + val MAP_NUM = 1 + val PARTITION_NUM = 3 + val APP = s"app-${System.currentTimeMillis()}" + + val clientConf = new CelebornConf() + .set(CelebornConf.MASTER_ENDPOINTS.key, masterEndpoint) + .set(CelebornConf.CLIENT_PUSH_MAX_REVIVE_TIMES.key, "5") + .set(CelebornConf.SHUFFLE_PARTITION_SPLIT_THRESHOLD.key, "1K") + val lifecycleManager = new LifecycleManager(APP, clientConf) + val shuffleClient1 = new ShuffleClientImpl(APP, clientConf, UserIdentifier("mock", "mock")) + shuffleClient1.setupLifecycleManagerRef(lifecycleManager.self) + + // ping and reserveSlots + val DATA0 = RandomStringUtils.secure().next(10).getBytes(StandardCharsets.UTF_8) + shuffleClient1.pushData( + SHUFFLE_ID, + MAP_ID, + ATTEMPT_ID, + 0, + DATA0, + 0, + DATA0.length, + MAP_NUM, + PARTITION_NUM) + + // find the worker that has at least 2 partitions + val partitionLocationMap1 = + shuffleClient1.getPartitionLocation(SHUFFLE_ID, MAP_NUM, PARTITION_NUM) + + val worker2PartitionIds = mutable.Map.empty[WorkerInfo, ArrayBuffer[Int]] + for (partitionId <- 0 until PARTITION_NUM) { + val partitionLocation = partitionLocationMap1.get(partitionId) + assert(partitionLocation.getEpoch == 0) + worker2PartitionIds + .getOrElseUpdate(partitionLocation.getWorker, ArrayBuffer.empty) + .append(partitionId) + } + val partitions = worker2PartitionIds.values.filter(_.size >= 2).head + assert(partitions.length >= 2) + + // prepare merged data + val PARTITION0_DATA = RandomStringUtils.secure().next(1024).getBytes(StandardCharsets.UTF_8) + shuffleClient1.mergeData( + SHUFFLE_ID, + MAP_ID, + ATTEMPT_ID, + partitions(0), + PARTITION0_DATA, + 0, + PARTITION0_DATA.length, + MAP_NUM, + PARTITION_NUM) + + val PARTITION1_DATA = RandomStringUtils.secure().next(1024).getBytes(StandardCharsets.UTF_8) + shuffleClient1.mergeData( + SHUFFLE_ID, + MAP_ID, + ATTEMPT_ID, + partitions(1), + PARTITION1_DATA, + 0, + PARTITION1_DATA.length, + MAP_NUM, + PARTITION_NUM) + + // pushData until partition(0) is split + val GIANT_DATA = + RandomStringUtils.secure().next(1024 * 100).getBytes(StandardCharsets.UTF_8) + for (_ <- 0 until 5) { + shuffleClient1.pushData( + SHUFFLE_ID, + MAP_ID, + ATTEMPT_ID, + partitions(0), + GIANT_DATA, + 0, + GIANT_DATA.length, + MAP_NUM, + PARTITION_NUM) + } + + for (_ <- 0 until 5) { + val TRIGGER_DATA = RandomStringUtils.secure().next(1024).getBytes(StandardCharsets.UTF_8) + shuffleClient1.pushData( + SHUFFLE_ID, + MAP_ID, + ATTEMPT_ID, + partitions(0), + TRIGGER_DATA, + 0, + TRIGGER_DATA.length, + MAP_NUM, + PARTITION_NUM) + Thread.sleep(5 * 1000) // wait for flush + } + + assert( + partitionLocationMap1.get(partitions(0)).getEpoch > 0 + ) // means partition(0) will be split + + // push merged data, we expect that partition(0) will be split, while partition(1) will not be split + shuffleClient1.pushMergedData(SHUFFLE_ID, MAP_ID, ATTEMPT_ID) + shuffleClient1.mapperEnd(SHUFFLE_ID, MAP_ID, ATTEMPT_ID, MAP_NUM) + // partition(1) will not be split + assert(partitionLocationMap1.get(partitions(1)).getEpoch == 0) + + val shuffleClient2 = new ShuffleClientImpl(APP, clientConf, UserIdentifier("mock", "mock")) + shuffleClient2.setupLifecycleManagerRef(lifecycleManager.self) + val partitionLocationMap2 = + shuffleClient2.getPartitionLocation(SHUFFLE_ID, MAP_NUM, PARTITION_NUM) + + // lifecycleManager response with the latest epoch(epoch of partition(0) is larger than 0 caused by split) + assert(partitionLocationMap2.get(partitions(0)).getEpoch > 0) + // epoch of partition(1) is 0 without split + assert(partitionLocationMap2.get(partitions(1)).getEpoch == 0) + } +}