[CELEBORN-367] [FLINK] Move pushdata functions used by mappartition from ShuffleClientImpl to FlinkShuffleClientImpl (#1295)

This commit is contained in:
zhongqiangchen 2023-03-02 18:50:38 +08:00 committed by GitHub
parent dcedf7b0a9
commit 9dc1bc2b1c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 473 additions and 540 deletions

View File

@ -44,5 +44,10 @@
<artifactId>flink-runtime</artifactId>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
</project>

View File

@ -18,18 +18,41 @@
package org.apache.celeborn.plugin.flink.readclient;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import scala.reflect.ClassTag$;
import com.google.common.util.concurrent.Uninterruptibles;
import io.netty.buffer.ByteBuf;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.celeborn.client.ShuffleClientImpl;
import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.exception.CelebornIOException;
import org.apache.celeborn.common.identity.UserIdentifier;
import org.apache.celeborn.common.network.TransportContext;
import org.apache.celeborn.common.network.buffer.NettyManagedBuffer;
import org.apache.celeborn.common.network.client.RpcResponseCallback;
import org.apache.celeborn.common.network.client.TransportClient;
import org.apache.celeborn.common.network.client.TransportClientFactory;
import org.apache.celeborn.common.network.protocol.PushData;
import org.apache.celeborn.common.network.protocol.PushDataHandShake;
import org.apache.celeborn.common.network.protocol.RegionFinish;
import org.apache.celeborn.common.network.protocol.RegionStart;
import org.apache.celeborn.common.network.util.TransportConf;
import org.apache.celeborn.common.protocol.PartitionLocation;
import org.apache.celeborn.common.protocol.PbChangeLocationResponse;
import org.apache.celeborn.common.protocol.TransportModuleConstants;
import org.apache.celeborn.common.protocol.message.ControlMessages;
import org.apache.celeborn.common.protocol.message.StatusCode;
import org.apache.celeborn.common.util.PbSerDeUtils;
import org.apache.celeborn.common.util.Utils;
import org.apache.celeborn.common.write.PushState;
import org.apache.celeborn.plugin.flink.network.FlinkTransportClientFactory;
import org.apache.celeborn.plugin.flink.network.ReadClientHandler;
@ -39,6 +62,7 @@ public class FlinkShuffleClientImpl extends ShuffleClientImpl {
private static volatile boolean initialized = false;
private FlinkTransportClientFactory flinkTransportClientFactory;
private ReadClientHandler readClientHandler = new ReadClientHandler();
private ConcurrentHashMap<String, TransportClient> currentClient = new ConcurrentHashMap<>();
public static FlinkShuffleClientImpl get(
String driverHost, int port, CelebornConf conf, UserIdentifier userIdentifier) {
@ -137,4 +161,397 @@ public class FlinkShuffleClientImpl extends ShuffleClientImpl {
public ReadClientHandler getReadClientHandler() {
return readClientHandler;
}
public int pushDataToLocation(
String applicationId,
int shuffleId,
int mapId,
int attemptId,
int partitionId,
ByteBuf data,
PartitionLocation location,
Runnable closeCallBack)
throws IOException {
// mapKey
final String mapKey = Utils.makeMapKey(shuffleId, mapId, attemptId);
final String shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId);
// return if shuffle stage already ended
if (mapperEnded(shuffleId, mapId, attemptId)) {
logger.info(
"Push data byteBuf to location {} ignored because mapper already ended for shuffle {} map {} attempt {}.",
location.hostAndPushPort(),
shuffleId,
mapId,
attemptId);
PushState pushState = pushStates.get(mapKey);
if (pushState != null) {
pushState.cleanup();
}
return 0;
}
PushState pushState = getPushState(mapKey);
// increment batchId
final int nextBatchId = pushState.nextBatchId();
int totalLength = data.readableBytes();
data.markWriterIndex();
data.writerIndex(0);
data.writeInt(partitionId);
data.writeInt(attemptId);
data.writeInt(nextBatchId);
data.writeInt(totalLength - BATCH_HEADER_SIZE);
data.resetWriterIndex();
logger.debug(
"Do push data byteBuf size {} for app {} shuffle {} map {} attempt {} reduce {} batch {}.",
totalLength,
applicationId,
shuffleId,
mapId,
attemptId,
partitionId,
nextBatchId);
// check limit
limitMaxInFlight(mapKey, pushState, location.hostAndPushPort());
// add inFlight requests
pushState.addBatch(nextBatchId, location.hostAndPushPort());
// build PushData request
NettyManagedBuffer buffer = new NettyManagedBuffer(data);
PushData pushData = new PushData(MASTER_MODE, shuffleKey, location.getUniqueId(), buffer);
// build callback
RpcResponseCallback callback =
new RpcResponseCallback() {
@Override
public void onSuccess(ByteBuffer response) {
pushState.removeBatch(nextBatchId, location.hostAndPushPort());
if (response.remaining() > 0) {
byte reason = response.get();
if (reason == StatusCode.STAGE_ENDED.getValue()) {
mapperEndMap
.computeIfAbsent(shuffleId, (id) -> ConcurrentHashMap.newKeySet())
.add(mapKey);
}
}
logger.debug(
"Push data byteBuf to {} success for shuffle {} map {} attemptId {} batch {}.",
location.hostAndPushPort(),
shuffleId,
mapId,
attemptId,
nextBatchId);
}
@Override
public void onFailure(Throwable e) {
pushState.removeBatch(nextBatchId, location.hostAndPushPort());
if (pushState.exception.get() != null) {
return;
}
if (!mapperEnded(shuffleId, mapId, attemptId)) {
String errorMsg =
String.format(
"Push data byteBuf to %s failed for shuffle %d map %d attempt %d batch %d.",
location.hostAndPushPort(), shuffleId, mapId, attemptId, nextBatchId);
pushState.exception.compareAndSet(null, new CelebornIOException(errorMsg, e));
} else {
logger.warn(
"Push data to {} failed but mapper already ended for shuffle {} map {} attempt {} batch {}.",
location.hostAndPushPort(),
shuffleId,
mapId,
attemptId,
nextBatchId);
}
}
};
// do push data
try {
TransportClient client = createClientWaitingInFlightRequest(location, mapKey, pushState);
client.pushData(pushData, pushDataTimeout, callback, closeCallBack);
} catch (Exception e) {
logger.error(
"Exception raised while pushing data byteBuf for shuffle {} map {} attempt {} partitionId {} batch {} location {}.",
shuffleId,
mapId,
attemptId,
partitionId,
nextBatchId,
location,
e);
callback.onFailure(
new CelebornIOException(StatusCode.PUSH_DATA_CREATE_CONNECTION_FAIL_MASTER, e));
}
return totalLength;
}
private TransportClient createClientWaitingInFlightRequest(
PartitionLocation location, String mapKey, PushState pushState)
throws IOException, InterruptedException {
TransportClient client =
dataClientFactory.createClient(
location.getHost(), location.getPushPort(), location.getId());
if (currentClient.get(mapKey) != client) {
// makesure that messages have been sent by old client, in order to keep receiving data
// orderly
if (currentClient.get(mapKey) != null) {
limitZeroInFlight(mapKey, pushState);
}
currentClient.put(mapKey, client);
}
return currentClient.get(mapKey);
}
public void pushDataHandShake(
String applicationId,
int shuffleId,
int mapId,
int attemptId,
int numPartitions,
int bufferSize,
PartitionLocation location)
throws IOException {
final String mapKey = Utils.makeMapKey(shuffleId, mapId, attemptId);
final PushState pushState = pushStates.computeIfAbsent(mapKey, (s) -> new PushState(conf));
sendMessageInternal(
shuffleId,
mapId,
attemptId,
location,
pushState,
() -> {
String shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId);
logger.info(
"PushDataHandShake shuffleKey {} attemptId {} locationId {}",
shuffleKey,
attemptId,
location.getUniqueId());
logger.debug("PushDataHandShake location {}", location.toString());
TransportClient client = createClientWaitingInFlightRequest(location, mapKey, pushState);
PushDataHandShake handShake =
new PushDataHandShake(
MASTER_MODE,
shuffleKey,
location.getUniqueId(),
attemptId,
numPartitions,
bufferSize);
client.sendRpcSync(handShake.toByteBuffer(), conf.pushDataTimeoutMs());
return null;
});
}
public Optional<PartitionLocation> regionStart(
String applicationId,
int shuffleId,
int mapId,
int attemptId,
PartitionLocation location,
int currentRegionIdx,
boolean isBroadcast)
throws IOException {
final String mapKey = Utils.makeMapKey(shuffleId, mapId, attemptId);
final PushState pushState = pushStates.computeIfAbsent(mapKey, (s) -> new PushState(conf));
return sendMessageInternal(
shuffleId,
mapId,
attemptId,
location,
pushState,
() -> {
String shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId);
logger.info(
"RegionStart for shuffle {} regionId {} attemptId {} locationId {}.",
shuffleId,
currentRegionIdx,
attemptId,
location.getUniqueId());
logger.debug("RegionStart for location {}.", location.toString());
TransportClient client = createClientWaitingInFlightRequest(location, mapKey, pushState);
RegionStart regionStart =
new RegionStart(
MASTER_MODE,
shuffleKey,
location.getUniqueId(),
attemptId,
currentRegionIdx,
isBroadcast);
ByteBuffer regionStartResponse =
client.sendRpcSync(regionStart.toByteBuffer(), conf.pushDataTimeoutMs());
if (regionStartResponse.hasRemaining()
&& regionStartResponse.get() == StatusCode.HARD_SPLIT.getValue()) {
// if split then revive
PbChangeLocationResponse response =
driverRssMetaService.askSync(
ControlMessages.Revive$.MODULE$.apply(
applicationId,
shuffleId,
mapId,
attemptId,
location.getId(),
location.getEpoch(),
location,
StatusCode.HARD_SPLIT),
conf.requestPartitionLocationRpcAskTimeout(),
ClassTag$.MODULE$.apply(PbChangeLocationResponse.class));
// per partitionKey only serve single PartitionLocation in Client Cache.
StatusCode respStatus = Utils.toStatusCode(response.getStatus());
if (StatusCode.SUCCESS.equals(respStatus)) {
return Optional.of(PbSerDeUtils.fromPbPartitionLocation(response.getLocation()));
} else if (StatusCode.MAP_ENDED.equals(respStatus)) {
mapperEndMap
.computeIfAbsent(shuffleId, (id) -> ConcurrentHashMap.newKeySet())
.add(mapKey);
return Optional.empty();
} else {
// throw exception
logger.error(
"Exception raised while reviving for shuffle {} map {} attemptId {} partition {} epoch {}.",
shuffleId,
mapId,
attemptId,
location.getId(),
location.getEpoch());
throw new CelebornIOException("RegionStart revive failed");
}
}
return Optional.empty();
});
}
public void regionFinish(
String applicationId, int shuffleId, int mapId, int attemptId, PartitionLocation location)
throws IOException {
final String mapKey = Utils.makeMapKey(shuffleId, mapId, attemptId);
final PushState pushState = pushStates.computeIfAbsent(mapKey, (s) -> new PushState(conf));
sendMessageInternal(
shuffleId,
mapId,
attemptId,
location,
pushState,
() -> {
final String shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId);
logger.info(
"RegionFinish for shuffle {} map {} attemptId {} locationId {}.",
shuffleId,
mapId,
attemptId,
location.getUniqueId());
logger.debug("RegionFinish for location {}.", location.toString());
TransportClient client = createClientWaitingInFlightRequest(location, mapKey, pushState);
RegionFinish regionFinish =
new RegionFinish(MASTER_MODE, shuffleKey, location.getUniqueId(), attemptId);
client.sendRpcSync(regionFinish.toByteBuffer(), conf.pushDataTimeoutMs());
return null;
});
}
private <R> R sendMessageInternal(
int shuffleId,
int mapId,
int attemptId,
PartitionLocation location,
PushState pushState,
ThrowingExceptionSupplier<R, Exception> supplier)
throws IOException {
int batchId = 0;
try {
// mapKey
final String mapKey = Utils.makeMapKey(shuffleId, mapId, attemptId);
// return if shuffle stage already ended
if (mapperEnded(shuffleId, mapId, attemptId)) {
logger.debug(
"Send message to {} ignored because mapper already ended for shuffle {} map {} attempt {}.",
location.hostAndPushPort(),
shuffleId,
mapId,
attemptId);
return null;
}
pushState = getPushState(mapKey);
// force data has been send
limitZeroInFlight(mapKey, pushState);
// add inFlight requests
batchId = pushState.nextBatchId();
pushState.addBatch(batchId, location.hostAndPushPort());
return retrySendMessage(supplier);
} finally {
if (pushState != null) {
pushState.removeBatch(batchId, location.hostAndPushPort());
}
}
}
@FunctionalInterface
interface ThrowingExceptionSupplier<R, E extends Exception> {
R get() throws E;
}
private <R> R retrySendMessage(ThrowingExceptionSupplier<R, Exception> supplier)
throws IOException {
int retryTimes = 0;
boolean isSuccess = false;
Exception currentException = null;
R result = null;
while (!Thread.currentThread().isInterrupted()
&& !isSuccess
&& retryTimes < conf.networkIoMaxRetries(TransportModuleConstants.PUSH_MODULE)) {
logger.debug("RetrySendMessage retry times {}.", retryTimes);
try {
result = supplier.get();
isSuccess = true;
} catch (Exception e) {
currentException = e;
if (e instanceof InterruptedException) {
Thread.currentThread().interrupt();
}
if (shouldRetry(e)) {
retryTimes++;
Uninterruptibles.sleepUninterruptibly(
conf.networkIoRetryWaitMs(TransportModuleConstants.PUSH_MODULE),
TimeUnit.MILLISECONDS);
} else {
break;
}
}
}
if (!isSuccess) {
if (currentException instanceof IOException) {
throw (IOException) currentException;
} else {
throw new CelebornIOException(currentException.getMessage(), currentException);
}
}
return result;
}
private boolean shouldRetry(Throwable e) {
boolean isIOException =
e instanceof IOException
|| e instanceof TimeoutException
|| (e.getCause() != null && e.getCause() instanceof TimeoutException)
|| (e.getCause() != null && e.getCause() instanceof IOException)
|| (e instanceof RuntimeException
&& e.getMessage() != null
&& e.getMessage().startsWith(IOException.class.getName()));
return isIOException;
}
@Override
public void cleanup(String applicationId, int shuffleId, int mapId, int attemptId) {
final String mapKey = Utils.makeMapKey(shuffleId, mapId, attemptId);
super.cleanup(applicationId, shuffleId, mapId, attemptId);
if (currentClient != null) {
currentClient.remove(mapKey);
}
}
public void setDataClientFactory(TransportClientFactory dataClientFactory) {
this.dataClientFactory = dataClientFactory;
}
}

View File

@ -15,10 +15,9 @@
* limitations under the License.
*/
package org.apache.celeborn.client;
package org.apache.celeborn.plugin.flink;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyLong;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import java.io.IOException;
@ -26,27 +25,48 @@ import java.nio.ByteBuffer;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelFuture;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Matchers;
import org.mockito.Mockito;
import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.network.client.RpcResponseCallback;
import org.apache.celeborn.common.protocol.CompressionCodec;
import org.apache.celeborn.common.network.client.TransportClient;
import org.apache.celeborn.common.network.client.TransportClientFactory;
import org.apache.celeborn.common.protocol.PartitionLocation;
import org.apache.celeborn.common.protocol.message.StatusCode;
import org.apache.celeborn.plugin.flink.readclient.FlinkShuffleClientImpl;
public class ShuffleClientImplSuiteJ extends ShuffleClientBaseSuiteJ {
public class FlinkShuffleClientImplSuiteJ {
static int BufferSize = 64;
static byte[] TEST_BUF1 = new byte[BufferSize];
protected ChannelFuture mockedFuture = mock(ChannelFuture.class);
static CelebornConf conf;
static FlinkShuffleClientImpl shuffleClient;
protected static final TransportClientFactory clientFactory = mock(TransportClientFactory.class);
protected final TransportClient client = mock(TransportClient.class);
protected static final PartitionLocation masterLocation =
new PartitionLocation(0, 1, "localhost", 1, 1, 1, 1, PartitionLocation.Mode.MASTER);
@Before
public void setup() throws IOException, InterruptedException {
conf = setupEnv(CompressionCodec.LZ4);
conf = new CelebornConf();
shuffleClient =
new FlinkShuffleClientImpl("localhost", 1232, conf, null) {
@Override
public void setupMetaServiceRef(String host, int port) {}
};
when(clientFactory.createClient(masterLocation.getHost(), masterLocation.getPushPort(), 1))
.thenAnswer(t -> client);
shuffleClient.setDataClientFactory(clientFactory);
}
public ByteBuf createByteBuf() {
for (int i = BATCH_HEADER_SIZE; i < BufferSize; i++) {
for (int i = 16; i < BufferSize; i++) {
TEST_BUF1[i] = 1;
}
ByteBuf byteBuf = Unpooled.wrappedBuffer(TEST_BUF1);
@ -57,7 +77,7 @@ public class ShuffleClientImplSuiteJ extends ShuffleClientBaseSuiteJ {
@Test
public void testPushDataByteBufSuccess() throws IOException {
ByteBuf byteBuf = createByteBuf();
when(client.pushData(any(), anyLong(), any()))
Mockito.when(client.pushData(Matchers.any(), Matchers.anyLong(), Matchers.any()))
.thenAnswer(
t -> {
RpcResponseCallback rpcResponseCallback =
@ -68,22 +88,14 @@ public class ShuffleClientImplSuiteJ extends ShuffleClientBaseSuiteJ {
});
int pushDataLen =
shuffleClient.pushDataToLocation(
TEST_APPLICATION_ID,
TEST_SHUFFLE_ID,
TEST_ATTEMPT_ID,
TEST_ATTEMPT_ID,
TEST_REDUCRE_ID,
byteBuf,
masterLocation,
() -> {});
shuffleClient.pushDataToLocation("1", 2, 3, 4, 5, byteBuf, masterLocation, () -> {});
Assert.assertEquals(BufferSize, pushDataLen);
}
@Test
public void testPushDataByteBufHardSplit() throws IOException {
ByteBuf byteBuf = Unpooled.wrappedBuffer(TEST_BUF1);
when(client.pushData(any(), anyLong(), any()))
Mockito.when(client.pushData(Matchers.any(), Matchers.anyLong(), Matchers.any()))
.thenAnswer(
t -> {
RpcResponseCallback rpcResponseCallback =
@ -94,21 +106,14 @@ public class ShuffleClientImplSuiteJ extends ShuffleClientBaseSuiteJ {
return mockedFuture;
});
int pushDataLen =
shuffleClient.pushDataToLocation(
TEST_APPLICATION_ID,
TEST_SHUFFLE_ID,
TEST_ATTEMPT_ID,
TEST_ATTEMPT_ID,
TEST_REDUCRE_ID,
byteBuf,
masterLocation,
() -> {});
shuffleClient.pushDataToLocation("1", 2, 3, 4, 5, byteBuf, masterLocation, () -> {});
}
@Test
public void testPushDataByteBufFail() throws IOException {
ByteBuf byteBuf = Unpooled.wrappedBuffer(TEST_BUF1);
when(client.pushData(any(), anyLong(), any(), any()))
Mockito.when(
client.pushData(Matchers.any(), Matchers.anyLong(), Matchers.any(), Matchers.any()))
.thenAnswer(
t -> {
RpcResponseCallback rpcResponseCallback =
@ -117,28 +122,12 @@ public class ShuffleClientImplSuiteJ extends ShuffleClientBaseSuiteJ {
return mockedFuture;
});
// first push just set pushdata.exception
shuffleClient.pushDataToLocation(
TEST_APPLICATION_ID,
TEST_SHUFFLE_ID,
TEST_ATTEMPT_ID,
TEST_ATTEMPT_ID,
TEST_REDUCRE_ID,
byteBuf,
masterLocation,
() -> {});
shuffleClient.pushDataToLocation("1", 2, 3, 4, 5, byteBuf, masterLocation, () -> {});
boolean isFailed = false;
// second push will throw exception
try {
shuffleClient.pushDataToLocation(
TEST_APPLICATION_ID,
TEST_SHUFFLE_ID,
TEST_ATTEMPT_ID,
TEST_ATTEMPT_ID,
TEST_REDUCRE_ID,
byteBuf,
masterLocation,
() -> {});
shuffleClient.pushDataToLocation("1", 2, 3, 4, 5, byteBuf, masterLocation, () -> {});
} catch (IOException e) {
isFailed = true;
} finally {

View File

@ -29,11 +29,11 @@ import org.apache.flink.util.function.SupplierWithException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.celeborn.client.ShuffleClient;
import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.identity.UserIdentifier;
import org.apache.celeborn.common.protocol.PartitionLocation;
import org.apache.celeborn.plugin.flink.buffer.BufferPacker;
import org.apache.celeborn.plugin.flink.readclient.FlinkShuffleClientImpl;
import org.apache.celeborn.plugin.flink.utils.BufferUtils;
import org.apache.celeborn.plugin.flink.utils.Utils;
@ -58,7 +58,7 @@ public class RemoteShuffleOutputGate {
private static final Logger LOG = LoggerFactory.getLogger(RemoteShuffleOutputGate.class);
private final RemoteShuffleDescriptor shuffleDesc;
protected final int numSubs;
protected ShuffleClient shuffleWriteClient;
protected FlinkShuffleClientImpl shuffleWriteClient;
protected final SupplierWithException<BufferPool, IOException> bufferPoolFactory;
protected BufferPool bufferPool;
private CelebornConf celebornConf;
@ -212,8 +212,9 @@ public class RemoteShuffleOutputGate {
}
@VisibleForTesting
ShuffleClient createWriteClient() {
return ShuffleClient.get(rssMetaServiceHost, rssMetaServicePort, celebornConf, userIdentifier);
FlinkShuffleClientImpl createWriteClient() {
return FlinkShuffleClientImpl.get(
rssMetaServiceHost, rssMetaServicePort, celebornConf, userIdentifier);
}
/** Writes a piece of data to a subpartition. */

View File

@ -36,12 +36,12 @@ import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.apache.celeborn.client.ShuffleClientImpl;
import org.apache.celeborn.common.protocol.PartitionLocation;
import org.apache.celeborn.plugin.flink.readclient.FlinkShuffleClientImpl;
public class RemoteShuffleOutputGateSuiteJ {
private RemoteShuffleOutputGate remoteShuffleOutputGate = mock(RemoteShuffleOutputGate.class);
private ShuffleClientImpl shuffleClient = mock(ShuffleClientImpl.class);
private FlinkShuffleClientImpl shuffleClient = mock(FlinkShuffleClientImpl.class);
private static final int BUFFER_SIZE = 20;
private NetworkBufferPool networkBufferPool;
private BufferPool bufferPool;

View File

@ -18,10 +18,8 @@
package org.apache.celeborn.client;
import java.io.IOException;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import io.netty.buffer.ByteBuf;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.slf4j.Logger;
@ -196,44 +194,6 @@ public abstract class ShuffleClient {
public abstract void shutdown();
// Write data to a specific map partition, input data's type is Bytebuf.
// data's type is Bytebuf to avoid copy between application and netty
// closecallback will do some clean operations like memory release.
public abstract int pushDataToLocation(
String applicationId,
int shuffleId,
int mapId,
int attemptId,
int partitionId,
ByteBuf data,
PartitionLocation location,
Runnable closeCallBack)
throws IOException;
public abstract Optional<PartitionLocation> regionStart(
String applicationId,
int shuffleId,
int mapId,
int attemptId,
PartitionLocation location,
int currentRegionIdx,
boolean isBroadcast)
throws IOException;
public abstract void regionFinish(
String applicationId, int shuffleId, int mapId, int attemptId, PartitionLocation location)
throws IOException;
public abstract void pushDataHandShake(
String applicationId,
int shuffleId,
int mapId,
int attemptId,
int numPartitions,
int bufferSize,
PartitionLocation location)
throws IOException;
public abstract PartitionLocation registerMapPartitionTask(
String appId, int shuffleId, int numMappers, int mapId, int attemptId) throws IOException;

View File

@ -24,13 +24,10 @@ import java.util.concurrent.Callable;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import scala.reflect.ClassTag$;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.util.concurrent.Uninterruptibles;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.CompositeByteBuf;
import io.netty.buffer.Unpooled;
import org.slf4j.Logger;
@ -48,14 +45,10 @@ import org.apache.celeborn.common.network.client.RpcResponseCallback;
import org.apache.celeborn.common.network.client.TransportClient;
import org.apache.celeborn.common.network.client.TransportClientFactory;
import org.apache.celeborn.common.network.protocol.PushData;
import org.apache.celeborn.common.network.protocol.PushDataHandShake;
import org.apache.celeborn.common.network.protocol.PushMergedData;
import org.apache.celeborn.common.network.protocol.RegionFinish;
import org.apache.celeborn.common.network.protocol.RegionStart;
import org.apache.celeborn.common.network.server.BaseMessageHandler;
import org.apache.celeborn.common.network.util.TransportConf;
import org.apache.celeborn.common.protocol.*;
import org.apache.celeborn.common.protocol.message.ControlMessages;
import org.apache.celeborn.common.protocol.message.ControlMessages.*;
import org.apache.celeborn.common.protocol.message.StatusCode;
import org.apache.celeborn.common.rpc.RpcAddress;
@ -72,7 +65,7 @@ import org.apache.celeborn.common.write.PushState;
public class ShuffleClientImpl extends ShuffleClient {
private static final Logger logger = LoggerFactory.getLogger(ShuffleClientImpl.class);
private static final byte MASTER_MODE = PartitionLocation.Mode.MASTER.mode();
protected static final byte MASTER_MODE = PartitionLocation.Mode.MASTER.mode();
private static final Random RND = new Random();
@ -85,24 +78,24 @@ public class ShuffleClientImpl extends ShuffleClient {
private int maxReviveTimes;
private boolean testRetryRevive;
private final int pushBufferMaxSize;
private final long pushDataTimeout;
protected final long pushDataTimeout;
private final RpcEnv rpcEnv;
private RpcEndpointRef driverRssMetaService;
protected RpcEndpointRef driverRssMetaService;
protected TransportClientFactory dataClientFactory;
final int BATCH_HEADER_SIZE = 4 * 4;
protected final int BATCH_HEADER_SIZE = 4 * 4;
// key: shuffleId, value: (partitionId, PartitionLocation)
private final Map<Integer, ConcurrentHashMap<Integer, PartitionLocation>> reducePartitionMap =
new ConcurrentHashMap<>();
private final ConcurrentHashMap<Integer, Set<String>> mapperEndMap = new ConcurrentHashMap<>();
protected final ConcurrentHashMap<Integer, Set<String>> mapperEndMap = new ConcurrentHashMap<>();
// key: shuffleId-mapId-attemptId
private final Map<String, PushState> pushStates = new ConcurrentHashMap<>();
protected final Map<String, PushState> pushStates = new ConcurrentHashMap<>();
private final ExecutorService pushDataRetryPool;
@ -141,8 +134,6 @@ public class ShuffleClientImpl extends ShuffleClient {
// key: shuffleId
protected final Map<Integer, ReduceFileGroups> reduceFileGroupsMap = new ConcurrentHashMap<>();
private ConcurrentHashMap<String, TransportClient> currentClient = new ConcurrentHashMap<>();
public ShuffleClientImpl(CelebornConf conf, UserIdentifier userIdentifier) {
super();
this.conf = conf;
@ -460,7 +451,7 @@ public class ShuffleClientImpl extends ShuffleClient {
return null;
}
private void limitMaxInFlight(String mapKey, PushState pushState, String hostAndPushPort)
protected void limitMaxInFlight(String mapKey, PushState pushState, String hostAndPushPort)
throws IOException {
boolean reachLimit = pushState.limitMaxInFlight(hostAndPushPort);
@ -470,7 +461,7 @@ public class ShuffleClientImpl extends ShuffleClient {
}
}
private void limitZeroInFlight(String mapKey, PushState pushState) throws IOException {
protected void limitZeroInFlight(String mapKey, PushState pushState) throws IOException {
boolean reachLimit = pushState.limitZeroInFlight();
if (reachLimit) {
@ -1296,9 +1287,6 @@ public class ShuffleClientImpl extends ShuffleClient {
pushState.exception.compareAndSet(null, new CelebornIOException("Cleaned Up"));
pushState.cleanup();
}
if (currentClient != null) {
currentClient.remove(mapKey);
}
}
@Override
@ -1462,7 +1450,7 @@ public class ShuffleClientImpl extends ShuffleClient {
driverRssMetaService = endpointRef;
}
private boolean mapperEnded(int shuffleId, int mapId, int attemptId) {
protected boolean mapperEnded(int shuffleId, int mapId, int attemptId) {
return mapperEndMap.containsKey(shuffleId)
&& mapperEndMap.get(shuffleId).contains(Utils.makeMapKey(shuffleId, mapId, attemptId));
}
@ -1502,387 +1490,4 @@ public class ShuffleClientImpl extends ShuffleClient {
|| (message.equals("Connection reset by peer"))
|| (message.startsWith("Failed to send RPC "));
}
public int pushDataToLocation(
String applicationId,
int shuffleId,
int mapId,
int attemptId,
int partitionId,
ByteBuf data,
PartitionLocation location,
Runnable closeCallBack)
throws IOException {
// mapKey
final String mapKey = Utils.makeMapKey(shuffleId, mapId, attemptId);
final String shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId);
// return if shuffle stage already ended
if (mapperEnded(shuffleId, mapId, attemptId)) {
logger.info(
"Push data byteBuf to location {} ignored because mapper already ended for shuffle {} map {} attempt {}.",
location.hostAndPushPort(),
shuffleId,
mapId,
attemptId);
PushState pushState = pushStates.get(mapKey);
if (pushState != null) {
pushState.cleanup();
}
return 0;
}
PushState pushState = getPushState(mapKey);
// increment batchId
final int nextBatchId = pushState.nextBatchId();
int totalLength = data.readableBytes();
data.markWriterIndex();
data.writerIndex(0);
data.writeInt(partitionId);
data.writeInt(attemptId);
data.writeInt(nextBatchId);
data.writeInt(totalLength - BATCH_HEADER_SIZE);
data.resetWriterIndex();
logger.debug(
"Do push data byteBuf size {} for app {} shuffle {} map {} attempt {} reduce {} batch {}.",
totalLength,
applicationId,
shuffleId,
mapId,
attemptId,
partitionId,
nextBatchId);
// check limit
limitMaxInFlight(mapKey, pushState, location.hostAndPushPort());
// add inFlight requests
pushState.addBatch(nextBatchId, location.hostAndPushPort());
// build PushData request
NettyManagedBuffer buffer = new NettyManagedBuffer(data);
PushData pushData = new PushData(MASTER_MODE, shuffleKey, location.getUniqueId(), buffer);
// build callback
RpcResponseCallback callback =
new RpcResponseCallback() {
@Override
public void onSuccess(ByteBuffer response) {
pushState.removeBatch(nextBatchId, location.hostAndPushPort());
if (response.remaining() > 0) {
byte reason = response.get();
if (reason == StatusCode.STAGE_ENDED.getValue()) {
mapperEndMap
.computeIfAbsent(shuffleId, (id) -> ConcurrentHashMap.newKeySet())
.add(mapKey);
}
}
logger.debug(
"Push data byteBuf to {} success for shuffle {} map {} attemptId {} batch {}.",
location.hostAndPushPort(),
shuffleId,
mapId,
attemptId,
nextBatchId);
}
@Override
public void onFailure(Throwable e) {
pushState.removeBatch(nextBatchId, location.hostAndPushPort());
if (pushState.exception.get() != null) {
return;
}
if (!mapperEnded(shuffleId, mapId, attemptId)) {
String errorMsg =
String.format(
"Push data byteBuf to %s failed for shuffle %d map %d attempt %d batch %d.",
location.hostAndPushPort(), shuffleId, mapId, attemptId, nextBatchId);
pushState.exception.compareAndSet(null, new CelebornIOException(errorMsg, e));
} else {
logger.warn(
"Push data to {} failed but mapper already ended for shuffle {} map {} attempt {} batch {}.",
location.hostAndPushPort(),
shuffleId,
mapId,
attemptId,
nextBatchId);
}
}
};
// do push data
try {
TransportClient client = createClientWaitingInFlightRequest(location, mapKey, pushState);
client.pushData(pushData, pushDataTimeout, callback, closeCallBack);
} catch (Exception e) {
logger.error(
"Exception raised while pushing data byteBuf for shuffle {} map {} attempt {} partitionId {} batch {} location {}.",
shuffleId,
mapId,
attemptId,
partitionId,
nextBatchId,
location,
e);
callback.onFailure(
new CelebornIOException(StatusCode.PUSH_DATA_CREATE_CONNECTION_FAIL_MASTER, e));
}
return totalLength;
}
private TransportClient createClientWaitingInFlightRequest(
PartitionLocation location, String mapKey, PushState pushState)
throws IOException, InterruptedException {
TransportClient client =
dataClientFactory.createClient(
location.getHost(), location.getPushPort(), location.getId());
if (currentClient.get(mapKey) != client) {
// makesure that messages have been sent by old client, in order to keep receiving data
// orderly
if (currentClient.get(mapKey) != null) {
limitZeroInFlight(mapKey, pushState);
}
currentClient.put(mapKey, client);
}
return currentClient.get(mapKey);
}
@Override
public void pushDataHandShake(
String applicationId,
int shuffleId,
int mapId,
int attemptId,
int numPartitions,
int bufferSize,
PartitionLocation location)
throws IOException {
final String mapKey = Utils.makeMapKey(shuffleId, mapId, attemptId);
final PushState pushState = pushStates.computeIfAbsent(mapKey, (s) -> new PushState(conf));
sendMessageInternal(
shuffleId,
mapId,
attemptId,
location,
pushState,
() -> {
String shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId);
logger.info(
"PushDataHandShake shuffleKey {} attemptId {} locationId {}",
shuffleKey,
attemptId,
location.getUniqueId());
logger.debug("PushDataHandShake location {}", location.toString());
TransportClient client = createClientWaitingInFlightRequest(location, mapKey, pushState);
PushDataHandShake handShake =
new PushDataHandShake(
MASTER_MODE,
shuffleKey,
location.getUniqueId(),
attemptId,
numPartitions,
bufferSize);
client.sendRpcSync(handShake.toByteBuffer(), conf.pushDataTimeoutMs());
return null;
});
}
@Override
public Optional<PartitionLocation> regionStart(
String applicationId,
int shuffleId,
int mapId,
int attemptId,
PartitionLocation location,
int currentRegionIdx,
boolean isBroadcast)
throws IOException {
final String mapKey = Utils.makeMapKey(shuffleId, mapId, attemptId);
final PushState pushState = pushStates.computeIfAbsent(mapKey, (s) -> new PushState(conf));
return sendMessageInternal(
shuffleId,
mapId,
attemptId,
location,
pushState,
() -> {
String shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId);
logger.info(
"RegionStart for shuffle {} regionId {} attemptId {} locationId {}.",
shuffleId,
currentRegionIdx,
attemptId,
location.getUniqueId());
logger.debug("RegionStart for location {}.", location.toString());
TransportClient client = createClientWaitingInFlightRequest(location, mapKey, pushState);
RegionStart regionStart =
new RegionStart(
MASTER_MODE,
shuffleKey,
location.getUniqueId(),
attemptId,
currentRegionIdx,
isBroadcast);
ByteBuffer regionStartResponse =
client.sendRpcSync(regionStart.toByteBuffer(), conf.pushDataTimeoutMs());
if (regionStartResponse.hasRemaining()
&& regionStartResponse.get() == StatusCode.HARD_SPLIT.getValue()) {
// if split then revive
PbChangeLocationResponse response =
driverRssMetaService.askSync(
ControlMessages.Revive$.MODULE$.apply(
applicationId,
shuffleId,
mapId,
attemptId,
location.getId(),
location.getEpoch(),
location,
StatusCode.HARD_SPLIT),
conf.requestPartitionLocationRpcAskTimeout(),
ClassTag$.MODULE$.apply(PbChangeLocationResponse.class));
// per partitionKey only serve single PartitionLocation in Client Cache.
StatusCode respStatus = Utils.toStatusCode(response.getStatus());
if (StatusCode.SUCCESS.equals(respStatus)) {
return Optional.of(PbSerDeUtils.fromPbPartitionLocation(response.getLocation()));
} else if (StatusCode.MAP_ENDED.equals(respStatus)) {
mapperEndMap
.computeIfAbsent(shuffleId, (id) -> ConcurrentHashMap.newKeySet())
.add(mapKey);
return Optional.empty();
} else {
// throw exception
logger.error(
"Exception raised while reviving for shuffle {} map {} attemptId {} partition {} epoch {}.",
shuffleId,
mapId,
attemptId,
location.getId(),
location.getEpoch());
throw new CelebornIOException("RegionStart revive failed");
}
}
return Optional.empty();
});
}
@Override
public void regionFinish(
String applicationId, int shuffleId, int mapId, int attemptId, PartitionLocation location)
throws IOException {
final String mapKey = Utils.makeMapKey(shuffleId, mapId, attemptId);
final PushState pushState = pushStates.computeIfAbsent(mapKey, (s) -> new PushState(conf));
sendMessageInternal(
shuffleId,
mapId,
attemptId,
location,
pushState,
() -> {
final String shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId);
logger.info(
"RegionFinish for shuffle {} map {} attemptId {} locationId {}.",
shuffleId,
mapId,
attemptId,
location.getUniqueId());
logger.debug("RegionFinish for location {}.", location.toString());
TransportClient client = createClientWaitingInFlightRequest(location, mapKey, pushState);
RegionFinish regionFinish =
new RegionFinish(MASTER_MODE, shuffleKey, location.getUniqueId(), attemptId);
client.sendRpcSync(regionFinish.toByteBuffer(), conf.pushDataTimeoutMs());
return null;
});
}
private <R> R sendMessageInternal(
int shuffleId,
int mapId,
int attemptId,
PartitionLocation location,
PushState pushState,
ThrowingExceptionSupplier<R, Exception> supplier)
throws IOException {
int batchId = 0;
try {
// mapKey
final String mapKey = Utils.makeMapKey(shuffleId, mapId, attemptId);
// return if shuffle stage already ended
if (mapperEnded(shuffleId, mapId, attemptId)) {
logger.debug(
"Send message to {} ignored because mapper already ended for shuffle {} map {} attempt {}.",
location.hostAndPushPort(),
shuffleId,
mapId,
attemptId);
return null;
}
pushState = getPushState(mapKey);
// force data has been send
limitZeroInFlight(mapKey, pushState);
// add inFlight requests
batchId = pushState.nextBatchId();
pushState.addBatch(batchId, location.hostAndPushPort());
return retrySendMessage(supplier);
} finally {
if (pushState != null) {
pushState.removeBatch(batchId, location.hostAndPushPort());
}
}
}
@FunctionalInterface
interface ThrowingExceptionSupplier<R, E extends Exception> {
R get() throws E;
}
private <R> R retrySendMessage(ThrowingExceptionSupplier<R, Exception> supplier)
throws IOException {
int retryTimes = 0;
boolean isSuccess = false;
Exception currentException = null;
R result = null;
while (!Thread.currentThread().isInterrupted()
&& !isSuccess
&& retryTimes < conf.networkIoMaxRetries(TransportModuleConstants.PUSH_MODULE)) {
logger.debug("RetrySendMessage retry times {}.", retryTimes);
try {
result = supplier.get();
isSuccess = true;
} catch (Exception e) {
currentException = e;
if (e instanceof InterruptedException) {
Thread.currentThread().interrupt();
}
if (shouldRetry(e)) {
retryTimes++;
Uninterruptibles.sleepUninterruptibly(
conf.networkIoRetryWaitMs(TransportModuleConstants.PUSH_MODULE),
TimeUnit.MILLISECONDS);
} else {
break;
}
}
}
if (!isSuccess) {
if (currentException instanceof IOException) {
throw (IOException) currentException;
} else {
throw new CelebornIOException(currentException.getMessage(), currentException);
}
}
return result;
}
private boolean shouldRetry(Throwable e) {
boolean isIOException =
e instanceof IOException
|| e instanceof TimeoutException
|| (e.getCause() != null && e.getCause() instanceof TimeoutException)
|| (e.getCause() != null && e.getCause() instanceof IOException)
|| (e instanceof RuntimeException
&& e.getMessage() != null
&& e.getMessage().startsWith(IOException.class.getName()));
return isIOException;
}
}

View File

@ -28,10 +28,8 @@ import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import io.netty.buffer.ByteBuf;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@ -150,48 +148,6 @@ public class DummyShuffleClient extends ShuffleClient {
}
}
@Override
public int pushDataToLocation(
String applicationId,
int shuffleId,
int mapId,
int attemptId,
int partitionId,
ByteBuf data,
PartitionLocation location,
Runnable closeCallBack) {
return 0;
}
@Override
public Optional<PartitionLocation> regionStart(
String applicationId,
int shuffleId,
int mapId,
int attemptId,
PartitionLocation location,
int currentRegionIdx,
boolean isBroadcast)
throws IOException {
return Optional.empty();
}
@Override
public void regionFinish(
String applicationId, int shuffleId, int mapId, int attemptId, PartitionLocation location)
throws IOException {}
@Override
public void pushDataHandShake(
String applicationId,
int shuffleId,
int mapId,
int attemptId,
int numPartitions,
int bufferSize,
PartitionLocation location)
throws IOException {}
@Override
public PartitionLocation registerMapPartitionTask(
String appId, int shuffleId, int numMappers, int mapId, int attemptId) {