diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/kyuubi/SQLOperationListener.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/kyuubi/SQLOperationListener.scala index 4e4a940d2..686cb1f35 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/kyuubi/SQLOperationListener.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/kyuubi/SQLOperationListener.scala @@ -20,6 +20,8 @@ package org.apache.spark.kyuubi import java.util.Properties import java.util.concurrent.ConcurrentHashMap +import scala.collection.JavaConverters._ + import org.apache.spark.scheduler._ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionEnd @@ -44,7 +46,7 @@ class SQLOperationListener( spark: SparkSession) extends StatsReportListener with Logging { private val operationId: String = operation.getHandle.identifier.toString - private lazy val activeJobs = new java.util.HashSet[Int]() + private lazy val activeJobs = new ConcurrentHashMap[Int, SparkJobInfo]() private lazy val activeStages = new ConcurrentHashMap[SparkStageAttempt, SparkStageInfo]() private var executionId: Option[Long] = None @@ -53,6 +55,7 @@ class SQLOperationListener( if (conf.get(ENGINE_SPARK_SHOW_PROGRESS)) { Some(new SparkConsoleProgressBar( operation, + activeJobs, activeStages, conf.get(ENGINE_SPARK_SHOW_PROGRESS_UPDATE_INTERVAL), conf.get(ENGINE_SPARK_SHOW_PROGRESS_TIME_FORMAT))) @@ -79,9 +82,10 @@ class SQLOperationListener( } } - override def onJobStart(jobStart: SparkListenerJobStart): Unit = activeJobs.synchronized { + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { if (sameGroupId(jobStart.properties)) { val jobId = jobStart.jobId + val stageIds = jobStart.stageInfos.map(_.stageId).toSet val stageSize = jobStart.stageInfos.size if (executionId.isEmpty) { executionId = Option(jobStart.properties.getProperty(SPARK_SQL_EXECUTION_ID_KEY)) @@ -93,17 +97,19 @@ class SQLOperationListener( case _ => } } + activeJobs.put( + jobId, + new SparkJobInfo(stageSize, stageIds)) withOperationLog { - activeJobs.add(jobId) info(s"Query [$operationId]: Job $jobId started with $stageSize stages," + s" ${activeJobs.size()} active jobs running") } } } - override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = activeJobs.synchronized { + override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { val jobId = jobEnd.jobId - if (activeJobs.remove(jobId)) { + if (activeJobs.remove(jobId) != null) { val hint = jobEnd.jobResult match { case JobSucceeded => "succeeded" case _ => "failed" // TODO: Handle JobFailed(exception: Exception) @@ -134,9 +140,18 @@ class SQLOperationListener( override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = { val stageInfo = stageCompleted.stageInfo + val stageId = stageInfo.stageId val stageAttempt = SparkStageAttempt(stageInfo.stageId, stageInfo.attemptNumber()) activeStages.synchronized { if (activeStages.remove(stageAttempt) != null) { + stageInfo.getStatusString match { + case "succeeded" => + activeJobs.asScala.foreach { case (_, jobInfo) => + if (jobInfo.stageIds.contains(stageId)) { + jobInfo.numCompleteStages.getAndIncrement() + } + } + } withOperationLog(super.onStageCompleted(stageCompleted)) } } diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/kyuubi/SparkConsoleProgressBar.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/kyuubi/SparkConsoleProgressBar.scala index dc8b493cc..feb0d16a1 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/kyuubi/SparkConsoleProgressBar.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/kyuubi/SparkConsoleProgressBar.scala @@ -29,6 +29,7 @@ import org.apache.kyuubi.operation.Operation class SparkConsoleProgressBar( operation: Operation, + liveJobs: ConcurrentHashMap[Int, SparkJobInfo], liveStages: ConcurrentHashMap[SparkStageAttempt, SparkStageInfo], updatePeriodMSec: Long, timeFormat: String) @@ -72,6 +73,17 @@ class SparkConsoleProgressBar( } } + /** + * Use stageId to find stage's jobId + * @param stageId + * @return jobId (Optional) + */ + private def findJobId(stageId: Int): Option[Int] = { + liveJobs.asScala.collectFirst { + case (jobId, jobInfo) if jobInfo.stageIds.contains(stageId) => jobId + } + } + /** * Show progress bar in console. The progress bar is displayed in the next line * after your last output, keeps overwriting itself to hold in one line. The logging will follow @@ -81,9 +93,13 @@ class SparkConsoleProgressBar( val width = TerminalWidth / stages.size val bar = stages.map { s => val total = s.numTasks - val header = s"[Stage ${s.stageId}:" + val jobHeader = findJobId(s.stageId).map(jobId => + s"[Job $jobId (${liveJobs.get(jobId).numCompleteStages} " + + s"/ ${liveJobs.get(jobId).numStages}) Stages] ").getOrElse( + "[There is no job about this stage] ") + val header = jobHeader + s"[Stage ${s.stageId}:" val tailer = s"(${s.numCompleteTasks} + ${s.numActiveTasks}) / $total]" - val w = width - header.length - tailer.length + val w = width + jobHeader.length - header.length - tailer.length val bar = if (w > 0) { val percent = w * s.numCompleteTasks.get / total diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/kyuubi/StageStatus.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/kyuubi/StageStatus.scala index 2ea9c3fda..29644f9f4 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/kyuubi/StageStatus.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/kyuubi/StageStatus.scala @@ -24,6 +24,10 @@ case class SparkStageAttempt(stageId: Int, stageAttemptId: Int) { } class SparkStageInfo(val stageId: Int, val numTasks: Int) { - var numActiveTasks = new AtomicInteger(0) - var numCompleteTasks = new AtomicInteger(0) + val numActiveTasks = new AtomicInteger(0) + val numCompleteTasks = new AtomicInteger(0) +} + +class SparkJobInfo(val numStages: Int, val stageIds: Set[Int]) { + val numCompleteStages = new AtomicInteger(0) } diff --git a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/spark/kyuubi/SQLOperationListenerSuite.scala b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/spark/kyuubi/SQLOperationListenerSuite.scala index 04277fca4..f732f7c38 100644 --- a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/spark/kyuubi/SQLOperationListenerSuite.scala +++ b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/spark/kyuubi/SQLOperationListenerSuite.scala @@ -22,13 +22,16 @@ import scala.collection.JavaConverters.asScalaBufferConverter import org.apache.hive.service.rpc.thrift.{TExecuteStatementReq, TFetchOrientation, TFetchResultsReq, TOperationHandle} import org.scalatest.time.SpanSugar._ +import org.apache.kyuubi.config.KyuubiConf import org.apache.kyuubi.config.KyuubiConf.OPERATION_SPARK_LISTENER_ENABLED import org.apache.kyuubi.engine.spark.WithSparkSQLEngine import org.apache.kyuubi.operation.HiveJDBCTestHelper class SQLOperationListenerSuite extends WithSparkSQLEngine with HiveJDBCTestHelper { - override def withKyuubiConf: Map[String, String] = Map.empty + override def withKyuubiConf: Map[String, String] = Map( + KyuubiConf.ENGINE_SPARK_SHOW_PROGRESS.key -> "true", + KyuubiConf.ENGINE_SPARK_SHOW_PROGRESS_UPDATE_INTERVAL.key -> "200") override protected def jdbcUrl: String = getJdbcUrl @@ -54,6 +57,24 @@ class SQLOperationListenerSuite extends WithSparkSQLEngine with HiveJDBCTestHelp } } + test("operation listener with progress job info") { + val sql = "SELECT java_method('java.lang.Thread', 'sleep', 10000l) FROM range(1, 3, 1, 2);" + withSessionHandle { (client, handle) => + val req = new TExecuteStatementReq() + req.setSessionHandle(handle) + req.setStatement(sql) + val tExecuteStatementResp = client.ExecuteStatement(req) + val opHandle = tExecuteStatementResp.getOperationHandle + val fetchResultsReq = new TFetchResultsReq(opHandle, TFetchOrientation.FETCH_NEXT, 1000) + fetchResultsReq.setFetchType(1.toShort) + eventually(timeout(90.seconds), interval(500.milliseconds)) { + val resultsResp = client.FetchResults(fetchResultsReq) + val logs = resultsResp.getResults.getColumns.get(0).getStringVal.getValues.asScala + assert(logs.exists(_.matches(".*\\[Job .* Stages\\] \\[Stage .*\\]"))) + } + } + } + test("SQLOperationListener configurable") { val sql = "select /*+ REPARTITION(3, a) */ a from values(1) t(a);" withSessionHandle { (client, handle) =>