From ff0cf15770227f090c5a674dc13ebc9ae54d78ee Mon Sep 17 00:00:00 2001 From: SteNicholas Date: Fri, 23 Feb 2024 15:30:24 +0800 Subject: [PATCH] [CELEBORN-1283] TransportClientFactory avoid contention and get or create clientPools quickly ### What changes were proposed in this pull request? `TransportClientFactory` avoid contention and get or create clientPools quickly. ### Why are the changes needed? Avoid contention for getting or creating clientPools, and clean up the code. Backport: [[SPARK-38555][NETWORK][SHUFFLE] Avoid contention and get or create clientPools quickly in the TransportClientFactory](https://github.com/apache/spark/pull/35860) ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? No. Closes #2322 from SteNicholas/CELEBORN-1283. Authored-by: SteNicholas Signed-off-by: mingji --- .../common/network/client/TransportClientFactory.java | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/common/src/main/java/org/apache/celeborn/common/network/client/TransportClientFactory.java b/common/src/main/java/org/apache/celeborn/common/network/client/TransportClientFactory.java index 68bad239b..f51d04a57 100644 --- a/common/src/main/java/org/apache/celeborn/common/network/client/TransportClientFactory.java +++ b/common/src/main/java/org/apache/celeborn/common/network/client/TransportClientFactory.java @@ -136,13 +136,9 @@ public class TransportClientFactory implements Closeable { InetSocketAddress.createUnresolved(remoteHost, remotePort); // Create the ClientPool if we don't have it yet. - ClientPool clientPool = connectionPool.get(unresolvedAddress); - if (clientPool == null) { - connectionPool.computeIfAbsent( - unresolvedAddress, key -> new ClientPool(numConnectionsPerPeer)); - clientPool = connectionPool.get(unresolvedAddress); - } - + ClientPool clientPool = + connectionPool.computeIfAbsent( + unresolvedAddress, key -> new ClientPool(numConnectionsPerPeer)); int clientIndex = partitionId < 0 ? rand.nextInt(numConnectionsPerPeer) : partitionId % numConnectionsPerPeer; TransportClient cachedClient = clientPool.clients[clientIndex];