[CELEBORN-1506][BUG] Revert "[CELEBORN-1036][FOLLOWUP] totalInflightReqs should decrement when batchIdSet contains the batchId to avoid duplicate caller of removeBatch"

### What changes were proposed in this pull request?
One of our users reported a dataloss issue in https://github.com/apache/celeborn/pull/2612 , I tried to reproduce
the bug with the following setup:
1. Partition data is far larger than `spark.celeborn.client.shuffle.partitionSplit.threshold`, which means split happens very often
2. `spark.celeborn.client.shuffle.partitionSplit.threshold` is larger than `celeborn.worker.shuffle.partitionSplit.max`, which means when split happens, it is `HARD_SPLIT`
3. `celeborn.client.shuffle.batchHandleChangePartition.enabled` is true, which means when hard split happens, `LifecycleManager` will commit the splits before the stage finishes

Configs in spark side:
```
spark.celeborn.client.push.maxReqsInFlight.perWorker | 256
spark.celeborn.client.push.maxReqsInFlight.total | 2048
spark.celeborn.client.shuffle.batchHandleCommitPartition.enabled | true
spark.celeborn.client.shuffle.compression.codec | zstd
spark.celeborn.client.shuffle.partitionSplit.threshold | 48m
spark.celeborn.client.spark.fetch.throwsFetchFailure | true
spark.celeborn.client.spark.push.sort.memory.adaptiveThreshold | true
spark.celeborn.client.spark.push.sort.memory.threshold | 512m
spark.celeborn.client.spark.shuffle.writer | sort
spark.celeborn.master.endpoints | master-1-1:9097

```
Configs in celeborn side:
```
celeborn.metrics.enabled=false
celeborn.replicate.io.numConnectionsPerPeer=24
celeborn.application.heartbeat.timeout=120s
celeborn.worker.storage.dirs=/mnt/disk1,/mnt/disk2
celeborn.network.timeout=2000s
celeborn.ha.enabled=false
celeborn.worker.closeIdleConnections=true
celeborn.worker.monitor.disk.enabled=false
celeborn.worker.flusher.threads=16

celeborn.worker.graceful.shutdown.enabled=true
celeborn.worker.rpc.port=9100
celeborn.worker.push.port=9101
celeborn.worker.fetch.port=9102
celeborn.worker.replicate.port=9103

celeborn.worker.shuffle.partitionSplit.max=10m  // this is made to be small
```

My query on 10T TPCDS:
```
select
max(ss_sold_time_sk      ),
max(ss_item_sk           ),
max(ss_customer_sk       ),
max(ss_cdemo_sk          ),
max(ss_hdemo_sk          ),
max(ss_addr_sk           ),
max(ss_store_sk          ),
max(ss_promo_sk          ),
max(ss_ticket_number     ),
max(ss_quantity          ),
max(ss_wholesale_cost    ),
max(ss_list_price        ),
max(ss_sales_price       ),
max(ss_ext_discount_amt  ),
max(ss_ext_sales_price   ),
max(ss_ext_wholesale_cost),
max(ss_ext_list_price    ),
max(ss_ext_tax           ),
max(ss_coupon_amt        ),
max(ss_net_paid          ),
max(ss_net_paid_inc_tax  ),
max(ss_net_profit        ),
max(ss_sold_date_sk      )
from (
select * from store_sales where ss_sold_date_sk is not null distribute by ss_sold_date_sk
) a;
```

After digging into it, I found the bug is introduced by https://github.com/apache/celeborn/pull/2134 . #2134 added
check in `InFlightRequestTracker#addBatch` and `InFlightRequestTracker#removeBatch` and only
increments/decrements `totalInflightReqs`  when `batchIdSet` contains current `batchId`, which conflicts with
`ShuffleClientImpl#PushDataRpcResponseCallback#updateLatestPartition`, which calls `addBatch` first then calls
`removeBatch` with the same batchId. As a result, the call to `addBatch` fails to increment `totalInflightReqs`, but
the call to `removeBatch` decrements `totalInflightReqs`, which means the retried push is not counted, then later
`limitZeroInFlight` in `mapperEnd` will return even though the retried push fails.

This PR fixes the bug by reverting #2134

### Why are the changes needed?
ditto

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
Manual test.

Closes #2621 from waitinfuture/1506.

Authored-by: zky.zhoukeyong <zky.zhoukeyong@alibaba-inc.com>
Signed-off-by: zky.zhoukeyong <zky.zhoukeyong@alibaba-inc.com>
This commit is contained in:
zky.zhoukeyong 2024-07-13 13:09:46 +08:00
parent 09d3a3b05f
commit 8d0b4cf4cd

View File

@ -58,28 +58,22 @@ public class InFlightRequestTracker {
}
public void addBatch(int batchId, String hostAndPushPort) {
Set<Integer> batchIdSet =
Set<Integer> batchIdSetPerPair =
inflightBatchesPerAddress.computeIfAbsent(
hostAndPushPort, id -> ConcurrentHashMap.newKeySet());
if (batchIdSet.add(batchId)) {
totalInflightReqs.increment();
} else {
logger.debug("{} has already been inflight.", batchId);
}
batchIdSetPerPair.add(batchId);
totalInflightReqs.increment();
}
public void removeBatch(int batchId, String hostAndPushPort) {
Set<Integer> batchIdSet = inflightBatchesPerAddress.get(hostAndPushPort);
// TODO: Need to debug why batchIdSet will be null.
if (batchIdSet != null) {
if (batchIdSet.remove(batchId)) {
totalInflightReqs.decrement();
} else {
logger.debug("BatchIdSet has removed {}.", batchId);
}
batchIdSet.remove(batchId);
} else {
logger.warn("BatchIdSet of {} is null.", hostAndPushPort);
}
totalInflightReqs.decrement();
}
public void onSuccess(String hostAndPushPort) {
@ -103,7 +97,7 @@ public class InFlightRequestTracker {
pushStrategy.limitPushSpeed(pushState, hostAndPushPort);
int currentMaxReqsInFlight = pushStrategy.getCurrentMaxReqsInFlight(hostAndPushPort);
Set<Integer> batchIdSet = getBatchIdSetByAddressPair(hostAndPushPort);
Set batchIdSet = getBatchIdSetByAddressPair(hostAndPushPort);
long times = waitInflightTimeoutMs / delta;
try {
while (times > 0) {