[CELEBORN-1720] Prevent stage re-run if another task attempt is running or successful
### What changes were proposed in this pull request? Prevent stage re-run if another task attempt is running. If a shuffle read task can not read the shuffle data and the task another attempt is running or successful, just throw the CelebornIOException instead of FetchFailureException. The app will not failure before reach the task maxFailures. <img width="1610" alt="image" src="https://github.com/user-attachments/assets/ffc6d80e-7c90-4729-adf7-6f8c46a8f226"> ### Why are the changes needed? I met below issue because I set the wrong parameters, I should set `spark.celeborn.data.io.connectTime=30s` but set the `spark.celeborn.data.io.connectionTime=30s`, and the Disk IO Utils was high at that time. 0. speculation is enabled 1. one task failed to fetch shuffle 0 in stage 5. 2. then it triggered the stage 0 re-run (stage 4) 3. then stage 5 retry, however, no task run in stage 5 (retry 1) <img width="1212" alt="image" src="https://github.com/user-attachments/assets/555f36b0-0f0d-452d-af0b-1573601165e2"> 4. because the speculation task succeeded, so no task in stage 5(retry 1) <img width="1715" alt="image" src="https://github.com/user-attachments/assets/7f315149-1d5c-4c32-ae9b-87b099b3297f"> Due the stage re-run is heavy, so I wonder that, we should ignore the shuffle fetch failure, if there is another task attempt running. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? UT for the SparkUtils method only, due it is impossible to add UT for speculation.d5da49d56d/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala (L236-L244)<img width="867" alt="image" src="https://github.com/user-attachments/assets/f93bd14f-0f34-4c81-a8db-13be511405d9"> For local master, it would not start the speculationScheduler.d5da49d56d/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala (L322-L346)<img width="1010" alt="image" src="https://github.com/user-attachments/assets/477729a4-2fc1-47e9-b128-522c6e2ceb48"> and it is also not allowed to launch speculative task on the same host. Closes #2921 from turboFei/task_id. Lead-authored-by: Wang, Fei <fwang12@ebay.com> Co-authored-by: Fei Wang <cn.feiwang@gmail.com> Signed-off-by: Shuang <lvshuang.xjs@alibaba-inc.com>
This commit is contained in:
parent
45450e793c
commit
ad933815b6
@ -155,6 +155,7 @@ public class CelebornShuffleConsumer<K, V>
|
||||
reduceId.getTaskID().getId(),
|
||||
reduceId.getId(),
|
||||
0,
|
||||
0,
|
||||
Integer.MAX_VALUE,
|
||||
metricsCallback);
|
||||
CelebornShuffleFetcher<K, V> shuffleReader =
|
||||
|
||||
@ -101,6 +101,11 @@ public class SparkShuffleManager implements ShuffleManager {
|
||||
if (celebornConf.clientStageRerunEnabled()) {
|
||||
MapOutputTrackerMaster mapOutputTracker =
|
||||
(MapOutputTrackerMaster) SparkEnv.get().mapOutputTracker();
|
||||
|
||||
lifecycleManager.registerReportTaskShuffleFetchFailurePreCheck(
|
||||
taskId -> !SparkUtils.taskAnotherAttemptRunningOrSuccessful(taskId));
|
||||
SparkUtils.addSparkListener(new ShuffleFetchFailureReportTaskCleanListener());
|
||||
|
||||
lifecycleManager.registerShuffleTrackerCallback(
|
||||
shuffleId -> mapOutputTracker.unregisterAllMapOutput(shuffleId));
|
||||
}
|
||||
|
||||
@ -20,12 +20,19 @@ package org.apache.spark.shuffle.celeborn;
|
||||
import java.io.IOException;
|
||||
import java.lang.reflect.Field;
|
||||
import java.lang.reflect.Method;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.atomic.LongAdder;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import scala.Option;
|
||||
import scala.Some;
|
||||
import scala.Tuple2;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import org.apache.spark.BarrierTaskContext;
|
||||
import org.apache.spark.SparkConf;
|
||||
import org.apache.spark.SparkContext;
|
||||
@ -35,6 +42,10 @@ import org.apache.spark.scheduler.DAGScheduler;
|
||||
import org.apache.spark.scheduler.MapStatus;
|
||||
import org.apache.spark.scheduler.MapStatus$;
|
||||
import org.apache.spark.scheduler.ShuffleMapStage;
|
||||
import org.apache.spark.scheduler.SparkListener;
|
||||
import org.apache.spark.scheduler.TaskInfo;
|
||||
import org.apache.spark.scheduler.TaskSchedulerImpl;
|
||||
import org.apache.spark.scheduler.TaskSetManager;
|
||||
import org.apache.spark.sql.execution.UnsafeRowSerializer;
|
||||
import org.apache.spark.sql.execution.metric.SQLMetric;
|
||||
import org.apache.spark.storage.BlockManagerId;
|
||||
@ -43,6 +54,7 @@ import org.slf4j.LoggerFactory;
|
||||
|
||||
import org.apache.celeborn.client.ShuffleClient;
|
||||
import org.apache.celeborn.common.CelebornConf;
|
||||
import org.apache.celeborn.common.util.JavaUtils;
|
||||
import org.apache.celeborn.common.util.Utils;
|
||||
import org.apache.celeborn.reflect.DynFields;
|
||||
|
||||
@ -203,4 +215,135 @@ public class SparkUtils {
|
||||
logger.error("Can not get active SparkContext, skip cancelShuffle.");
|
||||
}
|
||||
}
|
||||
|
||||
private static final DynFields.UnboundField<ConcurrentHashMap<Long, TaskSetManager>>
|
||||
TASK_ID_TO_TASK_SET_MANAGER_FIELD =
|
||||
DynFields.builder()
|
||||
.hiddenImpl(TaskSchedulerImpl.class, "taskIdToTaskSetManager")
|
||||
.defaultAlwaysNull()
|
||||
.build();
|
||||
private static final DynFields.UnboundField<scala.collection.mutable.HashMap<Long, TaskInfo>>
|
||||
TASK_INFOS_FIELD =
|
||||
DynFields.builder()
|
||||
.hiddenImpl(TaskSetManager.class, "taskInfos")
|
||||
.defaultAlwaysNull()
|
||||
.build();
|
||||
|
||||
/**
|
||||
* TaskSetManager - it is not designed to be used outside the spark scheduler. Please be careful.
|
||||
*/
|
||||
@VisibleForTesting
|
||||
protected static TaskSetManager getTaskSetManager(TaskSchedulerImpl taskScheduler, long taskId) {
|
||||
synchronized (taskScheduler) {
|
||||
ConcurrentHashMap<Long, TaskSetManager> taskIdToTaskSetManager =
|
||||
TASK_ID_TO_TASK_SET_MANAGER_FIELD.bind(taskScheduler).get();
|
||||
return taskIdToTaskSetManager.get(taskId);
|
||||
}
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
protected static Tuple2<TaskInfo, List<TaskInfo>> getTaskAttempts(
|
||||
TaskSetManager taskSetManager, long taskId) {
|
||||
if (taskSetManager != null) {
|
||||
scala.Option<TaskInfo> taskInfoOption =
|
||||
TASK_INFOS_FIELD.bind(taskSetManager).get().get(taskId);
|
||||
if (taskInfoOption.isDefined()) {
|
||||
TaskInfo taskInfo = taskInfoOption.get();
|
||||
List<TaskInfo> taskAttempts =
|
||||
scala.collection.JavaConverters.asJavaCollectionConverter(
|
||||
taskSetManager.taskAttempts()[taskInfo.index()])
|
||||
.asJavaCollection().stream()
|
||||
.collect(Collectors.toList());
|
||||
return Tuple2.apply(taskInfo, taskAttempts);
|
||||
} else {
|
||||
logger.error("Can not get TaskInfo for taskId: {}", taskId);
|
||||
return null;
|
||||
}
|
||||
} else {
|
||||
logger.error("Can not get TaskSetManager for taskId: {}", taskId);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
protected static Map<String, Set<Long>> reportedStageShuffleFetchFailureTaskIds =
|
||||
JavaUtils.newConcurrentHashMap();
|
||||
|
||||
protected static void removeStageReportedShuffleFetchFailureTaskIds(
|
||||
int stageId, int stageAttemptId) {
|
||||
reportedStageShuffleFetchFailureTaskIds.remove(stageId + "-" + stageAttemptId);
|
||||
}
|
||||
|
||||
/**
|
||||
* Only used to check for the shuffle fetch failure task whether another attempt is running or
|
||||
* successful. If another attempt(excluding the reported shuffle fetch failure tasks in current
|
||||
* stage) is running or successful, return true. Otherwise, return false.
|
||||
*/
|
||||
public static boolean taskAnotherAttemptRunningOrSuccessful(long taskId) {
|
||||
SparkContext sparkContext = SparkContext$.MODULE$.getActive().getOrElse(null);
|
||||
if (sparkContext == null) {
|
||||
logger.error("Can not get active SparkContext.");
|
||||
return false;
|
||||
}
|
||||
TaskSchedulerImpl taskScheduler = (TaskSchedulerImpl) sparkContext.taskScheduler();
|
||||
synchronized (taskScheduler) {
|
||||
TaskSetManager taskSetManager = getTaskSetManager(taskScheduler, taskId);
|
||||
if (taskSetManager != null) {
|
||||
int stageId = taskSetManager.stageId();
|
||||
int stageAttemptId = taskSetManager.taskSet().stageAttemptId();
|
||||
String stageUniqId = stageId + "-" + stageAttemptId;
|
||||
Set<Long> reportedStageTaskIds =
|
||||
reportedStageShuffleFetchFailureTaskIds.computeIfAbsent(
|
||||
stageUniqId, k -> new HashSet<>());
|
||||
reportedStageTaskIds.add(taskId);
|
||||
|
||||
Tuple2<TaskInfo, List<TaskInfo>> taskAttempts = getTaskAttempts(taskSetManager, taskId);
|
||||
|
||||
if (taskAttempts == null) return false;
|
||||
|
||||
TaskInfo taskInfo = taskAttempts._1();
|
||||
for (TaskInfo ti : taskAttempts._2()) {
|
||||
if (ti.taskId() != taskId) {
|
||||
if (reportedStageTaskIds.contains(ti.taskId())) {
|
||||
logger.info(
|
||||
"StageId={} index={} taskId={} attempt={} another attempt {} has reported shuffle fetch failure, ignore it.",
|
||||
stageId,
|
||||
taskInfo.index(),
|
||||
taskId,
|
||||
taskInfo.attemptNumber(),
|
||||
ti.attemptNumber());
|
||||
} else if (ti.successful()) {
|
||||
logger.info(
|
||||
"StageId={} index={} taskId={} attempt={} another attempt {} is successful.",
|
||||
stageId,
|
||||
taskInfo.index(),
|
||||
taskId,
|
||||
taskInfo.attemptNumber(),
|
||||
ti.attemptNumber());
|
||||
return true;
|
||||
} else if (ti.running()) {
|
||||
logger.info(
|
||||
"StageId={} index={} taskId={} attempt={} another attempt {} is running.",
|
||||
stageId,
|
||||
taskInfo.index(),
|
||||
taskId,
|
||||
taskInfo.attemptNumber(),
|
||||
ti.attemptNumber());
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
} else {
|
||||
logger.error("Can not get TaskSetManager for taskId: {}", taskId);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public static void addSparkListener(SparkListener listener) {
|
||||
SparkContext sparkContext = SparkContext$.MODULE$.getActive().getOrElse(null);
|
||||
if (sparkContext != null) {
|
||||
sparkContext.addSparkListener(listener);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -98,6 +98,7 @@ class CelebornShuffleReader[K, C](
|
||||
shuffleId,
|
||||
partitionId,
|
||||
encodedAttemptId,
|
||||
context.taskAttemptId(),
|
||||
startMapIndex,
|
||||
endMapIndex,
|
||||
metricsCallback)
|
||||
@ -124,7 +125,10 @@ class CelebornShuffleReader[K, C](
|
||||
exceptionRef.get() match {
|
||||
case ce @ (_: CelebornIOException | _: PartitionUnRetryAbleException) =>
|
||||
if (handle.throwsFetchFailure &&
|
||||
shuffleClient.reportShuffleFetchFailure(handle.shuffleId, shuffleId)) {
|
||||
shuffleClient.reportShuffleFetchFailure(
|
||||
handle.shuffleId,
|
||||
shuffleId,
|
||||
context.taskAttemptId())) {
|
||||
throw new FetchFailedException(
|
||||
null,
|
||||
handle.shuffleId,
|
||||
@ -158,7 +162,10 @@ class CelebornShuffleReader[K, C](
|
||||
} catch {
|
||||
case e @ (_: CelebornIOException | _: PartitionUnRetryAbleException) =>
|
||||
if (handle.throwsFetchFailure &&
|
||||
shuffleClient.reportShuffleFetchFailure(handle.shuffleId, shuffleId)) {
|
||||
shuffleClient.reportShuffleFetchFailure(
|
||||
handle.shuffleId,
|
||||
shuffleId,
|
||||
context.taskAttemptId())) {
|
||||
throw new FetchFailedException(
|
||||
null,
|
||||
handle.shuffleId,
|
||||
|
||||
@ -0,0 +1,28 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.shuffle.celeborn
|
||||
|
||||
import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted}
|
||||
|
||||
class ShuffleFetchFailureReportTaskCleanListener extends SparkListener {
|
||||
override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = {
|
||||
SparkUtils.removeStageReportedShuffleFetchFailureTaskIds(
|
||||
stageCompleted.stageInfo.stageId,
|
||||
stageCompleted.stageInfo.attemptNumber())
|
||||
}
|
||||
}
|
||||
@ -144,6 +144,10 @@ public class SparkShuffleManager implements ShuffleManager {
|
||||
MapOutputTrackerMaster mapOutputTracker =
|
||||
(MapOutputTrackerMaster) SparkEnv.get().mapOutputTracker();
|
||||
|
||||
lifecycleManager.registerReportTaskShuffleFetchFailurePreCheck(
|
||||
taskId -> !SparkUtils.taskAnotherAttemptRunningOrSuccessful(taskId));
|
||||
SparkUtils.addSparkListener(new ShuffleFetchFailureReportTaskCleanListener());
|
||||
|
||||
lifecycleManager.registerShuffleTrackerCallback(
|
||||
shuffleId -> SparkUtils.unregisterAllMapOutput(mapOutputTracker, shuffleId));
|
||||
}
|
||||
|
||||
@ -17,12 +17,19 @@
|
||||
|
||||
package org.apache.spark.shuffle.celeborn;
|
||||
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.atomic.LongAdder;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import scala.Option;
|
||||
import scala.Some;
|
||||
import scala.Tuple2;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import org.apache.spark.BarrierTaskContext;
|
||||
import org.apache.spark.MapOutputTrackerMaster;
|
||||
import org.apache.spark.SparkConf;
|
||||
@ -33,6 +40,10 @@ import org.apache.spark.scheduler.DAGScheduler;
|
||||
import org.apache.spark.scheduler.MapStatus;
|
||||
import org.apache.spark.scheduler.MapStatus$;
|
||||
import org.apache.spark.scheduler.ShuffleMapStage;
|
||||
import org.apache.spark.scheduler.SparkListener;
|
||||
import org.apache.spark.scheduler.TaskInfo;
|
||||
import org.apache.spark.scheduler.TaskSchedulerImpl;
|
||||
import org.apache.spark.scheduler.TaskSetManager;
|
||||
import org.apache.spark.shuffle.ShuffleHandle;
|
||||
import org.apache.spark.shuffle.ShuffleReadMetricsReporter;
|
||||
import org.apache.spark.shuffle.ShuffleReader;
|
||||
@ -46,6 +57,7 @@ import org.slf4j.LoggerFactory;
|
||||
|
||||
import org.apache.celeborn.client.ShuffleClient;
|
||||
import org.apache.celeborn.common.CelebornConf;
|
||||
import org.apache.celeborn.common.util.JavaUtils;
|
||||
import org.apache.celeborn.reflect.DynConstructors;
|
||||
import org.apache.celeborn.reflect.DynFields;
|
||||
import org.apache.celeborn.reflect.DynMethods;
|
||||
@ -319,4 +331,135 @@ public class SparkUtils {
|
||||
LOG.error("Can not get active SparkContext, skip cancelShuffle.");
|
||||
}
|
||||
}
|
||||
|
||||
private static final DynFields.UnboundField<ConcurrentHashMap<Long, TaskSetManager>>
|
||||
TASK_ID_TO_TASK_SET_MANAGER_FIELD =
|
||||
DynFields.builder()
|
||||
.hiddenImpl(TaskSchedulerImpl.class, "taskIdToTaskSetManager")
|
||||
.defaultAlwaysNull()
|
||||
.build();
|
||||
private static final DynFields.UnboundField<scala.collection.mutable.HashMap<Long, TaskInfo>>
|
||||
TASK_INFOS_FIELD =
|
||||
DynFields.builder()
|
||||
.hiddenImpl(TaskSetManager.class, "taskInfos")
|
||||
.defaultAlwaysNull()
|
||||
.build();
|
||||
|
||||
/**
|
||||
* TaskSetManager - it is not designed to be used outside the spark scheduler. Please be careful.
|
||||
*/
|
||||
@VisibleForTesting
|
||||
protected static TaskSetManager getTaskSetManager(TaskSchedulerImpl taskScheduler, long taskId) {
|
||||
synchronized (taskScheduler) {
|
||||
ConcurrentHashMap<Long, TaskSetManager> taskIdToTaskSetManager =
|
||||
TASK_ID_TO_TASK_SET_MANAGER_FIELD.bind(taskScheduler).get();
|
||||
return taskIdToTaskSetManager.get(taskId);
|
||||
}
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
protected static Tuple2<TaskInfo, List<TaskInfo>> getTaskAttempts(
|
||||
TaskSetManager taskSetManager, long taskId) {
|
||||
if (taskSetManager != null) {
|
||||
scala.Option<TaskInfo> taskInfoOption =
|
||||
TASK_INFOS_FIELD.bind(taskSetManager).get().get(taskId);
|
||||
if (taskInfoOption.isDefined()) {
|
||||
TaskInfo taskInfo = taskInfoOption.get();
|
||||
List<TaskInfo> taskAttempts =
|
||||
scala.collection.JavaConverters.asJavaCollectionConverter(
|
||||
taskSetManager.taskAttempts()[taskInfo.index()])
|
||||
.asJavaCollection().stream()
|
||||
.collect(Collectors.toList());
|
||||
return Tuple2.apply(taskInfo, taskAttempts);
|
||||
} else {
|
||||
LOG.error("Can not get TaskInfo for taskId: {}", taskId);
|
||||
return null;
|
||||
}
|
||||
} else {
|
||||
LOG.error("Can not get TaskSetManager for taskId: {}", taskId);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
protected static Map<String, Set<Long>> reportedStageShuffleFetchFailureTaskIds =
|
||||
JavaUtils.newConcurrentHashMap();
|
||||
|
||||
protected static void removeStageReportedShuffleFetchFailureTaskIds(
|
||||
int stageId, int stageAttemptId) {
|
||||
reportedStageShuffleFetchFailureTaskIds.remove(stageId + "-" + stageAttemptId);
|
||||
}
|
||||
|
||||
/**
|
||||
* Only used to check for the shuffle fetch failure task whether another attempt is running or
|
||||
* successful. If another attempt(excluding the reported shuffle fetch failure tasks in current
|
||||
* stage) is running or successful, return true. Otherwise, return false.
|
||||
*/
|
||||
public static boolean taskAnotherAttemptRunningOrSuccessful(long taskId) {
|
||||
SparkContext sparkContext = SparkContext$.MODULE$.getActive().getOrElse(null);
|
||||
if (sparkContext == null) {
|
||||
LOG.error("Can not get active SparkContext.");
|
||||
return false;
|
||||
}
|
||||
TaskSchedulerImpl taskScheduler = (TaskSchedulerImpl) sparkContext.taskScheduler();
|
||||
synchronized (taskScheduler) {
|
||||
TaskSetManager taskSetManager = getTaskSetManager(taskScheduler, taskId);
|
||||
if (taskSetManager != null) {
|
||||
int stageId = taskSetManager.stageId();
|
||||
int stageAttemptId = taskSetManager.taskSet().stageAttemptId();
|
||||
String stageUniqId = stageId + "-" + stageAttemptId;
|
||||
Set<Long> reportedStageTaskIds =
|
||||
reportedStageShuffleFetchFailureTaskIds.computeIfAbsent(
|
||||
stageUniqId, k -> new HashSet<>());
|
||||
reportedStageTaskIds.add(taskId);
|
||||
|
||||
Tuple2<TaskInfo, List<TaskInfo>> taskAttempts = getTaskAttempts(taskSetManager, taskId);
|
||||
|
||||
if (taskAttempts == null) return false;
|
||||
|
||||
TaskInfo taskInfo = taskAttempts._1();
|
||||
for (TaskInfo ti : taskAttempts._2()) {
|
||||
if (ti.taskId() != taskId) {
|
||||
if (reportedStageTaskIds.contains(ti.taskId())) {
|
||||
LOG.info(
|
||||
"StageId={} index={} taskId={} attempt={} another attempt {} has reported shuffle fetch failure, ignore it.",
|
||||
stageId,
|
||||
taskInfo.index(),
|
||||
taskId,
|
||||
taskInfo.attemptNumber(),
|
||||
ti.attemptNumber());
|
||||
} else if (ti.successful()) {
|
||||
LOG.info(
|
||||
"StageId={} index={} taskId={} attempt={} another attempt {} is successful.",
|
||||
stageId,
|
||||
taskInfo.index(),
|
||||
taskId,
|
||||
taskInfo.attemptNumber(),
|
||||
ti.attemptNumber());
|
||||
return true;
|
||||
} else if (ti.running()) {
|
||||
LOG.info(
|
||||
"StageId={} index={} taskId={} attempt={} another attempt {} is running.",
|
||||
stageId,
|
||||
taskInfo.index(),
|
||||
taskId,
|
||||
taskInfo.attemptNumber(),
|
||||
ti.attemptNumber());
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
} else {
|
||||
LOG.error("Can not get TaskSetManager for taskId: {}", taskId);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public static void addSparkListener(SparkListener listener) {
|
||||
SparkContext sparkContext = SparkContext$.MODULE$.getActive().getOrElse(null);
|
||||
if (sparkContext != null) {
|
||||
sparkContext.addSparkListener(listener);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -215,6 +215,7 @@ class CelebornShuffleReader[K, C](
|
||||
handle.shuffleId,
|
||||
partitionId,
|
||||
encodedAttemptId,
|
||||
context.taskAttemptId(),
|
||||
startMapIndex,
|
||||
endMapIndex,
|
||||
if (throwsFetchFailure) ExceptionMakerHelper.SHUFFLE_FETCH_FAILURE_EXCEPTION_MAKER
|
||||
@ -375,7 +376,7 @@ class CelebornShuffleReader[K, C](
|
||||
partitionId: Int,
|
||||
ce: Throwable) = {
|
||||
if (throwsFetchFailure &&
|
||||
shuffleClient.reportShuffleFetchFailure(appShuffleId, shuffleId)) {
|
||||
shuffleClient.reportShuffleFetchFailure(appShuffleId, shuffleId, context.taskAttemptId())) {
|
||||
logWarning(s"Handle fetch exceptions for ${shuffleId}-${partitionId}", ce)
|
||||
throw new FetchFailedException(
|
||||
null,
|
||||
|
||||
@ -0,0 +1,28 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.shuffle.celeborn
|
||||
|
||||
import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted}
|
||||
|
||||
class ShuffleFetchFailureReportTaskCleanListener extends SparkListener {
|
||||
override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = {
|
||||
SparkUtils.removeStageReportedShuffleFetchFailureTaskIds(
|
||||
stageCompleted.stageInfo.stageId,
|
||||
stageCompleted.stageInfo.attemptNumber())
|
||||
}
|
||||
}
|
||||
@ -224,6 +224,7 @@ public abstract class ShuffleClient {
|
||||
int shuffleId,
|
||||
int partitionId,
|
||||
int attemptNumber,
|
||||
long taskId,
|
||||
int startMapIndex,
|
||||
int endMapIndex,
|
||||
MetricsCallback metricsCallback)
|
||||
@ -233,6 +234,7 @@ public abstract class ShuffleClient {
|
||||
shuffleId,
|
||||
partitionId,
|
||||
attemptNumber,
|
||||
taskId,
|
||||
startMapIndex,
|
||||
endMapIndex,
|
||||
null,
|
||||
@ -247,6 +249,7 @@ public abstract class ShuffleClient {
|
||||
int appShuffleId,
|
||||
int partitionId,
|
||||
int attemptNumber,
|
||||
long taskId,
|
||||
int startMapIndex,
|
||||
int endMapIndex,
|
||||
ExceptionMaker exceptionMaker,
|
||||
@ -276,7 +279,7 @@ public abstract class ShuffleClient {
|
||||
* cleanup for spark app. It must be a sync call and make sure the cleanup is done, otherwise,
|
||||
* incorrect shuffle data can be fetched in re-run tasks
|
||||
*/
|
||||
public abstract boolean reportShuffleFetchFailure(int appShuffleId, int shuffleId);
|
||||
public abstract boolean reportShuffleFetchFailure(int appShuffleId, int shuffleId, long taskId);
|
||||
|
||||
/**
|
||||
* Report barrier task failure. When any barrier task fails, all (shuffle) output for that stage
|
||||
|
||||
@ -630,11 +630,12 @@ public class ShuffleClientImpl extends ShuffleClient {
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean reportShuffleFetchFailure(int appShuffleId, int shuffleId) {
|
||||
public boolean reportShuffleFetchFailure(int appShuffleId, int shuffleId, long taskId) {
|
||||
PbReportShuffleFetchFailure pbReportShuffleFetchFailure =
|
||||
PbReportShuffleFetchFailure.newBuilder()
|
||||
.setAppShuffleId(appShuffleId)
|
||||
.setShuffleId(shuffleId)
|
||||
.setTaskId(taskId)
|
||||
.build();
|
||||
PbReportShuffleFetchFailureResponse pbReportShuffleFetchFailureResponse =
|
||||
lifecycleManagerRef.askSync(
|
||||
@ -1845,6 +1846,7 @@ public class ShuffleClientImpl extends ShuffleClient {
|
||||
int appShuffleId,
|
||||
int partitionId,
|
||||
int attemptNumber,
|
||||
long taskId,
|
||||
int startMapIndex,
|
||||
int endMapIndex,
|
||||
ExceptionMaker exceptionMaker,
|
||||
@ -1883,6 +1885,7 @@ public class ShuffleClientImpl extends ShuffleClient {
|
||||
streamHandlers,
|
||||
mapAttempts,
|
||||
attemptNumber,
|
||||
taskId,
|
||||
startMapIndex,
|
||||
endMapIndex,
|
||||
fetchExcludedWorkers,
|
||||
|
||||
@ -55,6 +55,7 @@ public abstract class CelebornInputStream extends InputStream {
|
||||
ArrayList<PbStreamHandler> streamHandlers,
|
||||
int[] attempts,
|
||||
int attemptNumber,
|
||||
long taskId,
|
||||
int startMapIndex,
|
||||
int endMapIndex,
|
||||
ConcurrentHashMap<String, Long> fetchExcludedWorkers,
|
||||
@ -76,6 +77,7 @@ public abstract class CelebornInputStream extends InputStream {
|
||||
streamHandlers,
|
||||
attempts,
|
||||
attemptNumber,
|
||||
taskId,
|
||||
startMapIndex,
|
||||
endMapIndex,
|
||||
fetchExcludedWorkers,
|
||||
@ -129,6 +131,7 @@ public abstract class CelebornInputStream extends InputStream {
|
||||
private ArrayList<PbStreamHandler> streamHandlers;
|
||||
private int[] attempts;
|
||||
private final int attemptNumber;
|
||||
private final long taskId;
|
||||
private final int startMapIndex;
|
||||
private final int endMapIndex;
|
||||
|
||||
@ -178,6 +181,7 @@ public abstract class CelebornInputStream extends InputStream {
|
||||
ArrayList<PbStreamHandler> streamHandlers,
|
||||
int[] attempts,
|
||||
int attemptNumber,
|
||||
long taskId,
|
||||
int startMapIndex,
|
||||
int endMapIndex,
|
||||
ConcurrentHashMap<String, Long> fetchExcludedWorkers,
|
||||
@ -197,6 +201,7 @@ public abstract class CelebornInputStream extends InputStream {
|
||||
}
|
||||
this.attempts = attempts;
|
||||
this.attemptNumber = attemptNumber;
|
||||
this.taskId = taskId;
|
||||
this.startMapIndex = startMapIndex;
|
||||
this.endMapIndex = endMapIndex;
|
||||
this.rangeReadFilter = conf.shuffleRangeReadFilterEnabled();
|
||||
@ -654,7 +659,7 @@ public abstract class CelebornInputStream extends InputStream {
|
||||
ioe = new IOException(e);
|
||||
}
|
||||
if (exceptionMaker != null) {
|
||||
if (shuffleClient.reportShuffleFetchFailure(appShuffleId, shuffleId)) {
|
||||
if (shuffleClient.reportShuffleFetchFailure(appShuffleId, shuffleId, taskId)) {
|
||||
/*
|
||||
* [[ExceptionMaker.makeException]], for spark applications with celeborn.client.spark.fetch.throwsFetchFailure enabled will result in creating
|
||||
* a FetchFailedException; and that will make the TaskContext as failed with shuffle fetch issues - see SPARK-19276 for more.
|
||||
|
||||
@ -445,8 +445,9 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
|
||||
case pb: PbReportShuffleFetchFailure =>
|
||||
val appShuffleId = pb.getAppShuffleId
|
||||
val shuffleId = pb.getShuffleId
|
||||
logDebug(s"Received ReportShuffleFetchFailure request, appShuffleId $appShuffleId shuffleId $shuffleId")
|
||||
handleReportShuffleFetchFailure(context, appShuffleId, shuffleId)
|
||||
val taskId = pb.getTaskId
|
||||
logDebug(s"Received ReportShuffleFetchFailure request, appShuffleId $appShuffleId shuffleId $shuffleId taskId $taskId")
|
||||
handleReportShuffleFetchFailure(context, appShuffleId, shuffleId, taskId)
|
||||
|
||||
case pb: PbReportBarrierStageAttemptFailure =>
|
||||
val appShuffleId = pb.getAppShuffleId
|
||||
@ -935,7 +936,8 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
|
||||
private def handleReportShuffleFetchFailure(
|
||||
context: RpcCallContext,
|
||||
appShuffleId: Int,
|
||||
shuffleId: Int): Unit = {
|
||||
shuffleId: Int,
|
||||
taskId: Long): Unit = {
|
||||
|
||||
val shuffleIds = shuffleIdMapping.get(appShuffleId)
|
||||
if (shuffleIds == null) {
|
||||
@ -945,9 +947,15 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
|
||||
shuffleIds.synchronized {
|
||||
shuffleIds.find(e => e._2._1 == shuffleId) match {
|
||||
case Some((appShuffleIdentifier, (shuffleId, true))) =>
|
||||
logInfo(s"handle fetch failure for appShuffleId $appShuffleId shuffleId $shuffleId")
|
||||
ret = invokeAppShuffleTrackerCallback(appShuffleId)
|
||||
shuffleIds.put(appShuffleIdentifier, (shuffleId, false))
|
||||
if (invokeReportTaskShuffleFetchFailurePreCheck(taskId)) {
|
||||
logInfo(s"handle fetch failure for appShuffleId $appShuffleId shuffleId $shuffleId")
|
||||
ret = invokeAppShuffleTrackerCallback(appShuffleId)
|
||||
shuffleIds.put(appShuffleIdentifier, (shuffleId, false))
|
||||
} else {
|
||||
logInfo(
|
||||
s"Ignoring fetch failure from appShuffleIdentifier $appShuffleIdentifier shuffleId $shuffleId taskId $taskId")
|
||||
ret = false
|
||||
}
|
||||
case Some((appShuffleIdentifier, (shuffleId, false))) =>
|
||||
logInfo(
|
||||
s"Ignoring fetch failure from appShuffleIdentifier $appShuffleIdentifier shuffleId $shuffleId, " +
|
||||
@ -1010,6 +1018,20 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
|
||||
}
|
||||
}
|
||||
|
||||
private def invokeReportTaskShuffleFetchFailurePreCheck(taskId: Long): Boolean = {
|
||||
reportTaskShuffleFetchFailurePreCheck match {
|
||||
case Some(preCheck) =>
|
||||
try {
|
||||
preCheck.apply(taskId)
|
||||
} catch {
|
||||
case t: Throwable =>
|
||||
logError(s"Error preChecking the shuffle fetch failure reported by task: $taskId", t)
|
||||
false
|
||||
}
|
||||
case None => true
|
||||
}
|
||||
}
|
||||
|
||||
private def handleStageEnd(shuffleId: Int): Unit = {
|
||||
// check whether shuffle has registered
|
||||
if (!registeredShuffle.contains(shuffleId)) {
|
||||
@ -1770,6 +1792,14 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
|
||||
workerStatusTracker.registerWorkerStatusListener(workerStatusListener)
|
||||
}
|
||||
|
||||
@volatile private var reportTaskShuffleFetchFailurePreCheck
|
||||
: Option[java.util.function.Function[java.lang.Long, Boolean]] = None
|
||||
def registerReportTaskShuffleFetchFailurePreCheck(preCheck: java.util.function.Function[
|
||||
java.lang.Long,
|
||||
Boolean]): Unit = {
|
||||
reportTaskShuffleFetchFailurePreCheck = Some(preCheck)
|
||||
}
|
||||
|
||||
@volatile private var appShuffleTrackerCallback: Option[Consumer[Integer]] = None
|
||||
def registerShuffleTrackerCallback(callback: Consumer[Integer]): Unit = {
|
||||
appShuffleTrackerCallback = Some(callback)
|
||||
|
||||
@ -130,6 +130,7 @@ public class DummyShuffleClient extends ShuffleClient {
|
||||
int appShuffleId,
|
||||
int partitionId,
|
||||
int attemptNumber,
|
||||
long taskId,
|
||||
int startMapIndex,
|
||||
int endMapIndex,
|
||||
ExceptionMaker exceptionMaker,
|
||||
@ -179,7 +180,7 @@ public class DummyShuffleClient extends ShuffleClient {
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean reportShuffleFetchFailure(int appShuffleId, int shuffleId) {
|
||||
public boolean reportShuffleFetchFailure(int appShuffleId, int shuffleId, long taskId) {
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
@ -158,6 +158,7 @@ trait WithShuffleClientSuite extends CelebornFunSuite {
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
Integer.MAX_VALUE,
|
||||
null,
|
||||
null,
|
||||
@ -173,6 +174,7 @@ trait WithShuffleClientSuite extends CelebornFunSuite {
|
||||
3,
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
Integer.MAX_VALUE,
|
||||
null,
|
||||
null,
|
||||
|
||||
@ -391,6 +391,7 @@ message PbGetShuffleIdResponse {
|
||||
message PbReportShuffleFetchFailure {
|
||||
int32 appShuffleId = 1;
|
||||
int32 shuffleId = 2;
|
||||
int64 taskId = 3;
|
||||
}
|
||||
|
||||
message PbReportShuffleFetchFailureResponse {
|
||||
|
||||
@ -17,22 +17,19 @@
|
||||
|
||||
package org.apache.celeborn.tests.spark
|
||||
|
||||
import java.io.{File, IOException}
|
||||
import java.io.IOException
|
||||
import java.util.concurrent.atomic.AtomicBoolean
|
||||
|
||||
import org.apache.spark.{BarrierTaskContext, ShuffleDependency, SparkConf, SparkContextHelper, SparkException, TaskContext}
|
||||
import org.apache.spark.celeborn.ExceptionMakerHelper
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.shuffle.ShuffleHandle
|
||||
import org.apache.spark.shuffle.celeborn.{CelebornShuffleHandle, ShuffleManagerHook, SparkShuffleManager, SparkUtils, TestCelebornShuffleManager}
|
||||
import org.apache.spark.shuffle.celeborn.{SparkShuffleManager, SparkUtils, TestCelebornShuffleManager}
|
||||
import org.apache.spark.sql.SparkSession
|
||||
import org.scalatest.BeforeAndAfterEach
|
||||
import org.scalatest.funsuite.AnyFunSuite
|
||||
|
||||
import org.apache.celeborn.client.ShuffleClient
|
||||
import org.apache.celeborn.common.CelebornConf
|
||||
import org.apache.celeborn.common.protocol.ShuffleMode
|
||||
import org.apache.celeborn.service.deploy.worker.Worker
|
||||
|
||||
class CelebornFetchFailureSuite extends AnyFunSuite
|
||||
with SparkTestBase
|
||||
@ -46,57 +43,6 @@ class CelebornFetchFailureSuite extends AnyFunSuite
|
||||
System.gc()
|
||||
}
|
||||
|
||||
var workerDirs: Seq[String] = Seq.empty
|
||||
|
||||
override def createWorker(map: Map[String, String]): Worker = {
|
||||
val storageDir = createTmpDir()
|
||||
this.synchronized {
|
||||
workerDirs = workerDirs :+ storageDir
|
||||
}
|
||||
super.createWorker(map, storageDir)
|
||||
}
|
||||
|
||||
class ShuffleReaderGetHook(conf: CelebornConf) extends ShuffleManagerHook {
|
||||
var executed: AtomicBoolean = new AtomicBoolean(false)
|
||||
val lock = new Object
|
||||
|
||||
override def exec(
|
||||
handle: ShuffleHandle,
|
||||
startPartition: Int,
|
||||
endPartition: Int,
|
||||
context: TaskContext): Unit = {
|
||||
if (executed.get() == true) return
|
||||
|
||||
lock.synchronized {
|
||||
handle match {
|
||||
case h: CelebornShuffleHandle[_, _, _] => {
|
||||
val appUniqueId = h.appUniqueId
|
||||
val shuffleClient = ShuffleClient.get(
|
||||
h.appUniqueId,
|
||||
h.lifecycleManagerHost,
|
||||
h.lifecycleManagerPort,
|
||||
conf,
|
||||
h.userIdentifier,
|
||||
h.extension)
|
||||
val celebornShuffleId = SparkUtils.celebornShuffleId(shuffleClient, h, context, false)
|
||||
val allFiles = workerDirs.map(dir => {
|
||||
new File(s"$dir/celeborn-worker/shuffle_data/$appUniqueId/$celebornShuffleId")
|
||||
})
|
||||
val datafile = allFiles.filter(_.exists())
|
||||
.flatMap(_.listFiles().iterator).headOption
|
||||
datafile match {
|
||||
case Some(file) => file.delete()
|
||||
case None => throw new RuntimeException("unexpected, there must be some data file" +
|
||||
s" under ${workerDirs.mkString(",")}")
|
||||
}
|
||||
}
|
||||
case _ => throw new RuntimeException("unexpected, only support RssShuffleHandle here")
|
||||
}
|
||||
executed.set(true)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("celeborn spark integration test - Fetch Failure") {
|
||||
if (Spark3OrNewer) {
|
||||
val sparkConf = new SparkConf().setAppName("rss-demo").setMaster("local[2,3]")
|
||||
@ -111,7 +57,7 @@ class CelebornFetchFailureSuite extends AnyFunSuite
|
||||
.getOrCreate()
|
||||
|
||||
val celebornConf = SparkUtils.fromSparkConf(sparkSession.sparkContext.getConf)
|
||||
val hook = new ShuffleReaderGetHook(celebornConf)
|
||||
val hook = new ShuffleReaderFetchFailureGetHook(celebornConf)
|
||||
TestCelebornShuffleManager.registerReaderGetHook(hook)
|
||||
|
||||
val value = Range(1, 10000).mkString(",")
|
||||
@ -184,7 +130,7 @@ class CelebornFetchFailureSuite extends AnyFunSuite
|
||||
.getOrCreate()
|
||||
|
||||
val celebornConf = SparkUtils.fromSparkConf(sparkSession.sparkContext.getConf)
|
||||
val hook = new ShuffleReaderGetHook(celebornConf)
|
||||
val hook = new ShuffleReaderFetchFailureGetHook(celebornConf)
|
||||
TestCelebornShuffleManager.registerReaderGetHook(hook)
|
||||
|
||||
import sparkSession.implicits._
|
||||
@ -215,7 +161,7 @@ class CelebornFetchFailureSuite extends AnyFunSuite
|
||||
.getOrCreate()
|
||||
|
||||
val celebornConf = SparkUtils.fromSparkConf(sparkSession.sparkContext.getConf)
|
||||
val hook = new ShuffleReaderGetHook(celebornConf)
|
||||
val hook = new ShuffleReaderFetchFailureGetHook(celebornConf)
|
||||
TestCelebornShuffleManager.registerReaderGetHook(hook)
|
||||
|
||||
val sc = sparkSession.sparkContext
|
||||
@ -255,7 +201,7 @@ class CelebornFetchFailureSuite extends AnyFunSuite
|
||||
.getOrCreate()
|
||||
|
||||
val celebornConf = SparkUtils.fromSparkConf(sparkSession.sparkContext.getConf)
|
||||
val hook = new ShuffleReaderGetHook(celebornConf)
|
||||
val hook = new ShuffleReaderFetchFailureGetHook(celebornConf)
|
||||
TestCelebornShuffleManager.registerReaderGetHook(hook)
|
||||
|
||||
val sc = sparkSession.sparkContext
|
||||
|
||||
@ -17,19 +17,26 @@
|
||||
|
||||
package org.apache.celeborn.tests.spark
|
||||
|
||||
import java.io.File
|
||||
import java.util.concurrent.atomic.AtomicBoolean
|
||||
|
||||
import scala.util.Random
|
||||
|
||||
import org.apache.spark.SPARK_VERSION
|
||||
import org.apache.spark.SparkConf
|
||||
import org.apache.spark.{SPARK_VERSION, SparkConf, TaskContext}
|
||||
import org.apache.spark.shuffle.ShuffleHandle
|
||||
import org.apache.spark.shuffle.celeborn.{CelebornShuffleHandle, ShuffleManagerHook, SparkUtils}
|
||||
import org.apache.spark.sql.SparkSession
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}
|
||||
import org.scalatest.funsuite.AnyFunSuite
|
||||
|
||||
import org.apache.celeborn.client.ShuffleClient
|
||||
import org.apache.celeborn.common.CelebornConf
|
||||
import org.apache.celeborn.common.CelebornConf._
|
||||
import org.apache.celeborn.common.internal.Logging
|
||||
import org.apache.celeborn.common.protocol.ShuffleMode
|
||||
import org.apache.celeborn.service.deploy.MiniClusterFeature
|
||||
import org.apache.celeborn.service.deploy.worker.Worker
|
||||
|
||||
trait SparkTestBase extends AnyFunSuite
|
||||
with Logging with MiniClusterFeature with BeforeAndAfterAll with BeforeAndAfterEach {
|
||||
@ -52,6 +59,16 @@ trait SparkTestBase extends AnyFunSuite
|
||||
shutdownMiniCluster()
|
||||
}
|
||||
|
||||
var workerDirs: Seq[String] = Seq.empty
|
||||
|
||||
override def createWorker(map: Map[String, String]): Worker = {
|
||||
val storageDir = createTmpDir()
|
||||
this.synchronized {
|
||||
workerDirs = workerDirs :+ storageDir
|
||||
}
|
||||
super.createWorker(map, storageDir)
|
||||
}
|
||||
|
||||
def updateSparkConf(sparkConf: SparkConf, mode: ShuffleMode): SparkConf = {
|
||||
sparkConf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
|
||||
sparkConf.set(
|
||||
@ -98,4 +115,45 @@ trait SparkTestBase extends AnyFunSuite
|
||||
val outMap = result.collect().map(row => row.getString(0) -> row.getLong(1)).toMap
|
||||
outMap
|
||||
}
|
||||
|
||||
class ShuffleReaderFetchFailureGetHook(conf: CelebornConf) extends ShuffleManagerHook {
|
||||
var executed: AtomicBoolean = new AtomicBoolean(false)
|
||||
val lock = new Object
|
||||
|
||||
override def exec(
|
||||
handle: ShuffleHandle,
|
||||
startPartition: Int,
|
||||
endPartition: Int,
|
||||
context: TaskContext): Unit = {
|
||||
if (executed.get() == true) return
|
||||
|
||||
lock.synchronized {
|
||||
handle match {
|
||||
case h: CelebornShuffleHandle[_, _, _] => {
|
||||
val appUniqueId = h.appUniqueId
|
||||
val shuffleClient = ShuffleClient.get(
|
||||
h.appUniqueId,
|
||||
h.lifecycleManagerHost,
|
||||
h.lifecycleManagerPort,
|
||||
conf,
|
||||
h.userIdentifier,
|
||||
h.extension)
|
||||
val celebornShuffleId = SparkUtils.celebornShuffleId(shuffleClient, h, context, false)
|
||||
val allFiles = workerDirs.map(dir => {
|
||||
new File(s"$dir/celeborn-worker/shuffle_data/$appUniqueId/$celebornShuffleId")
|
||||
})
|
||||
val datafile = allFiles.filter(_.exists())
|
||||
.flatMap(_.listFiles().iterator).sortBy(_.getName).headOption
|
||||
datafile match {
|
||||
case Some(file) => file.delete()
|
||||
case None => throw new RuntimeException("unexpected, there must be some data file" +
|
||||
s" under ${workerDirs.mkString(",")}")
|
||||
}
|
||||
}
|
||||
case _ => throw new RuntimeException("unexpected, only support RssShuffleHandle here")
|
||||
}
|
||||
executed.set(true)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -0,0 +1,160 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.shuffle.celeborn
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import org.apache.spark.SparkConf
|
||||
import org.apache.spark.scheduler.TaskSchedulerImpl
|
||||
import org.apache.spark.sql.SparkSession
|
||||
import org.scalatest.BeforeAndAfterEach
|
||||
import org.scalatest.concurrent.Eventually.eventually
|
||||
import org.scalatest.concurrent.Futures.{interval, timeout}
|
||||
import org.scalatest.funsuite.AnyFunSuite
|
||||
import org.scalatest.time.SpanSugar.convertIntToGrainOfTime
|
||||
|
||||
import org.apache.celeborn.client.ShuffleClient
|
||||
import org.apache.celeborn.common.protocol.ShuffleMode
|
||||
import org.apache.celeborn.tests.spark.SparkTestBase
|
||||
|
||||
class SparkUtilsSuite extends AnyFunSuite
|
||||
with SparkTestBase
|
||||
with BeforeAndAfterEach {
|
||||
|
||||
override def beforeEach(): Unit = {
|
||||
ShuffleClient.reset()
|
||||
}
|
||||
|
||||
override def afterEach(): Unit = {
|
||||
System.gc()
|
||||
}
|
||||
|
||||
test("check if fetch failure task another attempt is running or successful") {
|
||||
if (Spark3OrNewer) {
|
||||
val sparkConf = new SparkConf().setAppName("rss-demo").setMaster("local[2,3]")
|
||||
val sparkSession = SparkSession.builder()
|
||||
.config(updateSparkConf(sparkConf, ShuffleMode.HASH))
|
||||
.config("spark.sql.shuffle.partitions", 2)
|
||||
.config("spark.celeborn.shuffle.forceFallback.partition.enabled", false)
|
||||
.config("spark.celeborn.client.spark.fetch.throwsFetchFailure", "true")
|
||||
.config(
|
||||
"spark.shuffle.manager",
|
||||
"org.apache.spark.shuffle.celeborn.TestCelebornShuffleManager")
|
||||
.getOrCreate()
|
||||
|
||||
val celebornConf = SparkUtils.fromSparkConf(sparkSession.sparkContext.getConf)
|
||||
val hook = new ShuffleReaderFetchFailureGetHook(celebornConf)
|
||||
TestCelebornShuffleManager.registerReaderGetHook(hook)
|
||||
|
||||
try {
|
||||
val sc = sparkSession.sparkContext
|
||||
val jobThread = new Thread {
|
||||
override def run(): Unit = {
|
||||
try {
|
||||
val value = Range(1, 10000).mkString(",")
|
||||
sc.parallelize(1 to 10000, 2)
|
||||
.map { i => (i, value) }
|
||||
.groupByKey(10)
|
||||
.mapPartitions { iter =>
|
||||
Thread.sleep(3000)
|
||||
iter
|
||||
}.collect()
|
||||
} catch {
|
||||
case _: InterruptedException =>
|
||||
}
|
||||
}
|
||||
}
|
||||
jobThread.start()
|
||||
|
||||
val taskScheduler = sc.taskScheduler.asInstanceOf[TaskSchedulerImpl]
|
||||
eventually(timeout(30.seconds), interval(0.milliseconds)) {
|
||||
assert(hook.executed.get() == true)
|
||||
val reportedTaskId =
|
||||
SparkUtils.reportedStageShuffleFetchFailureTaskIds.values().asScala.flatMap(
|
||||
_.asScala).head
|
||||
val taskSetManager = SparkUtils.getTaskSetManager(taskScheduler, reportedTaskId)
|
||||
assert(taskSetManager != null)
|
||||
assert(SparkUtils.getTaskAttempts(taskSetManager, reportedTaskId)._2.size() == 1)
|
||||
assert(!SparkUtils.taskAnotherAttemptRunningOrSuccessful(reportedTaskId))
|
||||
}
|
||||
|
||||
sparkSession.sparkContext.cancelAllJobs()
|
||||
|
||||
jobThread.interrupt()
|
||||
|
||||
eventually(timeout(3.seconds), interval(100.milliseconds)) {
|
||||
assert(SparkUtils.reportedStageShuffleFetchFailureTaskIds.size() == 0)
|
||||
}
|
||||
} finally {
|
||||
sparkSession.stop()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("getTaskSetManager/getTaskAttempts test") {
|
||||
val sparkConf = new SparkConf().setAppName("rss-demo").setMaster("local[2,3]")
|
||||
val sparkSession = SparkSession.builder()
|
||||
.config(updateSparkConf(sparkConf, ShuffleMode.HASH))
|
||||
.config("spark.sql.shuffle.partitions", 2)
|
||||
.config("spark.celeborn.shuffle.forceFallback.partition.enabled", false)
|
||||
.config("spark.celeborn.client.spark.fetch.throwsFetchFailure", "true")
|
||||
.config(
|
||||
"spark.shuffle.manager",
|
||||
"org.apache.spark.shuffle.celeborn.TestCelebornShuffleManager")
|
||||
.getOrCreate()
|
||||
|
||||
try {
|
||||
val sc = sparkSession.sparkContext
|
||||
val jobThread = new Thread {
|
||||
override def run(): Unit = {
|
||||
try {
|
||||
sc.parallelize(1 to 100, 2)
|
||||
.repartition(1)
|
||||
.mapPartitions { iter =>
|
||||
Thread.sleep(3000)
|
||||
iter
|
||||
}.collect()
|
||||
} catch {
|
||||
case _: InterruptedException =>
|
||||
}
|
||||
}
|
||||
}
|
||||
jobThread.start()
|
||||
|
||||
val taskScheduler = sc.taskScheduler.asInstanceOf[TaskSchedulerImpl]
|
||||
eventually(timeout(3.seconds), interval(100.milliseconds)) {
|
||||
val taskId = 0
|
||||
val taskSetManager = SparkUtils.getTaskSetManager(taskScheduler, taskId)
|
||||
assert(taskSetManager != null)
|
||||
assert(SparkUtils.getTaskAttempts(taskSetManager, taskId)._2.size() == 1)
|
||||
assert(!SparkUtils.taskAnotherAttemptRunningOrSuccessful(taskId))
|
||||
assert(SparkUtils.reportedStageShuffleFetchFailureTaskIds.size() == 1)
|
||||
}
|
||||
|
||||
sparkSession.sparkContext.cancelAllJobs()
|
||||
|
||||
jobThread.interrupt()
|
||||
|
||||
eventually(timeout(3.seconds), interval(100.milliseconds)) {
|
||||
assert(SparkUtils.reportedStageShuffleFetchFailureTaskIds.size() == 0)
|
||||
}
|
||||
} finally {
|
||||
sparkSession.stop()
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -110,6 +110,7 @@ trait ReadWriteTestBase extends AnyFunSuite
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
Integer.MAX_VALUE,
|
||||
null,
|
||||
null,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user