diff --git a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala index 819014f1b..af29b57dd 100644 --- a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala +++ b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala @@ -75,8 +75,8 @@ class CelebornShuffleReader[K, C]( override def read(): Iterator[Product2[K, C]] = { + val startTime = System.currentTimeMillis() val serializerInstance = newSerializerInstance(dep) - val shuffleId = SparkUtils.celebornShuffleId(shuffleClient, handle, context, false) shuffleIdTracker.track(handle.shuffleId, shuffleId) logDebug( @@ -104,7 +104,6 @@ class CelebornShuffleReader[K, C]( } } - val startTime = System.currentTimeMillis() val fetchTimeoutMs = conf.clientFetchTimeoutMs val localFetchEnabled = conf.enableReadLocalShuffleFile val localHostAddress = Utils.localHostName(conf) @@ -121,6 +120,7 @@ class CelebornShuffleReader[K, C]( case e: Throwable => throw e } + val batchOpenStreamStartTime = System.currentTimeMillis() // host-port -> (TransportClient, PartitionLocation Array, PbOpenStreamList) val workerRequestMap = new JHashMap[ String, @@ -223,7 +223,9 @@ class CelebornShuffleReader[K, C]( // wait for all futures to complete futures.foreach(f => f.get()) val end = System.currentTimeMillis() - logInfo(s"BatchOpenStream for $partCnt cost ${end - startTime}ms") + // readTime should include batchOpenStreamTime, getShuffleId Rpc time and updateFileGroup Rpc time + metricsCallback.incReadTime(end - startTime) + logInfo(s"BatchOpenStream for $partCnt cost ${end - batchOpenStreamStartTime}ms") val streams = JavaUtils.newConcurrentHashMap[Integer, CelebornInputStream]() @@ -304,14 +306,15 @@ class CelebornShuffleReader[K, C]( } } if (sleepCnt == 0) { - logInfo(s"inputStream for partition: $partitionId is null, sleeping...") + logInfo(s"inputStream for partition: $partitionId is null, sleeping 5ms") } sleepCnt += 1 - Thread.sleep(50) + Thread.sleep(5) inputStream = streams.get(partitionId) } if (sleepCnt > 0) { - logInfo(s"inputStream for partition: $partitionId is not null, sleep count: $sleepCnt") + logInfo( + s"inputStream for partition: $partitionId is not null, sleep $sleepCnt times for ${5 * sleepCnt} ms") } metricsCallback.incReadTime( TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startFetchWait))