[CELEBORN-1720] Prevent stage re-run if another task attempt is running or successful

### What changes were proposed in this pull request?
Prevent stage re-run if another task attempt is running.

If a shuffle read task can not read the shuffle data and the task another attempt is running or successful, just throw the CelebornIOException instead of FetchFailureException.

The app will not failure before reach the task maxFailures.

<img width="1610" alt="image" src="https://github.com/user-attachments/assets/ffc6d80e-7c90-4729-adf7-6f8c46a8f226">

### Why are the changes needed?
I met below issue because I set the wrong parameters, I should set `spark.celeborn.data.io.connectTime=30s` but set the `spark.celeborn.data.io.connectionTime=30s`, and the Disk IO Utils was high at that time.

0. speculation is enabled
1. one task failed to fetch shuffle 0 in stage 5.
2. then it triggered the stage 0 re-run (stage 4)
3. then stage 5 retry, however, no task run in stage 5 (retry 1)
<img width="1212" alt="image" src="https://github.com/user-attachments/assets/555f36b0-0f0d-452d-af0b-1573601165e2">
4. because the speculation task succeeded, so no task in stage 5(retry 1)
<img width="1715" alt="image" src="https://github.com/user-attachments/assets/7f315149-1d5c-4c32-ae9b-87b099b3297f">

Due the stage re-run is heavy, so I wonder that, we should ignore the shuffle fetch failure, if there is another task attempt running.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?

UT for the SparkUtils method only, due it is impossible to add UT for speculation.

d5da49d56d/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala (L236-L244)

<img width="867" alt="image" src="https://github.com/user-attachments/assets/f93bd14f-0f34-4c81-a8db-13be511405d9">

For local master, it would not start the speculationScheduler.

d5da49d56d/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala (L322-L346)

<img width="1010" alt="image" src="https://github.com/user-attachments/assets/477729a4-2fc1-47e9-b128-522c6e2ceb48">

and it is also not allowed to launch speculative task on the same host.

Closes #2921 from turboFei/task_id.

Lead-authored-by: Wang, Fei <fwang12@ebay.com>
Co-authored-by: Fei Wang <cn.feiwang@gmail.com>
Signed-off-by: Shuang <lvshuang.xjs@alibaba-inc.com>
This commit is contained in:
Wang, Fei 2025-01-16 11:09:44 +08:00 committed by Shuang
parent 45450e793c
commit ad933815b6
20 changed files with 645 additions and 75 deletions

View File

@ -155,6 +155,7 @@ public class CelebornShuffleConsumer<K, V>
reduceId.getTaskID().getId(),
reduceId.getId(),
0,
0,
Integer.MAX_VALUE,
metricsCallback);
CelebornShuffleFetcher<K, V> shuffleReader =

View File

@ -101,6 +101,11 @@ public class SparkShuffleManager implements ShuffleManager {
if (celebornConf.clientStageRerunEnabled()) {
MapOutputTrackerMaster mapOutputTracker =
(MapOutputTrackerMaster) SparkEnv.get().mapOutputTracker();
lifecycleManager.registerReportTaskShuffleFetchFailurePreCheck(
taskId -> !SparkUtils.taskAnotherAttemptRunningOrSuccessful(taskId));
SparkUtils.addSparkListener(new ShuffleFetchFailureReportTaskCleanListener());
lifecycleManager.registerShuffleTrackerCallback(
shuffleId -> mapOutputTracker.unregisterAllMapOutput(shuffleId));
}

View File

@ -20,12 +20,19 @@ package org.apache.spark.shuffle.celeborn;
import java.io.IOException;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.LongAdder;
import java.util.stream.Collectors;
import scala.Option;
import scala.Some;
import scala.Tuple2;
import com.google.common.annotations.VisibleForTesting;
import org.apache.spark.BarrierTaskContext;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
@ -35,6 +42,10 @@ import org.apache.spark.scheduler.DAGScheduler;
import org.apache.spark.scheduler.MapStatus;
import org.apache.spark.scheduler.MapStatus$;
import org.apache.spark.scheduler.ShuffleMapStage;
import org.apache.spark.scheduler.SparkListener;
import org.apache.spark.scheduler.TaskInfo;
import org.apache.spark.scheduler.TaskSchedulerImpl;
import org.apache.spark.scheduler.TaskSetManager;
import org.apache.spark.sql.execution.UnsafeRowSerializer;
import org.apache.spark.sql.execution.metric.SQLMetric;
import org.apache.spark.storage.BlockManagerId;
@ -43,6 +54,7 @@ import org.slf4j.LoggerFactory;
import org.apache.celeborn.client.ShuffleClient;
import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.util.JavaUtils;
import org.apache.celeborn.common.util.Utils;
import org.apache.celeborn.reflect.DynFields;
@ -203,4 +215,135 @@ public class SparkUtils {
logger.error("Can not get active SparkContext, skip cancelShuffle.");
}
}
private static final DynFields.UnboundField<ConcurrentHashMap<Long, TaskSetManager>>
TASK_ID_TO_TASK_SET_MANAGER_FIELD =
DynFields.builder()
.hiddenImpl(TaskSchedulerImpl.class, "taskIdToTaskSetManager")
.defaultAlwaysNull()
.build();
private static final DynFields.UnboundField<scala.collection.mutable.HashMap<Long, TaskInfo>>
TASK_INFOS_FIELD =
DynFields.builder()
.hiddenImpl(TaskSetManager.class, "taskInfos")
.defaultAlwaysNull()
.build();
/**
* TaskSetManager - it is not designed to be used outside the spark scheduler. Please be careful.
*/
@VisibleForTesting
protected static TaskSetManager getTaskSetManager(TaskSchedulerImpl taskScheduler, long taskId) {
synchronized (taskScheduler) {
ConcurrentHashMap<Long, TaskSetManager> taskIdToTaskSetManager =
TASK_ID_TO_TASK_SET_MANAGER_FIELD.bind(taskScheduler).get();
return taskIdToTaskSetManager.get(taskId);
}
}
@VisibleForTesting
protected static Tuple2<TaskInfo, List<TaskInfo>> getTaskAttempts(
TaskSetManager taskSetManager, long taskId) {
if (taskSetManager != null) {
scala.Option<TaskInfo> taskInfoOption =
TASK_INFOS_FIELD.bind(taskSetManager).get().get(taskId);
if (taskInfoOption.isDefined()) {
TaskInfo taskInfo = taskInfoOption.get();
List<TaskInfo> taskAttempts =
scala.collection.JavaConverters.asJavaCollectionConverter(
taskSetManager.taskAttempts()[taskInfo.index()])
.asJavaCollection().stream()
.collect(Collectors.toList());
return Tuple2.apply(taskInfo, taskAttempts);
} else {
logger.error("Can not get TaskInfo for taskId: {}", taskId);
return null;
}
} else {
logger.error("Can not get TaskSetManager for taskId: {}", taskId);
return null;
}
}
protected static Map<String, Set<Long>> reportedStageShuffleFetchFailureTaskIds =
JavaUtils.newConcurrentHashMap();
protected static void removeStageReportedShuffleFetchFailureTaskIds(
int stageId, int stageAttemptId) {
reportedStageShuffleFetchFailureTaskIds.remove(stageId + "-" + stageAttemptId);
}
/**
* Only used to check for the shuffle fetch failure task whether another attempt is running or
* successful. If another attempt(excluding the reported shuffle fetch failure tasks in current
* stage) is running or successful, return true. Otherwise, return false.
*/
public static boolean taskAnotherAttemptRunningOrSuccessful(long taskId) {
SparkContext sparkContext = SparkContext$.MODULE$.getActive().getOrElse(null);
if (sparkContext == null) {
logger.error("Can not get active SparkContext.");
return false;
}
TaskSchedulerImpl taskScheduler = (TaskSchedulerImpl) sparkContext.taskScheduler();
synchronized (taskScheduler) {
TaskSetManager taskSetManager = getTaskSetManager(taskScheduler, taskId);
if (taskSetManager != null) {
int stageId = taskSetManager.stageId();
int stageAttemptId = taskSetManager.taskSet().stageAttemptId();
String stageUniqId = stageId + "-" + stageAttemptId;
Set<Long> reportedStageTaskIds =
reportedStageShuffleFetchFailureTaskIds.computeIfAbsent(
stageUniqId, k -> new HashSet<>());
reportedStageTaskIds.add(taskId);
Tuple2<TaskInfo, List<TaskInfo>> taskAttempts = getTaskAttempts(taskSetManager, taskId);
if (taskAttempts == null) return false;
TaskInfo taskInfo = taskAttempts._1();
for (TaskInfo ti : taskAttempts._2()) {
if (ti.taskId() != taskId) {
if (reportedStageTaskIds.contains(ti.taskId())) {
logger.info(
"StageId={} index={} taskId={} attempt={} another attempt {} has reported shuffle fetch failure, ignore it.",
stageId,
taskInfo.index(),
taskId,
taskInfo.attemptNumber(),
ti.attemptNumber());
} else if (ti.successful()) {
logger.info(
"StageId={} index={} taskId={} attempt={} another attempt {} is successful.",
stageId,
taskInfo.index(),
taskId,
taskInfo.attemptNumber(),
ti.attemptNumber());
return true;
} else if (ti.running()) {
logger.info(
"StageId={} index={} taskId={} attempt={} another attempt {} is running.",
stageId,
taskInfo.index(),
taskId,
taskInfo.attemptNumber(),
ti.attemptNumber());
return true;
}
}
}
return false;
} else {
logger.error("Can not get TaskSetManager for taskId: {}", taskId);
return false;
}
}
}
public static void addSparkListener(SparkListener listener) {
SparkContext sparkContext = SparkContext$.MODULE$.getActive().getOrElse(null);
if (sparkContext != null) {
sparkContext.addSparkListener(listener);
}
}
}

View File

@ -98,6 +98,7 @@ class CelebornShuffleReader[K, C](
shuffleId,
partitionId,
encodedAttemptId,
context.taskAttemptId(),
startMapIndex,
endMapIndex,
metricsCallback)
@ -124,7 +125,10 @@ class CelebornShuffleReader[K, C](
exceptionRef.get() match {
case ce @ (_: CelebornIOException | _: PartitionUnRetryAbleException) =>
if (handle.throwsFetchFailure &&
shuffleClient.reportShuffleFetchFailure(handle.shuffleId, shuffleId)) {
shuffleClient.reportShuffleFetchFailure(
handle.shuffleId,
shuffleId,
context.taskAttemptId())) {
throw new FetchFailedException(
null,
handle.shuffleId,
@ -158,7 +162,10 @@ class CelebornShuffleReader[K, C](
} catch {
case e @ (_: CelebornIOException | _: PartitionUnRetryAbleException) =>
if (handle.throwsFetchFailure &&
shuffleClient.reportShuffleFetchFailure(handle.shuffleId, shuffleId)) {
shuffleClient.reportShuffleFetchFailure(
handle.shuffleId,
shuffleId,
context.taskAttemptId())) {
throw new FetchFailedException(
null,
handle.shuffleId,

View File

@ -0,0 +1,28 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.shuffle.celeborn
import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted}
class ShuffleFetchFailureReportTaskCleanListener extends SparkListener {
override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = {
SparkUtils.removeStageReportedShuffleFetchFailureTaskIds(
stageCompleted.stageInfo.stageId,
stageCompleted.stageInfo.attemptNumber())
}
}

View File

@ -144,6 +144,10 @@ public class SparkShuffleManager implements ShuffleManager {
MapOutputTrackerMaster mapOutputTracker =
(MapOutputTrackerMaster) SparkEnv.get().mapOutputTracker();
lifecycleManager.registerReportTaskShuffleFetchFailurePreCheck(
taskId -> !SparkUtils.taskAnotherAttemptRunningOrSuccessful(taskId));
SparkUtils.addSparkListener(new ShuffleFetchFailureReportTaskCleanListener());
lifecycleManager.registerShuffleTrackerCallback(
shuffleId -> SparkUtils.unregisterAllMapOutput(mapOutputTracker, shuffleId));
}

View File

@ -17,12 +17,19 @@
package org.apache.spark.shuffle.celeborn;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.LongAdder;
import java.util.stream.Collectors;
import scala.Option;
import scala.Some;
import scala.Tuple2;
import com.google.common.annotations.VisibleForTesting;
import org.apache.spark.BarrierTaskContext;
import org.apache.spark.MapOutputTrackerMaster;
import org.apache.spark.SparkConf;
@ -33,6 +40,10 @@ import org.apache.spark.scheduler.DAGScheduler;
import org.apache.spark.scheduler.MapStatus;
import org.apache.spark.scheduler.MapStatus$;
import org.apache.spark.scheduler.ShuffleMapStage;
import org.apache.spark.scheduler.SparkListener;
import org.apache.spark.scheduler.TaskInfo;
import org.apache.spark.scheduler.TaskSchedulerImpl;
import org.apache.spark.scheduler.TaskSetManager;
import org.apache.spark.shuffle.ShuffleHandle;
import org.apache.spark.shuffle.ShuffleReadMetricsReporter;
import org.apache.spark.shuffle.ShuffleReader;
@ -46,6 +57,7 @@ import org.slf4j.LoggerFactory;
import org.apache.celeborn.client.ShuffleClient;
import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.util.JavaUtils;
import org.apache.celeborn.reflect.DynConstructors;
import org.apache.celeborn.reflect.DynFields;
import org.apache.celeborn.reflect.DynMethods;
@ -319,4 +331,135 @@ public class SparkUtils {
LOG.error("Can not get active SparkContext, skip cancelShuffle.");
}
}
private static final DynFields.UnboundField<ConcurrentHashMap<Long, TaskSetManager>>
TASK_ID_TO_TASK_SET_MANAGER_FIELD =
DynFields.builder()
.hiddenImpl(TaskSchedulerImpl.class, "taskIdToTaskSetManager")
.defaultAlwaysNull()
.build();
private static final DynFields.UnboundField<scala.collection.mutable.HashMap<Long, TaskInfo>>
TASK_INFOS_FIELD =
DynFields.builder()
.hiddenImpl(TaskSetManager.class, "taskInfos")
.defaultAlwaysNull()
.build();
/**
* TaskSetManager - it is not designed to be used outside the spark scheduler. Please be careful.
*/
@VisibleForTesting
protected static TaskSetManager getTaskSetManager(TaskSchedulerImpl taskScheduler, long taskId) {
synchronized (taskScheduler) {
ConcurrentHashMap<Long, TaskSetManager> taskIdToTaskSetManager =
TASK_ID_TO_TASK_SET_MANAGER_FIELD.bind(taskScheduler).get();
return taskIdToTaskSetManager.get(taskId);
}
}
@VisibleForTesting
protected static Tuple2<TaskInfo, List<TaskInfo>> getTaskAttempts(
TaskSetManager taskSetManager, long taskId) {
if (taskSetManager != null) {
scala.Option<TaskInfo> taskInfoOption =
TASK_INFOS_FIELD.bind(taskSetManager).get().get(taskId);
if (taskInfoOption.isDefined()) {
TaskInfo taskInfo = taskInfoOption.get();
List<TaskInfo> taskAttempts =
scala.collection.JavaConverters.asJavaCollectionConverter(
taskSetManager.taskAttempts()[taskInfo.index()])
.asJavaCollection().stream()
.collect(Collectors.toList());
return Tuple2.apply(taskInfo, taskAttempts);
} else {
LOG.error("Can not get TaskInfo for taskId: {}", taskId);
return null;
}
} else {
LOG.error("Can not get TaskSetManager for taskId: {}", taskId);
return null;
}
}
protected static Map<String, Set<Long>> reportedStageShuffleFetchFailureTaskIds =
JavaUtils.newConcurrentHashMap();
protected static void removeStageReportedShuffleFetchFailureTaskIds(
int stageId, int stageAttemptId) {
reportedStageShuffleFetchFailureTaskIds.remove(stageId + "-" + stageAttemptId);
}
/**
* Only used to check for the shuffle fetch failure task whether another attempt is running or
* successful. If another attempt(excluding the reported shuffle fetch failure tasks in current
* stage) is running or successful, return true. Otherwise, return false.
*/
public static boolean taskAnotherAttemptRunningOrSuccessful(long taskId) {
SparkContext sparkContext = SparkContext$.MODULE$.getActive().getOrElse(null);
if (sparkContext == null) {
LOG.error("Can not get active SparkContext.");
return false;
}
TaskSchedulerImpl taskScheduler = (TaskSchedulerImpl) sparkContext.taskScheduler();
synchronized (taskScheduler) {
TaskSetManager taskSetManager = getTaskSetManager(taskScheduler, taskId);
if (taskSetManager != null) {
int stageId = taskSetManager.stageId();
int stageAttemptId = taskSetManager.taskSet().stageAttemptId();
String stageUniqId = stageId + "-" + stageAttemptId;
Set<Long> reportedStageTaskIds =
reportedStageShuffleFetchFailureTaskIds.computeIfAbsent(
stageUniqId, k -> new HashSet<>());
reportedStageTaskIds.add(taskId);
Tuple2<TaskInfo, List<TaskInfo>> taskAttempts = getTaskAttempts(taskSetManager, taskId);
if (taskAttempts == null) return false;
TaskInfo taskInfo = taskAttempts._1();
for (TaskInfo ti : taskAttempts._2()) {
if (ti.taskId() != taskId) {
if (reportedStageTaskIds.contains(ti.taskId())) {
LOG.info(
"StageId={} index={} taskId={} attempt={} another attempt {} has reported shuffle fetch failure, ignore it.",
stageId,
taskInfo.index(),
taskId,
taskInfo.attemptNumber(),
ti.attemptNumber());
} else if (ti.successful()) {
LOG.info(
"StageId={} index={} taskId={} attempt={} another attempt {} is successful.",
stageId,
taskInfo.index(),
taskId,
taskInfo.attemptNumber(),
ti.attemptNumber());
return true;
} else if (ti.running()) {
LOG.info(
"StageId={} index={} taskId={} attempt={} another attempt {} is running.",
stageId,
taskInfo.index(),
taskId,
taskInfo.attemptNumber(),
ti.attemptNumber());
return true;
}
}
}
return false;
} else {
LOG.error("Can not get TaskSetManager for taskId: {}", taskId);
return false;
}
}
}
public static void addSparkListener(SparkListener listener) {
SparkContext sparkContext = SparkContext$.MODULE$.getActive().getOrElse(null);
if (sparkContext != null) {
sparkContext.addSparkListener(listener);
}
}
}

View File

@ -215,6 +215,7 @@ class CelebornShuffleReader[K, C](
handle.shuffleId,
partitionId,
encodedAttemptId,
context.taskAttemptId(),
startMapIndex,
endMapIndex,
if (throwsFetchFailure) ExceptionMakerHelper.SHUFFLE_FETCH_FAILURE_EXCEPTION_MAKER
@ -375,7 +376,7 @@ class CelebornShuffleReader[K, C](
partitionId: Int,
ce: Throwable) = {
if (throwsFetchFailure &&
shuffleClient.reportShuffleFetchFailure(appShuffleId, shuffleId)) {
shuffleClient.reportShuffleFetchFailure(appShuffleId, shuffleId, context.taskAttemptId())) {
logWarning(s"Handle fetch exceptions for ${shuffleId}-${partitionId}", ce)
throw new FetchFailedException(
null,

View File

@ -0,0 +1,28 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.shuffle.celeborn
import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted}
class ShuffleFetchFailureReportTaskCleanListener extends SparkListener {
override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = {
SparkUtils.removeStageReportedShuffleFetchFailureTaskIds(
stageCompleted.stageInfo.stageId,
stageCompleted.stageInfo.attemptNumber())
}
}

View File

@ -224,6 +224,7 @@ public abstract class ShuffleClient {
int shuffleId,
int partitionId,
int attemptNumber,
long taskId,
int startMapIndex,
int endMapIndex,
MetricsCallback metricsCallback)
@ -233,6 +234,7 @@ public abstract class ShuffleClient {
shuffleId,
partitionId,
attemptNumber,
taskId,
startMapIndex,
endMapIndex,
null,
@ -247,6 +249,7 @@ public abstract class ShuffleClient {
int appShuffleId,
int partitionId,
int attemptNumber,
long taskId,
int startMapIndex,
int endMapIndex,
ExceptionMaker exceptionMaker,
@ -276,7 +279,7 @@ public abstract class ShuffleClient {
* cleanup for spark app. It must be a sync call and make sure the cleanup is done, otherwise,
* incorrect shuffle data can be fetched in re-run tasks
*/
public abstract boolean reportShuffleFetchFailure(int appShuffleId, int shuffleId);
public abstract boolean reportShuffleFetchFailure(int appShuffleId, int shuffleId, long taskId);
/**
* Report barrier task failure. When any barrier task fails, all (shuffle) output for that stage

View File

@ -630,11 +630,12 @@ public class ShuffleClientImpl extends ShuffleClient {
}
@Override
public boolean reportShuffleFetchFailure(int appShuffleId, int shuffleId) {
public boolean reportShuffleFetchFailure(int appShuffleId, int shuffleId, long taskId) {
PbReportShuffleFetchFailure pbReportShuffleFetchFailure =
PbReportShuffleFetchFailure.newBuilder()
.setAppShuffleId(appShuffleId)
.setShuffleId(shuffleId)
.setTaskId(taskId)
.build();
PbReportShuffleFetchFailureResponse pbReportShuffleFetchFailureResponse =
lifecycleManagerRef.askSync(
@ -1845,6 +1846,7 @@ public class ShuffleClientImpl extends ShuffleClient {
int appShuffleId,
int partitionId,
int attemptNumber,
long taskId,
int startMapIndex,
int endMapIndex,
ExceptionMaker exceptionMaker,
@ -1883,6 +1885,7 @@ public class ShuffleClientImpl extends ShuffleClient {
streamHandlers,
mapAttempts,
attemptNumber,
taskId,
startMapIndex,
endMapIndex,
fetchExcludedWorkers,

View File

@ -55,6 +55,7 @@ public abstract class CelebornInputStream extends InputStream {
ArrayList<PbStreamHandler> streamHandlers,
int[] attempts,
int attemptNumber,
long taskId,
int startMapIndex,
int endMapIndex,
ConcurrentHashMap<String, Long> fetchExcludedWorkers,
@ -76,6 +77,7 @@ public abstract class CelebornInputStream extends InputStream {
streamHandlers,
attempts,
attemptNumber,
taskId,
startMapIndex,
endMapIndex,
fetchExcludedWorkers,
@ -129,6 +131,7 @@ public abstract class CelebornInputStream extends InputStream {
private ArrayList<PbStreamHandler> streamHandlers;
private int[] attempts;
private final int attemptNumber;
private final long taskId;
private final int startMapIndex;
private final int endMapIndex;
@ -178,6 +181,7 @@ public abstract class CelebornInputStream extends InputStream {
ArrayList<PbStreamHandler> streamHandlers,
int[] attempts,
int attemptNumber,
long taskId,
int startMapIndex,
int endMapIndex,
ConcurrentHashMap<String, Long> fetchExcludedWorkers,
@ -197,6 +201,7 @@ public abstract class CelebornInputStream extends InputStream {
}
this.attempts = attempts;
this.attemptNumber = attemptNumber;
this.taskId = taskId;
this.startMapIndex = startMapIndex;
this.endMapIndex = endMapIndex;
this.rangeReadFilter = conf.shuffleRangeReadFilterEnabled();
@ -654,7 +659,7 @@ public abstract class CelebornInputStream extends InputStream {
ioe = new IOException(e);
}
if (exceptionMaker != null) {
if (shuffleClient.reportShuffleFetchFailure(appShuffleId, shuffleId)) {
if (shuffleClient.reportShuffleFetchFailure(appShuffleId, shuffleId, taskId)) {
/*
* [[ExceptionMaker.makeException]], for spark applications with celeborn.client.spark.fetch.throwsFetchFailure enabled will result in creating
* a FetchFailedException; and that will make the TaskContext as failed with shuffle fetch issues - see SPARK-19276 for more.

View File

@ -445,8 +445,9 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
case pb: PbReportShuffleFetchFailure =>
val appShuffleId = pb.getAppShuffleId
val shuffleId = pb.getShuffleId
logDebug(s"Received ReportShuffleFetchFailure request, appShuffleId $appShuffleId shuffleId $shuffleId")
handleReportShuffleFetchFailure(context, appShuffleId, shuffleId)
val taskId = pb.getTaskId
logDebug(s"Received ReportShuffleFetchFailure request, appShuffleId $appShuffleId shuffleId $shuffleId taskId $taskId")
handleReportShuffleFetchFailure(context, appShuffleId, shuffleId, taskId)
case pb: PbReportBarrierStageAttemptFailure =>
val appShuffleId = pb.getAppShuffleId
@ -935,7 +936,8 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
private def handleReportShuffleFetchFailure(
context: RpcCallContext,
appShuffleId: Int,
shuffleId: Int): Unit = {
shuffleId: Int,
taskId: Long): Unit = {
val shuffleIds = shuffleIdMapping.get(appShuffleId)
if (shuffleIds == null) {
@ -945,9 +947,15 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
shuffleIds.synchronized {
shuffleIds.find(e => e._2._1 == shuffleId) match {
case Some((appShuffleIdentifier, (shuffleId, true))) =>
logInfo(s"handle fetch failure for appShuffleId $appShuffleId shuffleId $shuffleId")
ret = invokeAppShuffleTrackerCallback(appShuffleId)
shuffleIds.put(appShuffleIdentifier, (shuffleId, false))
if (invokeReportTaskShuffleFetchFailurePreCheck(taskId)) {
logInfo(s"handle fetch failure for appShuffleId $appShuffleId shuffleId $shuffleId")
ret = invokeAppShuffleTrackerCallback(appShuffleId)
shuffleIds.put(appShuffleIdentifier, (shuffleId, false))
} else {
logInfo(
s"Ignoring fetch failure from appShuffleIdentifier $appShuffleIdentifier shuffleId $shuffleId taskId $taskId")
ret = false
}
case Some((appShuffleIdentifier, (shuffleId, false))) =>
logInfo(
s"Ignoring fetch failure from appShuffleIdentifier $appShuffleIdentifier shuffleId $shuffleId, " +
@ -1010,6 +1018,20 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
}
}
private def invokeReportTaskShuffleFetchFailurePreCheck(taskId: Long): Boolean = {
reportTaskShuffleFetchFailurePreCheck match {
case Some(preCheck) =>
try {
preCheck.apply(taskId)
} catch {
case t: Throwable =>
logError(s"Error preChecking the shuffle fetch failure reported by task: $taskId", t)
false
}
case None => true
}
}
private def handleStageEnd(shuffleId: Int): Unit = {
// check whether shuffle has registered
if (!registeredShuffle.contains(shuffleId)) {
@ -1770,6 +1792,14 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
workerStatusTracker.registerWorkerStatusListener(workerStatusListener)
}
@volatile private var reportTaskShuffleFetchFailurePreCheck
: Option[java.util.function.Function[java.lang.Long, Boolean]] = None
def registerReportTaskShuffleFetchFailurePreCheck(preCheck: java.util.function.Function[
java.lang.Long,
Boolean]): Unit = {
reportTaskShuffleFetchFailurePreCheck = Some(preCheck)
}
@volatile private var appShuffleTrackerCallback: Option[Consumer[Integer]] = None
def registerShuffleTrackerCallback(callback: Consumer[Integer]): Unit = {
appShuffleTrackerCallback = Some(callback)

View File

@ -130,6 +130,7 @@ public class DummyShuffleClient extends ShuffleClient {
int appShuffleId,
int partitionId,
int attemptNumber,
long taskId,
int startMapIndex,
int endMapIndex,
ExceptionMaker exceptionMaker,
@ -179,7 +180,7 @@ public class DummyShuffleClient extends ShuffleClient {
}
@Override
public boolean reportShuffleFetchFailure(int appShuffleId, int shuffleId) {
public boolean reportShuffleFetchFailure(int appShuffleId, int shuffleId, long taskId) {
return true;
}

View File

@ -158,6 +158,7 @@ trait WithShuffleClientSuite extends CelebornFunSuite {
1,
1,
0,
0,
Integer.MAX_VALUE,
null,
null,
@ -173,6 +174,7 @@ trait WithShuffleClientSuite extends CelebornFunSuite {
3,
1,
0,
0,
Integer.MAX_VALUE,
null,
null,

View File

@ -391,6 +391,7 @@ message PbGetShuffleIdResponse {
message PbReportShuffleFetchFailure {
int32 appShuffleId = 1;
int32 shuffleId = 2;
int64 taskId = 3;
}
message PbReportShuffleFetchFailureResponse {

View File

@ -17,22 +17,19 @@
package org.apache.celeborn.tests.spark
import java.io.{File, IOException}
import java.io.IOException
import java.util.concurrent.atomic.AtomicBoolean
import org.apache.spark.{BarrierTaskContext, ShuffleDependency, SparkConf, SparkContextHelper, SparkException, TaskContext}
import org.apache.spark.celeborn.ExceptionMakerHelper
import org.apache.spark.rdd.RDD
import org.apache.spark.shuffle.ShuffleHandle
import org.apache.spark.shuffle.celeborn.{CelebornShuffleHandle, ShuffleManagerHook, SparkShuffleManager, SparkUtils, TestCelebornShuffleManager}
import org.apache.spark.shuffle.celeborn.{SparkShuffleManager, SparkUtils, TestCelebornShuffleManager}
import org.apache.spark.sql.SparkSession
import org.scalatest.BeforeAndAfterEach
import org.scalatest.funsuite.AnyFunSuite
import org.apache.celeborn.client.ShuffleClient
import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.protocol.ShuffleMode
import org.apache.celeborn.service.deploy.worker.Worker
class CelebornFetchFailureSuite extends AnyFunSuite
with SparkTestBase
@ -46,57 +43,6 @@ class CelebornFetchFailureSuite extends AnyFunSuite
System.gc()
}
var workerDirs: Seq[String] = Seq.empty
override def createWorker(map: Map[String, String]): Worker = {
val storageDir = createTmpDir()
this.synchronized {
workerDirs = workerDirs :+ storageDir
}
super.createWorker(map, storageDir)
}
class ShuffleReaderGetHook(conf: CelebornConf) extends ShuffleManagerHook {
var executed: AtomicBoolean = new AtomicBoolean(false)
val lock = new Object
override def exec(
handle: ShuffleHandle,
startPartition: Int,
endPartition: Int,
context: TaskContext): Unit = {
if (executed.get() == true) return
lock.synchronized {
handle match {
case h: CelebornShuffleHandle[_, _, _] => {
val appUniqueId = h.appUniqueId
val shuffleClient = ShuffleClient.get(
h.appUniqueId,
h.lifecycleManagerHost,
h.lifecycleManagerPort,
conf,
h.userIdentifier,
h.extension)
val celebornShuffleId = SparkUtils.celebornShuffleId(shuffleClient, h, context, false)
val allFiles = workerDirs.map(dir => {
new File(s"$dir/celeborn-worker/shuffle_data/$appUniqueId/$celebornShuffleId")
})
val datafile = allFiles.filter(_.exists())
.flatMap(_.listFiles().iterator).headOption
datafile match {
case Some(file) => file.delete()
case None => throw new RuntimeException("unexpected, there must be some data file" +
s" under ${workerDirs.mkString(",")}")
}
}
case _ => throw new RuntimeException("unexpected, only support RssShuffleHandle here")
}
executed.set(true)
}
}
}
test("celeborn spark integration test - Fetch Failure") {
if (Spark3OrNewer) {
val sparkConf = new SparkConf().setAppName("rss-demo").setMaster("local[2,3]")
@ -111,7 +57,7 @@ class CelebornFetchFailureSuite extends AnyFunSuite
.getOrCreate()
val celebornConf = SparkUtils.fromSparkConf(sparkSession.sparkContext.getConf)
val hook = new ShuffleReaderGetHook(celebornConf)
val hook = new ShuffleReaderFetchFailureGetHook(celebornConf)
TestCelebornShuffleManager.registerReaderGetHook(hook)
val value = Range(1, 10000).mkString(",")
@ -184,7 +130,7 @@ class CelebornFetchFailureSuite extends AnyFunSuite
.getOrCreate()
val celebornConf = SparkUtils.fromSparkConf(sparkSession.sparkContext.getConf)
val hook = new ShuffleReaderGetHook(celebornConf)
val hook = new ShuffleReaderFetchFailureGetHook(celebornConf)
TestCelebornShuffleManager.registerReaderGetHook(hook)
import sparkSession.implicits._
@ -215,7 +161,7 @@ class CelebornFetchFailureSuite extends AnyFunSuite
.getOrCreate()
val celebornConf = SparkUtils.fromSparkConf(sparkSession.sparkContext.getConf)
val hook = new ShuffleReaderGetHook(celebornConf)
val hook = new ShuffleReaderFetchFailureGetHook(celebornConf)
TestCelebornShuffleManager.registerReaderGetHook(hook)
val sc = sparkSession.sparkContext
@ -255,7 +201,7 @@ class CelebornFetchFailureSuite extends AnyFunSuite
.getOrCreate()
val celebornConf = SparkUtils.fromSparkConf(sparkSession.sparkContext.getConf)
val hook = new ShuffleReaderGetHook(celebornConf)
val hook = new ShuffleReaderFetchFailureGetHook(celebornConf)
TestCelebornShuffleManager.registerReaderGetHook(hook)
val sc = sparkSession.sparkContext

View File

@ -17,19 +17,26 @@
package org.apache.celeborn.tests.spark
import java.io.File
import java.util.concurrent.atomic.AtomicBoolean
import scala.util.Random
import org.apache.spark.SPARK_VERSION
import org.apache.spark.SparkConf
import org.apache.spark.{SPARK_VERSION, SparkConf, TaskContext}
import org.apache.spark.shuffle.ShuffleHandle
import org.apache.spark.shuffle.celeborn.{CelebornShuffleHandle, ShuffleManagerHook, SparkUtils}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.internal.SQLConf
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}
import org.scalatest.funsuite.AnyFunSuite
import org.apache.celeborn.client.ShuffleClient
import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.CelebornConf._
import org.apache.celeborn.common.internal.Logging
import org.apache.celeborn.common.protocol.ShuffleMode
import org.apache.celeborn.service.deploy.MiniClusterFeature
import org.apache.celeborn.service.deploy.worker.Worker
trait SparkTestBase extends AnyFunSuite
with Logging with MiniClusterFeature with BeforeAndAfterAll with BeforeAndAfterEach {
@ -52,6 +59,16 @@ trait SparkTestBase extends AnyFunSuite
shutdownMiniCluster()
}
var workerDirs: Seq[String] = Seq.empty
override def createWorker(map: Map[String, String]): Worker = {
val storageDir = createTmpDir()
this.synchronized {
workerDirs = workerDirs :+ storageDir
}
super.createWorker(map, storageDir)
}
def updateSparkConf(sparkConf: SparkConf, mode: ShuffleMode): SparkConf = {
sparkConf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
sparkConf.set(
@ -98,4 +115,45 @@ trait SparkTestBase extends AnyFunSuite
val outMap = result.collect().map(row => row.getString(0) -> row.getLong(1)).toMap
outMap
}
class ShuffleReaderFetchFailureGetHook(conf: CelebornConf) extends ShuffleManagerHook {
var executed: AtomicBoolean = new AtomicBoolean(false)
val lock = new Object
override def exec(
handle: ShuffleHandle,
startPartition: Int,
endPartition: Int,
context: TaskContext): Unit = {
if (executed.get() == true) return
lock.synchronized {
handle match {
case h: CelebornShuffleHandle[_, _, _] => {
val appUniqueId = h.appUniqueId
val shuffleClient = ShuffleClient.get(
h.appUniqueId,
h.lifecycleManagerHost,
h.lifecycleManagerPort,
conf,
h.userIdentifier,
h.extension)
val celebornShuffleId = SparkUtils.celebornShuffleId(shuffleClient, h, context, false)
val allFiles = workerDirs.map(dir => {
new File(s"$dir/celeborn-worker/shuffle_data/$appUniqueId/$celebornShuffleId")
})
val datafile = allFiles.filter(_.exists())
.flatMap(_.listFiles().iterator).sortBy(_.getName).headOption
datafile match {
case Some(file) => file.delete()
case None => throw new RuntimeException("unexpected, there must be some data file" +
s" under ${workerDirs.mkString(",")}")
}
}
case _ => throw new RuntimeException("unexpected, only support RssShuffleHandle here")
}
executed.set(true)
}
}
}
}

View File

@ -0,0 +1,160 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.shuffle.celeborn
import scala.collection.JavaConverters._
import org.apache.spark.SparkConf
import org.apache.spark.scheduler.TaskSchedulerImpl
import org.apache.spark.sql.SparkSession
import org.scalatest.BeforeAndAfterEach
import org.scalatest.concurrent.Eventually.eventually
import org.scalatest.concurrent.Futures.{interval, timeout}
import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.time.SpanSugar.convertIntToGrainOfTime
import org.apache.celeborn.client.ShuffleClient
import org.apache.celeborn.common.protocol.ShuffleMode
import org.apache.celeborn.tests.spark.SparkTestBase
class SparkUtilsSuite extends AnyFunSuite
with SparkTestBase
with BeforeAndAfterEach {
override def beforeEach(): Unit = {
ShuffleClient.reset()
}
override def afterEach(): Unit = {
System.gc()
}
test("check if fetch failure task another attempt is running or successful") {
if (Spark3OrNewer) {
val sparkConf = new SparkConf().setAppName("rss-demo").setMaster("local[2,3]")
val sparkSession = SparkSession.builder()
.config(updateSparkConf(sparkConf, ShuffleMode.HASH))
.config("spark.sql.shuffle.partitions", 2)
.config("spark.celeborn.shuffle.forceFallback.partition.enabled", false)
.config("spark.celeborn.client.spark.fetch.throwsFetchFailure", "true")
.config(
"spark.shuffle.manager",
"org.apache.spark.shuffle.celeborn.TestCelebornShuffleManager")
.getOrCreate()
val celebornConf = SparkUtils.fromSparkConf(sparkSession.sparkContext.getConf)
val hook = new ShuffleReaderFetchFailureGetHook(celebornConf)
TestCelebornShuffleManager.registerReaderGetHook(hook)
try {
val sc = sparkSession.sparkContext
val jobThread = new Thread {
override def run(): Unit = {
try {
val value = Range(1, 10000).mkString(",")
sc.parallelize(1 to 10000, 2)
.map { i => (i, value) }
.groupByKey(10)
.mapPartitions { iter =>
Thread.sleep(3000)
iter
}.collect()
} catch {
case _: InterruptedException =>
}
}
}
jobThread.start()
val taskScheduler = sc.taskScheduler.asInstanceOf[TaskSchedulerImpl]
eventually(timeout(30.seconds), interval(0.milliseconds)) {
assert(hook.executed.get() == true)
val reportedTaskId =
SparkUtils.reportedStageShuffleFetchFailureTaskIds.values().asScala.flatMap(
_.asScala).head
val taskSetManager = SparkUtils.getTaskSetManager(taskScheduler, reportedTaskId)
assert(taskSetManager != null)
assert(SparkUtils.getTaskAttempts(taskSetManager, reportedTaskId)._2.size() == 1)
assert(!SparkUtils.taskAnotherAttemptRunningOrSuccessful(reportedTaskId))
}
sparkSession.sparkContext.cancelAllJobs()
jobThread.interrupt()
eventually(timeout(3.seconds), interval(100.milliseconds)) {
assert(SparkUtils.reportedStageShuffleFetchFailureTaskIds.size() == 0)
}
} finally {
sparkSession.stop()
}
}
}
test("getTaskSetManager/getTaskAttempts test") {
val sparkConf = new SparkConf().setAppName("rss-demo").setMaster("local[2,3]")
val sparkSession = SparkSession.builder()
.config(updateSparkConf(sparkConf, ShuffleMode.HASH))
.config("spark.sql.shuffle.partitions", 2)
.config("spark.celeborn.shuffle.forceFallback.partition.enabled", false)
.config("spark.celeborn.client.spark.fetch.throwsFetchFailure", "true")
.config(
"spark.shuffle.manager",
"org.apache.spark.shuffle.celeborn.TestCelebornShuffleManager")
.getOrCreate()
try {
val sc = sparkSession.sparkContext
val jobThread = new Thread {
override def run(): Unit = {
try {
sc.parallelize(1 to 100, 2)
.repartition(1)
.mapPartitions { iter =>
Thread.sleep(3000)
iter
}.collect()
} catch {
case _: InterruptedException =>
}
}
}
jobThread.start()
val taskScheduler = sc.taskScheduler.asInstanceOf[TaskSchedulerImpl]
eventually(timeout(3.seconds), interval(100.milliseconds)) {
val taskId = 0
val taskSetManager = SparkUtils.getTaskSetManager(taskScheduler, taskId)
assert(taskSetManager != null)
assert(SparkUtils.getTaskAttempts(taskSetManager, taskId)._2.size() == 1)
assert(!SparkUtils.taskAnotherAttemptRunningOrSuccessful(taskId))
assert(SparkUtils.reportedStageShuffleFetchFailureTaskIds.size() == 1)
}
sparkSession.sparkContext.cancelAllJobs()
jobThread.interrupt()
eventually(timeout(3.seconds), interval(100.milliseconds)) {
assert(SparkUtils.reportedStageShuffleFetchFailureTaskIds.size() == 0)
}
} finally {
sparkSession.stop()
}
}
}

View File

@ -110,6 +110,7 @@ trait ReadWriteTestBase extends AnyFunSuite
0,
0,
0,
0,
Integer.MAX_VALUE,
null,
null,