From 9dc1bc2b1c5198ab6f5efeddeffb341b33e50e8e Mon Sep 17 00:00:00 2001 From: zhongqiangchen Date: Thu, 2 Mar 2023 18:50:38 +0800 Subject: [PATCH] [CELEBORN-367] [FLINK] Move pushdata functions used by mappartition from ShuffleClientImpl to FlinkShuffleClientImpl (#1295) --- client-flink/common/pom.xml | 5 + .../readclient/FlinkShuffleClientImpl.java | 417 ++++++++++++++++++ .../flink/FlinkShuffleClientImplSuiteJ.java | 81 ++-- .../plugin/flink/RemoteShuffleOutputGate.java | 9 +- .../flink/RemoteShuffleOutputGateSuiteJ.java | 4 +- .../apache/celeborn/client/ShuffleClient.java | 40 -- .../celeborn/client/ShuffleClientImpl.java | 413 +---------------- .../celeborn/client/DummyShuffleClient.java | 44 -- 8 files changed, 473 insertions(+), 540 deletions(-) rename client/src/test/java/org/apache/celeborn/client/ShuffleClientImplSuiteJ.java => client-flink/common/src/test/java/org/apache/celeborn/plugin/flink/FlinkShuffleClientImplSuiteJ.java (62%) diff --git a/client-flink/common/pom.xml b/client-flink/common/pom.xml index ec585e1dd..9cf960a2a 100644 --- a/client-flink/common/pom.xml +++ b/client-flink/common/pom.xml @@ -44,5 +44,10 @@ flink-runtime provided + + org.mockito + mockito-core + test + diff --git a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/FlinkShuffleClientImpl.java b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/FlinkShuffleClientImpl.java index 3de2bb1f7..36e52937b 100644 --- a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/FlinkShuffleClientImpl.java +++ b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/FlinkShuffleClientImpl.java @@ -18,18 +18,41 @@ package org.apache.celeborn.plugin.flink.readclient; import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import scala.reflect.ClassTag$; + +import com.google.common.util.concurrent.Uninterruptibles; +import io.netty.buffer.ByteBuf; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.celeborn.client.ShuffleClientImpl; import org.apache.celeborn.common.CelebornConf; +import org.apache.celeborn.common.exception.CelebornIOException; import org.apache.celeborn.common.identity.UserIdentifier; import org.apache.celeborn.common.network.TransportContext; +import org.apache.celeborn.common.network.buffer.NettyManagedBuffer; +import org.apache.celeborn.common.network.client.RpcResponseCallback; +import org.apache.celeborn.common.network.client.TransportClient; +import org.apache.celeborn.common.network.client.TransportClientFactory; +import org.apache.celeborn.common.network.protocol.PushData; +import org.apache.celeborn.common.network.protocol.PushDataHandShake; +import org.apache.celeborn.common.network.protocol.RegionFinish; +import org.apache.celeborn.common.network.protocol.RegionStart; import org.apache.celeborn.common.network.util.TransportConf; import org.apache.celeborn.common.protocol.PartitionLocation; +import org.apache.celeborn.common.protocol.PbChangeLocationResponse; import org.apache.celeborn.common.protocol.TransportModuleConstants; +import org.apache.celeborn.common.protocol.message.ControlMessages; +import org.apache.celeborn.common.protocol.message.StatusCode; +import org.apache.celeborn.common.util.PbSerDeUtils; import org.apache.celeborn.common.util.Utils; +import org.apache.celeborn.common.write.PushState; import org.apache.celeborn.plugin.flink.network.FlinkTransportClientFactory; import org.apache.celeborn.plugin.flink.network.ReadClientHandler; @@ -39,6 +62,7 @@ public class FlinkShuffleClientImpl extends ShuffleClientImpl { private static volatile boolean initialized = false; private FlinkTransportClientFactory flinkTransportClientFactory; private ReadClientHandler readClientHandler = new ReadClientHandler(); + private ConcurrentHashMap currentClient = new ConcurrentHashMap<>(); public static FlinkShuffleClientImpl get( String driverHost, int port, CelebornConf conf, UserIdentifier userIdentifier) { @@ -137,4 +161,397 @@ public class FlinkShuffleClientImpl extends ShuffleClientImpl { public ReadClientHandler getReadClientHandler() { return readClientHandler; } + + public int pushDataToLocation( + String applicationId, + int shuffleId, + int mapId, + int attemptId, + int partitionId, + ByteBuf data, + PartitionLocation location, + Runnable closeCallBack) + throws IOException { + // mapKey + final String mapKey = Utils.makeMapKey(shuffleId, mapId, attemptId); + final String shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId); + // return if shuffle stage already ended + if (mapperEnded(shuffleId, mapId, attemptId)) { + logger.info( + "Push data byteBuf to location {} ignored because mapper already ended for shuffle {} map {} attempt {}.", + location.hostAndPushPort(), + shuffleId, + mapId, + attemptId); + PushState pushState = pushStates.get(mapKey); + if (pushState != null) { + pushState.cleanup(); + } + return 0; + } + + PushState pushState = getPushState(mapKey); + + // increment batchId + final int nextBatchId = pushState.nextBatchId(); + int totalLength = data.readableBytes(); + data.markWriterIndex(); + data.writerIndex(0); + data.writeInt(partitionId); + data.writeInt(attemptId); + data.writeInt(nextBatchId); + data.writeInt(totalLength - BATCH_HEADER_SIZE); + data.resetWriterIndex(); + logger.debug( + "Do push data byteBuf size {} for app {} shuffle {} map {} attempt {} reduce {} batch {}.", + totalLength, + applicationId, + shuffleId, + mapId, + attemptId, + partitionId, + nextBatchId); + // check limit + limitMaxInFlight(mapKey, pushState, location.hostAndPushPort()); + + // add inFlight requests + pushState.addBatch(nextBatchId, location.hostAndPushPort()); + + // build PushData request + NettyManagedBuffer buffer = new NettyManagedBuffer(data); + PushData pushData = new PushData(MASTER_MODE, shuffleKey, location.getUniqueId(), buffer); + + // build callback + RpcResponseCallback callback = + new RpcResponseCallback() { + @Override + public void onSuccess(ByteBuffer response) { + pushState.removeBatch(nextBatchId, location.hostAndPushPort()); + if (response.remaining() > 0) { + byte reason = response.get(); + if (reason == StatusCode.STAGE_ENDED.getValue()) { + mapperEndMap + .computeIfAbsent(shuffleId, (id) -> ConcurrentHashMap.newKeySet()) + .add(mapKey); + } + } + logger.debug( + "Push data byteBuf to {} success for shuffle {} map {} attemptId {} batch {}.", + location.hostAndPushPort(), + shuffleId, + mapId, + attemptId, + nextBatchId); + } + + @Override + public void onFailure(Throwable e) { + pushState.removeBatch(nextBatchId, location.hostAndPushPort()); + if (pushState.exception.get() != null) { + return; + } + if (!mapperEnded(shuffleId, mapId, attemptId)) { + String errorMsg = + String.format( + "Push data byteBuf to %s failed for shuffle %d map %d attempt %d batch %d.", + location.hostAndPushPort(), shuffleId, mapId, attemptId, nextBatchId); + pushState.exception.compareAndSet(null, new CelebornIOException(errorMsg, e)); + } else { + logger.warn( + "Push data to {} failed but mapper already ended for shuffle {} map {} attempt {} batch {}.", + location.hostAndPushPort(), + shuffleId, + mapId, + attemptId, + nextBatchId); + } + } + }; + // do push data + try { + TransportClient client = createClientWaitingInFlightRequest(location, mapKey, pushState); + client.pushData(pushData, pushDataTimeout, callback, closeCallBack); + } catch (Exception e) { + logger.error( + "Exception raised while pushing data byteBuf for shuffle {} map {} attempt {} partitionId {} batch {} location {}.", + shuffleId, + mapId, + attemptId, + partitionId, + nextBatchId, + location, + e); + callback.onFailure( + new CelebornIOException(StatusCode.PUSH_DATA_CREATE_CONNECTION_FAIL_MASTER, e)); + } + return totalLength; + } + + private TransportClient createClientWaitingInFlightRequest( + PartitionLocation location, String mapKey, PushState pushState) + throws IOException, InterruptedException { + TransportClient client = + dataClientFactory.createClient( + location.getHost(), location.getPushPort(), location.getId()); + if (currentClient.get(mapKey) != client) { + // makesure that messages have been sent by old client, in order to keep receiving data + // orderly + if (currentClient.get(mapKey) != null) { + limitZeroInFlight(mapKey, pushState); + } + currentClient.put(mapKey, client); + } + return currentClient.get(mapKey); + } + + public void pushDataHandShake( + String applicationId, + int shuffleId, + int mapId, + int attemptId, + int numPartitions, + int bufferSize, + PartitionLocation location) + throws IOException { + final String mapKey = Utils.makeMapKey(shuffleId, mapId, attemptId); + final PushState pushState = pushStates.computeIfAbsent(mapKey, (s) -> new PushState(conf)); + sendMessageInternal( + shuffleId, + mapId, + attemptId, + location, + pushState, + () -> { + String shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId); + logger.info( + "PushDataHandShake shuffleKey {} attemptId {} locationId {}", + shuffleKey, + attemptId, + location.getUniqueId()); + logger.debug("PushDataHandShake location {}", location.toString()); + TransportClient client = createClientWaitingInFlightRequest(location, mapKey, pushState); + PushDataHandShake handShake = + new PushDataHandShake( + MASTER_MODE, + shuffleKey, + location.getUniqueId(), + attemptId, + numPartitions, + bufferSize); + client.sendRpcSync(handShake.toByteBuffer(), conf.pushDataTimeoutMs()); + return null; + }); + } + + public Optional regionStart( + String applicationId, + int shuffleId, + int mapId, + int attemptId, + PartitionLocation location, + int currentRegionIdx, + boolean isBroadcast) + throws IOException { + final String mapKey = Utils.makeMapKey(shuffleId, mapId, attemptId); + final PushState pushState = pushStates.computeIfAbsent(mapKey, (s) -> new PushState(conf)); + return sendMessageInternal( + shuffleId, + mapId, + attemptId, + location, + pushState, + () -> { + String shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId); + logger.info( + "RegionStart for shuffle {} regionId {} attemptId {} locationId {}.", + shuffleId, + currentRegionIdx, + attemptId, + location.getUniqueId()); + logger.debug("RegionStart for location {}.", location.toString()); + TransportClient client = createClientWaitingInFlightRequest(location, mapKey, pushState); + RegionStart regionStart = + new RegionStart( + MASTER_MODE, + shuffleKey, + location.getUniqueId(), + attemptId, + currentRegionIdx, + isBroadcast); + ByteBuffer regionStartResponse = + client.sendRpcSync(regionStart.toByteBuffer(), conf.pushDataTimeoutMs()); + if (regionStartResponse.hasRemaining() + && regionStartResponse.get() == StatusCode.HARD_SPLIT.getValue()) { + // if split then revive + PbChangeLocationResponse response = + driverRssMetaService.askSync( + ControlMessages.Revive$.MODULE$.apply( + applicationId, + shuffleId, + mapId, + attemptId, + location.getId(), + location.getEpoch(), + location, + StatusCode.HARD_SPLIT), + conf.requestPartitionLocationRpcAskTimeout(), + ClassTag$.MODULE$.apply(PbChangeLocationResponse.class)); + // per partitionKey only serve single PartitionLocation in Client Cache. + StatusCode respStatus = Utils.toStatusCode(response.getStatus()); + if (StatusCode.SUCCESS.equals(respStatus)) { + return Optional.of(PbSerDeUtils.fromPbPartitionLocation(response.getLocation())); + } else if (StatusCode.MAP_ENDED.equals(respStatus)) { + mapperEndMap + .computeIfAbsent(shuffleId, (id) -> ConcurrentHashMap.newKeySet()) + .add(mapKey); + return Optional.empty(); + } else { + // throw exception + logger.error( + "Exception raised while reviving for shuffle {} map {} attemptId {} partition {} epoch {}.", + shuffleId, + mapId, + attemptId, + location.getId(), + location.getEpoch()); + throw new CelebornIOException("RegionStart revive failed"); + } + } + return Optional.empty(); + }); + } + + public void regionFinish( + String applicationId, int shuffleId, int mapId, int attemptId, PartitionLocation location) + throws IOException { + final String mapKey = Utils.makeMapKey(shuffleId, mapId, attemptId); + final PushState pushState = pushStates.computeIfAbsent(mapKey, (s) -> new PushState(conf)); + sendMessageInternal( + shuffleId, + mapId, + attemptId, + location, + pushState, + () -> { + final String shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId); + logger.info( + "RegionFinish for shuffle {} map {} attemptId {} locationId {}.", + shuffleId, + mapId, + attemptId, + location.getUniqueId()); + logger.debug("RegionFinish for location {}.", location.toString()); + TransportClient client = createClientWaitingInFlightRequest(location, mapKey, pushState); + RegionFinish regionFinish = + new RegionFinish(MASTER_MODE, shuffleKey, location.getUniqueId(), attemptId); + client.sendRpcSync(regionFinish.toByteBuffer(), conf.pushDataTimeoutMs()); + return null; + }); + } + + private R sendMessageInternal( + int shuffleId, + int mapId, + int attemptId, + PartitionLocation location, + PushState pushState, + ThrowingExceptionSupplier supplier) + throws IOException { + int batchId = 0; + try { + // mapKey + final String mapKey = Utils.makeMapKey(shuffleId, mapId, attemptId); + // return if shuffle stage already ended + if (mapperEnded(shuffleId, mapId, attemptId)) { + logger.debug( + "Send message to {} ignored because mapper already ended for shuffle {} map {} attempt {}.", + location.hostAndPushPort(), + shuffleId, + mapId, + attemptId); + return null; + } + pushState = getPushState(mapKey); + // force data has been send + limitZeroInFlight(mapKey, pushState); + + // add inFlight requests + batchId = pushState.nextBatchId(); + pushState.addBatch(batchId, location.hostAndPushPort()); + return retrySendMessage(supplier); + } finally { + if (pushState != null) { + pushState.removeBatch(batchId, location.hostAndPushPort()); + } + } + } + + @FunctionalInterface + interface ThrowingExceptionSupplier { + R get() throws E; + } + + private R retrySendMessage(ThrowingExceptionSupplier supplier) + throws IOException { + + int retryTimes = 0; + boolean isSuccess = false; + Exception currentException = null; + R result = null; + while (!Thread.currentThread().isInterrupted() + && !isSuccess + && retryTimes < conf.networkIoMaxRetries(TransportModuleConstants.PUSH_MODULE)) { + logger.debug("RetrySendMessage retry times {}.", retryTimes); + try { + result = supplier.get(); + isSuccess = true; + } catch (Exception e) { + currentException = e; + if (e instanceof InterruptedException) { + Thread.currentThread().interrupt(); + } + if (shouldRetry(e)) { + retryTimes++; + Uninterruptibles.sleepUninterruptibly( + conf.networkIoRetryWaitMs(TransportModuleConstants.PUSH_MODULE), + TimeUnit.MILLISECONDS); + } else { + break; + } + } + } + if (!isSuccess) { + if (currentException instanceof IOException) { + throw (IOException) currentException; + } else { + throw new CelebornIOException(currentException.getMessage(), currentException); + } + } + return result; + } + + private boolean shouldRetry(Throwable e) { + boolean isIOException = + e instanceof IOException + || e instanceof TimeoutException + || (e.getCause() != null && e.getCause() instanceof TimeoutException) + || (e.getCause() != null && e.getCause() instanceof IOException) + || (e instanceof RuntimeException + && e.getMessage() != null + && e.getMessage().startsWith(IOException.class.getName())); + return isIOException; + } + + @Override + public void cleanup(String applicationId, int shuffleId, int mapId, int attemptId) { + final String mapKey = Utils.makeMapKey(shuffleId, mapId, attemptId); + super.cleanup(applicationId, shuffleId, mapId, attemptId); + if (currentClient != null) { + currentClient.remove(mapKey); + } + } + + public void setDataClientFactory(TransportClientFactory dataClientFactory) { + this.dataClientFactory = dataClientFactory; + } } diff --git a/client/src/test/java/org/apache/celeborn/client/ShuffleClientImplSuiteJ.java b/client-flink/common/src/test/java/org/apache/celeborn/plugin/flink/FlinkShuffleClientImplSuiteJ.java similarity index 62% rename from client/src/test/java/org/apache/celeborn/client/ShuffleClientImplSuiteJ.java rename to client-flink/common/src/test/java/org/apache/celeborn/plugin/flink/FlinkShuffleClientImplSuiteJ.java index d3b227848..414c9a352 100644 --- a/client/src/test/java/org/apache/celeborn/client/ShuffleClientImplSuiteJ.java +++ b/client-flink/common/src/test/java/org/apache/celeborn/plugin/flink/FlinkShuffleClientImplSuiteJ.java @@ -15,10 +15,9 @@ * limitations under the License. */ -package org.apache.celeborn.client; +package org.apache.celeborn.plugin.flink; -import static org.mockito.Matchers.any; -import static org.mockito.Matchers.anyLong; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; import java.io.IOException; @@ -26,27 +25,48 @@ import java.nio.ByteBuffer; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelFuture; import org.junit.Assert; import org.junit.Before; import org.junit.Test; +import org.mockito.Matchers; +import org.mockito.Mockito; import org.apache.celeborn.common.CelebornConf; import org.apache.celeborn.common.network.client.RpcResponseCallback; -import org.apache.celeborn.common.protocol.CompressionCodec; +import org.apache.celeborn.common.network.client.TransportClient; +import org.apache.celeborn.common.network.client.TransportClientFactory; +import org.apache.celeborn.common.protocol.PartitionLocation; import org.apache.celeborn.common.protocol.message.StatusCode; +import org.apache.celeborn.plugin.flink.readclient.FlinkShuffleClientImpl; -public class ShuffleClientImplSuiteJ extends ShuffleClientBaseSuiteJ { +public class FlinkShuffleClientImplSuiteJ { static int BufferSize = 64; static byte[] TEST_BUF1 = new byte[BufferSize]; + protected ChannelFuture mockedFuture = mock(ChannelFuture.class); static CelebornConf conf; + static FlinkShuffleClientImpl shuffleClient; + protected static final TransportClientFactory clientFactory = mock(TransportClientFactory.class); + protected final TransportClient client = mock(TransportClient.class); + protected static final PartitionLocation masterLocation = + new PartitionLocation(0, 1, "localhost", 1, 1, 1, 1, PartitionLocation.Mode.MASTER); @Before public void setup() throws IOException, InterruptedException { - conf = setupEnv(CompressionCodec.LZ4); + conf = new CelebornConf(); + shuffleClient = + new FlinkShuffleClientImpl("localhost", 1232, conf, null) { + @Override + public void setupMetaServiceRef(String host, int port) {} + }; + when(clientFactory.createClient(masterLocation.getHost(), masterLocation.getPushPort(), 1)) + .thenAnswer(t -> client); + + shuffleClient.setDataClientFactory(clientFactory); } public ByteBuf createByteBuf() { - for (int i = BATCH_HEADER_SIZE; i < BufferSize; i++) { + for (int i = 16; i < BufferSize; i++) { TEST_BUF1[i] = 1; } ByteBuf byteBuf = Unpooled.wrappedBuffer(TEST_BUF1); @@ -57,7 +77,7 @@ public class ShuffleClientImplSuiteJ extends ShuffleClientBaseSuiteJ { @Test public void testPushDataByteBufSuccess() throws IOException { ByteBuf byteBuf = createByteBuf(); - when(client.pushData(any(), anyLong(), any())) + Mockito.when(client.pushData(Matchers.any(), Matchers.anyLong(), Matchers.any())) .thenAnswer( t -> { RpcResponseCallback rpcResponseCallback = @@ -68,22 +88,14 @@ public class ShuffleClientImplSuiteJ extends ShuffleClientBaseSuiteJ { }); int pushDataLen = - shuffleClient.pushDataToLocation( - TEST_APPLICATION_ID, - TEST_SHUFFLE_ID, - TEST_ATTEMPT_ID, - TEST_ATTEMPT_ID, - TEST_REDUCRE_ID, - byteBuf, - masterLocation, - () -> {}); + shuffleClient.pushDataToLocation("1", 2, 3, 4, 5, byteBuf, masterLocation, () -> {}); Assert.assertEquals(BufferSize, pushDataLen); } @Test public void testPushDataByteBufHardSplit() throws IOException { ByteBuf byteBuf = Unpooled.wrappedBuffer(TEST_BUF1); - when(client.pushData(any(), anyLong(), any())) + Mockito.when(client.pushData(Matchers.any(), Matchers.anyLong(), Matchers.any())) .thenAnswer( t -> { RpcResponseCallback rpcResponseCallback = @@ -94,21 +106,14 @@ public class ShuffleClientImplSuiteJ extends ShuffleClientBaseSuiteJ { return mockedFuture; }); int pushDataLen = - shuffleClient.pushDataToLocation( - TEST_APPLICATION_ID, - TEST_SHUFFLE_ID, - TEST_ATTEMPT_ID, - TEST_ATTEMPT_ID, - TEST_REDUCRE_ID, - byteBuf, - masterLocation, - () -> {}); + shuffleClient.pushDataToLocation("1", 2, 3, 4, 5, byteBuf, masterLocation, () -> {}); } @Test public void testPushDataByteBufFail() throws IOException { ByteBuf byteBuf = Unpooled.wrappedBuffer(TEST_BUF1); - when(client.pushData(any(), anyLong(), any(), any())) + Mockito.when( + client.pushData(Matchers.any(), Matchers.anyLong(), Matchers.any(), Matchers.any())) .thenAnswer( t -> { RpcResponseCallback rpcResponseCallback = @@ -117,28 +122,12 @@ public class ShuffleClientImplSuiteJ extends ShuffleClientBaseSuiteJ { return mockedFuture; }); // first push just set pushdata.exception - shuffleClient.pushDataToLocation( - TEST_APPLICATION_ID, - TEST_SHUFFLE_ID, - TEST_ATTEMPT_ID, - TEST_ATTEMPT_ID, - TEST_REDUCRE_ID, - byteBuf, - masterLocation, - () -> {}); + shuffleClient.pushDataToLocation("1", 2, 3, 4, 5, byteBuf, masterLocation, () -> {}); boolean isFailed = false; // second push will throw exception try { - shuffleClient.pushDataToLocation( - TEST_APPLICATION_ID, - TEST_SHUFFLE_ID, - TEST_ATTEMPT_ID, - TEST_ATTEMPT_ID, - TEST_REDUCRE_ID, - byteBuf, - masterLocation, - () -> {}); + shuffleClient.pushDataToLocation("1", 2, 3, 4, 5, byteBuf, masterLocation, () -> {}); } catch (IOException e) { isFailed = true; } finally { diff --git a/client-flink/flink-1.14/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleOutputGate.java b/client-flink/flink-1.14/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleOutputGate.java index 247b22d9e..d0aa990d9 100644 --- a/client-flink/flink-1.14/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleOutputGate.java +++ b/client-flink/flink-1.14/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleOutputGate.java @@ -29,11 +29,11 @@ import org.apache.flink.util.function.SupplierWithException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.apache.celeborn.client.ShuffleClient; import org.apache.celeborn.common.CelebornConf; import org.apache.celeborn.common.identity.UserIdentifier; import org.apache.celeborn.common.protocol.PartitionLocation; import org.apache.celeborn.plugin.flink.buffer.BufferPacker; +import org.apache.celeborn.plugin.flink.readclient.FlinkShuffleClientImpl; import org.apache.celeborn.plugin.flink.utils.BufferUtils; import org.apache.celeborn.plugin.flink.utils.Utils; @@ -58,7 +58,7 @@ public class RemoteShuffleOutputGate { private static final Logger LOG = LoggerFactory.getLogger(RemoteShuffleOutputGate.class); private final RemoteShuffleDescriptor shuffleDesc; protected final int numSubs; - protected ShuffleClient shuffleWriteClient; + protected FlinkShuffleClientImpl shuffleWriteClient; protected final SupplierWithException bufferPoolFactory; protected BufferPool bufferPool; private CelebornConf celebornConf; @@ -212,8 +212,9 @@ public class RemoteShuffleOutputGate { } @VisibleForTesting - ShuffleClient createWriteClient() { - return ShuffleClient.get(rssMetaServiceHost, rssMetaServicePort, celebornConf, userIdentifier); + FlinkShuffleClientImpl createWriteClient() { + return FlinkShuffleClientImpl.get( + rssMetaServiceHost, rssMetaServicePort, celebornConf, userIdentifier); } /** Writes a piece of data to a subpartition. */ diff --git a/client-flink/flink-1.14/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleOutputGateSuiteJ.java b/client-flink/flink-1.14/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleOutputGateSuiteJ.java index 1faabf500..3dd9ec64f 100644 --- a/client-flink/flink-1.14/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleOutputGateSuiteJ.java +++ b/client-flink/flink-1.14/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleOutputGateSuiteJ.java @@ -36,12 +36,12 @@ import org.junit.Assert; import org.junit.Before; import org.junit.Test; -import org.apache.celeborn.client.ShuffleClientImpl; import org.apache.celeborn.common.protocol.PartitionLocation; +import org.apache.celeborn.plugin.flink.readclient.FlinkShuffleClientImpl; public class RemoteShuffleOutputGateSuiteJ { private RemoteShuffleOutputGate remoteShuffleOutputGate = mock(RemoteShuffleOutputGate.class); - private ShuffleClientImpl shuffleClient = mock(ShuffleClientImpl.class); + private FlinkShuffleClientImpl shuffleClient = mock(FlinkShuffleClientImpl.class); private static final int BUFFER_SIZE = 20; private NetworkBufferPool networkBufferPool; private BufferPool bufferPool; diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java index 3d34fcf2f..ad99a9878 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java @@ -18,10 +18,8 @@ package org.apache.celeborn.client; import java.io.IOException; -import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; -import io.netty.buffer.ByteBuf; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FileSystem; import org.slf4j.Logger; @@ -196,44 +194,6 @@ public abstract class ShuffleClient { public abstract void shutdown(); - // Write data to a specific map partition, input data's type is Bytebuf. - // data's type is Bytebuf to avoid copy between application and netty - // closecallback will do some clean operations like memory release. - public abstract int pushDataToLocation( - String applicationId, - int shuffleId, - int mapId, - int attemptId, - int partitionId, - ByteBuf data, - PartitionLocation location, - Runnable closeCallBack) - throws IOException; - - public abstract Optional regionStart( - String applicationId, - int shuffleId, - int mapId, - int attemptId, - PartitionLocation location, - int currentRegionIdx, - boolean isBroadcast) - throws IOException; - - public abstract void regionFinish( - String applicationId, int shuffleId, int mapId, int attemptId, PartitionLocation location) - throws IOException; - - public abstract void pushDataHandShake( - String applicationId, - int shuffleId, - int mapId, - int attemptId, - int numPartitions, - int bufferSize, - PartitionLocation location) - throws IOException; - public abstract PartitionLocation registerMapPartitionTask( String appId, int shuffleId, int numMappers, int mapId, int attemptId) throws IOException; diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java index 791794118..72f7ed8c0 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java @@ -24,13 +24,10 @@ import java.util.concurrent.Callable; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutorService; import java.util.concurrent.TimeUnit; -import java.util.concurrent.TimeoutException; import scala.reflect.ClassTag$; import com.google.common.annotations.VisibleForTesting; -import com.google.common.util.concurrent.Uninterruptibles; -import io.netty.buffer.ByteBuf; import io.netty.buffer.CompositeByteBuf; import io.netty.buffer.Unpooled; import org.slf4j.Logger; @@ -48,14 +45,10 @@ import org.apache.celeborn.common.network.client.RpcResponseCallback; import org.apache.celeborn.common.network.client.TransportClient; import org.apache.celeborn.common.network.client.TransportClientFactory; import org.apache.celeborn.common.network.protocol.PushData; -import org.apache.celeborn.common.network.protocol.PushDataHandShake; import org.apache.celeborn.common.network.protocol.PushMergedData; -import org.apache.celeborn.common.network.protocol.RegionFinish; -import org.apache.celeborn.common.network.protocol.RegionStart; import org.apache.celeborn.common.network.server.BaseMessageHandler; import org.apache.celeborn.common.network.util.TransportConf; import org.apache.celeborn.common.protocol.*; -import org.apache.celeborn.common.protocol.message.ControlMessages; import org.apache.celeborn.common.protocol.message.ControlMessages.*; import org.apache.celeborn.common.protocol.message.StatusCode; import org.apache.celeborn.common.rpc.RpcAddress; @@ -72,7 +65,7 @@ import org.apache.celeborn.common.write.PushState; public class ShuffleClientImpl extends ShuffleClient { private static final Logger logger = LoggerFactory.getLogger(ShuffleClientImpl.class); - private static final byte MASTER_MODE = PartitionLocation.Mode.MASTER.mode(); + protected static final byte MASTER_MODE = PartitionLocation.Mode.MASTER.mode(); private static final Random RND = new Random(); @@ -85,24 +78,24 @@ public class ShuffleClientImpl extends ShuffleClient { private int maxReviveTimes; private boolean testRetryRevive; private final int pushBufferMaxSize; - private final long pushDataTimeout; + protected final long pushDataTimeout; private final RpcEnv rpcEnv; - private RpcEndpointRef driverRssMetaService; + protected RpcEndpointRef driverRssMetaService; protected TransportClientFactory dataClientFactory; - final int BATCH_HEADER_SIZE = 4 * 4; + protected final int BATCH_HEADER_SIZE = 4 * 4; // key: shuffleId, value: (partitionId, PartitionLocation) private final Map> reducePartitionMap = new ConcurrentHashMap<>(); - private final ConcurrentHashMap> mapperEndMap = new ConcurrentHashMap<>(); + protected final ConcurrentHashMap> mapperEndMap = new ConcurrentHashMap<>(); // key: shuffleId-mapId-attemptId - private final Map pushStates = new ConcurrentHashMap<>(); + protected final Map pushStates = new ConcurrentHashMap<>(); private final ExecutorService pushDataRetryPool; @@ -141,8 +134,6 @@ public class ShuffleClientImpl extends ShuffleClient { // key: shuffleId protected final Map reduceFileGroupsMap = new ConcurrentHashMap<>(); - private ConcurrentHashMap currentClient = new ConcurrentHashMap<>(); - public ShuffleClientImpl(CelebornConf conf, UserIdentifier userIdentifier) { super(); this.conf = conf; @@ -460,7 +451,7 @@ public class ShuffleClientImpl extends ShuffleClient { return null; } - private void limitMaxInFlight(String mapKey, PushState pushState, String hostAndPushPort) + protected void limitMaxInFlight(String mapKey, PushState pushState, String hostAndPushPort) throws IOException { boolean reachLimit = pushState.limitMaxInFlight(hostAndPushPort); @@ -470,7 +461,7 @@ public class ShuffleClientImpl extends ShuffleClient { } } - private void limitZeroInFlight(String mapKey, PushState pushState) throws IOException { + protected void limitZeroInFlight(String mapKey, PushState pushState) throws IOException { boolean reachLimit = pushState.limitZeroInFlight(); if (reachLimit) { @@ -1296,9 +1287,6 @@ public class ShuffleClientImpl extends ShuffleClient { pushState.exception.compareAndSet(null, new CelebornIOException("Cleaned Up")); pushState.cleanup(); } - if (currentClient != null) { - currentClient.remove(mapKey); - } } @Override @@ -1462,7 +1450,7 @@ public class ShuffleClientImpl extends ShuffleClient { driverRssMetaService = endpointRef; } - private boolean mapperEnded(int shuffleId, int mapId, int attemptId) { + protected boolean mapperEnded(int shuffleId, int mapId, int attemptId) { return mapperEndMap.containsKey(shuffleId) && mapperEndMap.get(shuffleId).contains(Utils.makeMapKey(shuffleId, mapId, attemptId)); } @@ -1502,387 +1490,4 @@ public class ShuffleClientImpl extends ShuffleClient { || (message.equals("Connection reset by peer")) || (message.startsWith("Failed to send RPC ")); } - - public int pushDataToLocation( - String applicationId, - int shuffleId, - int mapId, - int attemptId, - int partitionId, - ByteBuf data, - PartitionLocation location, - Runnable closeCallBack) - throws IOException { - // mapKey - final String mapKey = Utils.makeMapKey(shuffleId, mapId, attemptId); - final String shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId); - // return if shuffle stage already ended - if (mapperEnded(shuffleId, mapId, attemptId)) { - logger.info( - "Push data byteBuf to location {} ignored because mapper already ended for shuffle {} map {} attempt {}.", - location.hostAndPushPort(), - shuffleId, - mapId, - attemptId); - PushState pushState = pushStates.get(mapKey); - if (pushState != null) { - pushState.cleanup(); - } - return 0; - } - - PushState pushState = getPushState(mapKey); - - // increment batchId - final int nextBatchId = pushState.nextBatchId(); - int totalLength = data.readableBytes(); - data.markWriterIndex(); - data.writerIndex(0); - data.writeInt(partitionId); - data.writeInt(attemptId); - data.writeInt(nextBatchId); - data.writeInt(totalLength - BATCH_HEADER_SIZE); - data.resetWriterIndex(); - logger.debug( - "Do push data byteBuf size {} for app {} shuffle {} map {} attempt {} reduce {} batch {}.", - totalLength, - applicationId, - shuffleId, - mapId, - attemptId, - partitionId, - nextBatchId); - // check limit - limitMaxInFlight(mapKey, pushState, location.hostAndPushPort()); - - // add inFlight requests - pushState.addBatch(nextBatchId, location.hostAndPushPort()); - - // build PushData request - NettyManagedBuffer buffer = new NettyManagedBuffer(data); - PushData pushData = new PushData(MASTER_MODE, shuffleKey, location.getUniqueId(), buffer); - - // build callback - RpcResponseCallback callback = - new RpcResponseCallback() { - @Override - public void onSuccess(ByteBuffer response) { - pushState.removeBatch(nextBatchId, location.hostAndPushPort()); - if (response.remaining() > 0) { - byte reason = response.get(); - if (reason == StatusCode.STAGE_ENDED.getValue()) { - mapperEndMap - .computeIfAbsent(shuffleId, (id) -> ConcurrentHashMap.newKeySet()) - .add(mapKey); - } - } - logger.debug( - "Push data byteBuf to {} success for shuffle {} map {} attemptId {} batch {}.", - location.hostAndPushPort(), - shuffleId, - mapId, - attemptId, - nextBatchId); - } - - @Override - public void onFailure(Throwable e) { - pushState.removeBatch(nextBatchId, location.hostAndPushPort()); - if (pushState.exception.get() != null) { - return; - } - if (!mapperEnded(shuffleId, mapId, attemptId)) { - String errorMsg = - String.format( - "Push data byteBuf to %s failed for shuffle %d map %d attempt %d batch %d.", - location.hostAndPushPort(), shuffleId, mapId, attemptId, nextBatchId); - pushState.exception.compareAndSet(null, new CelebornIOException(errorMsg, e)); - } else { - logger.warn( - "Push data to {} failed but mapper already ended for shuffle {} map {} attempt {} batch {}.", - location.hostAndPushPort(), - shuffleId, - mapId, - attemptId, - nextBatchId); - } - } - }; - // do push data - try { - TransportClient client = createClientWaitingInFlightRequest(location, mapKey, pushState); - client.pushData(pushData, pushDataTimeout, callback, closeCallBack); - } catch (Exception e) { - logger.error( - "Exception raised while pushing data byteBuf for shuffle {} map {} attempt {} partitionId {} batch {} location {}.", - shuffleId, - mapId, - attemptId, - partitionId, - nextBatchId, - location, - e); - callback.onFailure( - new CelebornIOException(StatusCode.PUSH_DATA_CREATE_CONNECTION_FAIL_MASTER, e)); - } - return totalLength; - } - - private TransportClient createClientWaitingInFlightRequest( - PartitionLocation location, String mapKey, PushState pushState) - throws IOException, InterruptedException { - TransportClient client = - dataClientFactory.createClient( - location.getHost(), location.getPushPort(), location.getId()); - if (currentClient.get(mapKey) != client) { - // makesure that messages have been sent by old client, in order to keep receiving data - // orderly - if (currentClient.get(mapKey) != null) { - limitZeroInFlight(mapKey, pushState); - } - currentClient.put(mapKey, client); - } - return currentClient.get(mapKey); - } - - @Override - public void pushDataHandShake( - String applicationId, - int shuffleId, - int mapId, - int attemptId, - int numPartitions, - int bufferSize, - PartitionLocation location) - throws IOException { - final String mapKey = Utils.makeMapKey(shuffleId, mapId, attemptId); - final PushState pushState = pushStates.computeIfAbsent(mapKey, (s) -> new PushState(conf)); - sendMessageInternal( - shuffleId, - mapId, - attemptId, - location, - pushState, - () -> { - String shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId); - logger.info( - "PushDataHandShake shuffleKey {} attemptId {} locationId {}", - shuffleKey, - attemptId, - location.getUniqueId()); - logger.debug("PushDataHandShake location {}", location.toString()); - TransportClient client = createClientWaitingInFlightRequest(location, mapKey, pushState); - PushDataHandShake handShake = - new PushDataHandShake( - MASTER_MODE, - shuffleKey, - location.getUniqueId(), - attemptId, - numPartitions, - bufferSize); - client.sendRpcSync(handShake.toByteBuffer(), conf.pushDataTimeoutMs()); - return null; - }); - } - - @Override - public Optional regionStart( - String applicationId, - int shuffleId, - int mapId, - int attemptId, - PartitionLocation location, - int currentRegionIdx, - boolean isBroadcast) - throws IOException { - final String mapKey = Utils.makeMapKey(shuffleId, mapId, attemptId); - final PushState pushState = pushStates.computeIfAbsent(mapKey, (s) -> new PushState(conf)); - return sendMessageInternal( - shuffleId, - mapId, - attemptId, - location, - pushState, - () -> { - String shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId); - logger.info( - "RegionStart for shuffle {} regionId {} attemptId {} locationId {}.", - shuffleId, - currentRegionIdx, - attemptId, - location.getUniqueId()); - logger.debug("RegionStart for location {}.", location.toString()); - TransportClient client = createClientWaitingInFlightRequest(location, mapKey, pushState); - RegionStart regionStart = - new RegionStart( - MASTER_MODE, - shuffleKey, - location.getUniqueId(), - attemptId, - currentRegionIdx, - isBroadcast); - ByteBuffer regionStartResponse = - client.sendRpcSync(regionStart.toByteBuffer(), conf.pushDataTimeoutMs()); - if (regionStartResponse.hasRemaining() - && regionStartResponse.get() == StatusCode.HARD_SPLIT.getValue()) { - // if split then revive - PbChangeLocationResponse response = - driverRssMetaService.askSync( - ControlMessages.Revive$.MODULE$.apply( - applicationId, - shuffleId, - mapId, - attemptId, - location.getId(), - location.getEpoch(), - location, - StatusCode.HARD_SPLIT), - conf.requestPartitionLocationRpcAskTimeout(), - ClassTag$.MODULE$.apply(PbChangeLocationResponse.class)); - // per partitionKey only serve single PartitionLocation in Client Cache. - StatusCode respStatus = Utils.toStatusCode(response.getStatus()); - if (StatusCode.SUCCESS.equals(respStatus)) { - return Optional.of(PbSerDeUtils.fromPbPartitionLocation(response.getLocation())); - } else if (StatusCode.MAP_ENDED.equals(respStatus)) { - mapperEndMap - .computeIfAbsent(shuffleId, (id) -> ConcurrentHashMap.newKeySet()) - .add(mapKey); - return Optional.empty(); - } else { - // throw exception - logger.error( - "Exception raised while reviving for shuffle {} map {} attemptId {} partition {} epoch {}.", - shuffleId, - mapId, - attemptId, - location.getId(), - location.getEpoch()); - throw new CelebornIOException("RegionStart revive failed"); - } - } - return Optional.empty(); - }); - } - - @Override - public void regionFinish( - String applicationId, int shuffleId, int mapId, int attemptId, PartitionLocation location) - throws IOException { - final String mapKey = Utils.makeMapKey(shuffleId, mapId, attemptId); - final PushState pushState = pushStates.computeIfAbsent(mapKey, (s) -> new PushState(conf)); - sendMessageInternal( - shuffleId, - mapId, - attemptId, - location, - pushState, - () -> { - final String shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId); - logger.info( - "RegionFinish for shuffle {} map {} attemptId {} locationId {}.", - shuffleId, - mapId, - attemptId, - location.getUniqueId()); - logger.debug("RegionFinish for location {}.", location.toString()); - TransportClient client = createClientWaitingInFlightRequest(location, mapKey, pushState); - RegionFinish regionFinish = - new RegionFinish(MASTER_MODE, shuffleKey, location.getUniqueId(), attemptId); - client.sendRpcSync(regionFinish.toByteBuffer(), conf.pushDataTimeoutMs()); - return null; - }); - } - - private R sendMessageInternal( - int shuffleId, - int mapId, - int attemptId, - PartitionLocation location, - PushState pushState, - ThrowingExceptionSupplier supplier) - throws IOException { - int batchId = 0; - try { - // mapKey - final String mapKey = Utils.makeMapKey(shuffleId, mapId, attemptId); - // return if shuffle stage already ended - if (mapperEnded(shuffleId, mapId, attemptId)) { - logger.debug( - "Send message to {} ignored because mapper already ended for shuffle {} map {} attempt {}.", - location.hostAndPushPort(), - shuffleId, - mapId, - attemptId); - return null; - } - pushState = getPushState(mapKey); - // force data has been send - limitZeroInFlight(mapKey, pushState); - - // add inFlight requests - batchId = pushState.nextBatchId(); - pushState.addBatch(batchId, location.hostAndPushPort()); - return retrySendMessage(supplier); - } finally { - if (pushState != null) { - pushState.removeBatch(batchId, location.hostAndPushPort()); - } - } - } - - @FunctionalInterface - interface ThrowingExceptionSupplier { - R get() throws E; - } - - private R retrySendMessage(ThrowingExceptionSupplier supplier) - throws IOException { - - int retryTimes = 0; - boolean isSuccess = false; - Exception currentException = null; - R result = null; - while (!Thread.currentThread().isInterrupted() - && !isSuccess - && retryTimes < conf.networkIoMaxRetries(TransportModuleConstants.PUSH_MODULE)) { - logger.debug("RetrySendMessage retry times {}.", retryTimes); - try { - result = supplier.get(); - isSuccess = true; - } catch (Exception e) { - currentException = e; - if (e instanceof InterruptedException) { - Thread.currentThread().interrupt(); - } - if (shouldRetry(e)) { - retryTimes++; - Uninterruptibles.sleepUninterruptibly( - conf.networkIoRetryWaitMs(TransportModuleConstants.PUSH_MODULE), - TimeUnit.MILLISECONDS); - } else { - break; - } - } - } - if (!isSuccess) { - if (currentException instanceof IOException) { - throw (IOException) currentException; - } else { - throw new CelebornIOException(currentException.getMessage(), currentException); - } - } - return result; - } - - private boolean shouldRetry(Throwable e) { - boolean isIOException = - e instanceof IOException - || e instanceof TimeoutException - || (e.getCause() != null && e.getCause() instanceof TimeoutException) - || (e.getCause() != null && e.getCause() instanceof IOException) - || (e instanceof RuntimeException - && e.getMessage() != null - && e.getMessage().startsWith(IOException.class.getName())); - return isIOException; - } } diff --git a/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java b/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java index 37b72be08..088dc5f21 100644 --- a/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java +++ b/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java @@ -28,10 +28,8 @@ import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; -import io.netty.buffer.ByteBuf; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -150,48 +148,6 @@ public class DummyShuffleClient extends ShuffleClient { } } - @Override - public int pushDataToLocation( - String applicationId, - int shuffleId, - int mapId, - int attemptId, - int partitionId, - ByteBuf data, - PartitionLocation location, - Runnable closeCallBack) { - return 0; - } - - @Override - public Optional regionStart( - String applicationId, - int shuffleId, - int mapId, - int attemptId, - PartitionLocation location, - int currentRegionIdx, - boolean isBroadcast) - throws IOException { - return Optional.empty(); - } - - @Override - public void regionFinish( - String applicationId, int shuffleId, int mapId, int attemptId, PartitionLocation location) - throws IOException {} - - @Override - public void pushDataHandShake( - String applicationId, - int shuffleId, - int mapId, - int attemptId, - int numPartitions, - int bufferSize, - PartitionLocation location) - throws IOException {} - @Override public PartitionLocation registerMapPartitionTask( String appId, int shuffleId, int numMappers, int mapId, int attemptId) {