diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedPusher.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedPusher.java index 3517cf7a4..2f7ccd2d2 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedPusher.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedPusher.java @@ -49,7 +49,7 @@ public class SortBasedPusher extends MemoryConsumer { private final ShuffleClient rssShuffleClient; private final DataPusher dataPusher; - private final int pushBufferSize; + private final int pushBufferMaxSize; private final long PushThreshold; final int uaoSize = UnsafeAlignedOffset.getUaoSize(); @@ -110,7 +110,7 @@ public class SortBasedPusher extends MemoryConsumer { afterPush, mapStatusLengths); - pushBufferSize = RssConf.pushBufferMaxSize(conf); + pushBufferMaxSize = conf.pushBufferMaxSize(); PushThreshold = RssConf.sortPushThreshold(conf); inMemSorter = new ShuffleInMemorySorter(this, 4 * 1024 * 1024); @@ -124,7 +124,7 @@ public class SortBasedPusher extends MemoryConsumer { final ShuffleInMemorySorter.ShuffleSorterIterator sortedRecords = inMemSorter.getSortedIterator(); - byte[] dataBuf = new byte[pushBufferSize]; + byte[] dataBuf = new byte[pushBufferMaxSize]; int offSet = 0; int currentPartition = -1; while (sortedRecords.hasNext()) { diff --git a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java index bd858682b..c34f67663 100644 --- a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java +++ b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java @@ -59,8 +59,8 @@ public class HashBasedShuffleWriter extends ShuffleWriter { private static final ClassTag OBJECT_CLASS_TAG = ClassTag$.MODULE$.Object(); private static final int DEFAULT_INITIAL_SER_BUFFER_SIZE = 1024 * 1024; - private final int SEND_BUFFER_INIT_SIZE; - private final int SEND_BUFFER_SIZE; + private final int PUSH_BUFFER_INIT_SIZE; + private final int PUSH_BUFFER_MAX_SIZE; private final ShuffleDependency dep; private final Partitioner partitioner; private final ShuffleWriteMetrics writeMetrics; @@ -130,8 +130,8 @@ public class HashBasedShuffleWriter extends ShuffleWriter { mapStatusRecords = new long[numPartitions]; tmpRecords = new long[numPartitions]; - SEND_BUFFER_INIT_SIZE = RssConf.pushBufferInitialSize(conf); - SEND_BUFFER_SIZE = RssConf.pushBufferMaxSize(conf); + PUSH_BUFFER_INIT_SIZE = conf.pushBufferInitialSize(); + PUSH_BUFFER_MAX_SIZE = conf.pushBufferMaxSize(); this.sendBufferPool = sendBufferPool; sendBuffers = sendBufferPool.acquireBuffer(numPartitions); @@ -194,7 +194,7 @@ public class HashBasedShuffleWriter extends ShuffleWriter { dataSize.add(serializedRecordSize); } - if (serializedRecordSize > SEND_BUFFER_SIZE) { + if (serializedRecordSize > PUSH_BUFFER_MAX_SIZE) { byte[] giantBuffer = new byte[serializedRecordSize]; Platform.putInt(giantBuffer, Platform.BYTE_ARRAY_OFFSET, Integer.reverseBytes(rowSize)); Platform.copyMemory( @@ -235,7 +235,7 @@ public class HashBasedShuffleWriter extends ShuffleWriter { final int serializedRecordSize = serBuffer.size(); assert (serializedRecordSize > 0); - if (serializedRecordSize > SEND_BUFFER_SIZE) { + if (serializedRecordSize > PUSH_BUFFER_MAX_SIZE) { pushGiantRecord(partitionId, serBuffer.getBuf(), serializedRecordSize); } else { int offset = getOrUpdateOffset(partitionId, serializedRecordSize); @@ -250,9 +250,9 @@ public class HashBasedShuffleWriter extends ShuffleWriter { private byte[] getOrCreateBuffer(int partitionId) { byte[] buffer = sendBuffers[partitionId]; if (buffer == null) { - buffer = new byte[SEND_BUFFER_INIT_SIZE]; + buffer = new byte[PUSH_BUFFER_INIT_SIZE]; sendBuffers[partitionId] = buffer; - peakMemoryUsedBytes += SEND_BUFFER_INIT_SIZE; + peakMemoryUsedBytes += PUSH_BUFFER_INIT_SIZE; } return buffer; } @@ -280,9 +280,10 @@ public class HashBasedShuffleWriter extends ShuffleWriter { private int getOrUpdateOffset(int partitionId, int serializedRecordSize) throws IOException { int offset = sendOffsets[partitionId]; byte[] buffer = getOrCreateBuffer(partitionId); - while ((buffer.length - offset) < serializedRecordSize && buffer.length < SEND_BUFFER_SIZE) { + while ((buffer.length - offset) < serializedRecordSize + && buffer.length < PUSH_BUFFER_MAX_SIZE) { - byte[] newBuffer = new byte[Math.min(buffer.length * 2, SEND_BUFFER_SIZE)]; + byte[] newBuffer = new byte[Math.min(buffer.length * 2, PUSH_BUFFER_MAX_SIZE)]; peakMemoryUsedBytes += newBuffer.length - buffer.length; System.arraycopy(buffer, 0, newBuffer, 0, offset); sendBuffers[partitionId] = newBuffer; diff --git a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/RssShuffleManager.java b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/RssShuffleManager.java index 32661633c..a8612331e 100644 --- a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/RssShuffleManager.java +++ b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/RssShuffleManager.java @@ -160,15 +160,15 @@ public class RssShuffleManager implements ShuffleManager { ShuffleClient client = ShuffleClient.get( h.rssMetaServiceHost(), h.rssMetaServicePort(), rssConf, h.userIdentifier()); - if ("sort".equals(RssConf.shuffleWriterMode(rssConf))) { + if ("sort".equals(rssConf.shuffleWriterMode())) { return new SortBasedShuffleWriter<>( h.dependency(), h.newAppId(), h.numMaps(), context, rssConf, client); - } else if ("hash".equals(RssConf.shuffleWriterMode(rssConf))) { + } else if ("hash".equals(rssConf.shuffleWriterMode())) { return new HashBasedShuffleWriter<>( h, mapId, context, rssConf, client, SendBufferPool.get(cores)); } else { throw new UnsupportedOperationException( - "Unrecognized shuffle write mode!" + RssConf.shuffleWriterMode(rssConf)); + "Unrecognized shuffle write mode!" + rssConf.shuffleWriterMode()); } } else { return sortShuffleManager().getWriter(handle, mapId, context); diff --git a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java index 76478138b..89aa14cd9 100644 --- a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java +++ b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java @@ -67,7 +67,7 @@ public class SortBasedShuffleWriter extends ShuffleWriter { private final int numMappers; private final int numPartitions; - private final long pushBufferSize; + private final long pushBufferMaxSize; private SortBasedPusher sortBasedPusher; private long peakMemoryUsedBytes = 0; @@ -118,7 +118,7 @@ public class SortBasedShuffleWriter extends ShuffleWriter { this.mapStatusRecords = new long[numPartitions]; tmpRecords = new long[numPartitions]; - pushBufferSize = RssConf.pushBufferMaxSize(conf); + pushBufferMaxSize = conf.pushBufferMaxSize(); sortBasedPusher = new SortBasedPusher( @@ -175,7 +175,7 @@ public class SortBasedShuffleWriter extends ShuffleWriter { dataSize.add(serializedRecordSize); } - if (serializedRecordSize > pushBufferSize) { + if (serializedRecordSize > pushBufferMaxSize) { byte[] giantBuffer = new byte[serializedRecordSize]; Platform.putInt(giantBuffer, Platform.BYTE_ARRAY_OFFSET, Integer.reverseBytes(rowSize)); Platform.copyMemory( @@ -210,7 +210,7 @@ public class SortBasedShuffleWriter extends ShuffleWriter { final int serializedRecordSize = serBuffer.size(); assert (serializedRecordSize > 0); - if (serializedRecordSize > pushBufferSize) { + if (serializedRecordSize > pushBufferMaxSize) { pushGiantRecord(partitionId, serBuffer.getBuf(), serializedRecordSize); } else { long insertStartTime = System.nanoTime(); diff --git a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java index 8d059e3be..ff5881b90 100644 --- a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java +++ b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java @@ -61,8 +61,8 @@ public class HashBasedShuffleWriter extends ShuffleWriter { private static final ClassTag OBJECT_CLASS_TAG = ClassTag$.MODULE$.Object(); private static final int DEFAULT_INITIAL_SER_BUFFER_SIZE = 1024 * 1024; - private final int SEND_BUFFER_INIT_SIZE; - private final int SEND_BUFFER_SIZE; + private final int PUSH_BUFFER_INIT_SIZE; + private final int PUSH_BUFFER_MAX_SIZE; private final ShuffleDependency dep; private final Partitioner partitioner; private final ShuffleWriteMetricsReporter writeMetrics; @@ -135,8 +135,8 @@ public class HashBasedShuffleWriter extends ShuffleWriter { } tmpRecords = new long[numPartitions]; - SEND_BUFFER_INIT_SIZE = RssConf.pushBufferInitialSize(conf); - SEND_BUFFER_SIZE = RssConf.pushBufferMaxSize(conf); + PUSH_BUFFER_INIT_SIZE = conf.pushBufferInitialSize(); + PUSH_BUFFER_MAX_SIZE = conf.pushBufferMaxSize(); this.sendBufferPool = sendBufferPool; sendBuffers = sendBufferPool.acquireBuffer(numPartitions); @@ -214,7 +214,7 @@ public class HashBasedShuffleWriter extends ShuffleWriter { rssColumnBuilders[partitionId] = columnBuilders; } rssColumnBuilders[partitionId].writeRow(row); - if (rssColumnBuilders[partitionId].getTotalSize() > SEND_BUFFER_SIZE + if (rssColumnBuilders[partitionId].getTotalSize() > PUSH_BUFFER_MAX_SIZE || rssColumnBuilders[partitionId].rowCnt() == RssConf.columnarShuffleBatchSize(rssConf)) { byte[] arr = rssColumnBuilders[partitionId].buildColumnBytes(); pushGiantRecord(partitionId, arr, arr.length); @@ -244,7 +244,7 @@ public class HashBasedShuffleWriter extends ShuffleWriter { dataSize.add(rowSize); } - if (serializedRecordSize > SEND_BUFFER_SIZE) { + if (serializedRecordSize > PUSH_BUFFER_MAX_SIZE) { byte[] giantBuffer = new byte[serializedRecordSize]; Platform.putInt(giantBuffer, Platform.BYTE_ARRAY_OFFSET, Integer.reverseBytes(rowSize)); Platform.copyMemory( @@ -285,7 +285,7 @@ public class HashBasedShuffleWriter extends ShuffleWriter { final int serializedRecordSize = serBuffer.size(); assert (serializedRecordSize > 0); - if (serializedRecordSize > SEND_BUFFER_SIZE) { + if (serializedRecordSize > PUSH_BUFFER_MAX_SIZE) { pushGiantRecord(partitionId, serBuffer.getBuf(), serializedRecordSize); } else { int offset = getOrUpdateOffset(partitionId, serializedRecordSize); @@ -300,9 +300,9 @@ public class HashBasedShuffleWriter extends ShuffleWriter { private byte[] getOrCreateBuffer(int partitionId) { byte[] buffer = sendBuffers[partitionId]; if (buffer == null) { - buffer = new byte[SEND_BUFFER_INIT_SIZE]; + buffer = new byte[PUSH_BUFFER_INIT_SIZE]; sendBuffers[partitionId] = buffer; - peakMemoryUsedBytes += SEND_BUFFER_INIT_SIZE; + peakMemoryUsedBytes += PUSH_BUFFER_INIT_SIZE; } return buffer; } @@ -330,9 +330,10 @@ public class HashBasedShuffleWriter extends ShuffleWriter { private int getOrUpdateOffset(int partitionId, int serializedRecordSize) throws IOException { int offset = sendOffsets[partitionId]; byte[] buffer = getOrCreateBuffer(partitionId); - while ((buffer.length - offset) < serializedRecordSize && buffer.length < SEND_BUFFER_SIZE) { + while ((buffer.length - offset) < serializedRecordSize + && buffer.length < PUSH_BUFFER_MAX_SIZE) { - byte[] newBuffer = new byte[Math.min(buffer.length * 2, SEND_BUFFER_SIZE)]; + byte[] newBuffer = new byte[Math.min(buffer.length * 2, PUSH_BUFFER_MAX_SIZE)]; peakMemoryUsedBytes += newBuffer.length - buffer.length; System.arraycopy(buffer, 0, newBuffer, 0, offset); sendBuffers[partitionId] = newBuffer; diff --git a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/RssShuffleManager.java b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/RssShuffleManager.java index 781427854..6c9433d3a 100644 --- a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/RssShuffleManager.java +++ b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/RssShuffleManager.java @@ -158,15 +158,15 @@ public class RssShuffleManager implements ShuffleManager { ShuffleClient client = ShuffleClient.get( h.rssMetaServiceHost(), h.rssMetaServicePort(), rssConf, h.userIdentifier()); - if ("sort".equals(RssConf.shuffleWriterMode(rssConf))) { + if ("sort".equals(rssConf.shuffleWriterMode())) { return new SortBasedShuffleWriter<>( h.dependency(), h.newAppId(), h.numMappers(), context, rssConf, client, metrics); - } else if ("hash".equals(RssConf.shuffleWriterMode(rssConf))) { + } else if ("hash".equals(rssConf.shuffleWriterMode())) { return new HashBasedShuffleWriter<>( h, context, rssConf, client, metrics, SendBufferPool.get(cores)); } else { throw new UnsupportedOperationException( - "Unrecognized shuffle write mode!" + RssConf.shuffleWriterMode(rssConf)); + "Unrecognized shuffle write mode!" + rssConf.shuffleWriterMode()); } } else { return sortShuffleManager().getWriter(handle, mapId, context, metrics); diff --git a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java index 7e544eacd..098e4a394 100644 --- a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java +++ b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java @@ -69,7 +69,7 @@ public class SortBasedShuffleWriter extends ShuffleWriter { private final int numMappers; private final int numPartitions; - private final long pushBufferSize; + private final long pushBufferMaxSize; private SortBasedPusher sortBasedPusher; @Nullable private long peakMemoryUsedBytes = 0; @@ -121,7 +121,7 @@ public class SortBasedShuffleWriter extends ShuffleWriter { } tmpRecords = new long[numPartitions]; - pushBufferSize = RssConf.pushBufferMaxSize(conf); + pushBufferMaxSize = conf.pushBufferMaxSize(); sortBasedPusher = new SortBasedPusher( @@ -178,7 +178,7 @@ public class SortBasedShuffleWriter extends ShuffleWriter { dataSize.add(serializedRecordSize); } - if (serializedRecordSize > pushBufferSize) { + if (serializedRecordSize > pushBufferMaxSize) { byte[] giantBuffer = new byte[serializedRecordSize]; Platform.putInt(giantBuffer, Platform.BYTE_ARRAY_OFFSET, Integer.reverseBytes(rowSize)); Platform.copyMemory( @@ -213,7 +213,7 @@ public class SortBasedShuffleWriter extends ShuffleWriter { final int serializedRecordSize = serBuffer.size(); assert (serializedRecordSize > 0); - if (serializedRecordSize > pushBufferSize) { + if (serializedRecordSize > pushBufferMaxSize) { pushGiantRecord(partitionId, serBuffer.getBuf(), serializedRecordSize); } else { long insertStartTime = System.nanoTime(); diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java index 54c7cb8da..b66c9eb7f 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java @@ -75,7 +75,7 @@ public class ShuffleClientImpl extends ShuffleClient { private final int registerShuffleMaxRetries; private final long registerShuffleRetryWait; private final int maxInFlight; - private final int pushBufferSize; + private final int pushBufferMaxSize; private final RpcEnv rpcEnv; @@ -124,10 +124,10 @@ public class ShuffleClientImpl extends ShuffleClient { super(); this.conf = conf; this.userIdentifier = userIdentifier; - registerShuffleMaxRetries = RssConf.registerShuffleMaxRetry(conf); - registerShuffleRetryWait = RssConf.registerShuffleRetryWait(conf); - maxInFlight = RssConf.pushMaxReqsInFlight(conf); - pushBufferSize = RssConf.pushBufferMaxSize(conf); + registerShuffleMaxRetries = conf.registerShuffleMaxRetry(); + registerShuffleRetryWait = conf.registerShuffleRetryWait(); + maxInFlight = conf.pushMaxReqsInFlight(); + pushBufferMaxSize = conf.pushBufferMaxSize(); // init rpc env and master endpointRef rpcEnv = RpcEnv.create("ShuffleClient", Utils.localHostName(), 0, conf); @@ -801,7 +801,7 @@ public class ShuffleClientImpl extends ShuffleClient { while (!batchesArr.isEmpty()) { limitMaxInFlight(mapKey, pushState, maxInFlight); Map.Entry entry = batchesArr.get(rand.nextInt(batchesArr.size())); - ArrayList batches = entry.getValue().requireBatches(pushBufferSize); + ArrayList batches = entry.getValue().requireBatches(pushBufferMaxSize); if (entry.getValue().getTotalSize() == 0) { batchesArr.remove(entry); } diff --git a/client/src/main/java/org/apache/celeborn/client/compress/Compressor.java b/client/src/main/java/org/apache/celeborn/client/compress/Compressor.java index bee3361c7..3c6130315 100644 --- a/client/src/main/java/org/apache/celeborn/client/compress/Compressor.java +++ b/client/src/main/java/org/apache/celeborn/client/compress/Compressor.java @@ -38,7 +38,7 @@ public interface Compressor { static Compressor getCompressor(RssConf conf) { String codec = RssConf.compressionCodec(conf); - int blockSize = RssConf.pushBufferMaxSize(conf); + int blockSize = conf.pushBufferMaxSize(); switch (codec) { case "lz4": return new RssLz4Compressor(blockSize); diff --git a/client/src/main/java/org/apache/celeborn/client/read/DfsPartitionReader.java b/client/src/main/java/org/apache/celeborn/client/read/DfsPartitionReader.java index 313b46841..c9fed804e 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/DfsPartitionReader.java +++ b/client/src/main/java/org/apache/celeborn/client/read/DfsPartitionReader.java @@ -44,8 +44,8 @@ import org.apache.celeborn.common.util.ShuffleBlockInfoUtils; import org.apache.celeborn.common.util.Utils; public class DfsPartitionReader implements PartitionReader { - private final int chunkSize; - private final int maxInFlight; + private final int shuffleChunkSize; + private final int fetchMaxReqsInFlight; private final LinkedBlockingQueue results; private final AtomicReference exception = new AtomicReference<>(); private volatile boolean closed = false; @@ -62,19 +62,19 @@ public class DfsPartitionReader implements PartitionReader { int startMapIndex, int endMapIndex) throws IOException { - chunkSize = (int) RssConf.shuffleChunkSize(conf); - maxInFlight = RssConf.fetchMaxReqsInFlight(conf); + shuffleChunkSize = (int) conf.shuffleChunkSize(); + fetchMaxReqsInFlight = conf.fetchMaxReqsInFlight(); results = new LinkedBlockingQueue<>(); final List chunkOffsets = new ArrayList<>(); if (endMapIndex != Integer.MAX_VALUE) { - long timeoutMs = RssConf.fetchTimeoutMs(conf); + long fetchTimeoutMs = conf.fetchTimeoutMs(); try { TransportClient client = clientFactory.createClient(location.getHost(), location.getFetchPort()); OpenStream openBlocks = new OpenStream(shuffleKey, location.getFileName(), startMapIndex, endMapIndex); - ByteBuffer response = client.sendRpcSync(openBlocks.toByteBuffer(), timeoutMs); + ByteBuffer response = client.sendRpcSync(openBlocks.toByteBuffer(), fetchTimeoutMs); Message.decode(response); // Parse this message to ensure sort is done. } catch (IOException | InterruptedException e) { @@ -100,7 +100,7 @@ public class DfsPartitionReader implements PartitionReader { () -> { try { while (!closed && currentChunkIndex.get() < numChunks) { - while (results.size() >= maxInFlight) { + while (results.size() >= fetchMaxReqsInFlight) { Thread.sleep(50); } long offset = chunkOffsets.get(currentChunkIndex.get()); @@ -148,7 +148,7 @@ public class DfsPartitionReader implements PartitionReader { ShuffleBlockInfoUtils.getChunkOffsetsFromShuffleBlockInfos( startMapIndex, endMapIndex, - chunkSize, + shuffleChunkSize, ShuffleBlockInfoUtils.parseShuffleBlockInfosFromByteBuffer(indexBuffer))); indexInputStream.close(); return offsets; diff --git a/client/src/main/java/org/apache/celeborn/client/read/RetryingChunkClient.java b/client/src/main/java/org/apache/celeborn/client/read/RetryingChunkClient.java index fd2153155..c61fc132b 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/RetryingChunkClient.java +++ b/client/src/main/java/org/apache/celeborn/client/read/RetryingChunkClient.java @@ -84,19 +84,21 @@ public class RetryingChunkClient { this.callback = callback; this.retryWaitMs = transportConf.ioRetryWaitTimeMs(); - long timeoutMs = RssConf.fetchTimeoutMs(conf); + long fetchTimeoutMs = conf.fetchTimeoutMs(); if (location == null) { throw new IllegalArgumentException("Must contain at least one available PartitionLocation."); } else { Replica main = - new Replica(timeoutMs, shuffleKey, location, clientFactory, startMapIndex, endMapIndex); + new Replica( + fetchTimeoutMs, shuffleKey, location, clientFactory, startMapIndex, endMapIndex); PartitionLocation peerLoc = location.getPeer(); if (peerLoc == null) { replicas = new Replica[] {main}; } else { Replica peer = - new Replica(timeoutMs, shuffleKey, peerLoc, clientFactory, startMapIndex, endMapIndex); + new Replica( + fetchTimeoutMs, shuffleKey, peerLoc, clientFactory, startMapIndex, endMapIndex); replicas = new Replica[] {main, peer}; } } diff --git a/client/src/main/java/org/apache/celeborn/client/read/RssInputStream.java b/client/src/main/java/org/apache/celeborn/client/read/RssInputStream.java index 48c52d5c2..40bba49bf 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/RssInputStream.java +++ b/client/src/main/java/org/apache/celeborn/client/read/RssInputStream.java @@ -138,7 +138,7 @@ public abstract class RssInputStream extends InputStream { this.rangeReadFilter = RssConf.rangeReadFilterEnabled(conf); int headerLen = Decompressor.getCompressionHeaderLength(conf); - int blockSize = RssConf.pushBufferMaxSize(conf) + headerLen; + int blockSize = conf.pushBufferMaxSize() + headerLen; compressedBuf = new byte[blockSize]; decompressedBuf = new byte[blockSize]; diff --git a/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java b/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java index 32fe41b78..556b8b085 100644 --- a/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java +++ b/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java @@ -45,7 +45,7 @@ public class WorkerPartitionReader implements PartitionReader { private final LinkedBlockingQueue results; private final AtomicReference exception = new AtomicReference<>(); - private final int maxInFlight; + private final int fetchMaxReqsInFlight; private boolean closed = false; WorkerPartitionReader( @@ -56,7 +56,7 @@ public class WorkerPartitionReader implements PartitionReader { int startMapIndex, int endMapIndex) throws IOException { - maxInFlight = RssConf.fetchMaxReqsInFlight(conf); + fetchMaxReqsInFlight = conf.fetchMaxReqsInFlight(); results = new LinkedBlockingQueue<>(); // only add the buffer to results queue if this reader is not closed. ChunkReceivedCallback callback = @@ -123,8 +123,8 @@ public class WorkerPartitionReader implements PartitionReader { private void fetchChunks() { final int inFlight = chunkIndex - returnedChunks; - if (inFlight < maxInFlight) { - final int toFetch = Math.min(maxInFlight - inFlight + 1, numChunks - chunkIndex); + if (inFlight < fetchMaxReqsInFlight) { + final int toFetch = Math.min(fetchMaxReqsInFlight - inFlight + 1, numChunks - chunkIndex); for (int i = 0; i < toFetch; i++) { client.fetchChunk(chunkIndex++); } diff --git a/client/src/main/java/org/apache/celeborn/client/write/DataPusher.java b/client/src/main/java/org/apache/celeborn/client/write/DataPusher.java index 59a06ffc4..53ffd0dcb 100644 --- a/client/src/main/java/org/apache/celeborn/client/write/DataPusher.java +++ b/client/src/main/java/org/apache/celeborn/client/write/DataPusher.java @@ -65,15 +65,15 @@ public class DataPusher { Consumer afterPush, LongAdder[] mapStatusLengths) throws IOException { - final int capacity = RssConf.pushQueueCapacity(conf); - final int bufferSize = RssConf.pushBufferMaxSize(conf); + final int pushQueueCapacity = conf.pushQueueCapacity(); + final int pushBufferMaxSize = conf.pushBufferMaxSize(); - idleQueue = new LinkedBlockingQueue<>(capacity); - workingQueue = new LinkedBlockingQueue<>(capacity); + idleQueue = new LinkedBlockingQueue<>(pushQueueCapacity); + workingQueue = new LinkedBlockingQueue<>(pushQueueCapacity); - for (int i = 0; i < capacity; i++) { + for (int i = 0; i < pushQueueCapacity; i++) { try { - idleQueue.put(new PushTask(bufferSize)); + idleQueue.put(new PushTask(pushBufferMaxSize)); } catch (InterruptedException e) { Thread.currentThread().interrupt(); throw new IOException(e); diff --git a/client/src/main/java/org/apache/celeborn/client/write/PushState.java b/client/src/main/java/org/apache/celeborn/client/write/PushState.java index 9de9878ec..ef1279094 100644 --- a/client/src/main/java/org/apache/celeborn/client/write/PushState.java +++ b/client/src/main/java/org/apache/celeborn/client/write/PushState.java @@ -34,7 +34,7 @@ import org.apache.celeborn.common.protocol.PartitionLocation; public class PushState { private static final Logger logger = LoggerFactory.getLogger(PushState.class); - private int pushBufferSize; + private int pushBufferMaxSize; public final AtomicInteger batchId = new AtomicInteger(); public final ConcurrentHashMap inFlightBatches = @@ -43,7 +43,7 @@ public class PushState { public AtomicReference exception = new AtomicReference<>(); public PushState(RssConf conf) { - pushBufferSize = RssConf.pushBufferMaxSize(conf); + pushBufferMaxSize = conf.pushBufferMaxSize(); } public void addFuture(int batchId, ChannelFuture future) { @@ -82,7 +82,7 @@ public class PushState { public boolean addBatchData(String addressPair, PartitionLocation loc, int batchId, byte[] body) { DataBatches batches = batchesMap.computeIfAbsent(addressPair, (s) -> new DataBatches()); batches.addDataBatch(loc, batchId, body); - return batches.getTotalSize() > pushBufferSize; + return batches.getTotalSize() > pushBufferMaxSize; } public DataBatches takeDataBatches(String addressPair) { diff --git a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala index 89537f1f1..9b8a7742a 100644 --- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala +++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala @@ -44,9 +44,9 @@ class LifecycleManager(appId: String, val conf: RssConf) extends RpcEndpoint wit private val lifecycleHost = Utils.localHostName - private val RemoveShuffleDelayMs = RssConf.shuffleExpiredCheckIntervalMs(conf) - private val GetBlacklistDelayMs = RssConf.workerExcludedCheckIntervalMs(conf) - private val ShouldReplicate = RssConf.pushReplicateEnabled(conf) + private val shuffleExpiredCheckIntervalMs = conf.shuffleExpiredCheckIntervalMs + private val workerExcludedCheckIntervalMs = conf.workerExcludedCheckIntervalMs + private val pushReplicateEnabled = conf.pushReplicateEnabled private val splitThreshold = RssConf.partitionSplitThreshold(conf) private val splitMode = RssConf.partitionSplitMode(conf) private val partitionType = RssConf.partitionType(conf) @@ -114,8 +114,9 @@ class LifecycleManager(appId: String, val conf: RssConf) extends RpcEndpoint wit private var getBlacklist: ScheduledFuture[_] = _ // Use independent app heartbeat threads to avoid being blocked by other operations. - private val heartbeatIntervalMs = RssConf.appHeartbeatIntervalMs(conf) - private val heartbeatThread = ThreadUtils.newDaemonSingleThreadScheduledExecutor("app-heartbeat") + private val appHeartbeatIntervalMs = conf.appHeartbeatIntervalMs + private val appHeartbeatHandlerThread = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("app-heartbeat") private var appHeartbeat: ScheduledFuture[_] = _ private val responseCheckerThread = ThreadUtils.newDaemonSingleThreadScheduledExecutor("rss-master-resp-checker") @@ -154,7 +155,7 @@ class LifecycleManager(appId: String, val conf: RssConf) extends RpcEndpoint wit // `rssHARetryClient` is called. Therefore, it's necessary to uniformly execute the initialization // method at the end of the construction of the class to perform the initialization operations. private def initialize(): Unit = { - appHeartbeat = heartbeatThread.scheduleAtFixedRate( + appHeartbeat = appHeartbeatHandlerThread.scheduleAtFixedRate( new Runnable { override def run(): Unit = { try { @@ -177,7 +178,7 @@ class LifecycleManager(appId: String, val conf: RssConf) extends RpcEndpoint wit } }, 0, - heartbeatIntervalMs, + appHeartbeatIntervalMs, TimeUnit.MILLISECONDS) handleChangePartitionInBatchSchedulerThread.foreach { @@ -228,8 +229,8 @@ class LifecycleManager(appId: String, val conf: RssConf) extends RpcEndpoint wit self.send(RemoveExpiredShuffle) } }, - RemoveShuffleDelayMs, - RemoveShuffleDelayMs, + shuffleExpiredCheckIntervalMs, + shuffleExpiredCheckIntervalMs, TimeUnit.MILLISECONDS) getBlacklist = forwardMessageThread.scheduleAtFixedRate( @@ -238,8 +239,8 @@ class LifecycleManager(appId: String, val conf: RssConf) extends RpcEndpoint wit self.send(GetBlacklist(new util.ArrayList[WorkerInfo](blacklist))) } }, - GetBlacklistDelayMs, - GetBlacklistDelayMs, + workerExcludedCheckIntervalMs, + workerExcludedCheckIntervalMs, TimeUnit.MILLISECONDS) } @@ -251,7 +252,7 @@ class LifecycleManager(appId: String, val conf: RssConf) extends RpcEndpoint wit ThreadUtils.shutdown(forwardMessageThread, 800.millis) appHeartbeat.cancel(true) - ThreadUtils.shutdown(heartbeatThread, 800.millis) + ThreadUtils.shutdown(appHeartbeatHandlerThread, 800.millis) ThreadUtils.shutdown(responseCheckerThread, 800.millis) @@ -430,7 +431,7 @@ class LifecycleManager(appId: String, val conf: RssConf) extends RpcEndpoint wit val connectFailedWorkers = ConcurrentHashMap.newKeySet[WorkerInfo]() // Second, for each worker, try to initialize the endpoint. - val parallelism = Math.min(Math.max(1, slots.size()), RssConf.rpcMaxParallelism(conf)) + val parallelism = Math.min(Math.max(1, slots.size()), conf.rpcMaxParallelism) ThreadUtils.parmap(slots.asScala.to, "InitWorkerRef", parallelism) { case (workerInfo, _) => try { workerInfo.endpoint = @@ -657,7 +658,7 @@ class LifecycleManager(appId: String, val conf: RssConf) extends RpcEndpoint wit } val candidates = workersNotBlacklisted(shuffleId) - if (candidates.size < 1 || (ShouldReplicate && candidates.size < 2)) { + if (candidates.size < 1 || (pushReplicateEnabled && candidates.size < 2)) { logError("[Update partition] failed for not enough candidates for revive.") replyFailure(ChangeLocationResponse(StatusCode.SLOT_NOT_AVAILABLE, None)) return @@ -822,7 +823,7 @@ class LifecycleManager(appId: String, val conf: RssConf) extends RpcEndpoint wit val currentShuffleFileCount = new LongAdder val commitFileStartTime = System.nanoTime() - val parallelism = Math.min(workerSnapshots(shuffleId).size(), RssConf.rpcMaxParallelism(conf)) + val parallelism = Math.min(workerSnapshots(shuffleId).size(), conf.rpcMaxParallelism) ThreadUtils.parmap( allocatedWorkers.asScala.to, "CommitFiles", @@ -897,7 +898,7 @@ class LifecycleManager(appId: String, val conf: RssConf) extends RpcEndpoint wit def hasCommitFailedIds: Boolean = { val shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId) - if (!ShouldReplicate && failedMasterPartitionIds.size() != 0) { + if (!pushReplicateEnabled && failedMasterPartitionIds.size() != 0) { val msg = failedMasterPartitionIds.asScala.map { case (partitionId, workerInfo) => s"Lost partition $partitionId in worker [${workerInfo.readableAddress()}]" }.mkString("\n") @@ -1054,7 +1055,7 @@ class LifecycleManager(appId: String, val conf: RssConf) extends RpcEndpoint wit shuffleId: Int, slots: WorkerResource): util.List[WorkerInfo] = { val reserveSlotFailedWorkers = ConcurrentHashMap.newKeySet[WorkerInfo]() - val parallelism = Math.min(Math.max(1, slots.size()), RssConf.rpcMaxParallelism(conf)) + val parallelism = Math.min(Math.max(1, slots.size()), conf.rpcMaxParallelism) ThreadUtils.parmap(slots.asScala.to, "ReserveSlot", parallelism) { case (workerInfo, (masterLocations, slaveLocations)) => val res = requestReserveSlots( @@ -1210,14 +1211,14 @@ class LifecycleManager(appId: String, val conf: RssConf) extends RpcEndpoint wit candidates: List[WorkerInfo], slots: WorkerResource): Boolean = { var requestSlots = slots - val maxRetryTimes = RssConf.reserveSlotsMaxRetries(conf) - val retryWaitInterval = RssConf.reserveSlotsRetryWait(conf) + val reserveSlotsMaxRetries = conf.reserveSlotsMaxRetries + val reserveSlotsRetryWait = conf.reserveSlotsRetryWait var retryTimes = 1 var noAvailableSlots = false var success = false - while (retryTimes <= maxRetryTimes && !success && !noAvailableSlots) { + while (retryTimes <= reserveSlotsMaxRetries && !success && !noAvailableSlots) { if (retryTimes > 1) { - Thread.sleep(retryWaitInterval) + Thread.sleep(reserveSlotsRetryWait) } // reserve buffers logInfo(s"Try reserve slots for ${Utils.makeShuffleKey(applicationId, shuffleId)} " + @@ -1231,17 +1232,17 @@ class LifecycleManager(appId: String, val conf: RssConf) extends RpcEndpoint wit val failedPartitionLocations = getFailedPartitionLocations(reserveFailedWorkers, slots) // When enable replicate, if one of the partition location reserve slots failed, we also // need to release another corresponding partition location and remove it from slots. - if (ShouldReplicate && failedPartitionLocations.nonEmpty && !slots.isEmpty) { + if (pushReplicateEnabled && failedPartitionLocations.nonEmpty && !slots.isEmpty) { releasePeerPartitionLocation(applicationId, shuffleId, slots, failedPartitionLocations) } - if (retryTimes < maxRetryTimes) { + if (retryTimes < reserveSlotsMaxRetries) { // get retryCandidates resource and retry reserve buffer val retryCandidates = new util.HashSet(slots.keySet()) // add candidates to avoid revive action passed in slots only 2 worker retryCandidates.addAll(candidates.asJava) // remove blacklist from retryCandidates retryCandidates.removeAll(blacklist) - if (retryCandidates.size < 1 || (ShouldReplicate && retryCandidates.size < 2)) { + if (retryCandidates.size < 1 || (pushReplicateEnabled && retryCandidates.size < 2)) { logError("Retry reserve slots failed caused by not enough slots.") noAvailableSlots = true } else { @@ -1259,7 +1260,7 @@ class LifecycleManager(appId: String, val conf: RssConf) extends RpcEndpoint wit } } } else { - logError(s"Try reserve slots failed after $maxRetryTimes retry.") + logError(s"Try reserve slots failed after $reserveSlotsMaxRetries retry.") } } retryTimes += 1 @@ -1307,7 +1308,7 @@ class LifecycleManager(appId: String, val conf: RssConf) extends RpcEndpoint wit candidates(masterIndex).replicatePort, PartitionLocation.Mode.MASTER) - if (ShouldReplicate) { + if (pushReplicateEnabled) { val slaveIndex = (masterIndex + 1) % candidates.size val slaveLocation = new PartitionLocation( id, @@ -1381,7 +1382,7 @@ class LifecycleManager(appId: String, val conf: RssConf) extends RpcEndpoint wit private def removeExpiredShuffle(): Unit = { val currentTime = System.currentTimeMillis() unregisterShuffleTime.keys().asScala.foreach { shuffleId => - if (unregisterShuffleTime.get(shuffleId) < currentTime - RemoveShuffleDelayMs) { + if (unregisterShuffleTime.get(shuffleId) < currentTime - shuffleExpiredCheckIntervalMs) { logInfo(s"Clear shuffle $shuffleId.") // clear for the shuffle registeredShuffle.remove(shuffleId) @@ -1422,7 +1423,13 @@ class LifecycleManager(appId: String, val conf: RssConf) extends RpcEndpoint wit shuffleId: Int, ids: util.ArrayList[Integer]): RequestSlotsResponse = { val req = - RequestSlots(applicationId, shuffleId, ids, lifecycleHost, ShouldReplicate, userIdentifier) + RequestSlots( + applicationId, + shuffleId, + ids, + lifecycleHost, + pushReplicateEnabled, + userIdentifier) val res = requestRequestSlots(rssHARetryClient, req) if (res.status != StatusCode.SUCCESS) { requestRequestSlots(rssHARetryClient, req) diff --git a/client/src/test/java/org/apache/celeborn/client/compress/CodecSuite.java b/client/src/test/java/org/apache/celeborn/client/compress/CodecSuite.java index 3a4a30cc2..863fd4cbd 100644 --- a/client/src/test/java/org/apache/celeborn/client/compress/CodecSuite.java +++ b/client/src/test/java/org/apache/celeborn/client/compress/CodecSuite.java @@ -29,7 +29,7 @@ public class CodecSuite { @Test public void testLz4Codec() { - int blockSize = RssConf.pushBufferMaxSize(new RssConf()); + int blockSize = (new RssConf()).pushBufferMaxSize(); RssLz4Compressor rssLz4Compressor = new RssLz4Compressor(blockSize); byte[] data = RandomStringUtils.random(1024).getBytes(StandardCharsets.UTF_8); int oriLength = data.length; @@ -49,7 +49,7 @@ public class CodecSuite { public void testZstdCodec() { for (int level = -5; level <= 22; level++) { System.out.println("level is " + level); - int blockSize = RssConf.pushBufferMaxSize(new RssConf()); + int blockSize = (new RssConf()).pushBufferMaxSize(); RssZstdCompressor rssZstdCompressor = new RssZstdCompressor(blockSize, level); byte[] data = RandomStringUtils.random(1024).getBytes(StandardCharsets.UTF_8); int oriLength = data.length; diff --git a/common/src/main/java/org/apache/celeborn/common/haclient/RssHARetryClient.java b/common/src/main/java/org/apache/celeborn/common/haclient/RssHARetryClient.java index a7d7899a6..58503ddc9 100644 --- a/common/src/main/java/org/apache/celeborn/common/haclient/RssHARetryClient.java +++ b/common/src/main/java/org/apache/celeborn/common/haclient/RssHARetryClient.java @@ -63,7 +63,7 @@ public class RssHARetryClient { public RssHARetryClient(RpcEnv rpcEnv, RssConf conf) { this.rpcEnv = rpcEnv; - this.masterEndpoints = RssConf.masterEndpoints(conf); + this.masterEndpoints = conf.masterEndpoints(); this.maxTries = Math.max(masterEndpoints.length, RssConf.haClientMaxTries(conf)); this.rpcTimeout = RpcUtils.haClientAskRpcTimeout(conf); this.rpcEndpointRef = new AtomicReference<>(); diff --git a/common/src/main/scala/org/apache/celeborn/common/RssConf.scala b/common/src/main/scala/org/apache/celeborn/common/RssConf.scala index dd912e7ec..22f952c94 100644 --- a/common/src/main/scala/org/apache/celeborn/common/RssConf.scala +++ b/common/src/main/scala/org/apache/celeborn/common/RssConf.scala @@ -364,6 +364,139 @@ class RssConf(loadDefaults: Boolean) extends Cloneable with Logging with Seriali } } + // ////////////////////////////////////////////////////// + // Master // + // ////////////////////////////////////////////////////// + + // ////////////////////////////////////////////////////// + // Worker // + // ////////////////////////////////////////////////////// + def workerHeartbeatTimeoutMs: Long = get(WORKER_HEARTBEAT_TIMEOUT) + def workerReplicateThreads: Int = get(WORKER_REPLICATE_THREADS) + def workerCommitThreads: Int = get(WORKER_COMMIT_THREADS) + def shuffleCommitTimeout: Long = get(WORKER_SHUFFLE_COMMIT_TIMEOUT) + + // ////////////////////////////////////////////////////// + // Client // + // ////////////////////////////////////////////////////// + def shuffleWriterMode: String = get(SHUFFLE_WRITER_MODE) + def shuffleChunkSize: Long = get(SHUFFLE_CHUCK_SIZE) + def registerShuffleMaxRetry: Int = get(SHUFFLE_REGISTER_MAX_RETRIES) + def registerShuffleRetryWait: Long = get(SHUFFLE_REGISTER_RETRY_WAIT) + def reserveSlotsMaxRetries: Int = get(RESERVE_SLOTS_MAX_RETRIES) + def reserveSlotsRetryWait: Long = get(RESERVE_SLOTS_RETRY_WAIT) + def rpcMaxParallelism: Int = get(CLIENT_RPC_MAX_PARALLELISM) + def appHeartbeatTimeoutMs: Long = get(APPLICATION_HEARTBEAT_TIMEOUT) + def appHeartbeatIntervalMs: Long = get(APPLICATION_HEARTBEAT_INTERVAL) + def shuffleExpiredCheckIntervalMs: Long = get(SHUFFLE_EXPIRED_CHECK_INTERVAL) + def workerExcludedCheckIntervalMs: Long = get(WORKER_EXCLUDED_INTERVAL) + + // ////////////////////////////////////////////////////// + // Address && HA && RATIS // + // ////////////////////////////////////////////////////// + def masterEndpoints: Array[String] = + get(MASTER_ENDPOINTS).toArray.map { endpoint => + Utils.parseHostPort(endpoint) match { + case (host, 0) => s"$host:${HA_MASTER_NODE_PORT.defaultValue.get}" + case (host, port) => s"$host:$port" + } + } + + def masterHost: String = get(MASTER_HOST) + + def masterPort: Int = get(MASTER_PORT) + + def haEnabled: Boolean = get(HA_ENABLED) + + def haMasterNodeId: Option[String] = get(HA_MASTER_NODE_ID) + + def haMasterNodeIds: Array[String] = { + def extractPrefix(original: String, stop: String): String = { + val i = original.indexOf(stop) + assert(i >= 0, s"$original does not contain $stop") + original.substring(0, i) + } + + val nodeConfPrefix = extractPrefix(HA_MASTER_NODE_HOST.key, "") + getAllWithPrefix(nodeConfPrefix) + .map(_._1) + .map(k => extractPrefix(k, ".")) + .distinct + } + + def haMasterNodeHost(nodeId: String): String = { + val key = HA_MASTER_NODE_HOST.key.replace("", nodeId) + get(key, Utils.localHostName) + } + + def haMasterNodePort(nodeId: String): Int = { + val key = HA_MASTER_NODE_PORT.key.replace("", nodeId) + getInt(key, HA_MASTER_NODE_PORT.defaultValue.get) + } + + def haMasterRatisHost(nodeId: String): String = { + val key = HA_MASTER_NODE_RATIS_HOST.key.replace("", nodeId) + val fallbackKey = HA_MASTER_NODE_HOST.key.replace("", nodeId) + get(key, get(fallbackKey)) + } + + def haMasterRatisPort(nodeId: String): Int = { + val key = HA_MASTER_NODE_RATIS_PORT.key.replace("", nodeId) + getInt(key, HA_MASTER_NODE_RATIS_PORT.defaultValue.get) + } + + def haMasterRatisRpcType: String = get(HA_MASTER_RATIS_RPC_TYPE) + def haMasterRatisStorageDir: String = get(HA_MASTER_RATIS_STORAGE_DIR) + def haMasterRatisLogSegmentSizeMax: Long = get(HA_MASTER_RATIS_LOG_SEGMENT_SIZE_MAX) + def haMasterRatisLogPreallocatedSize: Long = get(HA_MASTER_RATIS_LOG_PREALLOCATED_SIZE) + def haMasterRatisLogAppenderQueueNumElements: Int = + get(HA_MASTER_RATIS_LOG_APPENDER_QUEUE_NUM_ELEMENTS) + def haMasterRatisLogAppenderQueueBytesLimit: Long = + get(HA_MASTER_RATIS_LOG_APPENDER_QUEUE_BYTE_LIMIT) + def haMasterRatisLogPurgeGap: Int = get(HA_MASTER_RATIS_LOG_PURGE_GAP) + def haMasterRatisRpcRequestTimeout: Long = get(HA_MASTER_RATIS_RPC_REQUEST_TIMEOUT) + def haMasterRatisRetryCacheExpiryTime: Long = get(HA_MASTER_RATIS_SERVER_RETRY_CACHE_EXPIRY_TIME) + def haMasterRatisRpcTimeoutMin: Long = get(HA_MASTER_RATIS_RPC_TIMEOUT_MIN) + def haMasterRatisRpcTimeoutMax: Long = get(HA_MASTER_RATIS_RPC_TIMEOUT_MAX) + def haMasterRatisNotificationNoLeaderTimeout: Long = + get(HA_MASTER_RATIS_NOTIFICATION_NO_LEADER_TIMEOUT) + def haMasterRatisRpcSlownessTimeout: Long = get(HA_MASTER_RATIS_RPC_SLOWNESS_TIMEOUT) + def haMasterRatisRoleCheckInterval: Long = get(HA_MASTER_RATIS_ROLE_CHECK_INTERVAL) + def haMasterRatisSnapshotAutoTriggerEnabled: Boolean = + get(HA_MASTER_RATIS_SNAPSHOT_AUTO_TRIGGER_ENABLED) + def haMasterRatisSnapshotAutoTriggerThreshold: Long = + get(HA_MASTER_RATIS_SNAPSHOT_AUTO_TRIGGER_THRESHOLD) + def haMasterRatisSnapshotRetentionFileNum: Int = get(HA_MASTER_RATIS_SNAPSHOT_RETENTION_FILE_NUM) + + // ////////////////////////////////////////////////////// + // Metrics System // + // ////////////////////////////////////////////////////// + def metricsSystemEnable: Boolean = get(METRICS_ENABLED) + def metricsSampleRate: Double = get(METRICS_SAMPLE_RATE) + def metricsSlidingWindowSize: Int = get(METRICS_SLIDING_WINDOW_SIZE) + def masterPrometheusMetricHost: String = get(MASTER_PROMETHEUS_HOST) + def masterPrometheusMetricPort: Int = get(MASTER_PROMETHEUS_PORT) + def workerPrometheusMetricHost: String = get(WORKER_PROMETHEUS_HOST) + def workerPrometheusMetricPort: Int = get(WORKER_PROMETHEUS_PORT) + + // ////////////////////////////////////////////////////// + // Shuffle Client Fetch // + // ////////////////////////////////////////////////////// + def fetchTimeoutMs: Long = get(FETCH_TIMEOUT) + def fetchMaxReqsInFlight: Int = get(FETCH_MAX_REQS_IN_FLIGHT) + + // ////////////////////////////////////////////////////// + // Shuffle Client Push // + // ////////////////////////////////////////////////////// + def pushReplicateEnabled: Boolean = get(PUSH_REPLICATE_ENABLED) + def pushBufferInitialSize: Int = get(PUSH_BUFFER_INITIAL_SIZE).toInt + def pushBufferMaxSize: Int = get(PUSH_BUFFER_MAX_SIZE).toInt + def pushQueueCapacity: Int = get(PUSH_QUEUE_CAPACITY) + def pushMaxReqsInFlight: Int = get(PUSH_MAX_REQS_IN_FLIGHT) + + // ////////////////////////////////////////////////////// + // GraceFul Shutdown & Recover // + // ////////////////////////////////////////////////////// def workerGracefulShutdown: Boolean = get(WORKER_GRACEFUL_SHUTDOWN_ENABLED) def shutdownTimeoutMs: Long = get(WORKER_GRACEFUL_SHUTDOWN_TIMEOUT) def checkSlotsFinishedInterval: Long = get(WORKER_CHECK_SLOTS_FINISHED_INTERVAL) @@ -371,7 +504,11 @@ class RssConf(loadDefaults: Boolean) extends Cloneable with Logging with Seriali def workerRecoverPath: String = get(WORKER_RECOVER_PATH) def partitionSorterCloseAwaitTimeMs: Long = get(PARTITION_SORTER_SHUTDOWN_TIMEOUT) def workerFlusherShutdownTimeoutMs: Long = get(WORKER_FLUSHER_SHUTDOWN_TIMEOUT) - def shuffleCommitTimeout: Long = get(WORKER_SHUFFLE_COMMIT_TIMEOUT) + + // ////////////////////////////////////////////////////// + // Flusher // + // ////////////////////////////////////////////////////// + def workerFlusherBufferSize: Long = get(WORKER_FLUSHER_BUFFER_SIZE) def writerCloseTimeoutMs: Long = get(WORKER_WRITER_CLOSE_TIMEOUT) def hddFlusherThreads: Int = get(WORKER_FLUSHER_HDD_THREADS) def ssdFlusherThreads: Int = get(WORKER_FLUSHER_SSD_THREADS) @@ -797,40 +934,6 @@ object RssConf extends Logging { .timeConf(TimeUnit.MILLISECONDS) .createWithDefaultString("3s") - def pushReplicateEnabled(conf: RssConf): Boolean = conf.get(PUSH_REPLICATE_ENABLED) - - def pushBufferInitialSize(conf: RssConf): Int = conf.get(PUSH_BUFFER_INITIAL_SIZE).toInt - - def pushBufferMaxSize(conf: RssConf): Int = conf.get(PUSH_BUFFER_MAX_SIZE).toInt - - def pushQueueCapacity(conf: RssConf): Int = conf.get(PUSH_QUEUE_CAPACITY) - - def pushMaxReqsInFlight(conf: RssConf): Int = conf.get(PUSH_MAX_REQS_IN_FLIGHT) - - def fetchTimeoutMs(conf: RssConf): Long = conf.get(FETCH_TIMEOUT) - - def fetchMaxReqsInFlight(conf: RssConf): Int = conf.get(FETCH_MAX_REQS_IN_FLIGHT) - - def rpcMaxParallelism(conf: RssConf): Int = conf.get(CLIENT_RPC_MAX_PARALLELISM) - - def appHeartbeatTimeoutMs(conf: RssConf): Long = conf.get(APPLICATION_HEARTBEAT_TIMEOUT) - - def appHeartbeatIntervalMs(conf: RssConf): Long = conf.get(APPLICATION_HEARTBEAT_INTERVAL) - - def shuffleExpiredCheckIntervalMs(conf: RssConf): Long = conf.get(SHUFFLE_EXPIRED_CHECK_INTERVAL) - - def workerExcludedCheckIntervalMs(conf: RssConf): Long = conf.get(WORKER_EXCLUDED_INTERVAL) - - def shuffleChunkSize(conf: RssConf): Long = conf.get(SHUFFLE_CHUCK_SIZE) - - def registerShuffleMaxRetry(conf: RssConf): Int = conf.get(SHUFFLE_REGISTER_MAX_RETRIES) - - def registerShuffleRetryWait(conf: RssConf): Long = conf.get(SHUFFLE_REGISTER_RETRY_WAIT) - - def reserveSlotsMaxRetries(conf: RssConf): Int = conf.get(RESERVE_SLOTS_MAX_RETRIES) - - def reserveSlotsRetryWait(conf: RssConf): Long = conf.get(RESERVE_SLOTS_RETRY_WAIT) - val MASTER_HOST: ConfigEntry[String] = buildConf("celeborn.master.host") .categories("master") @@ -1055,102 +1158,6 @@ object RssConf extends Logging { .intConf .createWithDefault(3) - def masterEndpoints(conf: RssConf): Array[String] = - conf.get(MASTER_ENDPOINTS).toArray.map { endpoint => - Utils.parseHostPort(endpoint) match { - case (host, 0) => s"$host:${HA_MASTER_NODE_PORT.defaultValue.get}" - case (host, port) => s"$host:$port" - } - } - - def masterHost(conf: RssConf): String = conf.get(MASTER_HOST) - - def masterPort(conf: RssConf): Int = conf.get(MASTER_PORT) - - def haEnabled(conf: RssConf): Boolean = conf.get(HA_ENABLED) - - def haMasterNodeId(conf: RssConf): Option[String] = conf.get(HA_MASTER_NODE_ID) - - def haMasterNodeIds(conf: RssConf): Array[String] = { - def extractPrefix(original: String, stop: String): String = { - val i = original.indexOf(stop) - assert(i >= 0, s"$original does not contain $stop") - original.substring(0, i) - } - val nodeConfPrefix = extractPrefix(HA_MASTER_NODE_HOST.key, "") - conf.getAllWithPrefix(nodeConfPrefix) - .map(_._1) - .map(k => extractPrefix(k, ".")) - .distinct - } - - def haMasterNodeHost(conf: RssConf, nodeId: String): String = { - val key = HA_MASTER_NODE_HOST.key.replace("", nodeId) - conf.get(key, Utils.localHostName) - } - - def haMasterNodePort(conf: RssConf, nodeId: String): Int = { - val key = HA_MASTER_NODE_PORT.key.replace("", nodeId) - conf.getInt(key, HA_MASTER_NODE_PORT.defaultValue.get) - } - - def haMasterRatisHost(conf: RssConf, nodeId: String): String = { - val key = HA_MASTER_NODE_RATIS_HOST.key.replace("", nodeId) - val fallbackKey = HA_MASTER_NODE_HOST.key.replace("", nodeId) - conf.get(key, conf.get(fallbackKey)) - } - - def haMasterRatisPort(conf: RssConf, nodeId: String): Int = { - val key = HA_MASTER_NODE_RATIS_PORT.key.replace("", nodeId) - conf.getInt(key, HA_MASTER_NODE_RATIS_PORT.defaultValue.get) - } - - def haMasterRatisRpcType(conf: RssConf): String = conf.get(HA_MASTER_RATIS_RPC_TYPE) - - def haMasterRatisStorageDir(conf: RssConf): String = conf.get(HA_MASTER_RATIS_STORAGE_DIR) - - def haMasterRatisLogSegmentSizeMax(conf: RssConf): Long = - conf.get(HA_MASTER_RATIS_LOG_SEGMENT_SIZE_MAX) - - def haMasterRatisLogPreallocatedSize(conf: RssConf): Long = - conf.get(HA_MASTER_RATIS_LOG_PREALLOCATED_SIZE) - - def haMasterRatisLogAppenderQueueNumElements(conf: RssConf): Int = - conf.get(HA_MASTER_RATIS_LOG_APPENDER_QUEUE_NUM_ELEMENTS) - - def haMasterRatisLogAppenderQueueBytesLimit(conf: RssConf): Long = - conf.get(HA_MASTER_RATIS_LOG_APPENDER_QUEUE_BYTE_LIMIT) - - def haMasterRatisLogPurgeGap(conf: RssConf): Int = conf.get(HA_MASTER_RATIS_LOG_PURGE_GAP) - - def haMasterRatisRpcRequestTimeout(conf: RssConf): Long = - conf.get(HA_MASTER_RATIS_RPC_REQUEST_TIMEOUT) - - def haMasterRatisRetryCacheExpiryTime(conf: RssConf): Long = - conf.get(HA_MASTER_RATIS_SERVER_RETRY_CACHE_EXPIRY_TIME) - - def haMasterRatisRpcTimeoutMin(conf: RssConf): Long = conf.get(HA_MASTER_RATIS_RPC_TIMEOUT_MIN) - - def haMasterRatisRpcTimeoutMax(conf: RssConf): Long = conf.get(HA_MASTER_RATIS_RPC_TIMEOUT_MAX) - - def haMasterRatisNotificationNoLeaderTimeout(conf: RssConf): Long = - conf.get(HA_MASTER_RATIS_NOTIFICATION_NO_LEADER_TIMEOUT) - - def haMasterRatisRpcSlownessTimeout(conf: RssConf): Long = - conf.get(HA_MASTER_RATIS_RPC_SLOWNESS_TIMEOUT) - - def haMasterRatisRoleCheckInterval(conf: RssConf): Long = - conf.get(HA_MASTER_RATIS_ROLE_CHECK_INTERVAL) - - def haMasterRatisSnapshotAutoTriggerEnabled(conf: RssConf): Boolean = - conf.get(HA_MASTER_RATIS_SNAPSHOT_AUTO_TRIGGER_ENABLED) - - def haMasterRatisSnapshotAutoTriggerThreshold(conf: RssConf): Long = - conf.get(HA_MASTER_RATIS_SNAPSHOT_AUTO_TRIGGER_THRESHOLD) - - def haMasterRatisSnapshotRetentionFileNum(conf: RssConf): Int = - conf.get(HA_MASTER_RATIS_SNAPSHOT_RETENTION_FILE_NUM) - val WORKER_HEARTBEAT_TIMEOUT: ConfigEntry[Long] = buildConf("celeborn.worker.heartbeat.timeout") .withAlternative("rss.worker.timeout") @@ -1196,13 +1203,6 @@ object RssConf extends Logging { .doc("Size of buffer used by a single flusher.") .bytesConf(ByteUnit.BYTE) .createWithDefaultString("256k") - def workerHeartbeatTimeoutMs(conf: RssConf): Long = conf.get(WORKER_HEARTBEAT_TIMEOUT) - - def workerReplicateThreads(conf: RssConf): Int = conf.get(WORKER_REPLICATE_THREADS) - - def workerCommitThreads(conf: RssConf): Int = conf.get(WORKER_COMMIT_THREADS) - - def workerFlusherBufferSize(conf: RssConf): Long = conf.get(WORKER_FLUSHER_BUFFER_SIZE) val WORKER_SHUFFLE_COMMIT_TIMEOUT: ConfigEntry[Long] = buildConf("celeborn.worker.shuffle.commit.timeout") @@ -1415,28 +1415,14 @@ object RssConf extends Logging { .checkValue(p => p >= 1024 && p < 65535, "invalid port") .createWithDefault(9096) - def metricsSystemEnable(conf: RssConf): Boolean = conf.get(METRICS_ENABLED) - - def metricsSampleRate(conf: RssConf): Double = conf.get(METRICS_SAMPLE_RATE) - def metricsSamplePerfCritical(conf: RssConf): Boolean = { conf.getBoolean("rss.metrics.system.sample.perf.critical", false) } - def metricsSlidingWindowSize(conf: RssConf): Int = conf.get(METRICS_SLIDING_WINDOW_SIZE) - def innerMetricsSize(conf: RssConf): Int = { conf.getInt("rss.inner.metrics.size", 4096) } - def masterPrometheusMetricHost(conf: RssConf): String = conf.get(MASTER_PROMETHEUS_HOST) - - def masterPrometheusMetricPort(conf: RssConf): Int = conf.get(MASTER_PROMETHEUS_PORT) - - def workerPrometheusMetricHost(conf: RssConf): String = conf.get(WORKER_PROMETHEUS_HOST) - - def workerPrometheusMetricPort(conf: RssConf): Int = conf.get(WORKER_PROMETHEUS_PORT) - def workerRPCPort(conf: RssConf): Int = { conf.getInt("rss.worker.rpc.port", 0) } @@ -1445,8 +1431,6 @@ object RssConf extends Logging { conf.getInt("rss.offer.slots.extra.size", 2) } - def shuffleWriterMode(conf: RssConf): String = conf.get(SHUFFLE_WRITER_MODE) - def sortPushThreshold(conf: RssConf): Long = { conf.getSizeAsBytes("rss.sort.push.data.threshold", "64m") } diff --git a/common/src/main/scala/org/apache/celeborn/common/metrics/source/AbstractSource.scala b/common/src/main/scala/org/apache/celeborn/common/metrics/source/AbstractSource.scala index f43144997..d0319b5db 100644 --- a/common/src/main/scala/org/apache/celeborn/common/metrics/source/AbstractSource.scala +++ b/common/src/main/scala/org/apache/celeborn/common/metrics/source/AbstractSource.scala @@ -38,21 +38,21 @@ case class NamedHistogram(name: String, histogram: Histogram) case class NamedTimer(name: String, timer: Timer) -abstract class AbstractSource(rssConf: RssConf, role: String) +abstract class AbstractSource(conf: RssConf, role: String) extends Source with Logging { override val metricRegistry = new MetricRegistry() - val slidingWindowSize: Int = RssConf.metricsSlidingWindowSize(rssConf) + val metricsSlidingWindowSize: Int = conf.metricsSlidingWindowSize - val sampleRate: Double = RssConf.metricsSampleRate(rssConf) + val metricsSampleRate: Double = conf.metricsSampleRate - val samplePerfCritical: Boolean = RssConf.metricsSamplePerfCritical(rssConf) + val samplePerfCritical: Boolean = RssConf.metricsSamplePerfCritical(conf) - final val InnerMetricsSize = RssConf.innerMetricsSize(rssConf) + final val innerMetricsSize = RssConf.innerMetricsSize(conf) val innerMetrics: ConcurrentLinkedQueue[String] = new ConcurrentLinkedQueue[String]() - val timerSupplier = new TimerSupplier(slidingWindowSize) + val timerSupplier = new TimerSupplier(metricsSlidingWindowSize) val metricsCleaner: ScheduledExecutorService = ThreadUtils.newDaemonSingleThreadScheduledExecutor(s"worker-metrics-cleaner") @@ -102,12 +102,12 @@ abstract class AbstractSource(rssConf: RssConf, role: String) } def needSample(): Boolean = { - if (sampleRate >= 1) { + if (metricsSampleRate >= 1) { true - } else if (sampleRate <= 0) { + } else if (metricsSampleRate <= 0) { false } else { - Random.nextDouble() <= sampleRate + Random.nextDouble() <= metricsSampleRate } } @@ -153,7 +153,7 @@ abstract class AbstractSource(rssConf: RssConf, role: String) startTime match { case Some(t) => namedTimer.timer.update(System.nanoTime() - t, TimeUnit.NANOSECONDS) - if (namedTimer.timer.getCount % slidingWindowSize == 0) { + if (namedTimer.timer.getCount % metricsSlidingWindowSize == 0) { recordTimer(namedTimer) } case None => @@ -199,7 +199,7 @@ abstract class AbstractSource(rssConf: RssConf, role: String) private def updateInnerMetrics(str: String): Unit = { innerMetrics.synchronized { - if (innerMetrics.size() >= InnerMetricsSize) { + if (innerMetrics.size() >= innerMetricsSize) { innerMetrics.remove() } innerMetrics.offer(str) diff --git a/common/src/test/scala/org/apache/celeborn/common/RssConfSuite.scala b/common/src/test/scala/org/apache/celeborn/common/RssConfSuite.scala index 56d57619b..612aa2bcd 100644 --- a/common/src/test/scala/org/apache/celeborn/common/RssConfSuite.scala +++ b/common/src/test/scala/org/apache/celeborn/common/RssConfSuite.scala @@ -18,7 +18,6 @@ package org.apache.celeborn.common import org.apache.celeborn.RssFunSuite -import org.apache.celeborn.common.RssConf.masterEndpoints import org.apache.celeborn.common.util.Utils class RssConfSuite extends RssFunSuite { @@ -26,39 +25,39 @@ class RssConfSuite extends RssFunSuite { test("celeborn.master.endpoints support multi nodes") { val conf = new RssConf() .set("celeborn.master.endpoints", "localhost1:9097,localhost2:9097") - val endpoints = masterEndpoints(conf) - assert(endpoints.length == 2) - assert(endpoints(0) == "localhost1:9097") - assert(endpoints(1) == "localhost2:9097") + val masterEndpoints = conf.masterEndpoints + assert(masterEndpoints.length == 2) + assert(masterEndpoints(0) == "localhost1:9097") + assert(masterEndpoints(1) == "localhost2:9097") } test("basedir test") { val conf = new RssConf() val defaultMaxUsableSpace = 1024L * 1024 * 1024 * 1024 * 1024 conf.set("celeborn.worker.storage.dirs", "/mnt/disk1") - val parsedDirs = conf.workerBaseDirs - assert(parsedDirs.size == 1) - assert(parsedDirs.head._3 == 1) - assert(parsedDirs.head._2 == defaultMaxUsableSpace) + val workerBaseDirs = conf.workerBaseDirs + assert(workerBaseDirs.size == 1) + assert(workerBaseDirs.head._3 == 1) + assert(workerBaseDirs.head._2 == defaultMaxUsableSpace) } test("basedir test2") { val conf = new RssConf() val defaultMaxUsableSpace = 1024L * 1024 * 1024 * 1024 * 1024 conf.set("celeborn.worker.storage.dirs", "/mnt/disk1:disktype=SSD:capacity=10g") - val parsedDirs = conf.workerBaseDirs - assert(parsedDirs.size == 1) - assert(parsedDirs.head._3 == 8) - assert(parsedDirs.head._2 == 10 * 1024 * 1024 * 1024L) + val workerBaseDirs = conf.workerBaseDirs + assert(workerBaseDirs.size == 1) + assert(workerBaseDirs.head._3 == 8) + assert(workerBaseDirs.head._2 == 10 * 1024 * 1024 * 1024L) } test("basedir test3") { val conf = new RssConf() conf.set("celeborn.worker.storage.dirs", "/mnt/disk1:disktype=SSD:capacity=10g:flushthread=3") - val parsedDirs = conf.workerBaseDirs - assert(parsedDirs.size == 1) - assert(parsedDirs.head._3 == 3) - assert(parsedDirs.head._2 == 10 * 1024 * 1024 * 1024L) + val workerBaseDirs = conf.workerBaseDirs + assert(workerBaseDirs.size == 1) + assert(workerBaseDirs.head._3 == 3) + assert(workerBaseDirs.head._2 == 10 * 1024 * 1024 * 1024L) } test("basedir test4") { @@ -67,15 +66,15 @@ class RssConfSuite extends RssFunSuite { "celeborn.worker.storage.dirs", "/mnt/disk1:disktype=SSD:capacity=10g:flushthread=3," + "/mnt/disk2:disktype=HDD:capacity=15g:flushthread=7") - val parsedDirs = conf.workerBaseDirs - assert(parsedDirs.size == 2) - assert(parsedDirs.head._1 == "/mnt/disk1") - assert(parsedDirs.head._3 == 3) - assert(parsedDirs.head._2 == 10 * 1024 * 1024 * 1024L) + val workerBaseDirs = conf.workerBaseDirs + assert(workerBaseDirs.size == 2) + assert(workerBaseDirs.head._1 == "/mnt/disk1") + assert(workerBaseDirs.head._3 == 3) + assert(workerBaseDirs.head._2 == 10 * 1024 * 1024 * 1024L) - assert(parsedDirs(1)._1 == "/mnt/disk2") - assert(parsedDirs(1)._3 == 7) - assert(parsedDirs(1)._2 == 15 * 1024 * 1024 * 1024L) + assert(workerBaseDirs(1)._1 == "/mnt/disk2") + assert(workerBaseDirs(1)._3 == 7) + assert(workerBaseDirs(1)._2 == 15 * 1024 * 1024 * 1024L) } test("zstd level") { @@ -94,10 +93,10 @@ class RssConfSuite extends RssFunSuite { test("replace placeholder") { val conf = new RssConf() - val replacedHost = RssConf.masterHost(conf) + val replacedHost = conf.masterHost assert(!replacedHost.contains("")) assert(replacedHost === Utils.localHostName) - val replacedHosts = RssConf.masterEndpoints(conf) + val replacedHosts = conf.masterEndpoints replacedHosts.foreach { replacedHost => assert(!replacedHost.contains("")) assert(replacedHost contains Utils.localHostName) @@ -109,6 +108,6 @@ class RssConfSuite extends RssFunSuite { .set("celeborn.ha.master.node.1.host", "clb-1") .set("celeborn.ha.master.node.2.host", "clb-1") .set("celeborn.ha.master.node.3.host", "clb-1") - assert(RssConf.haMasterNodeIds(conf).sorted === Array("1", "2", "3")) + assert(conf.haMasterNodeIds.sorted === Array("1", "2", "3")) } } diff --git a/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/HARaftServer.java b/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/HARaftServer.java index 618d2d61f..c1cb10b95 100644 --- a/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/HARaftServer.java +++ b/master/src/main/java/org/apache/celeborn/service/deploy/master/clustermeta/ha/HARaftServer.java @@ -288,7 +288,7 @@ public class HARaftServer { private RaftProperties newRaftProperties(RssConf conf) { final RaftProperties properties = new RaftProperties(); // Set RPC type - final String rpcType = RssConf.haMasterRatisRpcType(conf); + final String rpcType = conf.haMasterRatisRpcType(); final RpcType rpc = SupportedRpcType.valueOfIgnoreCase(rpcType); RaftConfigKeys.Rpc.setType(properties, rpc); @@ -300,18 +300,18 @@ public class HARaftServer { } // Set Ratis storage directory - String storageDir = RssConf.haMasterRatisStorageDir(conf); + String storageDir = conf.haMasterRatisStorageDir(); RaftServerConfigKeys.setStorageDir(properties, Collections.singletonList(new File(storageDir))); // Set RAFT segment size - long raftSegmentSize = RssConf.haMasterRatisLogSegmentSizeMax(conf); + long raftSegmentSize = conf.haMasterRatisLogSegmentSizeMax(); RaftServerConfigKeys.Log.setSegmentSizeMax(properties, SizeInBytes.valueOf(raftSegmentSize)); RaftServerConfigKeys.Log.setPurgeUptoSnapshotIndex(properties, true); // Set RAFT segment pre-allocated size - long raftSegmentPreallocatedSize = RssConf.haMasterRatisLogPreallocatedSize(conf); - int logAppenderQueueNumElements = RssConf.haMasterRatisLogAppenderQueueNumElements(conf); - long logAppenderQueueByteLimit = RssConf.haMasterRatisLogAppenderQueueBytesLimit(conf); + long raftSegmentPreallocatedSize = conf.haMasterRatisLogPreallocatedSize(); + int logAppenderQueueNumElements = conf.haMasterRatisLogAppenderQueueNumElements(); + long logAppenderQueueByteLimit = conf.haMasterRatisLogAppenderQueueBytesLimit(); RaftServerConfigKeys.Log.Appender.setBufferElementLimit( properties, logAppenderQueueNumElements); RaftServerConfigKeys.Log.Appender.setBufferByteLimit( @@ -319,7 +319,7 @@ public class HARaftServer { RaftServerConfigKeys.Log.setPreallocatedSize( properties, SizeInBytes.valueOf(raftSegmentPreallocatedSize)); RaftServerConfigKeys.Log.Appender.setInstallSnapshotEnabled(properties, false); - int logPurgeGap = RssConf.haMasterRatisLogPurgeGap(conf); + int logPurgeGap = conf.haMasterRatisLogPurgeGap(); RaftServerConfigKeys.Log.setPurgeGap(properties, logPurgeGap); // For grpc set the maximum message size @@ -327,19 +327,19 @@ public class HARaftServer { // Set the server request timeout TimeDuration serverRequestTimeout = - TimeDuration.valueOf(RssConf.haMasterRatisRpcRequestTimeout(conf), TimeUnit.SECONDS); + TimeDuration.valueOf(conf.haMasterRatisRpcRequestTimeout(), TimeUnit.SECONDS); RaftServerConfigKeys.Rpc.setRequestTimeout(properties, serverRequestTimeout); // Set timeout for server retry cache entry TimeDuration retryCacheExpiryTime = - TimeDuration.valueOf(RssConf.haMasterRatisRetryCacheExpiryTime(conf), TimeUnit.SECONDS); + TimeDuration.valueOf(conf.haMasterRatisRetryCacheExpiryTime(), TimeUnit.SECONDS); RaftServerConfigKeys.RetryCache.setExpiryTime(properties, retryCacheExpiryTime); // Set the server min and max timeout TimeDuration rpcTimeoutMin = - TimeDuration.valueOf(RssConf.haMasterRatisRpcTimeoutMin(conf), TimeUnit.SECONDS); + TimeDuration.valueOf(conf.haMasterRatisRpcTimeoutMin(), TimeUnit.SECONDS); TimeDuration rpcTimeoutMax = - TimeDuration.valueOf(RssConf.haMasterRatisRpcTimeoutMax(conf), TimeUnit.SECONDS); + TimeDuration.valueOf(conf.haMasterRatisRpcTimeoutMax(), TimeUnit.SECONDS); RaftServerConfigKeys.Rpc.setTimeoutMin(properties, rpcTimeoutMin); RaftServerConfigKeys.Rpc.setTimeoutMax(properties, rpcTimeoutMax); @@ -347,25 +347,24 @@ public class HARaftServer { RaftServerConfigKeys.Log.setSegmentCacheNumMax(properties, 2); TimeDuration noLeaderTimeout = - TimeDuration.valueOf( - RssConf.haMasterRatisNotificationNoLeaderTimeout(conf), TimeUnit.SECONDS); + TimeDuration.valueOf(conf.haMasterRatisNotificationNoLeaderTimeout(), TimeUnit.SECONDS); RaftServerConfigKeys.Notification.setNoLeaderTimeout(properties, noLeaderTimeout); TimeDuration slownessTimeout = - TimeDuration.valueOf(RssConf.haMasterRatisRpcSlownessTimeout(conf), TimeUnit.SECONDS); + TimeDuration.valueOf(conf.haMasterRatisRpcSlownessTimeout(), TimeUnit.SECONDS); RaftServerConfigKeys.Rpc.setSlownessTimeout(properties, slownessTimeout); // Set role checker time - this.roleCheckIntervalMs = RssConf.haMasterRatisRoleCheckInterval(conf); + this.roleCheckIntervalMs = conf.haMasterRatisRoleCheckInterval(); // snapshot retention - int numSnapshotRetentionFileNum = RssConf.haMasterRatisSnapshotRetentionFileNum(conf); + int numSnapshotRetentionFileNum = conf.haMasterRatisSnapshotRetentionFileNum(); RaftServerConfigKeys.Snapshot.setRetentionFileNum(properties, numSnapshotRetentionFileNum); // snapshot interval RaftServerConfigKeys.Snapshot.setAutoTriggerEnabled( - properties, RssConf.haMasterRatisSnapshotAutoTriggerEnabled(conf)); + properties, conf.haMasterRatisSnapshotAutoTriggerEnabled()); - long snapshotAutoTriggerThreshold = RssConf.haMasterRatisSnapshotAutoTriggerThreshold(conf); + long snapshotAutoTriggerThreshold = conf.haMasterRatisSnapshotAutoTriggerThreshold(); RaftServerConfigKeys.Snapshot.setAutoTriggerThreshold(properties, snapshotAutoTriggerThreshold); return properties; diff --git a/master/src/main/scala/org/apache/celeborn/service/deploy/master/Master.scala b/master/src/main/scala/org/apache/celeborn/service/deploy/master/Master.scala index 4bf6cb65d..9b1f6b7d8 100644 --- a/master/src/main/scala/org/apache/celeborn/service/deploy/master/Master.scala +++ b/master/src/main/scala/org/apache/celeborn/service/deploy/master/Master.scala @@ -26,7 +26,6 @@ import scala.collection.JavaConverters._ import scala.util.Random import org.apache.celeborn.common.RssConf -import org.apache.celeborn.common.RssConf.haEnabled import org.apache.celeborn.common.haclient.RssHARetryClient import org.apache.celeborn.common.identity.UserIdentifier import org.apache.celeborn.common.internal.Logging @@ -64,7 +63,7 @@ private[celeborn] class Master( Math.max(64, Runtime.getRuntime.availableProcessors())) private val statusSystem = - if (haEnabled(conf)) { + if (conf.haEnabled) { val sys = new HAMasterMetaManager(rpcEnv, conf) val handler = new MetaHandler(sys) try { @@ -94,8 +93,8 @@ private[celeborn] class Master( private val nonEagerHandler = ThreadUtils.newDaemonCachedThreadPool("master-noneager-handler", 64) // Config constants - private val WorkerTimeoutMs = RssConf.workerHeartbeatTimeoutMs(conf) - private val ApplicationTimeoutMs = RssConf.appHeartbeatTimeoutMs(conf) + private val workerHeartbeatTimeoutMs = conf.workerHeartbeatTimeoutMs + private val appHeartbeatTimeoutMs = conf.appHeartbeatTimeoutMs private val quotaManager = QuotaManager.instantiate(conf) @@ -155,7 +154,7 @@ private[celeborn] class Master( } }, 0, - WorkerTimeoutMs, + workerHeartbeatTimeoutMs, TimeUnit.MILLISECONDS) checkForApplicationTimeOutTask = forwardMessageThread.scheduleAtFixedRate( @@ -165,7 +164,7 @@ private[celeborn] class Master( } }, 0, - ApplicationTimeoutMs / 2, + appHeartbeatTimeoutMs / 2, TimeUnit.MILLISECONDS) } @@ -309,7 +308,7 @@ private[celeborn] class Master( val currentTime = System.currentTimeMillis() var ind = 0 workersSnapShot.asScala.foreach { worker => - if (worker.lastHeartbeat < currentTime - WorkerTimeoutMs + if (worker.lastHeartbeat < currentTime - workerHeartbeatTimeoutMs && !statusSystem.workerLostEvents.contains(worker)) { logWarning(s"Worker ${worker.readableAddress()} timeout! Trigger WorkerLost event.") // trigger WorkerLost event @@ -328,7 +327,7 @@ private[celeborn] class Master( private def timeoutDeadApplications(): Unit = { val currentTime = System.currentTimeMillis() statusSystem.appHeartbeatTime.keySet().asScala.foreach { key => - if (statusSystem.appHeartbeatTime.get(key) < currentTime - ApplicationTimeoutMs) { + if (statusSystem.appHeartbeatTime.get(key) < currentTime - appHeartbeatTimeoutMs) { logWarning(s"Application $key timeout, trigger applicationLost event.") val requestId = RssHARetryClient.genRequestId() var res = self.askSync[ApplicationLostResponse](ApplicationLost(key, requestId)) @@ -723,7 +722,7 @@ private[celeborn] class Master( private def isMasterActive: Int = { // use int rather than bool for better monitoring on dashboard val isActive = - if (haEnabled(conf)) { + if (conf.haEnabled) { if (statusSystem.asInstanceOf[HAMasterMetaManager].getRatisServer.isLeader) { 1 } else { diff --git a/master/src/main/scala/org/apache/celeborn/service/deploy/master/MasterArguments.scala b/master/src/main/scala/org/apache/celeborn/service/deploy/master/MasterArguments.scala index c8b4d6ee4..91c762561 100644 --- a/master/src/main/scala/org/apache/celeborn/service/deploy/master/MasterArguments.scala +++ b/master/src/main/scala/org/apache/celeborn/service/deploy/master/MasterArguments.scala @@ -40,15 +40,15 @@ class MasterArguments(args: Array[String], conf: RssConf) { // 3rd read from configuration file _propertiesFile = Some(Utils.loadDefaultRssProperties(conf, _propertiesFile.orNull)) - if (haEnabled(conf)) { + if (conf.haEnabled) { val clusterInfo = MasterClusterInfo.loadHAConfig(conf) val localNode = clusterInfo.localNode - _host = _host.orElse(Some(haMasterNodeHost(conf, localNode.nodeId))) - _port = _port.orElse(Some(haMasterNodePort(conf, localNode.nodeId))) + _host = _host.orElse(Some(conf.haMasterNodeHost(localNode.nodeId))) + _port = _port.orElse(Some(conf.haMasterNodePort(localNode.nodeId))) _masterClusterInfo = Some(clusterInfo) } else { - _host = _host.orElse(Some(masterHost(conf))) - _port = _port.orElse(Some(masterPort(conf))) + _host = _host.orElse(Some(conf.masterHost)) + _port = _port.orElse(Some(conf.masterPort)) } if (_host.isEmpty || _port.isEmpty) { diff --git a/master/src/main/scala/org/apache/celeborn/service/deploy/master/clustermeta/ha/MasterClusterInfo.scala b/master/src/main/scala/org/apache/celeborn/service/deploy/master/clustermeta/ha/MasterClusterInfo.scala index c3c400253..2f71718c7 100644 --- a/master/src/main/scala/org/apache/celeborn/service/deploy/master/clustermeta/ha/MasterClusterInfo.scala +++ b/master/src/main/scala/org/apache/celeborn/service/deploy/master/clustermeta/ha/MasterClusterInfo.scala @@ -38,15 +38,15 @@ object MasterClusterInfo extends Logging { @throws[IllegalArgumentException] def loadHAConfig(conf: RssConf): MasterClusterInfo = { - val localNodeIdOpt = haMasterNodeId(conf) - val clusterNodeIds = haMasterNodeIds(conf) + val localNodeIdOpt = conf.haMasterNodeId + val clusterNodeIds = conf.haMasterNodeIds val masterNodes = clusterNodeIds.map { nodeId => - val ratisHost = RssConf.haMasterRatisHost(conf, nodeId) - val ratisPort = RssConf.haMasterRatisPort(conf, nodeId) + val ratisHost = conf.haMasterRatisHost(nodeId) + val ratisPort = conf.haMasterRatisPort(nodeId) val ratisAddr = createSocketAddr(ratisHost, ratisPort) - val rpcHost = RssConf.haMasterNodeHost(conf, nodeId) - val rpcPort = RssConf.haMasterNodePort(conf, nodeId) + val rpcHost = conf.haMasterNodeHost(nodeId) + val rpcPort = conf.haMasterNodePort(nodeId) val rpcAddr = createSocketAddr(rpcHost, rpcPort) MasterNode(nodeId, ratisAddr, rpcAddr) } diff --git a/service/src/main/scala/org/apache/celeborn/server/common/HttpService.scala b/service/src/main/scala/org/apache/celeborn/server/common/HttpService.scala index 1bd2949d6..382f251bf 100644 --- a/service/src/main/scala/org/apache/celeborn/server/common/HttpService.scala +++ b/service/src/main/scala/org/apache/celeborn/server/common/HttpService.scala @@ -53,18 +53,18 @@ abstract class HttpService extends Service with Logging { private def prometheusHost(): String = { serviceName match { case Service.MASTER => - RssConf.masterPrometheusMetricHost(conf) + conf.masterPrometheusMetricHost case Service.WORKER => - RssConf.workerPrometheusMetricHost(conf) + conf.workerPrometheusMetricHost } } private def prometheusPort(): Int = { serviceName match { case Service.MASTER => - RssConf.masterPrometheusMetricPort(conf) + conf.masterPrometheusMetricPort case Service.WORKER => - RssConf.workerPrometheusMetricPort(conf) + conf.workerPrometheusMetricPort } } diff --git a/service/src/main/scala/org/apache/celeborn/server/common/Service.scala b/service/src/main/scala/org/apache/celeborn/server/common/Service.scala index 28f7af54c..35b2ac928 100644 --- a/service/src/main/scala/org/apache/celeborn/server/common/Service.scala +++ b/service/src/main/scala/org/apache/celeborn/server/common/Service.scala @@ -29,7 +29,7 @@ abstract class Service extends Logging { def metricsSystem: MetricsSystem def initialize(): Unit = { - if (RssConf.metricsSystemEnable(conf)) { + if (conf.metricsSystemEnable) { logInfo(s"Metrics system enabled.") metricsSystem.start() } diff --git a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/FileWriter.java b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/FileWriter.java index 0b7c6694e..9477197dc 100644 --- a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/FileWriter.java +++ b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/FileWriter.java @@ -65,7 +65,7 @@ public final class FileWriter implements DeviceObserver { private final int flushWorkerIndex; private CompositeByteBuf flushBuffer; - private final long chunkSize; + private final long shuffleChunkSize; private final long writerCloseTimeoutMs; private final long flusherBufferSize; @@ -107,11 +107,11 @@ public final class FileWriter implements DeviceObserver { this.fileInfo = fileInfo; this.flusher = flusher; this.flushWorkerIndex = flusher.getWorkerIndex(); - this.chunkSize = RssConf.shuffleChunkSize(rssConf); - this.nextBoundary = this.chunkSize; + this.shuffleChunkSize = rssConf.shuffleChunkSize(); + this.nextBoundary = this.shuffleChunkSize; this.writerCloseTimeoutMs = rssConf.writerCloseTimeoutMs(); this.splitThreshold = splitThreshold; - this.flusherBufferSize = RssConf.workerFlusherBufferSize(rssConf); + this.flusherBufferSize = rssConf.workerFlusherBufferSize(); this.deviceMonitor = deviceMonitor; this.splitMode = splitMode; this.partitionType = partitionType; @@ -164,7 +164,7 @@ public final class FileWriter implements DeviceObserver { private void maybeSetChunkOffsets(boolean forceSet) { if (bytesFlushed >= nextBoundary || forceSet) { fileInfo.addChunkOffset(bytesFlushed); - nextBoundary = bytesFlushed + chunkSize; + nextBoundary = bytesFlushed + shuffleChunkSize; } } diff --git a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/PartitionFilesSorter.java b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/PartitionFilesSorter.java index 3ff74e4f1..7118ac6aa 100644 --- a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/PartitionFilesSorter.java +++ b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/PartitionFilesSorter.java @@ -83,7 +83,7 @@ public class PartitionFilesSorter extends ShuffleRecoverHelper { private final AtomicInteger sortedFileCount = new AtomicInteger(); private final AtomicLong sortedFilesSize = new AtomicLong(); protected final long sortTimeout; - protected final long fetchChunkSize; + protected final long shuffleChunkSize; protected final long initialReserveSingleSortMemory; private boolean gracefulShutdown; private long partitionSorterShutdownAwaitTime; @@ -101,7 +101,7 @@ public class PartitionFilesSorter extends ShuffleRecoverHelper { public PartitionFilesSorter(MemoryTracker memoryTracker, RssConf conf, AbstractSource source) { this.sortTimeout = RssConf.partitionSortTimeout(conf); - this.fetchChunkSize = RssConf.shuffleChunkSize(conf); + this.shuffleChunkSize = conf.shuffleChunkSize(); this.initialReserveSingleSortMemory = RssConf.initialReserveSingleSortMemory(conf); this.partitionSorterShutdownAwaitTime = conf.partitionSorterCloseAwaitTimeMs(); this.source = source; @@ -497,7 +497,7 @@ public class PartitionFilesSorter extends ShuffleRecoverHelper { return new FileInfo( sortedFilePath, ShuffleBlockInfoUtils.getChunkOffsetsFromShuffleBlockInfos( - startMapIndex, endMapIndex, fetchChunkSize, indexMap), + startMapIndex, endMapIndex, shuffleChunkSize, indexMap), userIdentifier); } diff --git a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala index 12d0a7c15..35827dd07 100644 --- a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala +++ b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/Worker.scala @@ -183,15 +183,15 @@ private[celeborn] class Worker( private var checkFastfailTask: ScheduledFuture[_] = _ val replicateThreadPool = ThreadUtils.newDaemonCachedThreadPool( "worker-replicate-data", - RssConf.workerReplicateThreads(conf)) + conf.workerReplicateThreads) val commitThreadPool = ThreadUtils.newDaemonCachedThreadPool( "Worker-CommitFiles", - RssConf.workerCommitThreads(conf)) + conf.workerCommitThreads) val asyncReplyPool = ThreadUtils.newDaemonSingleThreadScheduledExecutor("async-reply") val timer = new HashedWheelTimer() // Configs - private val HEARTBEAT_MILLIS = RssConf.workerHeartbeatTimeoutMs(conf) / 4 + private val HEARTBEAT_MILLIS = conf.workerHeartbeatTimeoutMs / 4 private val REPLICATE_FAST_FAIL_DURATION = RssConf.replicateFastFailDurationMs(conf) private val cleanTaskQueue = new LinkedBlockingQueue[JHashSet[String]]