diff --git a/kyuubi-server/src/main/scala/org/apache/kyuubi/client/KyuubiSyncThriftClient.scala b/kyuubi-server/src/main/scala/org/apache/kyuubi/client/KyuubiSyncThriftClient.scala index 65a034037..6bdc3e5fc 100644 --- a/kyuubi-server/src/main/scala/org/apache/kyuubi/client/KyuubiSyncThriftClient.scala +++ b/kyuubi-server/src/main/scala/org/apache/kyuubi/client/KyuubiSyncThriftClient.scala @@ -17,12 +17,14 @@ package org.apache.kyuubi.client -import java.util.concurrent.{ScheduledExecutorService, TimeUnit} +import java.util.concurrent.{ExecutorService, ScheduledExecutorService, TimeUnit} import java.util.concurrent.locks.ReentrantLock import scala.collection.JavaConverters._ +import scala.concurrent.ExecutionException import scala.concurrent.duration.Duration +import com.google.common.annotations.VisibleForTesting import org.apache.hive.service.rpc.thrift._ import org.apache.thrift.TException import org.apache.thrift.protocol.{TBinaryProtocol, TProtocol} @@ -56,6 +58,24 @@ class KyuubiSyncThriftClient private ( private var engineAliveThreadPool: ScheduledExecutorService = _ @volatile private var engineLastAlive: Long = _ + private var asyncRequestExecutor: ExecutorService = _ + + @VisibleForTesting + @volatile private[kyuubi] var asyncRequestInterrupted: Boolean = false + + @VisibleForTesting + private[kyuubi] def getEngineAliveProbeProtocol: Option[TProtocol] = engineAliveProbeProtocol + + private def newAsyncRequestExecutor(): ExecutorService = { + ThreadUtils.newDaemonSingleThreadScheduledExecutor( + "async-request-executor-" + _remoteSessionHandle) + } + + private def shutdownAsyncRequestExecutor(): Unit = { + Option(asyncRequestExecutor).filterNot(_.isShutdown).foreach(ThreadUtils.shutdown(_)) + asyncRequestInterrupted = true + } + private def startEngineAliveProbe(): Unit = { engineAliveThreadPool = ThreadUtils.newDaemonSingleThreadScheduledExecutor( "engine-alive-probe-" + _aliveProbeSessionHandle) @@ -82,6 +102,8 @@ class KyuubiSyncThriftClient private ( } } } + } else { + shutdownAsyncRequestExecutor() } } } @@ -106,16 +128,23 @@ class KyuubiSyncThriftClient private ( } finally lock.unlock() } - private def withRetryingRequest[T](block: => T, request: String): T = withLockAcquired { - val (resp, shouldResetEngineBroken) = KyuubiSyncThriftClient.withRetryingRequestNoLock( - block, - request, - maxAttempts, - remoteEngineBroken, - isConnectionValid) + private def withLockAcquiredAsyncRequest[T](block: => T): T = withLockAcquired { + if (asyncRequestExecutor == null || asyncRequestExecutor.isShutdown) { + asyncRequestExecutor = newAsyncRequestExecutor() + } - if (shouldResetEngineBroken) remoteEngineBroken = false - resp + val task = asyncRequestExecutor.submit(() => { + val resp = block + remoteEngineBroken = false + resp + }) + + try { + task.get() + } catch { + case e: ExecutionException => throw e.getCause + case e: Throwable => throw e + } } def engineId: Option[String] = _engineId @@ -132,7 +161,7 @@ class KyuubiSyncThriftClient private ( req.setUsername(user) req.setPassword(password) req.setConfiguration(configs.asJava) - val resp = withRetryingRequest(OpenSession(req), "OpenSession") + val resp = withLockAcquired(OpenSession(req)) ThriftUtils.verifyTStatus(resp.getStatus) _remoteSessionHandle = resp.getSessionHandle _engineId = Option(resp.getConfiguration) @@ -157,7 +186,7 @@ class KyuubiSyncThriftClient private ( try { if (_remoteSessionHandle != null) { val req = new TCloseSessionReq(_remoteSessionHandle) - val resp = withRetryingRequest(CloseSession(req), "CloseSession") + val resp = withLockAcquiredAsyncRequest(CloseSession(req)) ThriftUtils.verifyTStatus(resp.getStatus) } } catch { @@ -179,6 +208,7 @@ class KyuubiSyncThriftClient private ( Seq(protocol).union(engineAliveProbeProtocol.toSeq).foreach { tProtocol => if (tProtocol.getTransport.isOpen) tProtocol.getTransport.close() } + shutdownAsyncRequestExecutor() } } @@ -193,21 +223,21 @@ class KyuubiSyncThriftClient private ( req.setConfOverlay(confOverlay.asJava) req.setRunAsync(shouldRunAsync) req.setQueryTimeout(queryTimeout) - val resp = withRetryingRequest(ExecuteStatement(req), "ExecuteStatement") + val resp = withLockAcquiredAsyncRequest(ExecuteStatement(req)) ThriftUtils.verifyTStatus(resp.getStatus) resp.getOperationHandle } def getTypeInfo: TOperationHandle = { val req = new TGetTypeInfoReq(_remoteSessionHandle) - val resp = withRetryingRequest(GetTypeInfo(req), "GetTypeInfo") + val resp = withLockAcquiredAsyncRequest(GetTypeInfo(req)) ThriftUtils.verifyTStatus(resp.getStatus) resp.getOperationHandle } def getCatalogs: TOperationHandle = { val req = new TGetCatalogsReq(_remoteSessionHandle) - val resp = withRetryingRequest(GetCatalogs(req), "GetCatalogs") + val resp = withLockAcquiredAsyncRequest(GetCatalogs(req)) ThriftUtils.verifyTStatus(resp.getStatus) resp.getOperationHandle } @@ -217,7 +247,7 @@ class KyuubiSyncThriftClient private ( req.setSessionHandle(_remoteSessionHandle) req.setCatalogName(catalogName) req.setSchemaName(schemaName) - val resp = withRetryingRequest(GetSchemas(req), "GetSchemas") + val resp = withLockAcquiredAsyncRequest(GetSchemas(req)) ThriftUtils.verifyTStatus(resp.getStatus) resp.getOperationHandle } @@ -233,14 +263,14 @@ class KyuubiSyncThriftClient private ( req.setSchemaName(schemaName) req.setTableName(tableName) req.setTableTypes(tableTypes) - val resp = withRetryingRequest(GetTables(req), "GetTables") + val resp = withLockAcquiredAsyncRequest(GetTables(req)) ThriftUtils.verifyTStatus(resp.getStatus) resp.getOperationHandle } def getTableTypes: TOperationHandle = { val req = new TGetTableTypesReq(_remoteSessionHandle) - val resp = withRetryingRequest(GetTableTypes(req), "GetTableTypes") + val resp = withLockAcquiredAsyncRequest(GetTableTypes(req)) ThriftUtils.verifyTStatus(resp.getStatus) resp.getOperationHandle } @@ -255,7 +285,7 @@ class KyuubiSyncThriftClient private ( req.setSchemaName(schemaName) req.setTableName(tableName) req.setColumnName(columnName) - val resp = withRetryingRequest(GetColumns(req), "GetColumns") + val resp = withLockAcquiredAsyncRequest(GetColumns(req)) ThriftUtils.verifyTStatus(resp.getStatus) resp.getOperationHandle } @@ -267,7 +297,7 @@ class KyuubiSyncThriftClient private ( val req = new TGetFunctionsReq(_remoteSessionHandle, functionName) req.setCatalogName(catalogName) req.setSchemaName(schemaName) - val resp = withRetryingRequest(GetFunctions(req), "GetFunctions") + val resp = withLockAcquiredAsyncRequest(GetFunctions(req)) ThriftUtils.verifyTStatus(resp.getStatus) resp.getOperationHandle } @@ -281,7 +311,7 @@ class KyuubiSyncThriftClient private ( req.setCatalogName(catalogName) req.setSchemaName(schemaName) req.setTableName(tableName) - val resp = withRetryingRequest(GetPrimaryKeys(req), "GetPrimaryKeys") + val resp = withLockAcquiredAsyncRequest(GetPrimaryKeys(req)) ThriftUtils.verifyTStatus(resp.getStatus) resp.getOperationHandle } @@ -301,26 +331,26 @@ class KyuubiSyncThriftClient private ( req.setForeignCatalogName(foreignCatalog) req.setForeignSchemaName(foreignSchema) req.setForeignTableName(foreignTable) - val resp = withRetryingRequest(GetCrossReference(req), "GetCrossReference") + val resp = withLockAcquiredAsyncRequest(GetCrossReference(req)) ThriftUtils.verifyTStatus(resp.getStatus) resp.getOperationHandle } def getQueryId(operationHandle: TOperationHandle): TGetQueryIdResp = { val req = new TGetQueryIdReq(operationHandle) - val resp = withRetryingRequest(GetQueryId(req), "GetQueryId") + val resp = withLockAcquiredAsyncRequest(GetQueryId(req)) resp } def getOperationStatus(operationHandle: TOperationHandle): TGetOperationStatusResp = { val req = new TGetOperationStatusReq(operationHandle) - val resp = withRetryingRequest(GetOperationStatus(req), "GetOperationStatus") + val resp = withLockAcquiredAsyncRequest(GetOperationStatus(req)) resp } def cancelOperation(operationHandle: TOperationHandle): Unit = { val req = new TCancelOperationReq(operationHandle) - val resp = withRetryingRequest(CancelOperation(req), "CancelOperation") + val resp = withLockAcquiredAsyncRequest(CancelOperation(req)) if (resp.getStatus.getStatusCode == TStatusCode.SUCCESS_STATUS) { info(s"$req succeed on engine side") } else { @@ -330,7 +360,7 @@ class KyuubiSyncThriftClient private ( def closeOperation(operationHandle: TOperationHandle): Unit = { val req = new TCloseOperationReq(operationHandle) - val resp = withRetryingRequest(CloseOperation(req), "CloseOperation") + val resp = withLockAcquiredAsyncRequest(CloseOperation(req)) if (resp.getStatus.getStatusCode == TStatusCode.SUCCESS_STATUS) { info(s"$req succeed on engine side") } else { @@ -340,7 +370,7 @@ class KyuubiSyncThriftClient private ( def getResultSetMetadata(operationHandle: TOperationHandle): TTableSchema = { val req = new TGetResultSetMetadataReq(operationHandle) - val resp = withRetryingRequest(GetResultSetMetadata(req), "GetResultSetMetadata") + val resp = withLockAcquiredAsyncRequest(GetResultSetMetadata(req)) ThriftUtils.verifyTStatus(resp.getStatus) resp.getSchema } @@ -354,7 +384,7 @@ class KyuubiSyncThriftClient private ( val req = new TFetchResultsReq(operationHandle, or, maxRows) val fetchType = if (fetchLog) 1.toShort else 0.toShort req.setFetchType(fetchType) - val resp = withRetryingRequest(FetchResults(req), "FetchResults") + val resp = withLockAcquiredAsyncRequest(FetchResults(req)) ThriftUtils.verifyTStatus(resp.getStatus) resp.getResults } @@ -366,7 +396,7 @@ class KyuubiSyncThriftClient private ( req.setSessionHandle(_remoteSessionHandle) req.setDelegationToken(encodedCredentials) try { - val resp = withLockAcquired(RenewDelegationToken(req)) + val resp = withLockAcquiredAsyncRequest(RenewDelegationToken(req)) if (resp.getStatus.getStatusCode == TStatusCode.SUCCESS_STATUS) { debug(s"$req succeed on engine side") } else { @@ -376,10 +406,6 @@ class KyuubiSyncThriftClient private ( case e: Exception => warn(s"$req failed on engine side", e) } } - - def isConnectionValid(): Boolean = { - !remoteEngineBroken && protocol.getTransport.isOpen - } } private[kyuubi] object KyuubiSyncThriftClient extends Logging { @@ -437,21 +463,11 @@ private[kyuubi] object KyuubiSyncThriftClient extends Logging { val aliveProbeInterval = conf.get(KyuubiConf.ENGINE_ALIVE_PROBE_INTERVAL).toInt val aliveTimeout = conf.get(KyuubiConf.ENGINE_ALIVE_TIMEOUT) - val (tProtocol, _) = withRetryingRequestNoLock( - createTProtocol(user, passwd, host, port, 0, loginTimeout), - "CreatingTProtocol", - requestMaxAttempts, - false, - () => true) + val tProtocol = createTProtocol(user, passwd, host, port, 0, loginTimeout) val aliveProbeProtocol = if (aliveProbeEnabled) { - Option(withRetryingRequestNoLock( - createTProtocol(user, passwd, host, port, aliveProbeInterval, loginTimeout), - "CreatingTProtocol", - requestMaxAttempts, - false, - () => true)._1) + Option(createTProtocol(user, passwd, host, port, aliveProbeInterval, loginTimeout)) } else { None } diff --git a/kyuubi-server/src/test/scala/org/apache/kyuubi/operation/KyuubiOperationPerConnectionSuite.scala b/kyuubi-server/src/test/scala/org/apache/kyuubi/operation/KyuubiOperationPerConnectionSuite.scala index 4059eb084..566c24731 100644 --- a/kyuubi-server/src/test/scala/org/apache/kyuubi/operation/KyuubiOperationPerConnectionSuite.scala +++ b/kyuubi-server/src/test/scala/org/apache/kyuubi/operation/KyuubiOperationPerConnectionSuite.scala @@ -75,7 +75,7 @@ class KyuubiOperationPerConnectionSuite extends WithKyuubiServer with HiveJDBCTe assert(executeStmtResp.getStatus.getStatusCode === TStatusCode.ERROR_STATUS) assert(executeStmtResp.getOperationHandle === null) assert(executeStmtResp.getStatus.getErrorMessage contains - "Caused by: java.net.SocketException: Broken pipe (Write failed)") + "Caused by: java.net.SocketException: Connection reset") } } @@ -232,6 +232,8 @@ class KyuubiOperationPerConnectionSuite extends WithKyuubiServer with HiveJDBCTe val startTime = System.currentTimeMillis() val executeStmtResp = client.ExecuteStatement(executeStmtReq) assert(executeStmtResp.getStatus.getStatusCode === TStatusCode.ERROR_STATUS) + assert(executeStmtResp.getStatus.getErrorMessage contains + "Caused by: java.net.SocketException: Connection reset") val elapsedTime = System.currentTimeMillis() - startTime assert(elapsedTime > 3 * 1000 && elapsedTime < 20 * 1000) } diff --git a/kyuubi-server/src/test/scala/org/apache/kyuubi/operation/KyuubiOperationPerUserSuite.scala b/kyuubi-server/src/test/scala/org/apache/kyuubi/operation/KyuubiOperationPerUserSuite.scala index df4f1ac6c..bb72cd620 100644 --- a/kyuubi-server/src/test/scala/org/apache/kyuubi/operation/KyuubiOperationPerUserSuite.scala +++ b/kyuubi-server/src/test/scala/org/apache/kyuubi/operation/KyuubiOperationPerUserSuite.scala @@ -17,10 +17,12 @@ package org.apache.kyuubi.operation +import org.apache.hive.service.rpc.thrift.{TExecuteStatementReq, TStatusCode} import org.scalatest.time.SpanSugar._ import org.apache.kyuubi.{Utils, WithKyuubiServer} import org.apache.kyuubi.config.KyuubiConf +import org.apache.kyuubi.session.{KyuubiSessionImpl, KyuubiSessionManager, SessionHandle} class KyuubiOperationPerUserSuite extends WithKyuubiServer with SparkQueryTests { @@ -153,4 +155,50 @@ class KyuubiOperationPerUserSuite extends WithKyuubiServer with SparkQueryTests } } } + + test("support to interrupt the thrift request if remote engine is broken") { + if (!httpMode) { + withSessionConf(Map( + KyuubiConf.ENGINE_ALIVE_PROBE_ENABLED.key -> "true", + KyuubiConf.ENGINE_ALIVE_PROBE_INTERVAL.key -> "1000", + KyuubiConf.ENGINE_ALIVE_TIMEOUT.key -> "3000", + KyuubiConf.OPERATION_THRIFT_CLIENT_REQUEST_MAX_ATTEMPTS.key -> "10000"))(Map.empty)( + Map.empty) { + withSessionHandle { (client, handle) => + val preReq = new TExecuteStatementReq() + preReq.setStatement("select engine_name()") + preReq.setSessionHandle(handle) + preReq.setRunAsync(false) + client.ExecuteStatement(preReq) + + val sessionHandle = SessionHandle(handle) + val session = server.backendService.sessionManager.asInstanceOf[KyuubiSessionManager] + .getSession(sessionHandle).asInstanceOf[KyuubiSessionImpl] + session.client.getEngineAliveProbeProtocol.foreach(_.getTransport.close()) + + val exitReq = new TExecuteStatementReq() + exitReq.setStatement("SELECT java_method('java.lang.Thread', 'sleep', 1000L)," + + "java_method('java.lang.System', 'exit', 1)") + exitReq.setSessionHandle(handle) + exitReq.setRunAsync(true) + client.ExecuteStatement(exitReq) + + val executeStmtReq = new TExecuteStatementReq() + executeStmtReq.setStatement("SELECT java_method('java.lang.Thread', 'sleep', 30000l)") + executeStmtReq.setSessionHandle(handle) + executeStmtReq.setRunAsync(false) + val startTime = System.currentTimeMillis() + val executeStmtResp = client.ExecuteStatement(executeStmtReq) + assert(executeStmtResp.getStatus.getStatusCode === TStatusCode.ERROR_STATUS) + assert(executeStmtResp.getStatus.getErrorMessage.contains( + "java.net.SocketException: Connection reset") || + executeStmtResp.getStatus.getErrorMessage.contains( + "Caused by: java.net.SocketException: Broken pipe (Write failed)")) + val elapsedTime = System.currentTimeMillis() - startTime + assert(elapsedTime < 20 * 1000) + assert(session.client.asyncRequestInterrupted) + } + } + } + } }