diff --git a/externals/kyuubi-flink-sql-engine/src/main/scala/org/apache/kyuubi/engine/flink/operation/ExecuteStatement.scala b/externals/kyuubi-flink-sql-engine/src/main/scala/org/apache/kyuubi/engine/flink/operation/ExecuteStatement.scala index 96aff0ec1..37dc6f40b 100644 --- a/externals/kyuubi-flink-sql-engine/src/main/scala/org/apache/kyuubi/engine/flink/operation/ExecuteStatement.scala +++ b/externals/kyuubi-flink-sql-engine/src/main/scala/org/apache/kyuubi/engine/flink/operation/ExecuteStatement.scala @@ -179,7 +179,7 @@ class ExecuteStatement( private def addTimeoutMonitor(): Unit = { if (queryTimeout > 0) { val timeoutExecutor = - ThreadUtils.newDaemonSingleThreadScheduledExecutor("query-timeout-thread") + ThreadUtils.newDaemonSingleThreadScheduledExecutor("query-timeout-thread", false) val action: Runnable = () => cleanup(OperationState.TIMEOUT) timeoutExecutor.schedule(action, queryTimeout, TimeUnit.SECONDS) statementTimeoutCleaner = Some(timeoutExecutor) diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecuteStatement.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecuteStatement.scala index 7db23874f..c531e1cc9 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecuteStatement.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecuteStatement.scala @@ -127,7 +127,7 @@ class ExecuteStatement( private def addTimeoutMonitor(): Unit = { if (queryTimeout > 0) { val timeoutExecutor = - ThreadUtils.newDaemonSingleThreadScheduledExecutor("query-timeout-thread") + ThreadUtils.newDaemonSingleThreadScheduledExecutor("query-timeout-thread", false) timeoutExecutor.schedule( new Runnable { override def run(): Unit = { diff --git a/kyuubi-common/src/main/scala/org/apache/kyuubi/util/ThreadUtils.scala b/kyuubi-common/src/main/scala/org/apache/kyuubi/util/ThreadUtils.scala index a540b954d..48e8d0f59 100644 --- a/kyuubi-common/src/main/scala/org/apache/kyuubi/util/ThreadUtils.scala +++ b/kyuubi-common/src/main/scala/org/apache/kyuubi/util/ThreadUtils.scala @@ -26,11 +26,15 @@ import org.apache.kyuubi.{KyuubiException, Logging} object ThreadUtils extends Logging { - def newDaemonSingleThreadScheduledExecutor(threadName: String): ScheduledExecutorService = { + def newDaemonSingleThreadScheduledExecutor( + threadName: String, + executeExistingDelayedTasksAfterShutdown: Boolean = true): ScheduledExecutorService = { val threadFactory = new NamedThreadFactory(threadName, daemon = true) val executor = new ScheduledThreadPoolExecutor(1, threadFactory) executor.setRemoveOnCancelPolicy(true) executor + .setExecuteExistingDelayedTasksAfterShutdownPolicy(executeExistingDelayedTasksAfterShutdown) + executor } def newDaemonQueuedThreadPool( diff --git a/kyuubi-common/src/test/scala/org/apache/kyuubi/util/ThreadUtilsSuite.scala b/kyuubi-common/src/test/scala/org/apache/kyuubi/util/ThreadUtilsSuite.scala index 858a71570..6bf0247b6 100644 --- a/kyuubi-common/src/test/scala/org/apache/kyuubi/util/ThreadUtilsSuite.scala +++ b/kyuubi-common/src/test/scala/org/apache/kyuubi/util/ThreadUtilsSuite.scala @@ -35,4 +35,30 @@ class ThreadUtilsSuite extends KyuubiFunSuite { service.awaitTermination(10, TimeUnit.SECONDS) assert(threadName startsWith "ThreadUtilsTest") } + + test("New daemon single thread scheduled executor for shutdownNow") { + val service = ThreadUtils.newDaemonSingleThreadScheduledExecutor("ThreadUtilsTest") + @volatile var threadName = "" + service.submit(new Runnable { + override def run(): Unit = { + threadName = Thread.currentThread().getName + } + }) + service.shutdownNow() + service.awaitTermination(10, TimeUnit.SECONDS) + assert(threadName startsWith "") + } + + test("New daemon single thread scheduled executor for cancel delayed tasks") { + val service = ThreadUtils.newDaemonSingleThreadScheduledExecutor("ThreadUtilsTest", false) + @volatile var threadName = "" + service.submit(new Runnable { + override def run(): Unit = { + threadName = Thread.currentThread().getName + } + }) + service.shutdown() + service.awaitTermination(10, TimeUnit.SECONDS) + assert(threadName startsWith "") + } }