diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/SparkSQLEngine.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/SparkSQLEngine.scala index eaeb91bdc..ab7c58917 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/SparkSQLEngine.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/SparkSQLEngine.scala @@ -30,7 +30,7 @@ import org.apache.kyuubi.{KyuubiException, Logging} import org.apache.kyuubi.Utils._ import org.apache.kyuubi.config.KyuubiConf import org.apache.kyuubi.config.KyuubiConf._ -import org.apache.kyuubi.engine.spark.SparkSQLEngine.countDownLatch +import org.apache.kyuubi.engine.spark.SparkSQLEngine.{countDownLatch, currentEngine} import org.apache.kyuubi.engine.spark.events.{EngineEvent, EngineEventsStore, EventLoggingService} import org.apache.kyuubi.ha.HighAvailabilityConf._ import org.apache.kyuubi.ha.client.RetryPolicies @@ -54,7 +54,10 @@ case class SparkSQLEngine( super.start() // Start engine self-terminating checker after all services are ready and it can be reached by // all servers in engine spaces. - backendService.sessionManager.startTerminatingChecker() + backendService.sessionManager.startTerminatingChecker(() => { + assert(currentEngine.isDefined) + currentEngine.get.stop() + }) } override protected def stopServer(): Unit = { diff --git a/kyuubi-common/src/main/scala/org/apache/kyuubi/session/SessionManager.scala b/kyuubi-common/src/main/scala/org/apache/kyuubi/session/SessionManager.scala index 9ba51d636..2ea57df09 100644 --- a/kyuubi-common/src/main/scala/org/apache/kyuubi/session/SessionManager.scala +++ b/kyuubi-common/src/main/scala/org/apache/kyuubi/session/SessionManager.scala @@ -265,7 +265,7 @@ abstract class SessionManager(name: String) extends CompositeService(name) { timeoutChecker.scheduleWithFixedDelay(checkTask, interval, interval, TimeUnit.MILLISECONDS) } - private[kyuubi] def startTerminatingChecker(): Unit = if (!isServer) { + private[kyuubi] def startTerminatingChecker(stop: () => Unit): Unit = if (!isServer) { // initialize `_latestLogoutTime` at start _latestLogoutTime = System.currentTimeMillis() val interval = conf.get(ENGINE_CHECK_INTERVAL) @@ -275,7 +275,7 @@ abstract class SessionManager(name: String) extends CompositeService(name) { if (!shutdown && System.currentTimeMillis() - latestLogoutTime > idleTimeout && getOpenSessionCount <= 0) { info(s"Idled for more than $idleTimeout ms, terminating") - sys.exit(0) + stop() } } }