diff --git a/client/src/main/java/org/apache/celeborn/client/read/RetryingChunkClient.java b/client/src/main/java/org/apache/celeborn/client/read/ChunkClient.java similarity index 79% rename from client/src/main/java/org/apache/celeborn/client/read/RetryingChunkClient.java rename to client/src/main/java/org/apache/celeborn/client/read/ChunkClient.java index c044b2b5f..e48153754 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/RetryingChunkClient.java +++ b/client/src/main/java/org/apache/celeborn/client/read/ChunkClient.java @@ -50,19 +50,25 @@ import org.apache.celeborn.common.util.Utils; * file will not be too many. Each retry is actually a switch between Master and Slave. Therefore, * each retry needs to create a new connection and reopen the file to generate the stream id. */ -public class RetryingChunkClient { - private static final Logger logger = LoggerFactory.getLogger(RetryingChunkClient.class); +public class ChunkClient { + private static final Logger logger = LoggerFactory.getLogger(ChunkClient.class); private static final ExecutorService executorService = Executors.newCachedThreadPool(NettyUtils.createThreadFactory("fetch-chuck")); private final ChunkReceivedCallback callback; - private final Replica[] replicas; + private final Replica replica; private final long retryWaitMs; private final int maxTries; private volatile int numTries = 0; + private PartitionLocation location; + private int fetchFailedChunkIndex; - public RetryingChunkClient( + public PartitionLocation getLocation() { + return location; + } + + public ChunkClient( CelebornConf conf, String shuffleKey, PartitionLocation location, @@ -71,7 +77,7 @@ public class RetryingChunkClient { this(conf, shuffleKey, location, callback, clientFactory, 0, Integer.MAX_VALUE); } - public RetryingChunkClient( + public ChunkClient( CelebornConf conf, String shuffleKey, PartitionLocation location, @@ -81,30 +87,22 @@ public class RetryingChunkClient { int endMapIndex) { TransportConf transportConf = Utils.fromCelebornConf(conf, TransportModuleConstants.DATA_MODULE, 0); - + this.fetchFailedChunkIndex = conf.testFetchFailedChunkIndex(); this.callback = callback; this.retryWaitMs = transportConf.ioRetryWaitTimeMs(); long fetchTimeoutMs = conf.fetchTimeoutMs(); + this.location = location; if (location == null) { throw new IllegalArgumentException("Must contain at least one available PartitionLocation."); } else { - Replica main = + replica = new Replica( fetchTimeoutMs, shuffleKey, location, clientFactory, startMapIndex, endMapIndex); - PartitionLocation peerLoc = location.getPeer(); - if (peerLoc == null) { - replicas = new Replica[] {main}; - } else { - Replica peer = - new Replica( - fetchTimeoutMs, shuffleKey, peerLoc, clientFactory, startMapIndex, endMapIndex); - replicas = new Replica[] {main, peer}; - } } - this.maxTries = (transportConf.maxIORetries() + 1) * replicas.length; + this.maxTries = (transportConf.maxIORetries() + 1); } /** @@ -115,29 +113,26 @@ public class RetryingChunkClient { */ public synchronized int openChunks() throws IOException { int numChunks = -1; - Replica currentReplica = null; Exception currentException = null; while (numChunks == -1 && hasRemainingRetries()) { // Only not wait for first request to each replicate. - currentReplica = getCurrentReplica(); - if (numTries >= replicas.length) { + if (numTries != 0) { logger.info( "Retrying openChunk ({}/{}) for chunk from {} after {} ms.", numTries, maxTries, - currentReplica, + replica, retryWaitMs); Uninterruptibles.sleepUninterruptibly(retryWaitMs, TimeUnit.MILLISECONDS); } try { - currentReplica.getOrOpenStream(); - numChunks = currentReplica.getNumChunks(); + replica.getOrOpenStream(); + numChunks = replica.getNumChunks(); } catch (InterruptedException e) { Thread.currentThread().interrupt(); throw new IOException(e); } catch (Exception e) { - logger.error( - "Exception raised while sending open chunks message to " + currentReplica + ".", e); + logger.error("Exception raised while sending open chunks message to " + replica + ".", e); currentException = e; if (shouldRetry(e)) { numTries += 1; @@ -148,16 +143,21 @@ public class RetryingChunkClient { } if (numChunks == -1) { if (currentException != null) { - throw new IOException( - String.format( - "Could not open chunks from %s after %d tries.", currentReplica, numTries), - currentException); + callback.onFailure( + 0, + location, + new IOException( + String.format("Could not open chunks from %s after %d tries.", replica, numTries), + currentException)); } else { - throw new IOException( - String.format( - "Could not open chunks from %s after %d tries.", currentReplica, numTries)); + callback.onFailure( + 0, + location, + new IOException( + String.format("Could not open chunks from %s after %d tries.", replica, numTries))); } } + numTries = 0; return numChunks; } @@ -169,11 +169,17 @@ public class RetryingChunkClient { * @param chunkIndex the index of the chunk to be fetched. */ public void fetchChunk(int chunkIndex) { - Replica replica; RetryingChunkReceiveCallback callback; synchronized (this) { - replica = getCurrentReplica(); callback = new RetryingChunkReceiveCallback(numTries); + if (fetchFailedChunkIndex != 0 + && location.getPeer() != null + && chunkIndex == fetchFailedChunkIndex + && location.getMode() == PartitionLocation.Mode.MASTER) { + RuntimeException manualTriggeredFailure = + new RuntimeException("Manual triggered fetch failure"); + callback.onFailure(chunkIndex, location, manualTriggeredFailure); + } } try { TransportClient client = replica.getOrOpenStream(); @@ -188,17 +194,11 @@ public class RetryingChunkClient { if (shouldRetry(e)) { initiateRetry(chunkIndex, callback.currentNumTries); } else { - callback.onFailure(chunkIndex, e); + callback.onFailure(chunkIndex, location, e); } } } - @VisibleForTesting - Replica getCurrentReplica() { - int currentReplicaIndex = numTries % replicas.length; - return replicas[currentReplicaIndex]; - } - @VisibleForTesting int getNumTries() { return numTries; @@ -228,7 +228,7 @@ public class RetryingChunkClient { currentNumTries, maxTries, chunkIndex, - getCurrentReplica(), + replica, retryWaitMs); executorService.submit( @@ -246,17 +246,17 @@ public class RetryingChunkClient { } @Override - public void onSuccess(int chunkIndex, ManagedBuffer buffer) { - callback.onSuccess(chunkIndex, buffer); + public void onSuccess(int chunkIndex, ManagedBuffer buffer, PartitionLocation location) { + callback.onSuccess(chunkIndex, buffer, ChunkClient.this.location); } @Override - public void onFailure(int chunkIndex, Throwable e) { + public void onFailure(int chunkIndex, PartitionLocation location, Throwable e) { if (shouldRetry(e)) { initiateRetry(chunkIndex, this.currentNumTries); } else { logger.error("Abandon to fetch chunk {} after {} tries.", chunkIndex, this.currentNumTries); - callback.onFailure(chunkIndex, e); + callback.onFailure(chunkIndex, ChunkClient.this.location, e); } } } diff --git a/client/src/main/java/org/apache/celeborn/client/read/Replica.java b/client/src/main/java/org/apache/celeborn/client/read/Replica.java index 1fe7d3b35..40e60b88d 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/Replica.java +++ b/client/src/main/java/org/apache/celeborn/client/read/Replica.java @@ -58,7 +58,8 @@ class Replica { public synchronized TransportClient getOrOpenStream() throws IOException, InterruptedException { if (client == null || !client.isActive()) { client = clientFactory.createClient(location.getHost(), location.getFetchPort()); - + } + if (streamHandle == null) { OpenStream openBlocks = new OpenStream(shuffleKey, location.getFileName(), startMapIndex, endMapIndex); ByteBuffer response = client.sendRpcSync(openBlocks.toByteBuffer(), timeoutMs); diff --git a/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java b/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java index 01e35b4db..bb17f31fd 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java +++ b/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java @@ -18,12 +18,14 @@ package org.apache.celeborn.client.read; import java.io.IOException; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; import io.netty.buffer.ByteBuf; -import io.netty.util.ReferenceCounted; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -36,17 +38,19 @@ import org.apache.celeborn.common.protocol.PartitionLocation; public class WorkerPartitionReader implements PartitionReader { private final Logger logger = LoggerFactory.getLogger(WorkerPartitionReader.class); - private final RetryingChunkClient client; - private final int numChunks; + private ChunkClient client; + private int numChunks; private int returnedChunks; - private int chunkIndex; + private int currentChunkIndex; - private final LinkedBlockingQueue results; + private final LinkedBlockingQueue results; private final AtomicReference exception = new AtomicReference<>(); private final int fetchMaxReqsInFlight; - private boolean closed = false; + private AtomicBoolean closed = new AtomicBoolean(false); + private Set readableLocations = ConcurrentHashMap.newKeySet(); + private Set failedLocations = ConcurrentHashMap.newKeySet(); WorkerPartitionReader( CelebornConf conf, @@ -58,48 +62,88 @@ public class WorkerPartitionReader implements PartitionReader { throws IOException { fetchMaxReqsInFlight = conf.fetchMaxReqsInFlight(); results = new LinkedBlockingQueue<>(); + readableLocations.add(location); + if (location.getPeer() != null) { + readableLocations.add(location.getPeer()); + } // only add the buffer to results queue if this reader is not closed. ChunkReceivedCallback callback = new ChunkReceivedCallback() { @Override - public void onSuccess(int chunkIndex, ManagedBuffer buffer) { + public void onSuccess(int chunkIndex, ManagedBuffer buffer, PartitionLocation location) { // only add the buffer to results queue if this reader is not closed. - synchronized (this) { - ByteBuf buf = ((NettyManagedBuffer) buffer).getBuf(); - if (!closed) { - buf.retain(); - results.add(buf); - } + ByteBuf buf = ((NettyManagedBuffer) buffer).getBuf(); + if (!closed.get() && !failedLocations.contains(location)) { + buf.retain(); + results.add(new ChunkData(buf, location)); } } @Override - public void onFailure(int chunkIndex, Throwable e) { - String errorMsg = "Fetch chunk " + chunkIndex + " failed."; - logger.error(errorMsg, e); - exception.set(new IOException(errorMsg, e)); + public void onFailure(int chunkIndex, PartitionLocation location, Throwable e) { + readableLocations.remove(location); + if (readableLocations.isEmpty()) { + String errorMsg = "Fetch chunk " + chunkIndex + " failed."; + logger.error(errorMsg, e); + exception.set(new IOException(errorMsg, e)); + } else { + try { + synchronized (WorkerPartitionReader.this) { + if (!failedLocations.contains(location)) { + failedLocations.add(location); + client = + new ChunkClient( + conf, + shuffleKey, + location.getPeer(), + this, + clientFactory, + startMapIndex, + endMapIndex); + currentChunkIndex = 0; + returnedChunks = 0; + numChunks = client.openChunks(); + } + } + } catch (IOException e1) { + logger.error(e1.getMessage(), e1); + exception.set(new IOException(e1.getMessage(), e1)); + } + } } }; client = - new RetryingChunkClient( + new ChunkClient( conf, shuffleKey, location, callback, clientFactory, startMapIndex, endMapIndex); numChunks = client.openChunks(); } - public boolean hasNext() { + public synchronized boolean hasNext() { return returnedChunks < numChunks; } public ByteBuf next() throws IOException { checkException(); - if (chunkIndex < numChunks) { - fetchChunks(); + synchronized (this) { + if (currentChunkIndex < numChunks) { + fetchChunks(); + } } ByteBuf chunk = null; try { while (chunk == null) { checkException(); - chunk = results.poll(500, TimeUnit.MILLISECONDS); + ChunkData chunkData = results.poll(500, TimeUnit.MILLISECONDS); + if (chunkData != null) { + synchronized (this) { + if (failedLocations.contains(chunkData.location)) { + chunkData.release(); + } else { + chunk = chunkData.buf; + returnedChunks++; + } + } + } } } catch (InterruptedException e) { Thread.currentThread().interrupt(); @@ -107,26 +151,24 @@ public class WorkerPartitionReader implements PartitionReader { exception.set(ioe); throw ioe; } - returnedChunks++; return chunk; } public void close() { - synchronized (this) { - closed = true; - } + closed.set(true); if (results.size() > 0) { - results.forEach(ReferenceCounted::release); + results.forEach(ChunkData::release); } results.clear(); } private void fetchChunks() { - final int inFlight = chunkIndex - returnedChunks; + final int inFlight = currentChunkIndex - returnedChunks; if (inFlight < fetchMaxReqsInFlight) { - final int toFetch = Math.min(fetchMaxReqsInFlight - inFlight + 1, numChunks - chunkIndex); + final int toFetch = + Math.min(fetchMaxReqsInFlight - inFlight + 1, numChunks - currentChunkIndex); for (int i = 0; i < toFetch; i++) { - client.fetchChunk(chunkIndex++); + client.fetchChunk(currentChunkIndex++); } } } @@ -137,4 +179,18 @@ public class WorkerPartitionReader implements PartitionReader { throw e; } } + + private static class ChunkData { + ByteBuf buf; + PartitionLocation location; + + public ChunkData(ByteBuf buf, PartitionLocation location) { + this.buf = buf; + this.location = location; + } + + public void release() { + buf.release(); + } + } } diff --git a/client/src/test/java/org/apache/celeborn/client/read/RetryingChunkClientSuiteJ.java b/client/src/test/java/org/apache/celeborn/client/read/RetryingChunkClientSuiteJ.java deleted file mode 100644 index 0b85349bf..000000000 --- a/client/src/test/java/org/apache/celeborn/client/read/RetryingChunkClientSuiteJ.java +++ /dev/null @@ -1,423 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.celeborn.client.read; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; -import static org.junit.Assert.fail; -import static org.mockito.Mockito.any; -import static org.mockito.Mockito.anyInt; -import static org.mockito.Mockito.anyObject; -import static org.mockito.Mockito.anyString; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.eq; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.timeout; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyNoMoreInteractions; - -import java.io.IOException; -import java.nio.ByteBuffer; -import java.util.Arrays; -import java.util.LinkedHashSet; -import java.util.List; -import java.util.Map; -import java.util.Random; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.Semaphore; -import java.util.concurrent.TimeUnit; - -import com.google.common.collect.ImmutableMap; -import com.google.common.collect.Sets; -import io.netty.channel.Channel; -import org.junit.Test; -import org.mockito.stubbing.Answer; - -import org.apache.celeborn.common.CelebornConf; -import org.apache.celeborn.common.network.buffer.ManagedBuffer; -import org.apache.celeborn.common.network.buffer.NioManagedBuffer; -import org.apache.celeborn.common.network.client.ChunkReceivedCallback; -import org.apache.celeborn.common.network.client.TransportClient; -import org.apache.celeborn.common.network.client.TransportClientFactory; -import org.apache.celeborn.common.network.client.TransportResponseHandler; -import org.apache.celeborn.common.network.protocol.StreamHandle; -import org.apache.celeborn.common.protocol.PartitionLocation; -import org.apache.celeborn.common.util.ThreadUtils; - -public class RetryingChunkClientSuiteJ { - - private static final int MASTER_RPC_PORT = 1234; - private static final int MASTER_PUSH_PORT = 1235; - private static final int MASTER_FETCH_PORT = 1236; - private static final int MASTER_REPLICATE_PORT = 1237; - private static final int SLAVE_RPC_PORT = 4321; - private static final int SLAVE_PUSH_PORT = 4322; - private static final int SLAVE_FETCH_PORT = 4323; - private static final int SLAVE_REPLICATE_PORT = 4324; - private static final PartitionLocation masterLocation = - new PartitionLocation( - 0, - 1, - "localhost", - MASTER_RPC_PORT, - MASTER_PUSH_PORT, - MASTER_FETCH_PORT, - MASTER_REPLICATE_PORT, - PartitionLocation.Mode.MASTER); - private static final PartitionLocation slaveLocation = - new PartitionLocation( - 0, - 1, - "localhost", - SLAVE_RPC_PORT, - SLAVE_PUSH_PORT, - SLAVE_FETCH_PORT, - SLAVE_REPLICATE_PORT, - PartitionLocation.Mode.SLAVE); - - static { - masterLocation.setPeer(slaveLocation); - slaveLocation.setPeer(masterLocation); - } - - ManagedBuffer chunk0 = new NioManagedBuffer(ByteBuffer.wrap(new byte[13])); - ManagedBuffer chunk1 = new NioManagedBuffer(ByteBuffer.wrap(new byte[7])); - ManagedBuffer chunk2 = new NioManagedBuffer(ByteBuffer.wrap(new byte[19])); - - @Test - public void testNoFailures() throws IOException, InterruptedException { - ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); - Map> interactions = - ImmutableMap.>builder() - .put(0, Arrays.asList(chunk0)) - .put(1, Arrays.asList(chunk1)) - .put(2, Arrays.asList(chunk2)) - .build(); - - RetryingChunkClient client = performInteractions(interactions, callback); - - verify(callback, timeout(5000)).onSuccess(eq(0), eq(chunk0)); - verify(callback, timeout(5000)).onSuccess(eq(1), eq(chunk1)); - verify(callback, timeout(5000)).onSuccess(eq(2), eq(chunk2)); - verifyNoMoreInteractions(callback); - - assertEquals(0, client.getNumTries()); - assertEquals(masterLocation, client.getCurrentReplica().getLocation()); - } - - @Test - public void testUnrecoverableFailure() throws IOException, InterruptedException { - ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); - Map> interactions = - ImmutableMap.>builder() - .put(0, Arrays.asList(new RuntimeException("Ouch!"))) - .put(1, Arrays.asList(chunk1)) - .put(2, Arrays.asList(chunk2)) - .build(); - RetryingChunkClient client = performInteractions(interactions, callback); - - verify(callback, timeout(5000)).onFailure(eq(0), any()); - verify(callback, timeout(5000)).onSuccess(eq(1), eq(chunk1)); - verify(callback, timeout(5000)).onSuccess(eq(2), eq(chunk2)); - verifyNoMoreInteractions(callback); - - assertEquals(0, client.getNumTries()); - assertEquals(masterLocation, client.getCurrentReplica().getLocation()); - } - - @Test - public void testDuplicateSuccess() throws IOException, InterruptedException { - ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); - - Map> interactions = - ImmutableMap.>builder().put(0, Arrays.asList(chunk0, chunk1)).build(); - RetryingChunkClient client = performInteractions(interactions, callback); - verify(callback, timeout(5000)).onSuccess(eq(0), eq(chunk0)); - verifyNoMoreInteractions(callback); - - assertEquals(0, client.getNumTries()); - assertEquals(masterLocation, client.getCurrentReplica().getLocation()); - } - - @Test - public void testSingleIOException() throws IOException, InterruptedException { - Map result = new ConcurrentHashMap<>(); - ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); - Semaphore signal = new Semaphore(3); - signal.acquire(3); - - Answer answer = - invocation -> { - synchronized (signal) { - int chunkIndex = (Integer) invocation.getArguments()[0]; - assertFalse(result.containsKey(chunkIndex)); - Object value = invocation.getArguments()[1]; - result.put(chunkIndex, value); - signal.release(); - } - return null; - }; - doAnswer(answer).when(callback).onSuccess(anyInt(), anyObject()); - doAnswer(answer).when(callback).onFailure(anyInt(), anyObject()); - - Map> interactions = - ImmutableMap.>builder() - .put(0, Arrays.asList(new IOException(), chunk0)) - .put(1, Arrays.asList(chunk1)) - .put(2, Arrays.asList(chunk2)) - .build(); - RetryingChunkClient client = performInteractions(interactions, callback); - - while (!signal.tryAcquire(3, 500, TimeUnit.MILLISECONDS)) ; - assertEquals(1, client.getNumTries()); - assertEquals(slaveLocation, client.getCurrentReplica().getLocation()); - assertEquals(chunk0, result.get(0)); - assertEquals(chunk1, result.get(1)); - assertEquals(chunk2, result.get(2)); - } - - @Test - public void testTwoIOExceptions() throws IOException, InterruptedException { - Map result = new ConcurrentHashMap<>(); - ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); - Semaphore signal = new Semaphore(3); - signal.acquire(3); - - Answer answer = - invocation -> { - synchronized (signal) { - int chunkIndex = (Integer) invocation.getArguments()[0]; - assertFalse(result.containsKey(chunkIndex)); - Object value = invocation.getArguments()[1]; - result.put(chunkIndex, value); - signal.release(); - } - return null; - }; - doAnswer(answer).when(callback).onSuccess(anyInt(), anyObject()); - doAnswer(answer).when(callback).onFailure(anyInt(), anyObject()); - - Map> interactions = - ImmutableMap.>builder() - .put(0, Arrays.asList(new IOException("first ioexception"), chunk0)) - .put(1, Arrays.asList(new IOException("second ioexception"), chunk1)) - .put(2, Arrays.asList(chunk2)) - .build(); - RetryingChunkClient client = performInteractions(interactions, callback); - - while (!signal.tryAcquire(3, 500, TimeUnit.MILLISECONDS)) ; - assertEquals(1, client.getNumTries()); - assertEquals(slaveLocation, client.getCurrentReplica().getLocation()); - assertEquals(chunk0, result.get(0)); - assertEquals(chunk1, result.get(1)); - assertEquals(chunk2, result.get(2)); - } - - @Test - public void testThreeIOExceptions() throws IOException, InterruptedException { - Map result = new ConcurrentHashMap<>(); - ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); - Semaphore signal = new Semaphore(3); - signal.acquire(3); - - Answer answer = - invocation -> { - synchronized (signal) { - int chunkIndex = (Integer) invocation.getArguments()[0]; - assertFalse(result.containsKey(chunkIndex)); - Object value = invocation.getArguments()[1]; - result.put(chunkIndex, value); - signal.release(); - } - return null; - }; - doAnswer(answer).when(callback).onSuccess(anyInt(), anyObject()); - doAnswer(answer).when(callback).onFailure(anyInt(), anyObject()); - - Map> interactions = - ImmutableMap.>builder() - .put(0, Arrays.asList(new IOException("first ioexception"), chunk0)) - .put(1, Arrays.asList(new IOException("second ioexception"), chunk1)) - .put(2, Arrays.asList(new IOException("third ioexception"), chunk2)) - .build(); - - RetryingChunkClient client = performInteractions(interactions, callback); - while (!signal.tryAcquire(3, 500, TimeUnit.MILLISECONDS)) ; - assertEquals(1, client.getNumTries()); - assertEquals(slaveLocation, client.getCurrentReplica().getLocation()); - assertEquals(chunk0, result.get(0)); - assertEquals(chunk1, result.get(1)); - assertEquals(chunk2, result.get(2)); - } - - @Test - public void testFailedWithIOExceptions() throws IOException, InterruptedException { - Map result = new ConcurrentHashMap<>(); - ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); - Semaphore signal = new Semaphore(3); - signal.acquire(3); - - Answer answer = - invocation -> { - synchronized (signal) { - int chunkIndex = (Integer) invocation.getArguments()[0]; - assertFalse(result.containsKey(chunkIndex)); - Object value = invocation.getArguments()[1]; - result.put(chunkIndex, value); - signal.release(); - } - return null; - }; - doAnswer(answer).when(callback).onSuccess(anyInt(), anyObject()); - doAnswer(answer).when(callback).onFailure(anyInt(), anyObject()); - - IOException ioe = new IOException("failed exception"); - Map> interactions = - ImmutableMap.>builder() - .put(0, Arrays.asList(ioe, ioe, ioe, ioe, ioe, chunk0)) - .put(1, Arrays.asList(ioe, ioe, ioe, ioe, ioe, chunk1)) - .put(2, Arrays.asList(ioe, ioe, ioe, ioe, ioe, chunk2)) - .build(); - - RetryingChunkClient client = performInteractions(interactions, callback); - while (!signal.tryAcquire(3, 500, TimeUnit.MILLISECONDS)) ; - // Note: this may exceeds the max retries we want, but it doesn't master. - assertEquals(4, client.getNumTries()); - assertEquals(masterLocation, client.getCurrentReplica().getLocation()); - assertEquals(ioe, result.get(0)); - assertEquals(ioe, result.get(1)); - assertEquals(ioe, result.get(2)); - } - - @Test - public void testRetryAndUnrecoverable() throws IOException, InterruptedException { - Map result = new ConcurrentHashMap<>(); - ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); - Semaphore signal = new Semaphore(3); - signal.acquire(3); - - Answer answer = - invocation -> { - synchronized (signal) { - int chunkIndex = (Integer) invocation.getArguments()[0]; - assertFalse(result.containsKey(chunkIndex)); - Object value = invocation.getArguments()[1]; - result.put(chunkIndex, value); - signal.release(); - } - return null; - }; - doAnswer(answer).when(callback).onSuccess(anyInt(), anyObject()); - doAnswer(answer).when(callback).onFailure(anyInt(), anyObject()); - - Exception re = new RuntimeException("failed exception"); - Map> interactions = - ImmutableMap.>builder() - .put(0, Arrays.asList(new IOException("first ioexception"), re, chunk0)) - .put(1, Arrays.asList(chunk1)) - .put(2, Arrays.asList(new IOException("second ioexception"), chunk2)) - .build(); - - performInteractions(interactions, callback); - while (!signal.tryAcquire(3, 500, TimeUnit.MILLISECONDS)) ; - assertEquals(re, result.get(0)); - assertEquals(chunk1, result.get(1)); - assertEquals(chunk2, result.get(2)); - } - - private static RetryingChunkClient performInteractions( - Map> interactions, ChunkReceivedCallback callback) - throws IOException, InterruptedException { - CelebornConf conf = new CelebornConf(); - conf.set("celeborn.data.io.maxRetries", "1"); - conf.set("celeborn.data.io.retryWait", "0"); - - // Contains all chunk ids that are referenced across all interactions. - LinkedHashSet chunkIds = Sets.newLinkedHashSet(interactions.keySet()); - - final TransportClient client = new DummyTransportClient(chunkIds.size(), interactions); - final TransportClientFactory clientFactory = mock(TransportClientFactory.class); - doAnswer(invocation -> client).when(clientFactory).createClient(anyString(), anyInt()); - - RetryingChunkClient retryingChunkClient = - new RetryingChunkClient(conf, "test", masterLocation, callback, clientFactory); - chunkIds.stream().sorted().forEach(retryingChunkClient::fetchChunk); - return retryingChunkClient; - } - - private static class DummyTransportClient extends TransportClient { - - private static final Channel channel = mock(Channel.class); - private static final TransportResponseHandler handler = mock(TransportResponseHandler.class); - - private final long streamId = new Random().nextInt(Integer.MAX_VALUE) * 1000L; - private final int numChunks; - private final Map> interactions; - private final Map chunkIdToInterActionIndex; - - private final ScheduledExecutorService schedule = - ThreadUtils.newDaemonThreadPoolScheduledExecutor("test-fetch-chunk", 3); - - DummyTransportClient(int numChunks, Map> interactions) { - super(channel, handler); - this.numChunks = numChunks; - this.interactions = interactions; - this.chunkIdToInterActionIndex = new ConcurrentHashMap<>(); - interactions.keySet().forEach((chunkId) -> chunkIdToInterActionIndex.putIfAbsent(chunkId, 0)); - } - - @Override - public void fetchChunk(long streamId, int chunkId, ChunkReceivedCallback callback) { - schedule.schedule( - () -> { - Object action; - List interaction = interactions.get(chunkId); - synchronized (chunkIdToInterActionIndex) { - int index = chunkIdToInterActionIndex.get(chunkId); - assertTrue(index < interaction.size()); - action = interaction.get(index); - chunkIdToInterActionIndex.put(chunkId, index + 1); - } - - if (action instanceof ManagedBuffer) { - callback.onSuccess(chunkId, (ManagedBuffer) action); - } else if (action instanceof Exception) { - callback.onFailure(chunkId, (Exception) action); - } else { - fail("Can only handle ManagedBuffers and Exceptions, got " + action); - } - }, - 500, - TimeUnit.MILLISECONDS); - } - - @Override - public ByteBuffer sendRpcSync(ByteBuffer message, long timeoutMs) { - StreamHandle handle = new StreamHandle(streamId, numChunks); - return handle.toByteBuffer(); - } - - @Override - public void close() { - super.close(); - schedule.shutdownNow(); - } - } -} diff --git a/common/src/main/java/org/apache/celeborn/common/network/client/ChunkReceivedCallback.java b/common/src/main/java/org/apache/celeborn/common/network/client/ChunkReceivedCallback.java index 267a63df2..76017ad88 100644 --- a/common/src/main/java/org/apache/celeborn/common/network/client/ChunkReceivedCallback.java +++ b/common/src/main/java/org/apache/celeborn/common/network/client/ChunkReceivedCallback.java @@ -18,6 +18,7 @@ package org.apache.celeborn.common.network.client; import org.apache.celeborn.common.network.buffer.ManagedBuffer; +import org.apache.celeborn.common.protocol.PartitionLocation; /** * Callback for the result of a single chunk result. For a single stream, the callbacks are @@ -34,7 +35,7 @@ public interface ChunkReceivedCallback { * this call returns. You must therefore either retain() the buffer or copy its contents before * returning. */ - void onSuccess(int chunkIndex, ManagedBuffer buffer); + void onSuccess(int chunkIndex, ManagedBuffer buffer, PartitionLocation location); /** * Called upon failure to fetch a particular chunk. Note that this may actually be called due to @@ -43,5 +44,5 @@ public interface ChunkReceivedCallback { *

After receiving a failure, the stream may or may not be valid. The client should not assume * that the server's side of the stream has been closed. */ - void onFailure(int chunkIndex, Throwable e); + void onFailure(int chunkIndex, PartitionLocation location, Throwable e); } diff --git a/common/src/main/java/org/apache/celeborn/common/network/client/TransportClient.java b/common/src/main/java/org/apache/celeborn/common/network/client/TransportClient.java index e5e82db0b..e778449a1 100644 --- a/common/src/main/java/org/apache/celeborn/common/network/client/TransportClient.java +++ b/common/src/main/java/org/apache/celeborn/common/network/client/TransportClient.java @@ -126,7 +126,7 @@ public class TransportClient implements Closeable { @Override protected void handleFailure(String errorMsg, Throwable cause) { handler.removeFetchRequest(streamChunkSlice); - callback.onFailure(chunkIndex, new IOException(errorMsg, cause)); + callback.onFailure(chunkIndex, null, new IOException(errorMsg, cause)); } }; handler.addFetchRequest(streamChunkSlice, callback); diff --git a/common/src/main/java/org/apache/celeborn/common/network/client/TransportResponseHandler.java b/common/src/main/java/org/apache/celeborn/common/network/client/TransportResponseHandler.java index c55fbb3bd..1a9ab600c 100644 --- a/common/src/main/java/org/apache/celeborn/common/network/client/TransportResponseHandler.java +++ b/common/src/main/java/org/apache/celeborn/common/network/client/TransportResponseHandler.java @@ -86,7 +86,7 @@ public class TransportResponseHandler extends MessageHandler { private void failOutstandingRequests(Throwable cause) { for (Map.Entry entry : outstandingFetches.entrySet()) { try { - entry.getValue().onFailure(entry.getKey().chunkIndex, cause); + entry.getValue().onFailure(entry.getKey().chunkIndex, null, cause); } catch (Exception e) { logger.warn("ChunkReceivedCallback.onFailure throws exception", e); } @@ -144,7 +144,7 @@ public class TransportResponseHandler extends MessageHandler { resp.body().release(); } else { outstandingFetches.remove(resp.streamChunkSlice); - listener.onSuccess(resp.streamChunkSlice.chunkIndex, resp.body()); + listener.onSuccess(resp.streamChunkSlice.chunkIndex, resp.body(), null); resp.body().release(); } } else if (message instanceof ChunkFetchFailure) { @@ -161,6 +161,7 @@ public class TransportResponseHandler extends MessageHandler { logger.warn("Receive ChunkFetchFailure, errorMsg {}", resp.errorString); listener.onFailure( resp.streamChunkSlice.chunkIndex, + null, new ChunkFetchFailureException( "Failure while fetching " + resp.streamChunkSlice + ": " + resp.errorString)); } diff --git a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala index 7b31ab9b6..82b844115 100644 --- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala +++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala @@ -662,6 +662,7 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se def rpcCacheSize: Int = get(RPC_CACHE_SIZE) def rpcCacheConcurrencyLevel: Int = get(RPC_CACHE_CONCURRENCY_LEVEL) def rpcCacheExpireTime: Long = get(RPC_CACHE_EXPIRE_TIME) + def testFetchFailedChunkIndex: Int = get(TEST_FETCH_FAILED_CHUNK_INDEX) // ////////////////////////////////////////////////////// // Graceful Shutdown & Recover // @@ -2539,4 +2540,13 @@ object CelebornConf extends Logging { .doc("The time before a cache item is removed.") .timeConf(TimeUnit.MILLISECONDS) .createWithDefaultString("15s") + + val TEST_FETCH_FAILED_CHUNK_INDEX: ConfigEntry[Int] = + buildConf("celeborn.test.client.fetchFailedChuckIndex") + .categories("client") + .version("0.2.0") + .internal + .doc("The chunk index to trigger fetch chunk failure for testing purpose only.") + .intConf + .createWithDefault(0) } diff --git a/common/src/test/java/org/apache/celeborn/common/network/ChunkFetchIntegrationSuiteJ.java b/common/src/test/java/org/apache/celeborn/common/network/ChunkFetchIntegrationSuiteJ.java index dd8c90d5e..ec9722a8b 100644 --- a/common/src/test/java/org/apache/celeborn/common/network/ChunkFetchIntegrationSuiteJ.java +++ b/common/src/test/java/org/apache/celeborn/common/network/ChunkFetchIntegrationSuiteJ.java @@ -47,6 +47,7 @@ import org.apache.celeborn.common.network.server.BaseMessageHandler; import org.apache.celeborn.common.network.server.StreamManager; import org.apache.celeborn.common.network.server.TransportServer; import org.apache.celeborn.common.network.util.TransportConf; +import org.apache.celeborn.common.protocol.PartitionLocation; public class ChunkFetchIntegrationSuiteJ { static final long STREAM_ID = 1; @@ -153,7 +154,7 @@ public class ChunkFetchIntegrationSuiteJ { ChunkReceivedCallback callback = new ChunkReceivedCallback() { @Override - public void onSuccess(int chunkIndex, ManagedBuffer buffer) { + public void onSuccess(int chunkIndex, ManagedBuffer buffer, PartitionLocation location) { buffer.retain(); res.successChunks.add(chunkIndex); res.buffers.add(buffer); @@ -161,7 +162,7 @@ public class ChunkFetchIntegrationSuiteJ { } @Override - public void onFailure(int chunkIndex, Throwable e) { + public void onFailure(int chunkIndex, PartitionLocation location, Throwable e) { res.failedChunks.add(chunkIndex); sem.release(); } diff --git a/common/src/test/java/org/apache/celeborn/common/network/RequestTimeoutIntegrationSuiteJ.java b/common/src/test/java/org/apache/celeborn/common/network/RequestTimeoutIntegrationSuiteJ.java index 5933a5afb..2d66bbab0 100644 --- a/common/src/test/java/org/apache/celeborn/common/network/RequestTimeoutIntegrationSuiteJ.java +++ b/common/src/test/java/org/apache/celeborn/common/network/RequestTimeoutIntegrationSuiteJ.java @@ -42,6 +42,7 @@ import org.apache.celeborn.common.network.server.BaseMessageHandler; import org.apache.celeborn.common.network.server.StreamManager; import org.apache.celeborn.common.network.server.TransportServer; import org.apache.celeborn.common.network.util.TransportConf; +import org.apache.celeborn.common.protocol.PartitionLocation; /** * Suite which ensures that requests that go without a response for the network timeout period are @@ -267,7 +268,7 @@ public class RequestTimeoutIntegrationSuiteJ { } @Override - public void onSuccess(int chunkIndex, ManagedBuffer buffer) { + public void onSuccess(int chunkIndex, ManagedBuffer buffer, PartitionLocation location) { try { successLength = buffer.nioByteBuffer().remaining(); } catch (IOException e) { @@ -278,7 +279,7 @@ public class RequestTimeoutIntegrationSuiteJ { } @Override - public void onFailure(int chunkIndex, Throwable e) { + public void onFailure(int chunkIndex, PartitionLocation location, Throwable e) { failure = e; latch.countDown(); } diff --git a/common/src/test/java/org/apache/celeborn/common/network/TransportResponseHandlerSuiteJ.java b/common/src/test/java/org/apache/celeborn/common/network/TransportResponseHandlerSuiteJ.java index 16cacbd15..661c7e35d 100644 --- a/common/src/test/java/org/apache/celeborn/common/network/TransportResponseHandlerSuiteJ.java +++ b/common/src/test/java/org/apache/celeborn/common/network/TransportResponseHandlerSuiteJ.java @@ -42,7 +42,7 @@ public class TransportResponseHandlerSuiteJ { assertEquals(1, handler.numOutstandingRequests()); handler.handle(new ChunkFetchSuccess(streamChunkSlice, new TestManagedBuffer(123))); - verify(callback, times(1)).onSuccess(eq(0), any()); + verify(callback, times(1)).onSuccess(eq(0), any(), any()); assertEquals(0, handler.numOutstandingRequests()); } @@ -55,7 +55,7 @@ public class TransportResponseHandlerSuiteJ { assertEquals(1, handler.numOutstandingRequests()); handler.handle(new ChunkFetchFailure(streamChunkSlice, "some error msg")); - verify(callback, times(1)).onFailure(eq(0), any()); + verify(callback, times(1)).onFailure(eq(0), any(), any()); assertEquals(0, handler.numOutstandingRequests()); } @@ -72,9 +72,9 @@ public class TransportResponseHandlerSuiteJ { handler.exceptionCaught(new Exception("duh duh duhhhh")); // should fail both b2 and b3 - verify(callback, times(1)).onSuccess(eq(0), any()); - verify(callback, times(1)).onFailure(eq(1), any()); - verify(callback, times(1)).onFailure(eq(2), any()); + verify(callback, times(1)).onSuccess(eq(0), any(), any()); + verify(callback, times(1)).onFailure(eq(1), any(), any()); + verify(callback, times(1)).onFailure(eq(2), any(), any()); assertEquals(0, handler.numOutstandingRequests()); } diff --git a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/FileWriterSuiteJ.java b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/FileWriterSuiteJ.java index 9f7fa994d..fc97f704d 100644 --- a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/FileWriterSuiteJ.java +++ b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/FileWriterSuiteJ.java @@ -71,6 +71,7 @@ import org.apache.celeborn.common.network.server.MemoryTracker; import org.apache.celeborn.common.network.server.TransportServer; import org.apache.celeborn.common.network.util.JavaUtils; import org.apache.celeborn.common.network.util.TransportConf; +import org.apache.celeborn.common.protocol.PartitionLocation; import org.apache.celeborn.common.protocol.PartitionSplitMode; import org.apache.celeborn.common.protocol.PartitionType; import org.apache.celeborn.common.protocol.StorageInfo; @@ -209,7 +210,7 @@ public class FileWriterSuiteJ { ChunkReceivedCallback callback = new ChunkReceivedCallback() { @Override - public void onSuccess(int chunkIndex, ManagedBuffer buffer) { + public void onSuccess(int chunkIndex, ManagedBuffer buffer, PartitionLocation location) { buffer.retain(); res.successChunks.add(chunkIndex); res.buffers.add(buffer); @@ -217,7 +218,7 @@ public class FileWriterSuiteJ { } @Override - public void onFailure(int chunkIndex, Throwable e) { + public void onFailure(int chunkIndex, PartitionLocation location, Throwable e) { res.failedChunks.add(chunkIndex); sem.release(); }