[CELEBORN-367] [FLINK] Move pushdata functions used by mappartition from ShuffleClientImpl to FlinkShuffleClientImpl (#1295)
This commit is contained in:
parent
dcedf7b0a9
commit
9dc1bc2b1c
@ -44,5 +44,10 @@
|
||||
<artifactId>flink-runtime</artifactId>
|
||||
<scope>provided</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.mockito</groupId>
|
||||
<artifactId>mockito-core</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</project>
|
||||
|
||||
@ -18,18 +18,41 @@
|
||||
package org.apache.celeborn.plugin.flink.readclient;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.Optional;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.concurrent.TimeoutException;
|
||||
|
||||
import scala.reflect.ClassTag$;
|
||||
|
||||
import com.google.common.util.concurrent.Uninterruptibles;
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import org.apache.celeborn.client.ShuffleClientImpl;
|
||||
import org.apache.celeborn.common.CelebornConf;
|
||||
import org.apache.celeborn.common.exception.CelebornIOException;
|
||||
import org.apache.celeborn.common.identity.UserIdentifier;
|
||||
import org.apache.celeborn.common.network.TransportContext;
|
||||
import org.apache.celeborn.common.network.buffer.NettyManagedBuffer;
|
||||
import org.apache.celeborn.common.network.client.RpcResponseCallback;
|
||||
import org.apache.celeborn.common.network.client.TransportClient;
|
||||
import org.apache.celeborn.common.network.client.TransportClientFactory;
|
||||
import org.apache.celeborn.common.network.protocol.PushData;
|
||||
import org.apache.celeborn.common.network.protocol.PushDataHandShake;
|
||||
import org.apache.celeborn.common.network.protocol.RegionFinish;
|
||||
import org.apache.celeborn.common.network.protocol.RegionStart;
|
||||
import org.apache.celeborn.common.network.util.TransportConf;
|
||||
import org.apache.celeborn.common.protocol.PartitionLocation;
|
||||
import org.apache.celeborn.common.protocol.PbChangeLocationResponse;
|
||||
import org.apache.celeborn.common.protocol.TransportModuleConstants;
|
||||
import org.apache.celeborn.common.protocol.message.ControlMessages;
|
||||
import org.apache.celeborn.common.protocol.message.StatusCode;
|
||||
import org.apache.celeborn.common.util.PbSerDeUtils;
|
||||
import org.apache.celeborn.common.util.Utils;
|
||||
import org.apache.celeborn.common.write.PushState;
|
||||
import org.apache.celeborn.plugin.flink.network.FlinkTransportClientFactory;
|
||||
import org.apache.celeborn.plugin.flink.network.ReadClientHandler;
|
||||
|
||||
@ -39,6 +62,7 @@ public class FlinkShuffleClientImpl extends ShuffleClientImpl {
|
||||
private static volatile boolean initialized = false;
|
||||
private FlinkTransportClientFactory flinkTransportClientFactory;
|
||||
private ReadClientHandler readClientHandler = new ReadClientHandler();
|
||||
private ConcurrentHashMap<String, TransportClient> currentClient = new ConcurrentHashMap<>();
|
||||
|
||||
public static FlinkShuffleClientImpl get(
|
||||
String driverHost, int port, CelebornConf conf, UserIdentifier userIdentifier) {
|
||||
@ -137,4 +161,397 @@ public class FlinkShuffleClientImpl extends ShuffleClientImpl {
|
||||
public ReadClientHandler getReadClientHandler() {
|
||||
return readClientHandler;
|
||||
}
|
||||
|
||||
public int pushDataToLocation(
|
||||
String applicationId,
|
||||
int shuffleId,
|
||||
int mapId,
|
||||
int attemptId,
|
||||
int partitionId,
|
||||
ByteBuf data,
|
||||
PartitionLocation location,
|
||||
Runnable closeCallBack)
|
||||
throws IOException {
|
||||
// mapKey
|
||||
final String mapKey = Utils.makeMapKey(shuffleId, mapId, attemptId);
|
||||
final String shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId);
|
||||
// return if shuffle stage already ended
|
||||
if (mapperEnded(shuffleId, mapId, attemptId)) {
|
||||
logger.info(
|
||||
"Push data byteBuf to location {} ignored because mapper already ended for shuffle {} map {} attempt {}.",
|
||||
location.hostAndPushPort(),
|
||||
shuffleId,
|
||||
mapId,
|
||||
attemptId);
|
||||
PushState pushState = pushStates.get(mapKey);
|
||||
if (pushState != null) {
|
||||
pushState.cleanup();
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
PushState pushState = getPushState(mapKey);
|
||||
|
||||
// increment batchId
|
||||
final int nextBatchId = pushState.nextBatchId();
|
||||
int totalLength = data.readableBytes();
|
||||
data.markWriterIndex();
|
||||
data.writerIndex(0);
|
||||
data.writeInt(partitionId);
|
||||
data.writeInt(attemptId);
|
||||
data.writeInt(nextBatchId);
|
||||
data.writeInt(totalLength - BATCH_HEADER_SIZE);
|
||||
data.resetWriterIndex();
|
||||
logger.debug(
|
||||
"Do push data byteBuf size {} for app {} shuffle {} map {} attempt {} reduce {} batch {}.",
|
||||
totalLength,
|
||||
applicationId,
|
||||
shuffleId,
|
||||
mapId,
|
||||
attemptId,
|
||||
partitionId,
|
||||
nextBatchId);
|
||||
// check limit
|
||||
limitMaxInFlight(mapKey, pushState, location.hostAndPushPort());
|
||||
|
||||
// add inFlight requests
|
||||
pushState.addBatch(nextBatchId, location.hostAndPushPort());
|
||||
|
||||
// build PushData request
|
||||
NettyManagedBuffer buffer = new NettyManagedBuffer(data);
|
||||
PushData pushData = new PushData(MASTER_MODE, shuffleKey, location.getUniqueId(), buffer);
|
||||
|
||||
// build callback
|
||||
RpcResponseCallback callback =
|
||||
new RpcResponseCallback() {
|
||||
@Override
|
||||
public void onSuccess(ByteBuffer response) {
|
||||
pushState.removeBatch(nextBatchId, location.hostAndPushPort());
|
||||
if (response.remaining() > 0) {
|
||||
byte reason = response.get();
|
||||
if (reason == StatusCode.STAGE_ENDED.getValue()) {
|
||||
mapperEndMap
|
||||
.computeIfAbsent(shuffleId, (id) -> ConcurrentHashMap.newKeySet())
|
||||
.add(mapKey);
|
||||
}
|
||||
}
|
||||
logger.debug(
|
||||
"Push data byteBuf to {} success for shuffle {} map {} attemptId {} batch {}.",
|
||||
location.hostAndPushPort(),
|
||||
shuffleId,
|
||||
mapId,
|
||||
attemptId,
|
||||
nextBatchId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onFailure(Throwable e) {
|
||||
pushState.removeBatch(nextBatchId, location.hostAndPushPort());
|
||||
if (pushState.exception.get() != null) {
|
||||
return;
|
||||
}
|
||||
if (!mapperEnded(shuffleId, mapId, attemptId)) {
|
||||
String errorMsg =
|
||||
String.format(
|
||||
"Push data byteBuf to %s failed for shuffle %d map %d attempt %d batch %d.",
|
||||
location.hostAndPushPort(), shuffleId, mapId, attemptId, nextBatchId);
|
||||
pushState.exception.compareAndSet(null, new CelebornIOException(errorMsg, e));
|
||||
} else {
|
||||
logger.warn(
|
||||
"Push data to {} failed but mapper already ended for shuffle {} map {} attempt {} batch {}.",
|
||||
location.hostAndPushPort(),
|
||||
shuffleId,
|
||||
mapId,
|
||||
attemptId,
|
||||
nextBatchId);
|
||||
}
|
||||
}
|
||||
};
|
||||
// do push data
|
||||
try {
|
||||
TransportClient client = createClientWaitingInFlightRequest(location, mapKey, pushState);
|
||||
client.pushData(pushData, pushDataTimeout, callback, closeCallBack);
|
||||
} catch (Exception e) {
|
||||
logger.error(
|
||||
"Exception raised while pushing data byteBuf for shuffle {} map {} attempt {} partitionId {} batch {} location {}.",
|
||||
shuffleId,
|
||||
mapId,
|
||||
attemptId,
|
||||
partitionId,
|
||||
nextBatchId,
|
||||
location,
|
||||
e);
|
||||
callback.onFailure(
|
||||
new CelebornIOException(StatusCode.PUSH_DATA_CREATE_CONNECTION_FAIL_MASTER, e));
|
||||
}
|
||||
return totalLength;
|
||||
}
|
||||
|
||||
private TransportClient createClientWaitingInFlightRequest(
|
||||
PartitionLocation location, String mapKey, PushState pushState)
|
||||
throws IOException, InterruptedException {
|
||||
TransportClient client =
|
||||
dataClientFactory.createClient(
|
||||
location.getHost(), location.getPushPort(), location.getId());
|
||||
if (currentClient.get(mapKey) != client) {
|
||||
// makesure that messages have been sent by old client, in order to keep receiving data
|
||||
// orderly
|
||||
if (currentClient.get(mapKey) != null) {
|
||||
limitZeroInFlight(mapKey, pushState);
|
||||
}
|
||||
currentClient.put(mapKey, client);
|
||||
}
|
||||
return currentClient.get(mapKey);
|
||||
}
|
||||
|
||||
public void pushDataHandShake(
|
||||
String applicationId,
|
||||
int shuffleId,
|
||||
int mapId,
|
||||
int attemptId,
|
||||
int numPartitions,
|
||||
int bufferSize,
|
||||
PartitionLocation location)
|
||||
throws IOException {
|
||||
final String mapKey = Utils.makeMapKey(shuffleId, mapId, attemptId);
|
||||
final PushState pushState = pushStates.computeIfAbsent(mapKey, (s) -> new PushState(conf));
|
||||
sendMessageInternal(
|
||||
shuffleId,
|
||||
mapId,
|
||||
attemptId,
|
||||
location,
|
||||
pushState,
|
||||
() -> {
|
||||
String shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId);
|
||||
logger.info(
|
||||
"PushDataHandShake shuffleKey {} attemptId {} locationId {}",
|
||||
shuffleKey,
|
||||
attemptId,
|
||||
location.getUniqueId());
|
||||
logger.debug("PushDataHandShake location {}", location.toString());
|
||||
TransportClient client = createClientWaitingInFlightRequest(location, mapKey, pushState);
|
||||
PushDataHandShake handShake =
|
||||
new PushDataHandShake(
|
||||
MASTER_MODE,
|
||||
shuffleKey,
|
||||
location.getUniqueId(),
|
||||
attemptId,
|
||||
numPartitions,
|
||||
bufferSize);
|
||||
client.sendRpcSync(handShake.toByteBuffer(), conf.pushDataTimeoutMs());
|
||||
return null;
|
||||
});
|
||||
}
|
||||
|
||||
public Optional<PartitionLocation> regionStart(
|
||||
String applicationId,
|
||||
int shuffleId,
|
||||
int mapId,
|
||||
int attemptId,
|
||||
PartitionLocation location,
|
||||
int currentRegionIdx,
|
||||
boolean isBroadcast)
|
||||
throws IOException {
|
||||
final String mapKey = Utils.makeMapKey(shuffleId, mapId, attemptId);
|
||||
final PushState pushState = pushStates.computeIfAbsent(mapKey, (s) -> new PushState(conf));
|
||||
return sendMessageInternal(
|
||||
shuffleId,
|
||||
mapId,
|
||||
attemptId,
|
||||
location,
|
||||
pushState,
|
||||
() -> {
|
||||
String shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId);
|
||||
logger.info(
|
||||
"RegionStart for shuffle {} regionId {} attemptId {} locationId {}.",
|
||||
shuffleId,
|
||||
currentRegionIdx,
|
||||
attemptId,
|
||||
location.getUniqueId());
|
||||
logger.debug("RegionStart for location {}.", location.toString());
|
||||
TransportClient client = createClientWaitingInFlightRequest(location, mapKey, pushState);
|
||||
RegionStart regionStart =
|
||||
new RegionStart(
|
||||
MASTER_MODE,
|
||||
shuffleKey,
|
||||
location.getUniqueId(),
|
||||
attemptId,
|
||||
currentRegionIdx,
|
||||
isBroadcast);
|
||||
ByteBuffer regionStartResponse =
|
||||
client.sendRpcSync(regionStart.toByteBuffer(), conf.pushDataTimeoutMs());
|
||||
if (regionStartResponse.hasRemaining()
|
||||
&& regionStartResponse.get() == StatusCode.HARD_SPLIT.getValue()) {
|
||||
// if split then revive
|
||||
PbChangeLocationResponse response =
|
||||
driverRssMetaService.askSync(
|
||||
ControlMessages.Revive$.MODULE$.apply(
|
||||
applicationId,
|
||||
shuffleId,
|
||||
mapId,
|
||||
attemptId,
|
||||
location.getId(),
|
||||
location.getEpoch(),
|
||||
location,
|
||||
StatusCode.HARD_SPLIT),
|
||||
conf.requestPartitionLocationRpcAskTimeout(),
|
||||
ClassTag$.MODULE$.apply(PbChangeLocationResponse.class));
|
||||
// per partitionKey only serve single PartitionLocation in Client Cache.
|
||||
StatusCode respStatus = Utils.toStatusCode(response.getStatus());
|
||||
if (StatusCode.SUCCESS.equals(respStatus)) {
|
||||
return Optional.of(PbSerDeUtils.fromPbPartitionLocation(response.getLocation()));
|
||||
} else if (StatusCode.MAP_ENDED.equals(respStatus)) {
|
||||
mapperEndMap
|
||||
.computeIfAbsent(shuffleId, (id) -> ConcurrentHashMap.newKeySet())
|
||||
.add(mapKey);
|
||||
return Optional.empty();
|
||||
} else {
|
||||
// throw exception
|
||||
logger.error(
|
||||
"Exception raised while reviving for shuffle {} map {} attemptId {} partition {} epoch {}.",
|
||||
shuffleId,
|
||||
mapId,
|
||||
attemptId,
|
||||
location.getId(),
|
||||
location.getEpoch());
|
||||
throw new CelebornIOException("RegionStart revive failed");
|
||||
}
|
||||
}
|
||||
return Optional.empty();
|
||||
});
|
||||
}
|
||||
|
||||
public void regionFinish(
|
||||
String applicationId, int shuffleId, int mapId, int attemptId, PartitionLocation location)
|
||||
throws IOException {
|
||||
final String mapKey = Utils.makeMapKey(shuffleId, mapId, attemptId);
|
||||
final PushState pushState = pushStates.computeIfAbsent(mapKey, (s) -> new PushState(conf));
|
||||
sendMessageInternal(
|
||||
shuffleId,
|
||||
mapId,
|
||||
attemptId,
|
||||
location,
|
||||
pushState,
|
||||
() -> {
|
||||
final String shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId);
|
||||
logger.info(
|
||||
"RegionFinish for shuffle {} map {} attemptId {} locationId {}.",
|
||||
shuffleId,
|
||||
mapId,
|
||||
attemptId,
|
||||
location.getUniqueId());
|
||||
logger.debug("RegionFinish for location {}.", location.toString());
|
||||
TransportClient client = createClientWaitingInFlightRequest(location, mapKey, pushState);
|
||||
RegionFinish regionFinish =
|
||||
new RegionFinish(MASTER_MODE, shuffleKey, location.getUniqueId(), attemptId);
|
||||
client.sendRpcSync(regionFinish.toByteBuffer(), conf.pushDataTimeoutMs());
|
||||
return null;
|
||||
});
|
||||
}
|
||||
|
||||
private <R> R sendMessageInternal(
|
||||
int shuffleId,
|
||||
int mapId,
|
||||
int attemptId,
|
||||
PartitionLocation location,
|
||||
PushState pushState,
|
||||
ThrowingExceptionSupplier<R, Exception> supplier)
|
||||
throws IOException {
|
||||
int batchId = 0;
|
||||
try {
|
||||
// mapKey
|
||||
final String mapKey = Utils.makeMapKey(shuffleId, mapId, attemptId);
|
||||
// return if shuffle stage already ended
|
||||
if (mapperEnded(shuffleId, mapId, attemptId)) {
|
||||
logger.debug(
|
||||
"Send message to {} ignored because mapper already ended for shuffle {} map {} attempt {}.",
|
||||
location.hostAndPushPort(),
|
||||
shuffleId,
|
||||
mapId,
|
||||
attemptId);
|
||||
return null;
|
||||
}
|
||||
pushState = getPushState(mapKey);
|
||||
// force data has been send
|
||||
limitZeroInFlight(mapKey, pushState);
|
||||
|
||||
// add inFlight requests
|
||||
batchId = pushState.nextBatchId();
|
||||
pushState.addBatch(batchId, location.hostAndPushPort());
|
||||
return retrySendMessage(supplier);
|
||||
} finally {
|
||||
if (pushState != null) {
|
||||
pushState.removeBatch(batchId, location.hostAndPushPort());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@FunctionalInterface
|
||||
interface ThrowingExceptionSupplier<R, E extends Exception> {
|
||||
R get() throws E;
|
||||
}
|
||||
|
||||
private <R> R retrySendMessage(ThrowingExceptionSupplier<R, Exception> supplier)
|
||||
throws IOException {
|
||||
|
||||
int retryTimes = 0;
|
||||
boolean isSuccess = false;
|
||||
Exception currentException = null;
|
||||
R result = null;
|
||||
while (!Thread.currentThread().isInterrupted()
|
||||
&& !isSuccess
|
||||
&& retryTimes < conf.networkIoMaxRetries(TransportModuleConstants.PUSH_MODULE)) {
|
||||
logger.debug("RetrySendMessage retry times {}.", retryTimes);
|
||||
try {
|
||||
result = supplier.get();
|
||||
isSuccess = true;
|
||||
} catch (Exception e) {
|
||||
currentException = e;
|
||||
if (e instanceof InterruptedException) {
|
||||
Thread.currentThread().interrupt();
|
||||
}
|
||||
if (shouldRetry(e)) {
|
||||
retryTimes++;
|
||||
Uninterruptibles.sleepUninterruptibly(
|
||||
conf.networkIoRetryWaitMs(TransportModuleConstants.PUSH_MODULE),
|
||||
TimeUnit.MILLISECONDS);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!isSuccess) {
|
||||
if (currentException instanceof IOException) {
|
||||
throw (IOException) currentException;
|
||||
} else {
|
||||
throw new CelebornIOException(currentException.getMessage(), currentException);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
private boolean shouldRetry(Throwable e) {
|
||||
boolean isIOException =
|
||||
e instanceof IOException
|
||||
|| e instanceof TimeoutException
|
||||
|| (e.getCause() != null && e.getCause() instanceof TimeoutException)
|
||||
|| (e.getCause() != null && e.getCause() instanceof IOException)
|
||||
|| (e instanceof RuntimeException
|
||||
&& e.getMessage() != null
|
||||
&& e.getMessage().startsWith(IOException.class.getName()));
|
||||
return isIOException;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void cleanup(String applicationId, int shuffleId, int mapId, int attemptId) {
|
||||
final String mapKey = Utils.makeMapKey(shuffleId, mapId, attemptId);
|
||||
super.cleanup(applicationId, shuffleId, mapId, attemptId);
|
||||
if (currentClient != null) {
|
||||
currentClient.remove(mapKey);
|
||||
}
|
||||
}
|
||||
|
||||
public void setDataClientFactory(TransportClientFactory dataClientFactory) {
|
||||
this.dataClientFactory = dataClientFactory;
|
||||
}
|
||||
}
|
||||
|
||||
@ -15,10 +15,9 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.celeborn.client;
|
||||
package org.apache.celeborn.plugin.flink;
|
||||
|
||||
import static org.mockito.Matchers.any;
|
||||
import static org.mockito.Matchers.anyLong;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
import java.io.IOException;
|
||||
@ -26,27 +25,48 @@ import java.nio.ByteBuffer;
|
||||
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.buffer.Unpooled;
|
||||
import io.netty.channel.ChannelFuture;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
import org.mockito.Matchers;
|
||||
import org.mockito.Mockito;
|
||||
|
||||
import org.apache.celeborn.common.CelebornConf;
|
||||
import org.apache.celeborn.common.network.client.RpcResponseCallback;
|
||||
import org.apache.celeborn.common.protocol.CompressionCodec;
|
||||
import org.apache.celeborn.common.network.client.TransportClient;
|
||||
import org.apache.celeborn.common.network.client.TransportClientFactory;
|
||||
import org.apache.celeborn.common.protocol.PartitionLocation;
|
||||
import org.apache.celeborn.common.protocol.message.StatusCode;
|
||||
import org.apache.celeborn.plugin.flink.readclient.FlinkShuffleClientImpl;
|
||||
|
||||
public class ShuffleClientImplSuiteJ extends ShuffleClientBaseSuiteJ {
|
||||
public class FlinkShuffleClientImplSuiteJ {
|
||||
static int BufferSize = 64;
|
||||
static byte[] TEST_BUF1 = new byte[BufferSize];
|
||||
protected ChannelFuture mockedFuture = mock(ChannelFuture.class);
|
||||
static CelebornConf conf;
|
||||
static FlinkShuffleClientImpl shuffleClient;
|
||||
protected static final TransportClientFactory clientFactory = mock(TransportClientFactory.class);
|
||||
protected final TransportClient client = mock(TransportClient.class);
|
||||
protected static final PartitionLocation masterLocation =
|
||||
new PartitionLocation(0, 1, "localhost", 1, 1, 1, 1, PartitionLocation.Mode.MASTER);
|
||||
|
||||
@Before
|
||||
public void setup() throws IOException, InterruptedException {
|
||||
conf = setupEnv(CompressionCodec.LZ4);
|
||||
conf = new CelebornConf();
|
||||
shuffleClient =
|
||||
new FlinkShuffleClientImpl("localhost", 1232, conf, null) {
|
||||
@Override
|
||||
public void setupMetaServiceRef(String host, int port) {}
|
||||
};
|
||||
when(clientFactory.createClient(masterLocation.getHost(), masterLocation.getPushPort(), 1))
|
||||
.thenAnswer(t -> client);
|
||||
|
||||
shuffleClient.setDataClientFactory(clientFactory);
|
||||
}
|
||||
|
||||
public ByteBuf createByteBuf() {
|
||||
for (int i = BATCH_HEADER_SIZE; i < BufferSize; i++) {
|
||||
for (int i = 16; i < BufferSize; i++) {
|
||||
TEST_BUF1[i] = 1;
|
||||
}
|
||||
ByteBuf byteBuf = Unpooled.wrappedBuffer(TEST_BUF1);
|
||||
@ -57,7 +77,7 @@ public class ShuffleClientImplSuiteJ extends ShuffleClientBaseSuiteJ {
|
||||
@Test
|
||||
public void testPushDataByteBufSuccess() throws IOException {
|
||||
ByteBuf byteBuf = createByteBuf();
|
||||
when(client.pushData(any(), anyLong(), any()))
|
||||
Mockito.when(client.pushData(Matchers.any(), Matchers.anyLong(), Matchers.any()))
|
||||
.thenAnswer(
|
||||
t -> {
|
||||
RpcResponseCallback rpcResponseCallback =
|
||||
@ -68,22 +88,14 @@ public class ShuffleClientImplSuiteJ extends ShuffleClientBaseSuiteJ {
|
||||
});
|
||||
|
||||
int pushDataLen =
|
||||
shuffleClient.pushDataToLocation(
|
||||
TEST_APPLICATION_ID,
|
||||
TEST_SHUFFLE_ID,
|
||||
TEST_ATTEMPT_ID,
|
||||
TEST_ATTEMPT_ID,
|
||||
TEST_REDUCRE_ID,
|
||||
byteBuf,
|
||||
masterLocation,
|
||||
() -> {});
|
||||
shuffleClient.pushDataToLocation("1", 2, 3, 4, 5, byteBuf, masterLocation, () -> {});
|
||||
Assert.assertEquals(BufferSize, pushDataLen);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testPushDataByteBufHardSplit() throws IOException {
|
||||
ByteBuf byteBuf = Unpooled.wrappedBuffer(TEST_BUF1);
|
||||
when(client.pushData(any(), anyLong(), any()))
|
||||
Mockito.when(client.pushData(Matchers.any(), Matchers.anyLong(), Matchers.any()))
|
||||
.thenAnswer(
|
||||
t -> {
|
||||
RpcResponseCallback rpcResponseCallback =
|
||||
@ -94,21 +106,14 @@ public class ShuffleClientImplSuiteJ extends ShuffleClientBaseSuiteJ {
|
||||
return mockedFuture;
|
||||
});
|
||||
int pushDataLen =
|
||||
shuffleClient.pushDataToLocation(
|
||||
TEST_APPLICATION_ID,
|
||||
TEST_SHUFFLE_ID,
|
||||
TEST_ATTEMPT_ID,
|
||||
TEST_ATTEMPT_ID,
|
||||
TEST_REDUCRE_ID,
|
||||
byteBuf,
|
||||
masterLocation,
|
||||
() -> {});
|
||||
shuffleClient.pushDataToLocation("1", 2, 3, 4, 5, byteBuf, masterLocation, () -> {});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testPushDataByteBufFail() throws IOException {
|
||||
ByteBuf byteBuf = Unpooled.wrappedBuffer(TEST_BUF1);
|
||||
when(client.pushData(any(), anyLong(), any(), any()))
|
||||
Mockito.when(
|
||||
client.pushData(Matchers.any(), Matchers.anyLong(), Matchers.any(), Matchers.any()))
|
||||
.thenAnswer(
|
||||
t -> {
|
||||
RpcResponseCallback rpcResponseCallback =
|
||||
@ -117,28 +122,12 @@ public class ShuffleClientImplSuiteJ extends ShuffleClientBaseSuiteJ {
|
||||
return mockedFuture;
|
||||
});
|
||||
// first push just set pushdata.exception
|
||||
shuffleClient.pushDataToLocation(
|
||||
TEST_APPLICATION_ID,
|
||||
TEST_SHUFFLE_ID,
|
||||
TEST_ATTEMPT_ID,
|
||||
TEST_ATTEMPT_ID,
|
||||
TEST_REDUCRE_ID,
|
||||
byteBuf,
|
||||
masterLocation,
|
||||
() -> {});
|
||||
shuffleClient.pushDataToLocation("1", 2, 3, 4, 5, byteBuf, masterLocation, () -> {});
|
||||
|
||||
boolean isFailed = false;
|
||||
// second push will throw exception
|
||||
try {
|
||||
shuffleClient.pushDataToLocation(
|
||||
TEST_APPLICATION_ID,
|
||||
TEST_SHUFFLE_ID,
|
||||
TEST_ATTEMPT_ID,
|
||||
TEST_ATTEMPT_ID,
|
||||
TEST_REDUCRE_ID,
|
||||
byteBuf,
|
||||
masterLocation,
|
||||
() -> {});
|
||||
shuffleClient.pushDataToLocation("1", 2, 3, 4, 5, byteBuf, masterLocation, () -> {});
|
||||
} catch (IOException e) {
|
||||
isFailed = true;
|
||||
} finally {
|
||||
@ -29,11 +29,11 @@ import org.apache.flink.util.function.SupplierWithException;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import org.apache.celeborn.client.ShuffleClient;
|
||||
import org.apache.celeborn.common.CelebornConf;
|
||||
import org.apache.celeborn.common.identity.UserIdentifier;
|
||||
import org.apache.celeborn.common.protocol.PartitionLocation;
|
||||
import org.apache.celeborn.plugin.flink.buffer.BufferPacker;
|
||||
import org.apache.celeborn.plugin.flink.readclient.FlinkShuffleClientImpl;
|
||||
import org.apache.celeborn.plugin.flink.utils.BufferUtils;
|
||||
import org.apache.celeborn.plugin.flink.utils.Utils;
|
||||
|
||||
@ -58,7 +58,7 @@ public class RemoteShuffleOutputGate {
|
||||
private static final Logger LOG = LoggerFactory.getLogger(RemoteShuffleOutputGate.class);
|
||||
private final RemoteShuffleDescriptor shuffleDesc;
|
||||
protected final int numSubs;
|
||||
protected ShuffleClient shuffleWriteClient;
|
||||
protected FlinkShuffleClientImpl shuffleWriteClient;
|
||||
protected final SupplierWithException<BufferPool, IOException> bufferPoolFactory;
|
||||
protected BufferPool bufferPool;
|
||||
private CelebornConf celebornConf;
|
||||
@ -212,8 +212,9 @@ public class RemoteShuffleOutputGate {
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
ShuffleClient createWriteClient() {
|
||||
return ShuffleClient.get(rssMetaServiceHost, rssMetaServicePort, celebornConf, userIdentifier);
|
||||
FlinkShuffleClientImpl createWriteClient() {
|
||||
return FlinkShuffleClientImpl.get(
|
||||
rssMetaServiceHost, rssMetaServicePort, celebornConf, userIdentifier);
|
||||
}
|
||||
|
||||
/** Writes a piece of data to a subpartition. */
|
||||
|
||||
@ -36,12 +36,12 @@ import org.junit.Assert;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
|
||||
import org.apache.celeborn.client.ShuffleClientImpl;
|
||||
import org.apache.celeborn.common.protocol.PartitionLocation;
|
||||
import org.apache.celeborn.plugin.flink.readclient.FlinkShuffleClientImpl;
|
||||
|
||||
public class RemoteShuffleOutputGateSuiteJ {
|
||||
private RemoteShuffleOutputGate remoteShuffleOutputGate = mock(RemoteShuffleOutputGate.class);
|
||||
private ShuffleClientImpl shuffleClient = mock(ShuffleClientImpl.class);
|
||||
private FlinkShuffleClientImpl shuffleClient = mock(FlinkShuffleClientImpl.class);
|
||||
private static final int BUFFER_SIZE = 20;
|
||||
private NetworkBufferPool networkBufferPool;
|
||||
private BufferPool bufferPool;
|
||||
|
||||
@ -18,10 +18,8 @@
|
||||
package org.apache.celeborn.client;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Optional;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import org.apache.hadoop.conf.Configuration;
|
||||
import org.apache.hadoop.fs.FileSystem;
|
||||
import org.slf4j.Logger;
|
||||
@ -196,44 +194,6 @@ public abstract class ShuffleClient {
|
||||
|
||||
public abstract void shutdown();
|
||||
|
||||
// Write data to a specific map partition, input data's type is Bytebuf.
|
||||
// data's type is Bytebuf to avoid copy between application and netty
|
||||
// closecallback will do some clean operations like memory release.
|
||||
public abstract int pushDataToLocation(
|
||||
String applicationId,
|
||||
int shuffleId,
|
||||
int mapId,
|
||||
int attemptId,
|
||||
int partitionId,
|
||||
ByteBuf data,
|
||||
PartitionLocation location,
|
||||
Runnable closeCallBack)
|
||||
throws IOException;
|
||||
|
||||
public abstract Optional<PartitionLocation> regionStart(
|
||||
String applicationId,
|
||||
int shuffleId,
|
||||
int mapId,
|
||||
int attemptId,
|
||||
PartitionLocation location,
|
||||
int currentRegionIdx,
|
||||
boolean isBroadcast)
|
||||
throws IOException;
|
||||
|
||||
public abstract void regionFinish(
|
||||
String applicationId, int shuffleId, int mapId, int attemptId, PartitionLocation location)
|
||||
throws IOException;
|
||||
|
||||
public abstract void pushDataHandShake(
|
||||
String applicationId,
|
||||
int shuffleId,
|
||||
int mapId,
|
||||
int attemptId,
|
||||
int numPartitions,
|
||||
int bufferSize,
|
||||
PartitionLocation location)
|
||||
throws IOException;
|
||||
|
||||
public abstract PartitionLocation registerMapPartitionTask(
|
||||
String appId, int shuffleId, int numMappers, int mapId, int attemptId) throws IOException;
|
||||
|
||||
|
||||
@ -24,13 +24,10 @@ import java.util.concurrent.Callable;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.concurrent.TimeoutException;
|
||||
|
||||
import scala.reflect.ClassTag$;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import com.google.common.util.concurrent.Uninterruptibles;
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.buffer.CompositeByteBuf;
|
||||
import io.netty.buffer.Unpooled;
|
||||
import org.slf4j.Logger;
|
||||
@ -48,14 +45,10 @@ import org.apache.celeborn.common.network.client.RpcResponseCallback;
|
||||
import org.apache.celeborn.common.network.client.TransportClient;
|
||||
import org.apache.celeborn.common.network.client.TransportClientFactory;
|
||||
import org.apache.celeborn.common.network.protocol.PushData;
|
||||
import org.apache.celeborn.common.network.protocol.PushDataHandShake;
|
||||
import org.apache.celeborn.common.network.protocol.PushMergedData;
|
||||
import org.apache.celeborn.common.network.protocol.RegionFinish;
|
||||
import org.apache.celeborn.common.network.protocol.RegionStart;
|
||||
import org.apache.celeborn.common.network.server.BaseMessageHandler;
|
||||
import org.apache.celeborn.common.network.util.TransportConf;
|
||||
import org.apache.celeborn.common.protocol.*;
|
||||
import org.apache.celeborn.common.protocol.message.ControlMessages;
|
||||
import org.apache.celeborn.common.protocol.message.ControlMessages.*;
|
||||
import org.apache.celeborn.common.protocol.message.StatusCode;
|
||||
import org.apache.celeborn.common.rpc.RpcAddress;
|
||||
@ -72,7 +65,7 @@ import org.apache.celeborn.common.write.PushState;
|
||||
public class ShuffleClientImpl extends ShuffleClient {
|
||||
private static final Logger logger = LoggerFactory.getLogger(ShuffleClientImpl.class);
|
||||
|
||||
private static final byte MASTER_MODE = PartitionLocation.Mode.MASTER.mode();
|
||||
protected static final byte MASTER_MODE = PartitionLocation.Mode.MASTER.mode();
|
||||
|
||||
private static final Random RND = new Random();
|
||||
|
||||
@ -85,24 +78,24 @@ public class ShuffleClientImpl extends ShuffleClient {
|
||||
private int maxReviveTimes;
|
||||
private boolean testRetryRevive;
|
||||
private final int pushBufferMaxSize;
|
||||
private final long pushDataTimeout;
|
||||
protected final long pushDataTimeout;
|
||||
|
||||
private final RpcEnv rpcEnv;
|
||||
|
||||
private RpcEndpointRef driverRssMetaService;
|
||||
protected RpcEndpointRef driverRssMetaService;
|
||||
|
||||
protected TransportClientFactory dataClientFactory;
|
||||
|
||||
final int BATCH_HEADER_SIZE = 4 * 4;
|
||||
protected final int BATCH_HEADER_SIZE = 4 * 4;
|
||||
|
||||
// key: shuffleId, value: (partitionId, PartitionLocation)
|
||||
private final Map<Integer, ConcurrentHashMap<Integer, PartitionLocation>> reducePartitionMap =
|
||||
new ConcurrentHashMap<>();
|
||||
|
||||
private final ConcurrentHashMap<Integer, Set<String>> mapperEndMap = new ConcurrentHashMap<>();
|
||||
protected final ConcurrentHashMap<Integer, Set<String>> mapperEndMap = new ConcurrentHashMap<>();
|
||||
|
||||
// key: shuffleId-mapId-attemptId
|
||||
private final Map<String, PushState> pushStates = new ConcurrentHashMap<>();
|
||||
protected final Map<String, PushState> pushStates = new ConcurrentHashMap<>();
|
||||
|
||||
private final ExecutorService pushDataRetryPool;
|
||||
|
||||
@ -141,8 +134,6 @@ public class ShuffleClientImpl extends ShuffleClient {
|
||||
// key: shuffleId
|
||||
protected final Map<Integer, ReduceFileGroups> reduceFileGroupsMap = new ConcurrentHashMap<>();
|
||||
|
||||
private ConcurrentHashMap<String, TransportClient> currentClient = new ConcurrentHashMap<>();
|
||||
|
||||
public ShuffleClientImpl(CelebornConf conf, UserIdentifier userIdentifier) {
|
||||
super();
|
||||
this.conf = conf;
|
||||
@ -460,7 +451,7 @@ public class ShuffleClientImpl extends ShuffleClient {
|
||||
return null;
|
||||
}
|
||||
|
||||
private void limitMaxInFlight(String mapKey, PushState pushState, String hostAndPushPort)
|
||||
protected void limitMaxInFlight(String mapKey, PushState pushState, String hostAndPushPort)
|
||||
throws IOException {
|
||||
boolean reachLimit = pushState.limitMaxInFlight(hostAndPushPort);
|
||||
|
||||
@ -470,7 +461,7 @@ public class ShuffleClientImpl extends ShuffleClient {
|
||||
}
|
||||
}
|
||||
|
||||
private void limitZeroInFlight(String mapKey, PushState pushState) throws IOException {
|
||||
protected void limitZeroInFlight(String mapKey, PushState pushState) throws IOException {
|
||||
boolean reachLimit = pushState.limitZeroInFlight();
|
||||
|
||||
if (reachLimit) {
|
||||
@ -1296,9 +1287,6 @@ public class ShuffleClientImpl extends ShuffleClient {
|
||||
pushState.exception.compareAndSet(null, new CelebornIOException("Cleaned Up"));
|
||||
pushState.cleanup();
|
||||
}
|
||||
if (currentClient != null) {
|
||||
currentClient.remove(mapKey);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -1462,7 +1450,7 @@ public class ShuffleClientImpl extends ShuffleClient {
|
||||
driverRssMetaService = endpointRef;
|
||||
}
|
||||
|
||||
private boolean mapperEnded(int shuffleId, int mapId, int attemptId) {
|
||||
protected boolean mapperEnded(int shuffleId, int mapId, int attemptId) {
|
||||
return mapperEndMap.containsKey(shuffleId)
|
||||
&& mapperEndMap.get(shuffleId).contains(Utils.makeMapKey(shuffleId, mapId, attemptId));
|
||||
}
|
||||
@ -1502,387 +1490,4 @@ public class ShuffleClientImpl extends ShuffleClient {
|
||||
|| (message.equals("Connection reset by peer"))
|
||||
|| (message.startsWith("Failed to send RPC "));
|
||||
}
|
||||
|
||||
public int pushDataToLocation(
|
||||
String applicationId,
|
||||
int shuffleId,
|
||||
int mapId,
|
||||
int attemptId,
|
||||
int partitionId,
|
||||
ByteBuf data,
|
||||
PartitionLocation location,
|
||||
Runnable closeCallBack)
|
||||
throws IOException {
|
||||
// mapKey
|
||||
final String mapKey = Utils.makeMapKey(shuffleId, mapId, attemptId);
|
||||
final String shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId);
|
||||
// return if shuffle stage already ended
|
||||
if (mapperEnded(shuffleId, mapId, attemptId)) {
|
||||
logger.info(
|
||||
"Push data byteBuf to location {} ignored because mapper already ended for shuffle {} map {} attempt {}.",
|
||||
location.hostAndPushPort(),
|
||||
shuffleId,
|
||||
mapId,
|
||||
attemptId);
|
||||
PushState pushState = pushStates.get(mapKey);
|
||||
if (pushState != null) {
|
||||
pushState.cleanup();
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
PushState pushState = getPushState(mapKey);
|
||||
|
||||
// increment batchId
|
||||
final int nextBatchId = pushState.nextBatchId();
|
||||
int totalLength = data.readableBytes();
|
||||
data.markWriterIndex();
|
||||
data.writerIndex(0);
|
||||
data.writeInt(partitionId);
|
||||
data.writeInt(attemptId);
|
||||
data.writeInt(nextBatchId);
|
||||
data.writeInt(totalLength - BATCH_HEADER_SIZE);
|
||||
data.resetWriterIndex();
|
||||
logger.debug(
|
||||
"Do push data byteBuf size {} for app {} shuffle {} map {} attempt {} reduce {} batch {}.",
|
||||
totalLength,
|
||||
applicationId,
|
||||
shuffleId,
|
||||
mapId,
|
||||
attemptId,
|
||||
partitionId,
|
||||
nextBatchId);
|
||||
// check limit
|
||||
limitMaxInFlight(mapKey, pushState, location.hostAndPushPort());
|
||||
|
||||
// add inFlight requests
|
||||
pushState.addBatch(nextBatchId, location.hostAndPushPort());
|
||||
|
||||
// build PushData request
|
||||
NettyManagedBuffer buffer = new NettyManagedBuffer(data);
|
||||
PushData pushData = new PushData(MASTER_MODE, shuffleKey, location.getUniqueId(), buffer);
|
||||
|
||||
// build callback
|
||||
RpcResponseCallback callback =
|
||||
new RpcResponseCallback() {
|
||||
@Override
|
||||
public void onSuccess(ByteBuffer response) {
|
||||
pushState.removeBatch(nextBatchId, location.hostAndPushPort());
|
||||
if (response.remaining() > 0) {
|
||||
byte reason = response.get();
|
||||
if (reason == StatusCode.STAGE_ENDED.getValue()) {
|
||||
mapperEndMap
|
||||
.computeIfAbsent(shuffleId, (id) -> ConcurrentHashMap.newKeySet())
|
||||
.add(mapKey);
|
||||
}
|
||||
}
|
||||
logger.debug(
|
||||
"Push data byteBuf to {} success for shuffle {} map {} attemptId {} batch {}.",
|
||||
location.hostAndPushPort(),
|
||||
shuffleId,
|
||||
mapId,
|
||||
attemptId,
|
||||
nextBatchId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onFailure(Throwable e) {
|
||||
pushState.removeBatch(nextBatchId, location.hostAndPushPort());
|
||||
if (pushState.exception.get() != null) {
|
||||
return;
|
||||
}
|
||||
if (!mapperEnded(shuffleId, mapId, attemptId)) {
|
||||
String errorMsg =
|
||||
String.format(
|
||||
"Push data byteBuf to %s failed for shuffle %d map %d attempt %d batch %d.",
|
||||
location.hostAndPushPort(), shuffleId, mapId, attemptId, nextBatchId);
|
||||
pushState.exception.compareAndSet(null, new CelebornIOException(errorMsg, e));
|
||||
} else {
|
||||
logger.warn(
|
||||
"Push data to {} failed but mapper already ended for shuffle {} map {} attempt {} batch {}.",
|
||||
location.hostAndPushPort(),
|
||||
shuffleId,
|
||||
mapId,
|
||||
attemptId,
|
||||
nextBatchId);
|
||||
}
|
||||
}
|
||||
};
|
||||
// do push data
|
||||
try {
|
||||
TransportClient client = createClientWaitingInFlightRequest(location, mapKey, pushState);
|
||||
client.pushData(pushData, pushDataTimeout, callback, closeCallBack);
|
||||
} catch (Exception e) {
|
||||
logger.error(
|
||||
"Exception raised while pushing data byteBuf for shuffle {} map {} attempt {} partitionId {} batch {} location {}.",
|
||||
shuffleId,
|
||||
mapId,
|
||||
attemptId,
|
||||
partitionId,
|
||||
nextBatchId,
|
||||
location,
|
||||
e);
|
||||
callback.onFailure(
|
||||
new CelebornIOException(StatusCode.PUSH_DATA_CREATE_CONNECTION_FAIL_MASTER, e));
|
||||
}
|
||||
return totalLength;
|
||||
}
|
||||
|
||||
private TransportClient createClientWaitingInFlightRequest(
|
||||
PartitionLocation location, String mapKey, PushState pushState)
|
||||
throws IOException, InterruptedException {
|
||||
TransportClient client =
|
||||
dataClientFactory.createClient(
|
||||
location.getHost(), location.getPushPort(), location.getId());
|
||||
if (currentClient.get(mapKey) != client) {
|
||||
// makesure that messages have been sent by old client, in order to keep receiving data
|
||||
// orderly
|
||||
if (currentClient.get(mapKey) != null) {
|
||||
limitZeroInFlight(mapKey, pushState);
|
||||
}
|
||||
currentClient.put(mapKey, client);
|
||||
}
|
||||
return currentClient.get(mapKey);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void pushDataHandShake(
|
||||
String applicationId,
|
||||
int shuffleId,
|
||||
int mapId,
|
||||
int attemptId,
|
||||
int numPartitions,
|
||||
int bufferSize,
|
||||
PartitionLocation location)
|
||||
throws IOException {
|
||||
final String mapKey = Utils.makeMapKey(shuffleId, mapId, attemptId);
|
||||
final PushState pushState = pushStates.computeIfAbsent(mapKey, (s) -> new PushState(conf));
|
||||
sendMessageInternal(
|
||||
shuffleId,
|
||||
mapId,
|
||||
attemptId,
|
||||
location,
|
||||
pushState,
|
||||
() -> {
|
||||
String shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId);
|
||||
logger.info(
|
||||
"PushDataHandShake shuffleKey {} attemptId {} locationId {}",
|
||||
shuffleKey,
|
||||
attemptId,
|
||||
location.getUniqueId());
|
||||
logger.debug("PushDataHandShake location {}", location.toString());
|
||||
TransportClient client = createClientWaitingInFlightRequest(location, mapKey, pushState);
|
||||
PushDataHandShake handShake =
|
||||
new PushDataHandShake(
|
||||
MASTER_MODE,
|
||||
shuffleKey,
|
||||
location.getUniqueId(),
|
||||
attemptId,
|
||||
numPartitions,
|
||||
bufferSize);
|
||||
client.sendRpcSync(handShake.toByteBuffer(), conf.pushDataTimeoutMs());
|
||||
return null;
|
||||
});
|
||||
}
|
||||
|
||||
@Override
|
||||
public Optional<PartitionLocation> regionStart(
|
||||
String applicationId,
|
||||
int shuffleId,
|
||||
int mapId,
|
||||
int attemptId,
|
||||
PartitionLocation location,
|
||||
int currentRegionIdx,
|
||||
boolean isBroadcast)
|
||||
throws IOException {
|
||||
final String mapKey = Utils.makeMapKey(shuffleId, mapId, attemptId);
|
||||
final PushState pushState = pushStates.computeIfAbsent(mapKey, (s) -> new PushState(conf));
|
||||
return sendMessageInternal(
|
||||
shuffleId,
|
||||
mapId,
|
||||
attemptId,
|
||||
location,
|
||||
pushState,
|
||||
() -> {
|
||||
String shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId);
|
||||
logger.info(
|
||||
"RegionStart for shuffle {} regionId {} attemptId {} locationId {}.",
|
||||
shuffleId,
|
||||
currentRegionIdx,
|
||||
attemptId,
|
||||
location.getUniqueId());
|
||||
logger.debug("RegionStart for location {}.", location.toString());
|
||||
TransportClient client = createClientWaitingInFlightRequest(location, mapKey, pushState);
|
||||
RegionStart regionStart =
|
||||
new RegionStart(
|
||||
MASTER_MODE,
|
||||
shuffleKey,
|
||||
location.getUniqueId(),
|
||||
attemptId,
|
||||
currentRegionIdx,
|
||||
isBroadcast);
|
||||
ByteBuffer regionStartResponse =
|
||||
client.sendRpcSync(regionStart.toByteBuffer(), conf.pushDataTimeoutMs());
|
||||
if (regionStartResponse.hasRemaining()
|
||||
&& regionStartResponse.get() == StatusCode.HARD_SPLIT.getValue()) {
|
||||
// if split then revive
|
||||
PbChangeLocationResponse response =
|
||||
driverRssMetaService.askSync(
|
||||
ControlMessages.Revive$.MODULE$.apply(
|
||||
applicationId,
|
||||
shuffleId,
|
||||
mapId,
|
||||
attemptId,
|
||||
location.getId(),
|
||||
location.getEpoch(),
|
||||
location,
|
||||
StatusCode.HARD_SPLIT),
|
||||
conf.requestPartitionLocationRpcAskTimeout(),
|
||||
ClassTag$.MODULE$.apply(PbChangeLocationResponse.class));
|
||||
// per partitionKey only serve single PartitionLocation in Client Cache.
|
||||
StatusCode respStatus = Utils.toStatusCode(response.getStatus());
|
||||
if (StatusCode.SUCCESS.equals(respStatus)) {
|
||||
return Optional.of(PbSerDeUtils.fromPbPartitionLocation(response.getLocation()));
|
||||
} else if (StatusCode.MAP_ENDED.equals(respStatus)) {
|
||||
mapperEndMap
|
||||
.computeIfAbsent(shuffleId, (id) -> ConcurrentHashMap.newKeySet())
|
||||
.add(mapKey);
|
||||
return Optional.empty();
|
||||
} else {
|
||||
// throw exception
|
||||
logger.error(
|
||||
"Exception raised while reviving for shuffle {} map {} attemptId {} partition {} epoch {}.",
|
||||
shuffleId,
|
||||
mapId,
|
||||
attemptId,
|
||||
location.getId(),
|
||||
location.getEpoch());
|
||||
throw new CelebornIOException("RegionStart revive failed");
|
||||
}
|
||||
}
|
||||
return Optional.empty();
|
||||
});
|
||||
}
|
||||
|
||||
@Override
|
||||
public void regionFinish(
|
||||
String applicationId, int shuffleId, int mapId, int attemptId, PartitionLocation location)
|
||||
throws IOException {
|
||||
final String mapKey = Utils.makeMapKey(shuffleId, mapId, attemptId);
|
||||
final PushState pushState = pushStates.computeIfAbsent(mapKey, (s) -> new PushState(conf));
|
||||
sendMessageInternal(
|
||||
shuffleId,
|
||||
mapId,
|
||||
attemptId,
|
||||
location,
|
||||
pushState,
|
||||
() -> {
|
||||
final String shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId);
|
||||
logger.info(
|
||||
"RegionFinish for shuffle {} map {} attemptId {} locationId {}.",
|
||||
shuffleId,
|
||||
mapId,
|
||||
attemptId,
|
||||
location.getUniqueId());
|
||||
logger.debug("RegionFinish for location {}.", location.toString());
|
||||
TransportClient client = createClientWaitingInFlightRequest(location, mapKey, pushState);
|
||||
RegionFinish regionFinish =
|
||||
new RegionFinish(MASTER_MODE, shuffleKey, location.getUniqueId(), attemptId);
|
||||
client.sendRpcSync(regionFinish.toByteBuffer(), conf.pushDataTimeoutMs());
|
||||
return null;
|
||||
});
|
||||
}
|
||||
|
||||
private <R> R sendMessageInternal(
|
||||
int shuffleId,
|
||||
int mapId,
|
||||
int attemptId,
|
||||
PartitionLocation location,
|
||||
PushState pushState,
|
||||
ThrowingExceptionSupplier<R, Exception> supplier)
|
||||
throws IOException {
|
||||
int batchId = 0;
|
||||
try {
|
||||
// mapKey
|
||||
final String mapKey = Utils.makeMapKey(shuffleId, mapId, attemptId);
|
||||
// return if shuffle stage already ended
|
||||
if (mapperEnded(shuffleId, mapId, attemptId)) {
|
||||
logger.debug(
|
||||
"Send message to {} ignored because mapper already ended for shuffle {} map {} attempt {}.",
|
||||
location.hostAndPushPort(),
|
||||
shuffleId,
|
||||
mapId,
|
||||
attemptId);
|
||||
return null;
|
||||
}
|
||||
pushState = getPushState(mapKey);
|
||||
// force data has been send
|
||||
limitZeroInFlight(mapKey, pushState);
|
||||
|
||||
// add inFlight requests
|
||||
batchId = pushState.nextBatchId();
|
||||
pushState.addBatch(batchId, location.hostAndPushPort());
|
||||
return retrySendMessage(supplier);
|
||||
} finally {
|
||||
if (pushState != null) {
|
||||
pushState.removeBatch(batchId, location.hostAndPushPort());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@FunctionalInterface
|
||||
interface ThrowingExceptionSupplier<R, E extends Exception> {
|
||||
R get() throws E;
|
||||
}
|
||||
|
||||
private <R> R retrySendMessage(ThrowingExceptionSupplier<R, Exception> supplier)
|
||||
throws IOException {
|
||||
|
||||
int retryTimes = 0;
|
||||
boolean isSuccess = false;
|
||||
Exception currentException = null;
|
||||
R result = null;
|
||||
while (!Thread.currentThread().isInterrupted()
|
||||
&& !isSuccess
|
||||
&& retryTimes < conf.networkIoMaxRetries(TransportModuleConstants.PUSH_MODULE)) {
|
||||
logger.debug("RetrySendMessage retry times {}.", retryTimes);
|
||||
try {
|
||||
result = supplier.get();
|
||||
isSuccess = true;
|
||||
} catch (Exception e) {
|
||||
currentException = e;
|
||||
if (e instanceof InterruptedException) {
|
||||
Thread.currentThread().interrupt();
|
||||
}
|
||||
if (shouldRetry(e)) {
|
||||
retryTimes++;
|
||||
Uninterruptibles.sleepUninterruptibly(
|
||||
conf.networkIoRetryWaitMs(TransportModuleConstants.PUSH_MODULE),
|
||||
TimeUnit.MILLISECONDS);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!isSuccess) {
|
||||
if (currentException instanceof IOException) {
|
||||
throw (IOException) currentException;
|
||||
} else {
|
||||
throw new CelebornIOException(currentException.getMessage(), currentException);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
private boolean shouldRetry(Throwable e) {
|
||||
boolean isIOException =
|
||||
e instanceof IOException
|
||||
|| e instanceof TimeoutException
|
||||
|| (e.getCause() != null && e.getCause() instanceof TimeoutException)
|
||||
|| (e.getCause() != null && e.getCause() instanceof IOException)
|
||||
|| (e instanceof RuntimeException
|
||||
&& e.getMessage() != null
|
||||
&& e.getMessage().startsWith(IOException.class.getName()));
|
||||
return isIOException;
|
||||
}
|
||||
}
|
||||
|
||||
@ -28,10 +28,8 @@ import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
@ -150,48 +148,6 @@ public class DummyShuffleClient extends ShuffleClient {
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public int pushDataToLocation(
|
||||
String applicationId,
|
||||
int shuffleId,
|
||||
int mapId,
|
||||
int attemptId,
|
||||
int partitionId,
|
||||
ByteBuf data,
|
||||
PartitionLocation location,
|
||||
Runnable closeCallBack) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Optional<PartitionLocation> regionStart(
|
||||
String applicationId,
|
||||
int shuffleId,
|
||||
int mapId,
|
||||
int attemptId,
|
||||
PartitionLocation location,
|
||||
int currentRegionIdx,
|
||||
boolean isBroadcast)
|
||||
throws IOException {
|
||||
return Optional.empty();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void regionFinish(
|
||||
String applicationId, int shuffleId, int mapId, int attemptId, PartitionLocation location)
|
||||
throws IOException {}
|
||||
|
||||
@Override
|
||||
public void pushDataHandShake(
|
||||
String applicationId,
|
||||
int shuffleId,
|
||||
int mapId,
|
||||
int attemptId,
|
||||
int numPartitions,
|
||||
int bufferSize,
|
||||
PartitionLocation location)
|
||||
throws IOException {}
|
||||
|
||||
@Override
|
||||
public PartitionLocation registerMapPartitionTask(
|
||||
String appId, int shuffleId, int numMappers, int mapId, int attemptId) {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user