[BUG] Fix fetch incorrect data chunk (#926)

This commit is contained in:
Ethan Feng 2022-11-09 22:31:39 +08:00 committed by GitHub
parent 1b2ad16b94
commit 6f043f8ae9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 165 additions and 516 deletions

View File

@ -50,19 +50,25 @@ import org.apache.celeborn.common.util.Utils;
* file will not be too many. Each retry is actually a switch between Master and Slave. Therefore,
* each retry needs to create a new connection and reopen the file to generate the stream id.
*/
public class RetryingChunkClient {
private static final Logger logger = LoggerFactory.getLogger(RetryingChunkClient.class);
public class ChunkClient {
private static final Logger logger = LoggerFactory.getLogger(ChunkClient.class);
private static final ExecutorService executorService =
Executors.newCachedThreadPool(NettyUtils.createThreadFactory("fetch-chuck"));
private final ChunkReceivedCallback callback;
private final Replica[] replicas;
private final Replica replica;
private final long retryWaitMs;
private final int maxTries;
private volatile int numTries = 0;
private PartitionLocation location;
private int fetchFailedChunkIndex;
public RetryingChunkClient(
public PartitionLocation getLocation() {
return location;
}
public ChunkClient(
CelebornConf conf,
String shuffleKey,
PartitionLocation location,
@ -71,7 +77,7 @@ public class RetryingChunkClient {
this(conf, shuffleKey, location, callback, clientFactory, 0, Integer.MAX_VALUE);
}
public RetryingChunkClient(
public ChunkClient(
CelebornConf conf,
String shuffleKey,
PartitionLocation location,
@ -81,30 +87,22 @@ public class RetryingChunkClient {
int endMapIndex) {
TransportConf transportConf =
Utils.fromCelebornConf(conf, TransportModuleConstants.DATA_MODULE, 0);
this.fetchFailedChunkIndex = conf.testFetchFailedChunkIndex();
this.callback = callback;
this.retryWaitMs = transportConf.ioRetryWaitTimeMs();
long fetchTimeoutMs = conf.fetchTimeoutMs();
this.location = location;
if (location == null) {
throw new IllegalArgumentException("Must contain at least one available PartitionLocation.");
} else {
Replica main =
replica =
new Replica(
fetchTimeoutMs, shuffleKey, location, clientFactory, startMapIndex, endMapIndex);
PartitionLocation peerLoc = location.getPeer();
if (peerLoc == null) {
replicas = new Replica[] {main};
} else {
Replica peer =
new Replica(
fetchTimeoutMs, shuffleKey, peerLoc, clientFactory, startMapIndex, endMapIndex);
replicas = new Replica[] {main, peer};
}
}
this.maxTries = (transportConf.maxIORetries() + 1) * replicas.length;
this.maxTries = (transportConf.maxIORetries() + 1);
}
/**
@ -115,29 +113,26 @@ public class RetryingChunkClient {
*/
public synchronized int openChunks() throws IOException {
int numChunks = -1;
Replica currentReplica = null;
Exception currentException = null;
while (numChunks == -1 && hasRemainingRetries()) {
// Only not wait for first request to each replicate.
currentReplica = getCurrentReplica();
if (numTries >= replicas.length) {
if (numTries != 0) {
logger.info(
"Retrying openChunk ({}/{}) for chunk from {} after {} ms.",
numTries,
maxTries,
currentReplica,
replica,
retryWaitMs);
Uninterruptibles.sleepUninterruptibly(retryWaitMs, TimeUnit.MILLISECONDS);
}
try {
currentReplica.getOrOpenStream();
numChunks = currentReplica.getNumChunks();
replica.getOrOpenStream();
numChunks = replica.getNumChunks();
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new IOException(e);
} catch (Exception e) {
logger.error(
"Exception raised while sending open chunks message to " + currentReplica + ".", e);
logger.error("Exception raised while sending open chunks message to " + replica + ".", e);
currentException = e;
if (shouldRetry(e)) {
numTries += 1;
@ -148,16 +143,21 @@ public class RetryingChunkClient {
}
if (numChunks == -1) {
if (currentException != null) {
throw new IOException(
String.format(
"Could not open chunks from %s after %d tries.", currentReplica, numTries),
currentException);
callback.onFailure(
0,
location,
new IOException(
String.format("Could not open chunks from %s after %d tries.", replica, numTries),
currentException));
} else {
throw new IOException(
String.format(
"Could not open chunks from %s after %d tries.", currentReplica, numTries));
callback.onFailure(
0,
location,
new IOException(
String.format("Could not open chunks from %s after %d tries.", replica, numTries)));
}
}
numTries = 0;
return numChunks;
}
@ -169,11 +169,17 @@ public class RetryingChunkClient {
* @param chunkIndex the index of the chunk to be fetched.
*/
public void fetchChunk(int chunkIndex) {
Replica replica;
RetryingChunkReceiveCallback callback;
synchronized (this) {
replica = getCurrentReplica();
callback = new RetryingChunkReceiveCallback(numTries);
if (fetchFailedChunkIndex != 0
&& location.getPeer() != null
&& chunkIndex == fetchFailedChunkIndex
&& location.getMode() == PartitionLocation.Mode.MASTER) {
RuntimeException manualTriggeredFailure =
new RuntimeException("Manual triggered fetch failure");
callback.onFailure(chunkIndex, location, manualTriggeredFailure);
}
}
try {
TransportClient client = replica.getOrOpenStream();
@ -188,17 +194,11 @@ public class RetryingChunkClient {
if (shouldRetry(e)) {
initiateRetry(chunkIndex, callback.currentNumTries);
} else {
callback.onFailure(chunkIndex, e);
callback.onFailure(chunkIndex, location, e);
}
}
}
@VisibleForTesting
Replica getCurrentReplica() {
int currentReplicaIndex = numTries % replicas.length;
return replicas[currentReplicaIndex];
}
@VisibleForTesting
int getNumTries() {
return numTries;
@ -228,7 +228,7 @@ public class RetryingChunkClient {
currentNumTries,
maxTries,
chunkIndex,
getCurrentReplica(),
replica,
retryWaitMs);
executorService.submit(
@ -246,17 +246,17 @@ public class RetryingChunkClient {
}
@Override
public void onSuccess(int chunkIndex, ManagedBuffer buffer) {
callback.onSuccess(chunkIndex, buffer);
public void onSuccess(int chunkIndex, ManagedBuffer buffer, PartitionLocation location) {
callback.onSuccess(chunkIndex, buffer, ChunkClient.this.location);
}
@Override
public void onFailure(int chunkIndex, Throwable e) {
public void onFailure(int chunkIndex, PartitionLocation location, Throwable e) {
if (shouldRetry(e)) {
initiateRetry(chunkIndex, this.currentNumTries);
} else {
logger.error("Abandon to fetch chunk {} after {} tries.", chunkIndex, this.currentNumTries);
callback.onFailure(chunkIndex, e);
callback.onFailure(chunkIndex, ChunkClient.this.location, e);
}
}
}

View File

@ -58,7 +58,8 @@ class Replica {
public synchronized TransportClient getOrOpenStream() throws IOException, InterruptedException {
if (client == null || !client.isActive()) {
client = clientFactory.createClient(location.getHost(), location.getFetchPort());
}
if (streamHandle == null) {
OpenStream openBlocks =
new OpenStream(shuffleKey, location.getFileName(), startMapIndex, endMapIndex);
ByteBuffer response = client.sendRpcSync(openBlocks.toByteBuffer(), timeoutMs);

View File

@ -18,12 +18,14 @@
package org.apache.celeborn.client.read;
import java.io.IOException;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import io.netty.buffer.ByteBuf;
import io.netty.util.ReferenceCounted;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@ -36,17 +38,19 @@ import org.apache.celeborn.common.protocol.PartitionLocation;
public class WorkerPartitionReader implements PartitionReader {
private final Logger logger = LoggerFactory.getLogger(WorkerPartitionReader.class);
private final RetryingChunkClient client;
private final int numChunks;
private ChunkClient client;
private int numChunks;
private int returnedChunks;
private int chunkIndex;
private int currentChunkIndex;
private final LinkedBlockingQueue<ByteBuf> results;
private final LinkedBlockingQueue<ChunkData> results;
private final AtomicReference<IOException> exception = new AtomicReference<>();
private final int fetchMaxReqsInFlight;
private boolean closed = false;
private AtomicBoolean closed = new AtomicBoolean(false);
private Set<PartitionLocation> readableLocations = ConcurrentHashMap.newKeySet();
private Set<PartitionLocation> failedLocations = ConcurrentHashMap.newKeySet();
WorkerPartitionReader(
CelebornConf conf,
@ -58,48 +62,88 @@ public class WorkerPartitionReader implements PartitionReader {
throws IOException {
fetchMaxReqsInFlight = conf.fetchMaxReqsInFlight();
results = new LinkedBlockingQueue<>();
readableLocations.add(location);
if (location.getPeer() != null) {
readableLocations.add(location.getPeer());
}
// only add the buffer to results queue if this reader is not closed.
ChunkReceivedCallback callback =
new ChunkReceivedCallback() {
@Override
public void onSuccess(int chunkIndex, ManagedBuffer buffer) {
public void onSuccess(int chunkIndex, ManagedBuffer buffer, PartitionLocation location) {
// only add the buffer to results queue if this reader is not closed.
synchronized (this) {
ByteBuf buf = ((NettyManagedBuffer) buffer).getBuf();
if (!closed) {
buf.retain();
results.add(buf);
}
ByteBuf buf = ((NettyManagedBuffer) buffer).getBuf();
if (!closed.get() && !failedLocations.contains(location)) {
buf.retain();
results.add(new ChunkData(buf, location));
}
}
@Override
public void onFailure(int chunkIndex, Throwable e) {
String errorMsg = "Fetch chunk " + chunkIndex + " failed.";
logger.error(errorMsg, e);
exception.set(new IOException(errorMsg, e));
public void onFailure(int chunkIndex, PartitionLocation location, Throwable e) {
readableLocations.remove(location);
if (readableLocations.isEmpty()) {
String errorMsg = "Fetch chunk " + chunkIndex + " failed.";
logger.error(errorMsg, e);
exception.set(new IOException(errorMsg, e));
} else {
try {
synchronized (WorkerPartitionReader.this) {
if (!failedLocations.contains(location)) {
failedLocations.add(location);
client =
new ChunkClient(
conf,
shuffleKey,
location.getPeer(),
this,
clientFactory,
startMapIndex,
endMapIndex);
currentChunkIndex = 0;
returnedChunks = 0;
numChunks = client.openChunks();
}
}
} catch (IOException e1) {
logger.error(e1.getMessage(), e1);
exception.set(new IOException(e1.getMessage(), e1));
}
}
}
};
client =
new RetryingChunkClient(
new ChunkClient(
conf, shuffleKey, location, callback, clientFactory, startMapIndex, endMapIndex);
numChunks = client.openChunks();
}
public boolean hasNext() {
public synchronized boolean hasNext() {
return returnedChunks < numChunks;
}
public ByteBuf next() throws IOException {
checkException();
if (chunkIndex < numChunks) {
fetchChunks();
synchronized (this) {
if (currentChunkIndex < numChunks) {
fetchChunks();
}
}
ByteBuf chunk = null;
try {
while (chunk == null) {
checkException();
chunk = results.poll(500, TimeUnit.MILLISECONDS);
ChunkData chunkData = results.poll(500, TimeUnit.MILLISECONDS);
if (chunkData != null) {
synchronized (this) {
if (failedLocations.contains(chunkData.location)) {
chunkData.release();
} else {
chunk = chunkData.buf;
returnedChunks++;
}
}
}
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
@ -107,26 +151,24 @@ public class WorkerPartitionReader implements PartitionReader {
exception.set(ioe);
throw ioe;
}
returnedChunks++;
return chunk;
}
public void close() {
synchronized (this) {
closed = true;
}
closed.set(true);
if (results.size() > 0) {
results.forEach(ReferenceCounted::release);
results.forEach(ChunkData::release);
}
results.clear();
}
private void fetchChunks() {
final int inFlight = chunkIndex - returnedChunks;
final int inFlight = currentChunkIndex - returnedChunks;
if (inFlight < fetchMaxReqsInFlight) {
final int toFetch = Math.min(fetchMaxReqsInFlight - inFlight + 1, numChunks - chunkIndex);
final int toFetch =
Math.min(fetchMaxReqsInFlight - inFlight + 1, numChunks - currentChunkIndex);
for (int i = 0; i < toFetch; i++) {
client.fetchChunk(chunkIndex++);
client.fetchChunk(currentChunkIndex++);
}
}
}
@ -137,4 +179,18 @@ public class WorkerPartitionReader implements PartitionReader {
throw e;
}
}
private static class ChunkData {
ByteBuf buf;
PartitionLocation location;
public ChunkData(ByteBuf buf, PartitionLocation location) {
this.buf = buf;
this.location = location;
}
public void release() {
buf.release();
}
}
}

View File

@ -1,423 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.celeborn.client.read;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.anyInt;
import static org.mockito.Mockito.anyObject;
import static org.mockito.Mockito.anyString;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.timeout;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Sets;
import io.netty.channel.Channel;
import org.junit.Test;
import org.mockito.stubbing.Answer;
import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.network.buffer.ManagedBuffer;
import org.apache.celeborn.common.network.buffer.NioManagedBuffer;
import org.apache.celeborn.common.network.client.ChunkReceivedCallback;
import org.apache.celeborn.common.network.client.TransportClient;
import org.apache.celeborn.common.network.client.TransportClientFactory;
import org.apache.celeborn.common.network.client.TransportResponseHandler;
import org.apache.celeborn.common.network.protocol.StreamHandle;
import org.apache.celeborn.common.protocol.PartitionLocation;
import org.apache.celeborn.common.util.ThreadUtils;
public class RetryingChunkClientSuiteJ {
private static final int MASTER_RPC_PORT = 1234;
private static final int MASTER_PUSH_PORT = 1235;
private static final int MASTER_FETCH_PORT = 1236;
private static final int MASTER_REPLICATE_PORT = 1237;
private static final int SLAVE_RPC_PORT = 4321;
private static final int SLAVE_PUSH_PORT = 4322;
private static final int SLAVE_FETCH_PORT = 4323;
private static final int SLAVE_REPLICATE_PORT = 4324;
private static final PartitionLocation masterLocation =
new PartitionLocation(
0,
1,
"localhost",
MASTER_RPC_PORT,
MASTER_PUSH_PORT,
MASTER_FETCH_PORT,
MASTER_REPLICATE_PORT,
PartitionLocation.Mode.MASTER);
private static final PartitionLocation slaveLocation =
new PartitionLocation(
0,
1,
"localhost",
SLAVE_RPC_PORT,
SLAVE_PUSH_PORT,
SLAVE_FETCH_PORT,
SLAVE_REPLICATE_PORT,
PartitionLocation.Mode.SLAVE);
static {
masterLocation.setPeer(slaveLocation);
slaveLocation.setPeer(masterLocation);
}
ManagedBuffer chunk0 = new NioManagedBuffer(ByteBuffer.wrap(new byte[13]));
ManagedBuffer chunk1 = new NioManagedBuffer(ByteBuffer.wrap(new byte[7]));
ManagedBuffer chunk2 = new NioManagedBuffer(ByteBuffer.wrap(new byte[19]));
@Test
public void testNoFailures() throws IOException, InterruptedException {
ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class);
Map<Integer, List<Object>> interactions =
ImmutableMap.<Integer, List<Object>>builder()
.put(0, Arrays.asList(chunk0))
.put(1, Arrays.asList(chunk1))
.put(2, Arrays.asList(chunk2))
.build();
RetryingChunkClient client = performInteractions(interactions, callback);
verify(callback, timeout(5000)).onSuccess(eq(0), eq(chunk0));
verify(callback, timeout(5000)).onSuccess(eq(1), eq(chunk1));
verify(callback, timeout(5000)).onSuccess(eq(2), eq(chunk2));
verifyNoMoreInteractions(callback);
assertEquals(0, client.getNumTries());
assertEquals(masterLocation, client.getCurrentReplica().getLocation());
}
@Test
public void testUnrecoverableFailure() throws IOException, InterruptedException {
ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class);
Map<Integer, List<Object>> interactions =
ImmutableMap.<Integer, List<Object>>builder()
.put(0, Arrays.asList(new RuntimeException("Ouch!")))
.put(1, Arrays.asList(chunk1))
.put(2, Arrays.asList(chunk2))
.build();
RetryingChunkClient client = performInteractions(interactions, callback);
verify(callback, timeout(5000)).onFailure(eq(0), any());
verify(callback, timeout(5000)).onSuccess(eq(1), eq(chunk1));
verify(callback, timeout(5000)).onSuccess(eq(2), eq(chunk2));
verifyNoMoreInteractions(callback);
assertEquals(0, client.getNumTries());
assertEquals(masterLocation, client.getCurrentReplica().getLocation());
}
@Test
public void testDuplicateSuccess() throws IOException, InterruptedException {
ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class);
Map<Integer, List<Object>> interactions =
ImmutableMap.<Integer, List<Object>>builder().put(0, Arrays.asList(chunk0, chunk1)).build();
RetryingChunkClient client = performInteractions(interactions, callback);
verify(callback, timeout(5000)).onSuccess(eq(0), eq(chunk0));
verifyNoMoreInteractions(callback);
assertEquals(0, client.getNumTries());
assertEquals(masterLocation, client.getCurrentReplica().getLocation());
}
@Test
public void testSingleIOException() throws IOException, InterruptedException {
Map<Integer, Object> result = new ConcurrentHashMap<>();
ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class);
Semaphore signal = new Semaphore(3);
signal.acquire(3);
Answer<Void> answer =
invocation -> {
synchronized (signal) {
int chunkIndex = (Integer) invocation.getArguments()[0];
assertFalse(result.containsKey(chunkIndex));
Object value = invocation.getArguments()[1];
result.put(chunkIndex, value);
signal.release();
}
return null;
};
doAnswer(answer).when(callback).onSuccess(anyInt(), anyObject());
doAnswer(answer).when(callback).onFailure(anyInt(), anyObject());
Map<Integer, List<Object>> interactions =
ImmutableMap.<Integer, List<Object>>builder()
.put(0, Arrays.asList(new IOException(), chunk0))
.put(1, Arrays.asList(chunk1))
.put(2, Arrays.asList(chunk2))
.build();
RetryingChunkClient client = performInteractions(interactions, callback);
while (!signal.tryAcquire(3, 500, TimeUnit.MILLISECONDS)) ;
assertEquals(1, client.getNumTries());
assertEquals(slaveLocation, client.getCurrentReplica().getLocation());
assertEquals(chunk0, result.get(0));
assertEquals(chunk1, result.get(1));
assertEquals(chunk2, result.get(2));
}
@Test
public void testTwoIOExceptions() throws IOException, InterruptedException {
Map<Integer, Object> result = new ConcurrentHashMap<>();
ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class);
Semaphore signal = new Semaphore(3);
signal.acquire(3);
Answer<Void> answer =
invocation -> {
synchronized (signal) {
int chunkIndex = (Integer) invocation.getArguments()[0];
assertFalse(result.containsKey(chunkIndex));
Object value = invocation.getArguments()[1];
result.put(chunkIndex, value);
signal.release();
}
return null;
};
doAnswer(answer).when(callback).onSuccess(anyInt(), anyObject());
doAnswer(answer).when(callback).onFailure(anyInt(), anyObject());
Map<Integer, List<Object>> interactions =
ImmutableMap.<Integer, List<Object>>builder()
.put(0, Arrays.asList(new IOException("first ioexception"), chunk0))
.put(1, Arrays.asList(new IOException("second ioexception"), chunk1))
.put(2, Arrays.asList(chunk2))
.build();
RetryingChunkClient client = performInteractions(interactions, callback);
while (!signal.tryAcquire(3, 500, TimeUnit.MILLISECONDS)) ;
assertEquals(1, client.getNumTries());
assertEquals(slaveLocation, client.getCurrentReplica().getLocation());
assertEquals(chunk0, result.get(0));
assertEquals(chunk1, result.get(1));
assertEquals(chunk2, result.get(2));
}
@Test
public void testThreeIOExceptions() throws IOException, InterruptedException {
Map<Integer, Object> result = new ConcurrentHashMap<>();
ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class);
Semaphore signal = new Semaphore(3);
signal.acquire(3);
Answer<Void> answer =
invocation -> {
synchronized (signal) {
int chunkIndex = (Integer) invocation.getArguments()[0];
assertFalse(result.containsKey(chunkIndex));
Object value = invocation.getArguments()[1];
result.put(chunkIndex, value);
signal.release();
}
return null;
};
doAnswer(answer).when(callback).onSuccess(anyInt(), anyObject());
doAnswer(answer).when(callback).onFailure(anyInt(), anyObject());
Map<Integer, List<Object>> interactions =
ImmutableMap.<Integer, List<Object>>builder()
.put(0, Arrays.asList(new IOException("first ioexception"), chunk0))
.put(1, Arrays.asList(new IOException("second ioexception"), chunk1))
.put(2, Arrays.asList(new IOException("third ioexception"), chunk2))
.build();
RetryingChunkClient client = performInteractions(interactions, callback);
while (!signal.tryAcquire(3, 500, TimeUnit.MILLISECONDS)) ;
assertEquals(1, client.getNumTries());
assertEquals(slaveLocation, client.getCurrentReplica().getLocation());
assertEquals(chunk0, result.get(0));
assertEquals(chunk1, result.get(1));
assertEquals(chunk2, result.get(2));
}
@Test
public void testFailedWithIOExceptions() throws IOException, InterruptedException {
Map<Integer, Object> result = new ConcurrentHashMap<>();
ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class);
Semaphore signal = new Semaphore(3);
signal.acquire(3);
Answer<Void> answer =
invocation -> {
synchronized (signal) {
int chunkIndex = (Integer) invocation.getArguments()[0];
assertFalse(result.containsKey(chunkIndex));
Object value = invocation.getArguments()[1];
result.put(chunkIndex, value);
signal.release();
}
return null;
};
doAnswer(answer).when(callback).onSuccess(anyInt(), anyObject());
doAnswer(answer).when(callback).onFailure(anyInt(), anyObject());
IOException ioe = new IOException("failed exception");
Map<Integer, List<Object>> interactions =
ImmutableMap.<Integer, List<Object>>builder()
.put(0, Arrays.asList(ioe, ioe, ioe, ioe, ioe, chunk0))
.put(1, Arrays.asList(ioe, ioe, ioe, ioe, ioe, chunk1))
.put(2, Arrays.asList(ioe, ioe, ioe, ioe, ioe, chunk2))
.build();
RetryingChunkClient client = performInteractions(interactions, callback);
while (!signal.tryAcquire(3, 500, TimeUnit.MILLISECONDS)) ;
// Note: this may exceeds the max retries we want, but it doesn't master.
assertEquals(4, client.getNumTries());
assertEquals(masterLocation, client.getCurrentReplica().getLocation());
assertEquals(ioe, result.get(0));
assertEquals(ioe, result.get(1));
assertEquals(ioe, result.get(2));
}
@Test
public void testRetryAndUnrecoverable() throws IOException, InterruptedException {
Map<Integer, Object> result = new ConcurrentHashMap<>();
ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class);
Semaphore signal = new Semaphore(3);
signal.acquire(3);
Answer<Void> answer =
invocation -> {
synchronized (signal) {
int chunkIndex = (Integer) invocation.getArguments()[0];
assertFalse(result.containsKey(chunkIndex));
Object value = invocation.getArguments()[1];
result.put(chunkIndex, value);
signal.release();
}
return null;
};
doAnswer(answer).when(callback).onSuccess(anyInt(), anyObject());
doAnswer(answer).when(callback).onFailure(anyInt(), anyObject());
Exception re = new RuntimeException("failed exception");
Map<Integer, List<Object>> interactions =
ImmutableMap.<Integer, List<Object>>builder()
.put(0, Arrays.asList(new IOException("first ioexception"), re, chunk0))
.put(1, Arrays.asList(chunk1))
.put(2, Arrays.asList(new IOException("second ioexception"), chunk2))
.build();
performInteractions(interactions, callback);
while (!signal.tryAcquire(3, 500, TimeUnit.MILLISECONDS)) ;
assertEquals(re, result.get(0));
assertEquals(chunk1, result.get(1));
assertEquals(chunk2, result.get(2));
}
private static RetryingChunkClient performInteractions(
Map<Integer, List<Object>> interactions, ChunkReceivedCallback callback)
throws IOException, InterruptedException {
CelebornConf conf = new CelebornConf();
conf.set("celeborn.data.io.maxRetries", "1");
conf.set("celeborn.data.io.retryWait", "0");
// Contains all chunk ids that are referenced across all interactions.
LinkedHashSet<Integer> chunkIds = Sets.newLinkedHashSet(interactions.keySet());
final TransportClient client = new DummyTransportClient(chunkIds.size(), interactions);
final TransportClientFactory clientFactory = mock(TransportClientFactory.class);
doAnswer(invocation -> client).when(clientFactory).createClient(anyString(), anyInt());
RetryingChunkClient retryingChunkClient =
new RetryingChunkClient(conf, "test", masterLocation, callback, clientFactory);
chunkIds.stream().sorted().forEach(retryingChunkClient::fetchChunk);
return retryingChunkClient;
}
private static class DummyTransportClient extends TransportClient {
private static final Channel channel = mock(Channel.class);
private static final TransportResponseHandler handler = mock(TransportResponseHandler.class);
private final long streamId = new Random().nextInt(Integer.MAX_VALUE) * 1000L;
private final int numChunks;
private final Map<Integer, List<Object>> interactions;
private final Map<Integer, Integer> chunkIdToInterActionIndex;
private final ScheduledExecutorService schedule =
ThreadUtils.newDaemonThreadPoolScheduledExecutor("test-fetch-chunk", 3);
DummyTransportClient(int numChunks, Map<Integer, List<Object>> interactions) {
super(channel, handler);
this.numChunks = numChunks;
this.interactions = interactions;
this.chunkIdToInterActionIndex = new ConcurrentHashMap<>();
interactions.keySet().forEach((chunkId) -> chunkIdToInterActionIndex.putIfAbsent(chunkId, 0));
}
@Override
public void fetchChunk(long streamId, int chunkId, ChunkReceivedCallback callback) {
schedule.schedule(
() -> {
Object action;
List<Object> interaction = interactions.get(chunkId);
synchronized (chunkIdToInterActionIndex) {
int index = chunkIdToInterActionIndex.get(chunkId);
assertTrue(index < interaction.size());
action = interaction.get(index);
chunkIdToInterActionIndex.put(chunkId, index + 1);
}
if (action instanceof ManagedBuffer) {
callback.onSuccess(chunkId, (ManagedBuffer) action);
} else if (action instanceof Exception) {
callback.onFailure(chunkId, (Exception) action);
} else {
fail("Can only handle ManagedBuffers and Exceptions, got " + action);
}
},
500,
TimeUnit.MILLISECONDS);
}
@Override
public ByteBuffer sendRpcSync(ByteBuffer message, long timeoutMs) {
StreamHandle handle = new StreamHandle(streamId, numChunks);
return handle.toByteBuffer();
}
@Override
public void close() {
super.close();
schedule.shutdownNow();
}
}
}

View File

@ -18,6 +18,7 @@
package org.apache.celeborn.common.network.client;
import org.apache.celeborn.common.network.buffer.ManagedBuffer;
import org.apache.celeborn.common.protocol.PartitionLocation;
/**
* Callback for the result of a single chunk result. For a single stream, the callbacks are
@ -34,7 +35,7 @@ public interface ChunkReceivedCallback {
* this call returns. You must therefore either retain() the buffer or copy its contents before
* returning.
*/
void onSuccess(int chunkIndex, ManagedBuffer buffer);
void onSuccess(int chunkIndex, ManagedBuffer buffer, PartitionLocation location);
/**
* Called upon failure to fetch a particular chunk. Note that this may actually be called due to
@ -43,5 +44,5 @@ public interface ChunkReceivedCallback {
* <p>After receiving a failure, the stream may or may not be valid. The client should not assume
* that the server's side of the stream has been closed.
*/
void onFailure(int chunkIndex, Throwable e);
void onFailure(int chunkIndex, PartitionLocation location, Throwable e);
}

View File

@ -126,7 +126,7 @@ public class TransportClient implements Closeable {
@Override
protected void handleFailure(String errorMsg, Throwable cause) {
handler.removeFetchRequest(streamChunkSlice);
callback.onFailure(chunkIndex, new IOException(errorMsg, cause));
callback.onFailure(chunkIndex, null, new IOException(errorMsg, cause));
}
};
handler.addFetchRequest(streamChunkSlice, callback);

View File

@ -86,7 +86,7 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
private void failOutstandingRequests(Throwable cause) {
for (Map.Entry<StreamChunkSlice, ChunkReceivedCallback> entry : outstandingFetches.entrySet()) {
try {
entry.getValue().onFailure(entry.getKey().chunkIndex, cause);
entry.getValue().onFailure(entry.getKey().chunkIndex, null, cause);
} catch (Exception e) {
logger.warn("ChunkReceivedCallback.onFailure throws exception", e);
}
@ -144,7 +144,7 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
resp.body().release();
} else {
outstandingFetches.remove(resp.streamChunkSlice);
listener.onSuccess(resp.streamChunkSlice.chunkIndex, resp.body());
listener.onSuccess(resp.streamChunkSlice.chunkIndex, resp.body(), null);
resp.body().release();
}
} else if (message instanceof ChunkFetchFailure) {
@ -161,6 +161,7 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
logger.warn("Receive ChunkFetchFailure, errorMsg {}", resp.errorString);
listener.onFailure(
resp.streamChunkSlice.chunkIndex,
null,
new ChunkFetchFailureException(
"Failure while fetching " + resp.streamChunkSlice + ": " + resp.errorString));
}

View File

@ -662,6 +662,7 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se
def rpcCacheSize: Int = get(RPC_CACHE_SIZE)
def rpcCacheConcurrencyLevel: Int = get(RPC_CACHE_CONCURRENCY_LEVEL)
def rpcCacheExpireTime: Long = get(RPC_CACHE_EXPIRE_TIME)
def testFetchFailedChunkIndex: Int = get(TEST_FETCH_FAILED_CHUNK_INDEX)
// //////////////////////////////////////////////////////
// Graceful Shutdown & Recover //
@ -2539,4 +2540,13 @@ object CelebornConf extends Logging {
.doc("The time before a cache item is removed.")
.timeConf(TimeUnit.MILLISECONDS)
.createWithDefaultString("15s")
val TEST_FETCH_FAILED_CHUNK_INDEX: ConfigEntry[Int] =
buildConf("celeborn.test.client.fetchFailedChuckIndex")
.categories("client")
.version("0.2.0")
.internal
.doc("The chunk index to trigger fetch chunk failure for testing purpose only.")
.intConf
.createWithDefault(0)
}

View File

@ -47,6 +47,7 @@ import org.apache.celeborn.common.network.server.BaseMessageHandler;
import org.apache.celeborn.common.network.server.StreamManager;
import org.apache.celeborn.common.network.server.TransportServer;
import org.apache.celeborn.common.network.util.TransportConf;
import org.apache.celeborn.common.protocol.PartitionLocation;
public class ChunkFetchIntegrationSuiteJ {
static final long STREAM_ID = 1;
@ -153,7 +154,7 @@ public class ChunkFetchIntegrationSuiteJ {
ChunkReceivedCallback callback =
new ChunkReceivedCallback() {
@Override
public void onSuccess(int chunkIndex, ManagedBuffer buffer) {
public void onSuccess(int chunkIndex, ManagedBuffer buffer, PartitionLocation location) {
buffer.retain();
res.successChunks.add(chunkIndex);
res.buffers.add(buffer);
@ -161,7 +162,7 @@ public class ChunkFetchIntegrationSuiteJ {
}
@Override
public void onFailure(int chunkIndex, Throwable e) {
public void onFailure(int chunkIndex, PartitionLocation location, Throwable e) {
res.failedChunks.add(chunkIndex);
sem.release();
}

View File

@ -42,6 +42,7 @@ import org.apache.celeborn.common.network.server.BaseMessageHandler;
import org.apache.celeborn.common.network.server.StreamManager;
import org.apache.celeborn.common.network.server.TransportServer;
import org.apache.celeborn.common.network.util.TransportConf;
import org.apache.celeborn.common.protocol.PartitionLocation;
/**
* Suite which ensures that requests that go without a response for the network timeout period are
@ -267,7 +268,7 @@ public class RequestTimeoutIntegrationSuiteJ {
}
@Override
public void onSuccess(int chunkIndex, ManagedBuffer buffer) {
public void onSuccess(int chunkIndex, ManagedBuffer buffer, PartitionLocation location) {
try {
successLength = buffer.nioByteBuffer().remaining();
} catch (IOException e) {
@ -278,7 +279,7 @@ public class RequestTimeoutIntegrationSuiteJ {
}
@Override
public void onFailure(int chunkIndex, Throwable e) {
public void onFailure(int chunkIndex, PartitionLocation location, Throwable e) {
failure = e;
latch.countDown();
}

View File

@ -42,7 +42,7 @@ public class TransportResponseHandlerSuiteJ {
assertEquals(1, handler.numOutstandingRequests());
handler.handle(new ChunkFetchSuccess(streamChunkSlice, new TestManagedBuffer(123)));
verify(callback, times(1)).onSuccess(eq(0), any());
verify(callback, times(1)).onSuccess(eq(0), any(), any());
assertEquals(0, handler.numOutstandingRequests());
}
@ -55,7 +55,7 @@ public class TransportResponseHandlerSuiteJ {
assertEquals(1, handler.numOutstandingRequests());
handler.handle(new ChunkFetchFailure(streamChunkSlice, "some error msg"));
verify(callback, times(1)).onFailure(eq(0), any());
verify(callback, times(1)).onFailure(eq(0), any(), any());
assertEquals(0, handler.numOutstandingRequests());
}
@ -72,9 +72,9 @@ public class TransportResponseHandlerSuiteJ {
handler.exceptionCaught(new Exception("duh duh duhhhh"));
// should fail both b2 and b3
verify(callback, times(1)).onSuccess(eq(0), any());
verify(callback, times(1)).onFailure(eq(1), any());
verify(callback, times(1)).onFailure(eq(2), any());
verify(callback, times(1)).onSuccess(eq(0), any(), any());
verify(callback, times(1)).onFailure(eq(1), any(), any());
verify(callback, times(1)).onFailure(eq(2), any(), any());
assertEquals(0, handler.numOutstandingRequests());
}

View File

@ -71,6 +71,7 @@ import org.apache.celeborn.common.network.server.MemoryTracker;
import org.apache.celeborn.common.network.server.TransportServer;
import org.apache.celeborn.common.network.util.JavaUtils;
import org.apache.celeborn.common.network.util.TransportConf;
import org.apache.celeborn.common.protocol.PartitionLocation;
import org.apache.celeborn.common.protocol.PartitionSplitMode;
import org.apache.celeborn.common.protocol.PartitionType;
import org.apache.celeborn.common.protocol.StorageInfo;
@ -209,7 +210,7 @@ public class FileWriterSuiteJ {
ChunkReceivedCallback callback =
new ChunkReceivedCallback() {
@Override
public void onSuccess(int chunkIndex, ManagedBuffer buffer) {
public void onSuccess(int chunkIndex, ManagedBuffer buffer, PartitionLocation location) {
buffer.retain();
res.successChunks.add(chunkIndex);
res.buffers.add(buffer);
@ -217,7 +218,7 @@ public class FileWriterSuiteJ {
}
@Override
public void onFailure(int chunkIndex, Throwable e) {
public void onFailure(int chunkIndex, PartitionLocation location, Throwable e) {
res.failedChunks.add(chunkIndex);
sem.release();
}