[CELEBORN-1822] Respond to RegisterShuffle with max epoch PartitionLocation to avoid revive
### What changes were proposed in this pull request? LifecycleManager respond to RegisterShuffle with max epoch PartitionLocation. ### Why are the changes needed? Newly spun up executors in a Spark job will still get the partitionLocations with the minEpoch of the celeborn lost worker. These executors will fail to connect to the lost worker and then call into revive to get the latest PartitionLocation for a given partitionId in `ChangePartitionManager.getLatestPartition()`. Return with max epoch can reduce such revive requests. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? UT. Closes #3135 from zaynt4606/clb1822. Authored-by: zhengtao <shuaizhentao.szt@alibaba-inc.com> Signed-off-by: Shuang <lvshuang.xjs@alibaba-inc.com>
This commit is contained in:
parent
c1fb94d6e3
commit
b5fab42604
@ -145,6 +145,7 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
|
||||
locations: util.List[PartitionLocation]): Unit = {
|
||||
val map = latestPartitionLocation.computeIfAbsent(shuffleId, newMapFunc)
|
||||
locations.asScala.foreach(location => map.put(location.getId, location))
|
||||
invalidateLatestMaxLocsCache(shuffleId)
|
||||
}
|
||||
|
||||
case class RegisterCallContext(context: RpcCallContext, partitionId: Int = -1) {
|
||||
@ -547,12 +548,12 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
|
||||
shuffleId,
|
||||
rpcContext,
|
||||
partitionId,
|
||||
getInitialLocs(shuffleId, p => p.getId == partitionId))
|
||||
getLatestLocs(shuffleId, p => p.getId == partitionId))
|
||||
case PartitionType.REDUCE =>
|
||||
if (rpcContext.isInstanceOf[LocalNettyRpcCallContext]) {
|
||||
context.reply(RegisterShuffleResponse(
|
||||
StatusCode.SUCCESS,
|
||||
getInitialLocs(shuffleId, p => p.getEpoch == 0)))
|
||||
getLatestLocs(shuffleId, _ => true)))
|
||||
} else {
|
||||
val cachedMsg = registerShuffleResponseRpcCache.get(
|
||||
shuffleId,
|
||||
@ -561,7 +562,7 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
|
||||
rpcContext.asInstanceOf[RemoteNettyRpcCallContext].nettyEnv.serialize(
|
||||
RegisterShuffleResponse(
|
||||
StatusCode.SUCCESS,
|
||||
getInitialLocs(shuffleId, p => p.getEpoch == 0)))
|
||||
getLatestLocs(shuffleId, _ => true)))
|
||||
}
|
||||
})
|
||||
rpcContext.asInstanceOf[RemoteNettyRpcCallContext].callback.onSuccess(cachedMsg)
|
||||
@ -580,13 +581,23 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
|
||||
}
|
||||
}
|
||||
|
||||
def getInitialLocs(
|
||||
def getLatestLocs(
|
||||
shuffleId: Int,
|
||||
partitionLocationFilter: PartitionLocation => Boolean): Array[PartitionLocation] = {
|
||||
workerSnapshots(shuffleId)
|
||||
.values()
|
||||
.asScala
|
||||
.flatMap(_.getAllPrimaryLocationsWithMinEpoch())
|
||||
.flatMap(
|
||||
_.getAllPrimaryLocationsWithMaxEpoch()
|
||||
) // get the partition with latest epoch of each worker
|
||||
.foldLeft(Map.empty[Int, PartitionLocation]) { (partitionLocationMap, partitionLocation) =>
|
||||
partitionLocationMap.get(partitionLocation.getId) match {
|
||||
case Some(existing) if existing.getEpoch >= partitionLocation.getEpoch =>
|
||||
partitionLocationMap
|
||||
case _ => partitionLocationMap + (partitionLocation.getId -> partitionLocation)
|
||||
}
|
||||
} // get the partition with latest epoch of all the partitions
|
||||
.values
|
||||
.filter(partitionLocationFilter)
|
||||
.toArray
|
||||
}
|
||||
@ -1824,6 +1835,10 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
|
||||
cancelShuffleCallback = Some(callback)
|
||||
}
|
||||
|
||||
def invalidateLatestMaxLocsCache(shuffleId: Int): Unit = {
|
||||
registerShuffleResponseRpcCache.invalidate(shuffleId)
|
||||
}
|
||||
|
||||
// Initialize at the end of LifecycleManager construction.
|
||||
initialize()
|
||||
|
||||
|
||||
@ -89,22 +89,22 @@ class ShufflePartitionLocationInfo(val workerInfo: WorkerInfo) {
|
||||
}
|
||||
}
|
||||
|
||||
def getAllPrimaryLocationsWithMinEpoch(): ArrayBuffer[PartitionLocation] = {
|
||||
def getAllPrimaryLocationsWithMaxEpoch(): ArrayBuffer[PartitionLocation] = {
|
||||
val set = new ArrayBuffer[PartitionLocation](primaryPartitionLocations.size())
|
||||
val locationsIterator = primaryPartitionLocations.values().iterator()
|
||||
while (locationsIterator.hasNext) {
|
||||
val locationIterator = locationsIterator.next().iterator()
|
||||
var min: PartitionLocation = null
|
||||
var max: PartitionLocation = null
|
||||
if (locationIterator.hasNext) {
|
||||
min = locationIterator.next();
|
||||
max = locationIterator.next();
|
||||
}
|
||||
while (locationIterator.hasNext) {
|
||||
val next = locationIterator.next()
|
||||
if (min.getEpoch > next.getEpoch) {
|
||||
min = next;
|
||||
if (max.getEpoch < next.getEpoch) {
|
||||
max = next;
|
||||
}
|
||||
}
|
||||
set += min;
|
||||
set += max;
|
||||
}
|
||||
set
|
||||
}
|
||||
|
||||
@ -60,9 +60,9 @@ class ShufflePartitionLocationInfoSuite extends CelebornFunSuite {
|
||||
assertEquals(shufflePartitionLocationInfo.getReplicaPartitions(Some(0)).size(), 1)
|
||||
assertEquals(shufflePartitionLocationInfo.getReplicaPartitions(Some(1)).size(), 1)
|
||||
|
||||
// test get min epoch
|
||||
val locations = shufflePartitionLocationInfo.getAllPrimaryLocationsWithMinEpoch()
|
||||
assertTrue(locations.contains(partitionLocation00) && locations.contains(partitionLocation11))
|
||||
// test get max epoch
|
||||
val locations = shufflePartitionLocationInfo.getAllPrimaryLocationsWithMaxEpoch()
|
||||
assertTrue(locations.contains(partitionLocation02) && locations.contains(partitionLocation12))
|
||||
|
||||
// test remove
|
||||
assertEquals(shufflePartitionLocationInfo.removePrimaryPartitions(0).size(), 3)
|
||||
|
||||
@ -0,0 +1,174 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.celeborn.tests.client
|
||||
|
||||
import java.nio.charset.StandardCharsets
|
||||
|
||||
import scala.collection.mutable
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
|
||||
import org.apache.commons.lang3.RandomStringUtils
|
||||
import org.scalatest.BeforeAndAfterAll
|
||||
import org.scalatest.funsuite.AnyFunSuite
|
||||
|
||||
import org.apache.celeborn.client.{LifecycleManager, ShuffleClientImpl}
|
||||
import org.apache.celeborn.common.CelebornConf
|
||||
import org.apache.celeborn.common.identity.UserIdentifier
|
||||
import org.apache.celeborn.common.internal.Logging
|
||||
import org.apache.celeborn.common.meta.WorkerInfo
|
||||
import org.apache.celeborn.common.protocol.PartitionLocation
|
||||
import org.apache.celeborn.service.deploy.MiniClusterFeature
|
||||
|
||||
class LifecycleManagerReserveSlotsSuite extends AnyFunSuite
|
||||
with Logging with MiniClusterFeature with BeforeAndAfterAll {
|
||||
|
||||
var masterEndpoint = ""
|
||||
override def beforeAll(): Unit = {
|
||||
val conf = Map("celeborn.worker.flusher.buffer.size" -> "0")
|
||||
|
||||
logInfo("test initialized , setup Celeborn mini cluster")
|
||||
val (master, _) = setupMiniClusterWithRandomPorts(conf, conf, 2)
|
||||
masterEndpoint = master.conf.get(CelebornConf.MASTER_ENDPOINTS.key)
|
||||
}
|
||||
|
||||
override def afterAll(): Unit = {
|
||||
logInfo("all test complete , stop Celeborn mini cluster")
|
||||
super.shutdownMiniCluster()
|
||||
}
|
||||
|
||||
test("LifecycleManager Respond to RegisterShuffle with max epoch PartitionLocation") {
|
||||
val SHUFFLE_ID = 0
|
||||
val MAP_ID = 0
|
||||
val ATTEMPT_ID = 0
|
||||
val MAP_NUM = 1
|
||||
val PARTITION_NUM = 3
|
||||
val APP = s"app-${System.currentTimeMillis()}"
|
||||
|
||||
val clientConf = new CelebornConf()
|
||||
.set(CelebornConf.MASTER_ENDPOINTS.key, masterEndpoint)
|
||||
.set(CelebornConf.CLIENT_PUSH_MAX_REVIVE_TIMES.key, "5")
|
||||
.set(CelebornConf.SHUFFLE_PARTITION_SPLIT_THRESHOLD.key, "1K")
|
||||
val lifecycleManager = new LifecycleManager(APP, clientConf)
|
||||
val shuffleClient1 = new ShuffleClientImpl(APP, clientConf, UserIdentifier("mock", "mock"))
|
||||
shuffleClient1.setupLifecycleManagerRef(lifecycleManager.self)
|
||||
|
||||
// ping and reserveSlots
|
||||
val DATA0 = RandomStringUtils.secure().next(10).getBytes(StandardCharsets.UTF_8)
|
||||
shuffleClient1.pushData(
|
||||
SHUFFLE_ID,
|
||||
MAP_ID,
|
||||
ATTEMPT_ID,
|
||||
0,
|
||||
DATA0,
|
||||
0,
|
||||
DATA0.length,
|
||||
MAP_NUM,
|
||||
PARTITION_NUM)
|
||||
|
||||
// find the worker that has at least 2 partitions
|
||||
val partitionLocationMap1 =
|
||||
shuffleClient1.getPartitionLocation(SHUFFLE_ID, MAP_NUM, PARTITION_NUM)
|
||||
|
||||
val worker2PartitionIds = mutable.Map.empty[WorkerInfo, ArrayBuffer[Int]]
|
||||
for (partitionId <- 0 until PARTITION_NUM) {
|
||||
val partitionLocation = partitionLocationMap1.get(partitionId)
|
||||
assert(partitionLocation.getEpoch == 0)
|
||||
worker2PartitionIds
|
||||
.getOrElseUpdate(partitionLocation.getWorker, ArrayBuffer.empty)
|
||||
.append(partitionId)
|
||||
}
|
||||
val partitions = worker2PartitionIds.values.filter(_.size >= 2).head
|
||||
assert(partitions.length >= 2)
|
||||
|
||||
// prepare merged data
|
||||
val PARTITION0_DATA = RandomStringUtils.secure().next(1024).getBytes(StandardCharsets.UTF_8)
|
||||
shuffleClient1.mergeData(
|
||||
SHUFFLE_ID,
|
||||
MAP_ID,
|
||||
ATTEMPT_ID,
|
||||
partitions(0),
|
||||
PARTITION0_DATA,
|
||||
0,
|
||||
PARTITION0_DATA.length,
|
||||
MAP_NUM,
|
||||
PARTITION_NUM)
|
||||
|
||||
val PARTITION1_DATA = RandomStringUtils.secure().next(1024).getBytes(StandardCharsets.UTF_8)
|
||||
shuffleClient1.mergeData(
|
||||
SHUFFLE_ID,
|
||||
MAP_ID,
|
||||
ATTEMPT_ID,
|
||||
partitions(1),
|
||||
PARTITION1_DATA,
|
||||
0,
|
||||
PARTITION1_DATA.length,
|
||||
MAP_NUM,
|
||||
PARTITION_NUM)
|
||||
|
||||
// pushData until partition(0) is split
|
||||
val GIANT_DATA =
|
||||
RandomStringUtils.secure().next(1024 * 100).getBytes(StandardCharsets.UTF_8)
|
||||
for (_ <- 0 until 5) {
|
||||
shuffleClient1.pushData(
|
||||
SHUFFLE_ID,
|
||||
MAP_ID,
|
||||
ATTEMPT_ID,
|
||||
partitions(0),
|
||||
GIANT_DATA,
|
||||
0,
|
||||
GIANT_DATA.length,
|
||||
MAP_NUM,
|
||||
PARTITION_NUM)
|
||||
}
|
||||
|
||||
for (_ <- 0 until 5) {
|
||||
val TRIGGER_DATA = RandomStringUtils.secure().next(1024).getBytes(StandardCharsets.UTF_8)
|
||||
shuffleClient1.pushData(
|
||||
SHUFFLE_ID,
|
||||
MAP_ID,
|
||||
ATTEMPT_ID,
|
||||
partitions(0),
|
||||
TRIGGER_DATA,
|
||||
0,
|
||||
TRIGGER_DATA.length,
|
||||
MAP_NUM,
|
||||
PARTITION_NUM)
|
||||
Thread.sleep(5 * 1000) // wait for flush
|
||||
}
|
||||
|
||||
assert(
|
||||
partitionLocationMap1.get(partitions(0)).getEpoch > 0
|
||||
) // means partition(0) will be split
|
||||
|
||||
// push merged data, we expect that partition(0) will be split, while partition(1) will not be split
|
||||
shuffleClient1.pushMergedData(SHUFFLE_ID, MAP_ID, ATTEMPT_ID)
|
||||
shuffleClient1.mapperEnd(SHUFFLE_ID, MAP_ID, ATTEMPT_ID, MAP_NUM)
|
||||
// partition(1) will not be split
|
||||
assert(partitionLocationMap1.get(partitions(1)).getEpoch == 0)
|
||||
|
||||
val shuffleClient2 = new ShuffleClientImpl(APP, clientConf, UserIdentifier("mock", "mock"))
|
||||
shuffleClient2.setupLifecycleManagerRef(lifecycleManager.self)
|
||||
val partitionLocationMap2 =
|
||||
shuffleClient2.getPartitionLocation(SHUFFLE_ID, MAP_NUM, PARTITION_NUM)
|
||||
|
||||
// lifecycleManager response with the latest epoch(epoch of partition(0) is larger than 0 caused by split)
|
||||
assert(partitionLocationMap2.get(partitions(0)).getEpoch > 0)
|
||||
// epoch of partition(1) is 0 without split
|
||||
assert(partitionLocationMap2.get(partitions(1)).getEpoch == 0)
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user