diff --git a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java index edbfa529f..af3207e83 100644 --- a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java +++ b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java @@ -323,10 +323,10 @@ public class SortBasedShuffleWriter extends ShuffleWriter { } long pushStartTime = System.nanoTime(); if (pipelined) { - for (int i = 0; i < pushers.length; i++) { - pushers[i].waitPushFinish(); - pushers[i].pushData(); - pushers[i].close(); + for (SortBasedPusher pusher : pushers) { + pusher.waitPushFinish(); + pusher.pushData(); + pusher.close(); } } else { currentPusher.pushData(); @@ -344,13 +344,11 @@ public class SortBasedShuffleWriter extends ShuffleWriter { } private void updateMapStatus() { - long recordsWritten = 0; - for (int i = 0; i < partitioner.numPartitions(); i++) { + for (int i = 0; i < tmpRecords.length; i++) { mapStatusRecords[i] += tmpRecords[i]; - recordsWritten += tmpRecords[i]; + writeMetrics.incRecordsWritten(tmpRecords[i]); tmpRecords[i] = 0; } - writeMetrics.incRecordsWritten(recordsWritten); } @Override diff --git a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java index 070aa828f..ec92b0e46 100644 --- a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java +++ b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java @@ -83,7 +83,6 @@ public class SortBasedShuffleWriter extends ShuffleWriter { private final SerializationStream serOutputStream; private final LongAdder[] mapStatusLengths; - private final long[] mapStatusRecords; private final long[] tmpRecords; /** @@ -124,7 +123,6 @@ public class SortBasedShuffleWriter extends ShuffleWriter { serOutputStream = serializer.serializeStream(serBuffer); this.mapStatusLengths = new LongAdder[numPartitions]; - this.mapStatusRecords = new long[numPartitions]; for (int i = 0; i < numPartitions; i++) { this.mapStatusLengths[i] = new LongAdder(); } @@ -331,10 +329,10 @@ public class SortBasedShuffleWriter extends ShuffleWriter { } long pushStartTime = System.nanoTime(); if (pipelined) { - for (int i = 0; i < pushers.length; i++) { - pushers[i].waitPushFinish(); - pushers[i].pushData(); - pushers[i].close(); + for (SortBasedPusher pusher : pushers) { + pusher.waitPushFinish(); + pusher.pushData(); + pusher.close(); } } else { currentPusher.pushData(); @@ -354,13 +352,10 @@ public class SortBasedShuffleWriter extends ShuffleWriter { } private void updateMapStatus() { - long recordsWritten = 0; - for (int i = 0; i < partitioner.numPartitions(); i++) { - mapStatusRecords[i] += tmpRecords[i]; - recordsWritten += tmpRecords[i]; + for (int i = 0; i < tmpRecords.length; i++) { + writeMetrics.incRecordsWritten(tmpRecords[i]); tmpRecords[i] = 0; } - writeMetrics.incRecordsWritten(recordsWritten); } @Override