From d68deecaaa3a098cbdff806b0ddb522072347ebc Mon Sep 17 00:00:00 2001 From: Shuang Date: Sat, 22 Apr 2023 16:33:22 +0800 Subject: [PATCH] [CELEBORN-546][FLINK] Use autoIncrement partitionId replace encode(mapId, attemptId) for generating partitionId (#1447) --- .../plugin/flink/RemoteShuffleDescriptor.java | 4 ++ .../plugin/flink/RemoteShuffleMaster.java | 20 +++---- .../plugin/flink/RemoteShuffleOutputGate.java | 5 +- .../flink/ShuffleResourceDescriptor.java | 12 ++-- .../celeborn/plugin/flink/ShuffleTask.java | 42 -------------- .../plugin/flink/ShuffleTaskInfo.java | 47 ++++++--------- .../plugin/flink/RemoteShuffleMasterTest.java | 3 - .../flink/RemoteShuffleOutputGateSuiteJ.java | 3 +- .../RemoteShuffleResultPartitionSuiteJ.java | 7 +-- .../plugin/flink/ShuffleTaskInfoSuitJ.java | 14 ++--- .../apache/celeborn/client/ShuffleClient.java | 3 +- .../celeborn/client/ShuffleClientImpl.java | 4 +- .../celeborn/client/DummyShuffleClient.java | 2 +- .../client/WithShuffleClientSuite.scala | 45 +++++++------- .../common/protocol/PartitionLocation.java | 15 +---- .../common/util/PackedPartitionId.java | 58 ------------------- .../protocol/PartitionLocationSuiteJ.java | 18 ++---- .../common/util/PackedPartitionIdSuiteJ.java | 53 ----------------- .../tests/client/ShuffleClientSuite.scala | 2 +- .../deploy/worker/PushDataHandler.scala | 30 +++------- 20 files changed, 95 insertions(+), 292 deletions(-) delete mode 100644 client-flink/flink-1.14/src/main/java/org/apache/celeborn/plugin/flink/ShuffleTask.java delete mode 100644 common/src/main/java/org/apache/celeborn/common/util/PackedPartitionId.java delete mode 100644 common/src/test/java/org/apache/celeborn/common/util/PackedPartitionIdSuiteJ.java diff --git a/client-flink/flink-1.14/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleDescriptor.java b/client-flink/flink-1.14/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleDescriptor.java index f336a9446..ee0f28cb9 100644 --- a/client-flink/flink-1.14/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleDescriptor.java +++ b/client-flink/flink-1.14/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleDescriptor.java @@ -58,6 +58,10 @@ public class RemoteShuffleDescriptor implements ShuffleDescriptor { return jobId; } + public String getShuffleId() { + return shuffleId; + } + public RemoteShuffleResource getShuffleResource() { return shuffleResource; } diff --git a/client-flink/flink-1.14/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleMaster.java b/client-flink/flink-1.14/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleMaster.java index 959a6b647..1044402f1 100644 --- a/client-flink/flink-1.14/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleMaster.java +++ b/client-flink/flink-1.14/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleMaster.java @@ -127,18 +127,16 @@ public class RemoteShuffleMaster implements ShuffleMaster mapId_taskAttemptId -> attemptIdx - private ConcurrentHashMap> - taskShuffleAttemptIdToAttemptId = JavaUtils.newConcurrentHashMap(); // map attemptId index - private ConcurrentHashMap> taskShuffleAttemptIdIndex = - JavaUtils.newConcurrentHashMap(); + private ConcurrentHashMap> + shuffleIdMapAttemptIdIndex = JavaUtils.newConcurrentHashMap(); // task shuffle id -> celeborn shuffle id private ConcurrentHashMap taskShuffleIdToShuffleId = JavaUtils.newConcurrentHashMap(); @@ -36,6 +34,9 @@ public class ShuffleTaskInfo { private ConcurrentHashMap shuffleIdToTaskShuffleId = JavaUtils.newConcurrentHashMap(); + private ConcurrentHashMap shuffleIdPartitionIdIndex = + JavaUtils.newConcurrentHashMap(); + public int getShuffleId(String taskShuffleId) { synchronized (taskShuffleIdToShuffleId) { if (taskShuffleIdToShuffleId.containsKey(taskShuffleId)) { @@ -43,6 +44,8 @@ public class ShuffleTaskInfo { } else { taskShuffleIdToShuffleId.put(taskShuffleId, currentShuffleIndex); shuffleIdToTaskShuffleId.put(currentShuffleIndex, taskShuffleId); + shuffleIdMapAttemptIdIndex.put(currentShuffleIndex, JavaUtils.newConcurrentHashMap()); + shuffleIdPartitionIdIndex.put(currentShuffleIndex, new AtomicInteger(0)); int tempShuffleIndex = currentShuffleIndex; currentShuffleIndex = currentShuffleIndex + 1; return tempShuffleIndex; @@ -50,36 +53,24 @@ public class ShuffleTaskInfo { } } - public int getAttemptId(String taskShuffleId, int mapId, String attemptId) { - ConcurrentHashMap attemptIndex = - taskShuffleAttemptIdIndex.computeIfAbsent( - taskShuffleId, (id) -> JavaUtils.newConcurrentHashMap()); - ConcurrentHashMap attemptIdMap = - taskShuffleAttemptIdToAttemptId.computeIfAbsent( - taskShuffleId, (id) -> JavaUtils.newConcurrentHashMap()); - String mapAttemptId = mapId + "_" + attemptId; - synchronized (attemptIndex) { - if (!attemptIdMap.containsKey(mapAttemptId)) { - if (attemptIndex.containsKey(mapId)) { - int index = attemptIndex.get(mapId); - attemptIdMap.put(mapAttemptId, index + 1); - attemptIndex.put(mapId, index + 1); - } else { - attemptIdMap.put(mapAttemptId, 0); - attemptIndex.put(mapId, 0); - } - } - } + public int genAttemptId(int shuffleId, int mapId) { + AtomicInteger currentAttemptIndex = + shuffleIdMapAttemptIdIndex + .get(shuffleId) + .computeIfAbsent(mapId, (id) -> new AtomicInteger(0)); + return currentAttemptIndex.getAndIncrement(); + } - return attemptIdMap.get(mapAttemptId); + public int genPartitionId(int shuffleId) { + return shuffleIdPartitionIdIndex.get(shuffleId).getAndIncrement(); } public void removeExpiredShuffle(int shuffleId) { if (shuffleIdToTaskShuffleId.containsKey(shuffleId)) { + shuffleIdPartitionIdIndex.remove(shuffleId); + shuffleIdMapAttemptIdIndex.remove(shuffleId); String taskShuffleId = shuffleIdToTaskShuffleId.remove(shuffleId); taskShuffleIdToShuffleId.remove(taskShuffleId); - taskShuffleAttemptIdIndex.remove(taskShuffleId); - taskShuffleAttemptIdToAttemptId.remove(taskShuffleId); } } } diff --git a/client-flink/flink-1.14/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleMasterTest.java b/client-flink/flink-1.14/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleMasterTest.java index ce1b555c6..dc98b9d07 100644 --- a/client-flink/flink-1.14/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleMasterTest.java +++ b/client-flink/flink-1.14/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleMasterTest.java @@ -47,7 +47,6 @@ import org.junit.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.apache.celeborn.common.util.PackedPartitionId; import org.apache.celeborn.plugin.flink.config.PluginConf; import org.apache.celeborn.plugin.flink.utils.FlinkUtils; @@ -124,8 +123,6 @@ public class RemoteShuffleMasterTest { mapPartitionShuffleDescriptor = remoteShuffleDescriptor.getShuffleResource().getMapPartitionShuffleDescriptor(); Assert.assertEquals(0, mapPartitionShuffleDescriptor.getShuffleId()); - Assert.assertEquals( - PackedPartitionId.packedPartitionId(1, 1), mapPartitionShuffleDescriptor.getPartitionId()); Assert.assertEquals(1, mapPartitionShuffleDescriptor.getAttemptId()); Assert.assertEquals(1, mapPartitionShuffleDescriptor.getMapId()); } diff --git a/client-flink/flink-1.14/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleOutputGateSuiteJ.java b/client-flink/flink-1.14/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleOutputGateSuiteJ.java index 1b72b9c09..d9dc8d0f5 100644 --- a/client-flink/flink-1.14/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleOutputGateSuiteJ.java +++ b/client-flink/flink-1.14/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleOutputGateSuiteJ.java @@ -58,7 +58,8 @@ public class RemoteShuffleOutputGateSuiteJ { PartitionLocation partitionLocation = new PartitionLocation(1, 0, "localhost", 123, 245, 789, 238, PartitionLocation.Mode.MASTER); - when(shuffleClient.registerMapPartitionTask(any(), anyInt(), anyInt(), anyInt(), anyInt())) + when(shuffleClient.registerMapPartitionTask( + any(), anyInt(), anyInt(), anyInt(), anyInt(), anyInt())) .thenAnswer(t -> partitionLocation); doNothing() .when(remoteShuffleOutputGate.flinkShuffleClient) diff --git a/client-flink/flink-1.14/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleResultPartitionSuiteJ.java b/client-flink/flink-1.14/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleResultPartitionSuiteJ.java index e2fe8f9b3..45f8cf5b2 100644 --- a/client-flink/flink-1.14/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleResultPartitionSuiteJ.java +++ b/client-flink/flink-1.14/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleResultPartitionSuiteJ.java @@ -60,7 +60,6 @@ import org.apache.flink.util.function.SupplierWithException; import org.junit.After; import org.junit.Before; import org.junit.Test; -import org.mockito.Mockito; import org.apache.celeborn.common.CelebornConf; import org.apache.celeborn.plugin.flink.buffer.BufferPacker; @@ -557,16 +556,12 @@ public class RemoteShuffleResultPartitionSuiteJ { Random random = new Random(); byte[] bytes = new byte[16]; random.nextBytes(bytes); - ShuffleTask shuffleTask = Mockito.mock(ShuffleTask.class); - Mockito.when(shuffleTask.getAttemptId()).thenReturn(1); - Mockito.when(shuffleTask.getMapId()).thenReturn(1); - Mockito.when(shuffleTask.getShuffleId()).thenReturn(1); return new RemoteShuffleDescriptor( new JobID(bytes).toString(), new JobID(bytes), new JobID(bytes).toString(), new ResultPartitionID(), - new RemoteShuffleResource("1", 2, new ShuffleResourceDescriptor(shuffleTask))); + new RemoteShuffleResource("1", 2, new ShuffleResourceDescriptor(1, 1, 1, 0))); } /** Data written and its {@link Buffer.DataType}. */ diff --git a/client-flink/flink-1.14/src/test/java/org/apache/celeborn/plugin/flink/ShuffleTaskInfoSuitJ.java b/client-flink/flink-1.14/src/test/java/org/apache/celeborn/plugin/flink/ShuffleTaskInfoSuitJ.java index 7acaa6405..65e747e42 100644 --- a/client-flink/flink-1.14/src/test/java/org/apache/celeborn/plugin/flink/ShuffleTaskInfoSuitJ.java +++ b/client-flink/flink-1.14/src/test/java/org/apache/celeborn/plugin/flink/ShuffleTaskInfoSuitJ.java @@ -34,21 +34,21 @@ public class ShuffleTaskInfoSuitJ { int encodeShuffleId0 = shuffleTaskInfo.getShuffleId("shuffleId"); Assert.assertEquals(encodeShuffleId0, 0); - int encodeAttemptId011 = shuffleTaskInfo.getAttemptId("shuffleId", 1, "attempt1"); - int encodeAttemptId112 = shuffleTaskInfo.getAttemptId("shuffleId1", 1, "attempt2"); - int encodeAttemptId021 = shuffleTaskInfo.getAttemptId("shuffleId", 2, "attempt1"); - int encodeAttemptId012 = shuffleTaskInfo.getAttemptId("shuffleId", 1, "attempt2"); + int encodeAttemptId011 = shuffleTaskInfo.genAttemptId(encodeShuffleId1, 1); + int encodeAttemptId112 = shuffleTaskInfo.genAttemptId(encodeShuffleId1, 1); + int encodeAttemptId021 = shuffleTaskInfo.genAttemptId(encodeShuffleId0, 2); + int encodeAttemptId012 = shuffleTaskInfo.genAttemptId(encodeShuffleId0, 1); Assert.assertEquals(encodeAttemptId011, 0); - Assert.assertEquals(encodeAttemptId112, 0); + Assert.assertEquals(encodeAttemptId112, 1); Assert.assertEquals(encodeAttemptId021, 0); - Assert.assertEquals(encodeAttemptId012, 1); + Assert.assertEquals(encodeAttemptId012, 0); // remove shuffleId and reEncode shuffleTaskInfo.removeExpiredShuffle(encodeShuffleId); int encodeShuffleIdNew = shuffleTaskInfo.getShuffleId("shuffleId"); Assert.assertEquals(encodeShuffleIdNew, 2); - int encodeAttemptId211 = shuffleTaskInfo.getAttemptId("shuffleId", 1, "attempt1"); + int encodeAttemptId211 = shuffleTaskInfo.genAttemptId(encodeShuffleIdNew, 1); Assert.assertEquals(encodeAttemptId211, 0); } diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java index 51d52f078..2b5d07c5b 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java @@ -196,7 +196,8 @@ public abstract class ShuffleClient { public abstract void shutdown(); public abstract PartitionLocation registerMapPartitionTask( - String appId, int shuffleId, int numMappers, int mapId, int attemptId) throws IOException; + String appId, int shuffleId, int numMappers, int mapId, int attemptId, int partitionId) + throws IOException; public abstract ConcurrentHashMap getPartitionLocation( String applicationId, int shuffleId, int numMappers, int numPartitions); diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java index 06035575b..f11ffea5f 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java @@ -383,8 +383,8 @@ public class ShuffleClientImpl extends ShuffleClient { @VisibleForTesting public PartitionLocation registerMapPartitionTask( - String appId, int shuffleId, int numMappers, int mapId, int attemptId) throws IOException { - int partitionId = PackedPartitionId.packedPartitionId(mapId, attemptId); + String appId, int shuffleId, int numMappers, int mapId, int attemptId, int partitionId) + throws IOException { logger.info( "Register MapPartition task for shuffle {} map {} attempt {} partition {} with {} mapper.", shuffleId, diff --git a/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java b/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java index 9c88fd683..896ef1a26 100644 --- a/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java +++ b/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java @@ -151,7 +151,7 @@ public class DummyShuffleClient extends ShuffleClient { @Override public PartitionLocation registerMapPartitionTask( - String appId, int shuffleId, int numMappers, int mapId, int attemptId) { + String appId, int shuffleId, int numMappers, int mapId, int attemptId, int partitionId) { return null; } diff --git a/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala b/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala index b4e1a9600..9018c3889 100644 --- a/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala +++ b/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala @@ -28,7 +28,6 @@ import org.apache.celeborn.CelebornFunSuite import org.apache.celeborn.common.CelebornConf import org.apache.celeborn.common.identity.UserIdentifier import org.apache.celeborn.common.util.JavaUtils.timeOutOrMeetCondition -import org.apache.celeborn.common.util.PackedPartitionId trait WithShuffleClientSuite extends CelebornFunSuite { @@ -58,12 +57,13 @@ trait WithShuffleClientSuite extends CelebornFunSuite { prepareService() shuffleId = 1 var location = - shuffleClient.registerMapPartitionTask(APP, shuffleId, numMappers, mapId, attemptId) - Assert.assertEquals(location.getId, PackedPartitionId.packedPartitionId(mapId, attemptId)) + shuffleClient.registerMapPartitionTask(APP, shuffleId, numMappers, mapId, attemptId, 1) + Assert.assertEquals(location.getId, 1) // retry register - location = shuffleClient.registerMapPartitionTask(APP, shuffleId, numMappers, mapId, attemptId) - Assert.assertEquals(location.getId, PackedPartitionId.packedPartitionId(mapId, attemptId)) + location = + shuffleClient.registerMapPartitionTask(APP, shuffleId, numMappers, mapId, attemptId, 1) + Assert.assertEquals(location.getId, 1) // check all allocated slots var partitionLocationInfos = lifecycleManager.workerSnapshots(shuffleId).values().asScala @@ -73,15 +73,19 @@ trait WithShuffleClientSuite extends CelebornFunSuite { // another mapId location = - shuffleClient.registerMapPartitionTask(APP, shuffleId, numMappers, mapId + 1, attemptId) - Assert.assertEquals(location.getId, PackedPartitionId.packedPartitionId(mapId + 1, attemptId)) + shuffleClient.registerMapPartitionTask(APP, shuffleId, numMappers, mapId + 1, attemptId, 2) + Assert.assertEquals(location.getId, 2) // another mapId with another attemptId location = - shuffleClient.registerMapPartitionTask(APP, shuffleId, numMappers, mapId + 1, attemptId + 1) - Assert.assertEquals( - location.getId, - PackedPartitionId.packedPartitionId(mapId + 1, attemptId + 1)) + shuffleClient.registerMapPartitionTask( + APP, + shuffleId, + numMappers, + mapId + 1, + attemptId + 1, + numMappers + 1) + Assert.assertEquals(location.getId, numMappers + 1) // check all allocated all slots partitionLocationInfos = lifecycleManager.workerSnapshots(shuffleId).values().asScala @@ -101,7 +105,7 @@ trait WithShuffleClientSuite extends CelebornFunSuite { // check batch release lifecycleManager.releasePartition( shuffleId, - PackedPartitionId.packedPartitionId(mapId, attemptId + 1)) + 4) timeOutOrMeetCondition(new Callable[java.lang.Boolean] { override def call(): lang.Boolean = { @@ -122,7 +126,7 @@ trait WithShuffleClientSuite extends CelebornFunSuite { // check single release lifecycleManager.releasePartition( shuffleId, - PackedPartitionId.packedPartitionId(mapId, attemptId + 1)) + 4) Assert.assertEquals( partitionLocationInfos.map(r => r.getMasterPartitions().size()).sum, @@ -153,15 +157,15 @@ trait WithShuffleClientSuite extends CelebornFunSuite { } private def registerAndFinishPartition(): Unit = { - shuffleClient.registerMapPartitionTask(APP, shuffleId, numMappers, mapId, attemptId) - shuffleClient.registerMapPartitionTask(APP, shuffleId, numMappers, mapId + 1, attemptId) - shuffleClient.registerMapPartitionTask(APP, shuffleId, numMappers, mapId + 2, attemptId) + shuffleClient.registerMapPartitionTask(APP, shuffleId, numMappers, mapId, attemptId, 1) + shuffleClient.registerMapPartitionTask(APP, shuffleId, numMappers, mapId + 1, attemptId, 2) + shuffleClient.registerMapPartitionTask(APP, shuffleId, numMappers, mapId + 2, attemptId, 3) // task number incr to numMappers + 1 - shuffleClient.registerMapPartitionTask(APP, shuffleId, numMappers, mapId, attemptId + 1) - shuffleClient.mapPartitionMapperEnd(APP, shuffleId, mapId, attemptId, numMappers, mapId) + shuffleClient.registerMapPartitionTask(APP, shuffleId, numMappers, mapId, attemptId + 1, 9) + shuffleClient.mapPartitionMapperEnd(APP, shuffleId, mapId, attemptId, numMappers, 1) // retry - shuffleClient.mapPartitionMapperEnd(APP, shuffleId, mapId, attemptId, numMappers, mapId) + shuffleClient.mapPartitionMapperEnd(APP, shuffleId, mapId, attemptId, numMappers, 1) // another attempt shuffleClient.mapPartitionMapperEnd( APP, @@ -169,8 +173,7 @@ trait WithShuffleClientSuite extends CelebornFunSuite { mapId, attemptId + 1, numMappers, - PackedPartitionId - .packedPartitionId(mapId, attemptId + 1)) + 9) // another mapper shuffleClient.mapPartitionMapperEnd(APP, shuffleId, mapId + 1, attemptId, numMappers, mapId + 1) } diff --git a/common/src/main/java/org/apache/celeborn/common/protocol/PartitionLocation.java b/common/src/main/java/org/apache/celeborn/common/protocol/PartitionLocation.java index a590c16e5..932589632 100644 --- a/common/src/main/java/org/apache/celeborn/common/protocol/PartitionLocation.java +++ b/common/src/main/java/org/apache/celeborn/common/protocol/PartitionLocation.java @@ -22,7 +22,6 @@ import java.io.Serializable; import org.roaringbitmap.RoaringBitmap; import org.apache.celeborn.common.meta.WorkerInfo; -import org.apache.celeborn.common.util.PackedPartitionId; public class PartitionLocation implements Serializable { public enum Mode { @@ -286,13 +285,9 @@ public class PartitionLocation implements Serializable { peerAddr = peer.hostAndPorts(); } return "PartitionLocation[" - + "\n id(rawId-attemptId)-epoch:" + + "\n id-epoch:" + id - + "(" - + getRawId() + "-" - + getAttemptId() - + ")-" + epoch + "\n host-rpcPort-pushPort-fetchPort-replicatePort:" + host @@ -326,12 +321,4 @@ public class PartitionLocation implements Serializable { public void setMapIdBitMap(RoaringBitmap mapIdBitMap) { this.mapIdBitMap = mapIdBitMap; } - - public int getRawId() { - return PackedPartitionId.getRawPartitionId(id); - } - - public int getAttemptId() { - return PackedPartitionId.getAttemptId(id); - } } diff --git a/common/src/main/java/org/apache/celeborn/common/util/PackedPartitionId.java b/common/src/main/java/org/apache/celeborn/common/util/PackedPartitionId.java deleted file mode 100644 index 00163f5a5..000000000 --- a/common/src/main/java/org/apache/celeborn/common/util/PackedPartitionId.java +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.celeborn.common.util; - -import com.google.common.base.Preconditions; - -/** - * Pack for encode/decode id of partition Location for id of partitionLocation attemptId - * raw_partitionId
- * (upper 8 bits = attemptId) (lower 24 bits = raw id)
- * (0000 0000) (0000 0000 0000 0000 0000 0000)
- * - * @see org.apache.celeborn.common.protocol.PartitionLocation#id - */ -public class PackedPartitionId { - - /** The maximum partition identifier that can be encoded. Note that partition ids start from 0. */ - static final int MAXIMUM_PARTITION_ID = (1 << 24) - 1; // 16777215 - - /** The maximum partition attempt id that can be encoded. Note that attempt ids start from 0. */ - static final int MAXIMUM_ATTEMPT_ID = (1 << 8) - 1; // 255 - - static final int MASK_INT_LOWER_24_BITS = (int) (1L << 24) - 1; - - public static int packedPartitionId(int partitionRawId, int attemptId) { - Preconditions.checkArgument( - partitionRawId <= MAXIMUM_PARTITION_ID, - "packedPartitionId called with invalid partitionRawId: " + partitionRawId); - Preconditions.checkArgument( - attemptId <= MAXIMUM_ATTEMPT_ID, - "packedPartitionId called with invalid attemptId: " + attemptId); - - return (attemptId << 24) | partitionRawId; - } - - public static int getRawPartitionId(int packedPartitionId) { - return packedPartitionId & MASK_INT_LOWER_24_BITS; - } - - public static int getAttemptId(int packedPartitionId) { - return packedPartitionId >>> 24; - } -} diff --git a/common/src/test/java/org/apache/celeborn/common/protocol/PartitionLocationSuiteJ.java b/common/src/test/java/org/apache/celeborn/common/protocol/PartitionLocationSuiteJ.java index 010f5afcf..6eb952d2f 100644 --- a/common/src/test/java/org/apache/celeborn/common/protocol/PartitionLocationSuiteJ.java +++ b/common/src/test/java/org/apache/celeborn/common/protocol/PartitionLocationSuiteJ.java @@ -20,8 +20,6 @@ package org.apache.celeborn.common.protocol; import org.junit.Test; import org.roaringbitmap.RoaringBitmap; -import org.apache.celeborn.common.util.PackedPartitionId; - public class PartitionLocationSuiteJ { private final int partitionId = 0; @@ -186,12 +184,10 @@ public class PartitionLocationSuiteJ { bitmap.add(2); bitmap.add(3); - int attemptId = 10; - int rawPartitionId = 1000; - int newPartitionId = PackedPartitionId.packedPartitionId(rawPartitionId, attemptId); + int partitionId = 1000; PartitionLocation location3 = new PartitionLocation( - newPartitionId, + partitionId, epoch, host, rpcPort, @@ -205,7 +201,7 @@ public class PartitionLocationSuiteJ { String exp1 = "PartitionLocation[\n" - + " id(rawId-attemptId)-epoch:0(0-0)-0\n" + + " id-epoch:0-0\n" + " host-rpcPort-pushPort-fetchPort-replicatePort:localhost-3-1-2-4\n" + " mode:MASTER\n" + " peer:(empty)\n" @@ -213,7 +209,7 @@ public class PartitionLocationSuiteJ { + " mapIdBitMap:{}]"; String exp2 = "PartitionLocation[\n" - + " id(rawId-attemptId)-epoch:0(0-0)-0\n" + + " id-epoch:0-0\n" + " host-rpcPort-pushPort-fetchPort-replicatePort:localhost-3-1-2-4\n" + " mode:MASTER\n" + " peer:(host-rpcPort-pushPort-fetchPort-replicatePort:localhost-3-1-2-4)\n" @@ -221,16 +217,12 @@ public class PartitionLocationSuiteJ { + " mapIdBitMap:{}]"; String exp3 = "PartitionLocation[\n" - + " id(rawId-attemptId)-epoch:167773160(1000-10)-0\n" + + " id-epoch:1000-0\n" + " host-rpcPort-pushPort-fetchPort-replicatePort:localhost-3-1-2-4\n" + " mode:MASTER\n" + " peer:(host-rpcPort-pushPort-fetchPort-replicatePort:localhost-3-1-2-4)\n" + " storage hint:StorageInfo{type=MEMORY, mountPoint='/mnt/disk/0', finalResult=false, filePath=null}\n" + " mapIdBitMap:{1,2,3}]"; - System.out.println(location1); - System.out.println(location2); - System.out.println(location3); - assert exp1.equals(location1.toString()); assert exp2.equals(location2.toString()); assert exp3.equals(location3.toString()); diff --git a/common/src/test/java/org/apache/celeborn/common/util/PackedPartitionIdSuiteJ.java b/common/src/test/java/org/apache/celeborn/common/util/PackedPartitionIdSuiteJ.java deleted file mode 100644 index 8ed0bc30d..000000000 --- a/common/src/test/java/org/apache/celeborn/common/util/PackedPartitionIdSuiteJ.java +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.celeborn.common.util; - -import org.junit.Assert; -import org.junit.Test; - -public class PackedPartitionIdSuiteJ { - - @Test - public void testNormalPackedPartitionId() { - assertTest(0, 0); - assertTest(555, 1); - assertTest(888, 1); - assertTest(10001, 100); - - // testUseMaxPartitionId or MaxAttemptId - assertTest(PackedPartitionId.MAXIMUM_PARTITION_ID, 11); - assertTest(100, PackedPartitionId.MAXIMUM_ATTEMPT_ID); - assertTest(PackedPartitionId.MAXIMUM_PARTITION_ID, PackedPartitionId.MAXIMUM_ATTEMPT_ID); - } - - @Test(expected = IllegalArgumentException.class) - public void testAttemptIdGreaterThanMaximumAttemptId() { - PackedPartitionId.packedPartitionId(0, PackedPartitionId.MAXIMUM_ATTEMPT_ID + 1); - } - - @Test(expected = IllegalArgumentException.class) - public void testPartitionIdGreaterThanMaximumPartitionId() { - PackedPartitionId.packedPartitionId(PackedPartitionId.MAXIMUM_PARTITION_ID + 1, 1); - } - - private void assertTest(int partitionRawId, int attemptId) { - int packedPartitionId = PackedPartitionId.packedPartitionId(partitionRawId, attemptId); - Assert.assertTrue(partitionRawId == PackedPartitionId.getRawPartitionId(packedPartitionId)); - Assert.assertTrue(attemptId == PackedPartitionId.getAttemptId(packedPartitionId)); - } -} diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/ShuffleClientSuite.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/ShuffleClientSuite.scala index 8b1fde495..d02293380 100644 --- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/ShuffleClientSuite.scala +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/ShuffleClientSuite.scala @@ -53,7 +53,7 @@ class ShuffleClientSuite extends WithShuffleClientSuite with MiniClusterFeature } assertThrows[IOException] { - () -> shuffleClient.registerMapPartitionTask(APP, 1, 1, 0, 0) + () -> shuffleClient.registerMapPartitionTask(APP, 1, 1, 0, 0, 1) } lifecycleManager.stop() diff --git a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala index d57017c38..1e84a899c 100644 --- a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala +++ b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala @@ -37,7 +37,6 @@ import org.apache.celeborn.common.network.server.BaseMessageHandler import org.apache.celeborn.common.protocol.{PartitionLocation, PartitionSplitMode, PartitionType} import org.apache.celeborn.common.protocol.message.StatusCode import org.apache.celeborn.common.unsafe.Platform -import org.apache.celeborn.common.util.PackedPartitionId import org.apache.celeborn.service.deploy.worker.congestcontrol.CongestionController import org.apache.celeborn.service.deploy.worker.storage.{FileWriter, HdfsFlusher, LocalFlusher, MapPartitionFileWriter, StorageManager} @@ -965,22 +964,13 @@ class PushDataHandler extends BaseMessageHandler with Logging { callback: RpcResponseCallback, wrappedCallback: RpcResponseCallback): Boolean = { if (location == null) { - val (mapId, attemptId) = getMapAttempt(partitionUniqueId) - if (shuffleMapperAttempts.containsKey(shuffleKey) && - -1 != shuffleMapperAttempts.get(shuffleKey).get(mapId)) { - // partition data has already been committed - logInfo(s"Receive push data from speculative task(shuffle $shuffleKey, map $mapId, " + - s" attempt $attemptId), but this mapper has already been ended.") - wrappedCallback.onSuccess(ByteBuffer.wrap(Array[Byte](StatusCode.STAGE_ENDED.getValue))) - } else { - val msg = s"Partition location wasn't found for task(shuffle $shuffleKey, map $mapId, " + - s"attempt $attemptId, uniqueId $partitionUniqueId)." - logWarning(s"[handle$messageType] $msg") - messageType match { - case Type.PUSH_MERGED_DATA => callback.onFailure(new CelebornIOException(msg)) - case _ => callback.onFailure( - new CelebornIOException(StatusCode.PUSH_DATA_FAIL_PARTITION_NOT_FOUND)) - } + val msg = + s"Partition location wasn't found for task(shuffle $shuffleKey, uniqueId $partitionUniqueId)." + logWarning(s"[handle$messageType] $msg") + messageType match { + case Type.PUSH_MERGED_DATA => callback.onFailure(new CelebornIOException(msg)) + case _ => callback.onFailure( + new CelebornIOException(StatusCode.PUSH_DATA_FAIL_PARTITION_NOT_FOUND)) } return true } @@ -1058,12 +1048,6 @@ class PushDataHandler extends BaseMessageHandler with Logging { false } - private def getMapAttempt( - partitionUniqueId: String): (Int, Int) = { - val id = partitionUniqueId.split("-")(0).toInt - (PackedPartitionId.getRawPartitionId(id), PackedPartitionId.getAttemptId(id)) - } - private def getClient(host: String, port: Int, partitionId: Int): TransportClient = { if (conf.workerReplicateRandomConnectionEnabled) { pushClientFactory.createClient(host, port)