[CELEBORN-1071] Support stage rerun for shuffle data lost

### What changes were proposed in this pull request?
If shuffle data is lost and enabled throw fetch failures, triggered stage rerun.

### Why are the changes needed?
Rerun stage for shuffle lost scenarios.

### Does this PR introduce _any_ user-facing change?
NO.

### How was this patch tested?
GA.

Closes #2894 from FMX/b1701.

Authored-by: mingji <fengmingxiao.fmx@alibaba-inc.com>
Signed-off-by: Shuang <lvshuang.xjs@alibaba-inc.com>
This commit is contained in:
mingji 2024-11-12 10:07:26 +08:00 committed by Shuang
parent 7d1da5e915
commit 42d5d426a1
5 changed files with 130 additions and 28 deletions

View File

@ -34,6 +34,7 @@ import org.apache.spark.util.CompletionIterator
import org.apache.spark.util.collection.ExternalSorter
import org.apache.celeborn.client.ShuffleClient
import org.apache.celeborn.client.ShuffleClientImpl.ReduceFileGroups
import org.apache.celeborn.client.read.{CelebornInputStream, MetricsCallback}
import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.exception.{CelebornIOException, PartitionUnRetryAbleException}
@ -104,8 +105,16 @@ class CelebornShuffleReader[K, C](
val localFetchEnabled = conf.enableReadLocalShuffleFile
val localHostAddress = Utils.localHostName(conf)
val shuffleKey = Utils.makeShuffleKey(handle.appUniqueId, shuffleId)
// startPartition is irrelevant
val fileGroups = shuffleClient.updateFileGroup(shuffleId, startPartition)
var fileGroups: ReduceFileGroups = null
try {
// startPartition is irrelevant
fileGroups = shuffleClient.updateFileGroup(shuffleId, startPartition)
} catch {
case ce @ (_: CelebornIOException | _: PartitionUnRetryAbleException) =>
handleFetchExceptions(shuffleId, 0, ce)
case e: Throwable => throw e
}
// host-port -> (TransportClient, PartitionLocation Array, PbOpenStreamList)
val workerRequestMap = new util.HashMap[
String,
@ -245,18 +254,7 @@ class CelebornShuffleReader[K, C](
if (exceptionRef.get() != null) {
exceptionRef.get() match {
case ce @ (_: CelebornIOException | _: PartitionUnRetryAbleException) =>
if (throwsFetchFailure &&
shuffleClient.reportShuffleFetchFailure(handle.shuffleId, shuffleId)) {
throw new FetchFailedException(
null,
handle.shuffleId,
-1,
-1,
partitionId,
SparkUtils.FETCH_FAILURE_ERROR_MSG + handle.shuffleId + "/" + shuffleId,
ce)
} else
throw ce
handleFetchExceptions(handle.shuffleId, partitionId, ce)
case e => throw e
}
}
@ -291,18 +289,7 @@ class CelebornShuffleReader[K, C](
iter
} catch {
case e @ (_: CelebornIOException | _: PartitionUnRetryAbleException) =>
if (throwsFetchFailure &&
shuffleClient.reportShuffleFetchFailure(handle.shuffleId, shuffleId)) {
throw new FetchFailedException(
null,
handle.shuffleId,
-1,
-1,
partitionId,
SparkUtils.FETCH_FAILURE_ERROR_MSG + handle.shuffleId + "/" + shuffleId,
e)
} else
throw e
handleFetchExceptions(handle.shuffleId, partitionId, e)
}
}
@ -382,6 +369,22 @@ class CelebornShuffleReader[K, C](
}
}
private def handleFetchExceptions(shuffleId: Int, partitionId: Int, ce: Throwable) = {
if (throwsFetchFailure &&
shuffleClient.reportShuffleFetchFailure(handle.shuffleId, shuffleId)) {
logWarning(s"Handle fetch exceptions for ${shuffleId}-${partitionId}", ce)
throw new FetchFailedException(
null,
handle.shuffleId,
-1,
-1,
partitionId,
SparkUtils.FETCH_FAILURE_ERROR_MSG + handle.shuffleId + "/" + shuffleId,
ce)
} else
throw ce
}
def newSerializerInstance(dep: ShuffleDependency[K, _, C]): SerializerInstance = {
dep.serializer.newInstance()
}

View File

@ -139,7 +139,7 @@ public class ShuffleClientImpl extends ShuffleClient {
private final ReviveManager reviveManager;
protected static class ReduceFileGroups {
public static class ReduceFileGroups {
public Map<Integer, Set<PartitionLocation>> partitionGroups;
public int[] mapAttempts;
public Set<Integer> partitionIds;

View File

@ -68,6 +68,8 @@ class ReducePartitionCommitHandler(
private val inProcessStageEndShuffleSet = ConcurrentHashMap.newKeySet[Int]()
private val shuffleMapperAttempts = JavaUtils.newConcurrentHashMap[Int, Array[Int]]()
private val stageEndTimeout = conf.clientPushStageEndTimeout
private val mockShuffleLost = conf.testMockShuffleLost
private val mockShuffleLostShuffle = conf.testMockShuffleLostShuffle
private val rpcCacheSize = conf.clientRpcCacheSize
private val rpcCacheConcurrencyLevel = conf.clientRpcCacheConcurrencyLevel
@ -94,7 +96,11 @@ class ReducePartitionCommitHandler(
}
override def isStageDataLost(shuffleId: Int): Boolean = {
dataLostShuffleSet.contains(shuffleId)
if (mockShuffleLost) {
mockShuffleLostShuffle == shuffleId
} else {
dataLostShuffleSet.contains(shuffleId)
}
}
override def isPartitionInProcess(shuffleId: Int, partitionId: Int): Boolean = {

View File

@ -1323,6 +1323,8 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se
def testFetchFailure: Boolean = get(TEST_CLIENT_FETCH_FAILURE)
def testMockDestroySlotsFailure: Boolean = get(TEST_CLIENT_MOCK_DESTROY_SLOTS_FAILURE)
def testMockCommitFilesFailure: Boolean = get(TEST_CLIENT_MOCK_COMMIT_FILES_FAILURE)
def testMockShuffleLost: Boolean = get(TEST_CLIENT_MOCK_SHUFFLE_LOST)
def testMockShuffleLostShuffle: Int = get(TEST_CLIENT_MOCK_SHUFFLE_LOST_SHUFFLE)
def testPushPrimaryDataTimeout: Boolean = get(TEST_CLIENT_PUSH_PRIMARY_DATA_TIMEOUT)
def testPushReplicaDataTimeout: Boolean = get(TEST_WORKER_PUSH_REPLICA_DATA_TIMEOUT)
def testRetryRevive: Boolean = get(TEST_CLIENT_RETRY_REVIVE)
@ -4257,6 +4259,26 @@ object CelebornConf extends Logging {
.booleanConf
.createWithDefault(false)
val TEST_CLIENT_MOCK_SHUFFLE_LOST: ConfigEntry[Boolean] =
buildConf("celeborn.test.client.mockShuffleLost")
.internal
.categories("test", "client")
.doc("Mock shuffle lost.")
.version("0.5.2")
.internal
.booleanConf
.createWithDefault(false)
val TEST_CLIENT_MOCK_SHUFFLE_LOST_SHUFFLE: ConfigEntry[Int] =
buildConf("celeborn.test.client.mockShuffleLostShuffle")
.internal
.categories("test", "client")
.doc("Mock shuffle lost for shuffle")
.version("0.5.2")
.internal
.intConf
.createWithDefault(0)
val CLIENT_PUSH_REPLICATE_ENABLED: ConfigEntry[Boolean] =
buildConf("celeborn.client.push.replicate.enabled")
.withAlternative("celeborn.push.replicate.enabled")

View File

@ -0,0 +1,71 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.celeborn.tests.spark
import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
import org.scalatest.BeforeAndAfterEach
import org.scalatest.funsuite.AnyFunSuite
import org.apache.celeborn.client.ShuffleClient
import org.apache.celeborn.common.protocol.ShuffleMode
class CelebornShuffleLostSuite extends AnyFunSuite
with SparkTestBase
with BeforeAndAfterEach {
override def beforeEach(): Unit = {
ShuffleClient.reset()
}
override def afterEach(): Unit = {
System.gc()
}
test("celeborn shuffle data lost - hash") {
val sparkConf = new SparkConf().setAppName("celeborn-demo").setMaster("local[2]")
val sparkSession = SparkSession.builder().config(sparkConf).getOrCreate()
val combineResult = combine(sparkSession)
val groupbyResult = groupBy(sparkSession)
val repartitionResult = repartition(sparkSession)
val sqlResult = runsql(sparkSession)
Thread.sleep(3000L)
sparkSession.stop()
val conf = updateSparkConf(sparkConf, ShuffleMode.HASH)
conf.set("spark.celeborn.client.spark.fetch.throwsFetchFailure", "true")
conf.set("spark.celeborn.test.client.mockShuffleLost", "true")
val celebornSparkSession = SparkSession.builder()
.config(conf)
.getOrCreate()
val celebornCombineResult = combine(celebornSparkSession)
val celebornGroupbyResult = groupBy(celebornSparkSession)
val celebornRepartitionResult = repartition(celebornSparkSession)
val celebornSqlResult = runsql(celebornSparkSession)
assert(combineResult.equals(celebornCombineResult))
assert(groupbyResult.equals(celebornGroupbyResult))
assert(repartitionResult.equals(celebornRepartitionResult))
assert(combineResult.equals(celebornCombineResult))
assert(sqlResult.equals(celebornSqlResult))
celebornSparkSession.stop()
}
}