diff --git a/externals/kyuubi-trino-engine/pom.xml b/externals/kyuubi-trino-engine/pom.xml index d51d2e054..a97fa9a3b 100644 --- a/externals/kyuubi-trino-engine/pom.xml +++ b/externals/kyuubi-trino-engine/pom.xml @@ -82,6 +82,13 @@ test + + org.apache.kyuubi + kyuubi-hive-jdbc-shaded + ${project.version} + test + + diff --git a/externals/kyuubi-trino-engine/src/test/scala/org/apache/kyuubi/engine/trino/TrinoContextSuite.scala b/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/TrinoBackendService.scala similarity index 64% rename from externals/kyuubi-trino-engine/src/test/scala/org/apache/kyuubi/engine/trino/TrinoContextSuite.scala rename to externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/TrinoBackendService.scala index dd9f6c545..6cc4141e1 100644 --- a/externals/kyuubi-trino-engine/src/test/scala/org/apache/kyuubi/engine/trino/TrinoContextSuite.scala +++ b/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/TrinoBackendService.scala @@ -17,16 +17,13 @@ package org.apache.kyuubi.engine.trino -class TrinoContextSuite extends WithTrinoContainerServer { +import org.apache.kyuubi.engine.trino.session.TrinoSessionManager +import org.apache.kyuubi.service.AbstractBackendService +import org.apache.kyuubi.session.SessionManager - test("set current schema") { - withTrinoContainer { trinoContext => - val trinoStatement = TrinoStatement(trinoContext, kyuubiConf, "select 1") - assert("tiny" === trinoStatement.getCurrentDatabase) +class TrinoBackendService + extends AbstractBackendService("TrinoBackendService") { + + override val sessionManager: SessionManager = new TrinoSessionManager() - trinoContext.setCurrentSchema("sf1") - val trinoStatement2 = TrinoStatement(trinoContext, kyuubiConf, "select 1") - assert("sf1" === trinoStatement2.getCurrentDatabase) - } - } } diff --git a/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/TrinoConf.scala b/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/TrinoConf.scala index 161a1d4cf..9441112d3 100644 --- a/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/TrinoConf.scala +++ b/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/TrinoConf.scala @@ -17,6 +17,8 @@ package org.apache.kyuubi.engine.trino +import java.time.Duration + import org.apache.kyuubi.config.ConfigBuilder import org.apache.kyuubi.config.ConfigEntry import org.apache.kyuubi.config.KyuubiConf @@ -30,4 +32,11 @@ object TrinoConf { .version("1.5.0") .intConf .createWithDefault(3) + + val CLIENT_REQUEST_TIMEOUT: ConfigEntry[Long] = + buildConf("trino.client.request.timeout") + .doc("Timeout for Trino client request to trino cluster") + .version("1.5.0") + .timeConf + .createWithDefault(Duration.ofMinutes(2).toMillis) } diff --git a/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/TrinoContext.scala b/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/TrinoContext.scala index cad7c97be..ed7a98fa3 100644 --- a/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/TrinoContext.scala +++ b/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/TrinoContext.scala @@ -22,19 +22,11 @@ import java.util.concurrent.atomic.AtomicReference import io.trino.client.ClientSession import okhttp3.OkHttpClient -class TrinoContext( - val httpClient: OkHttpClient, - val clientSession: AtomicReference[ClientSession]) { - - def getClientSession: ClientSession = clientSession.get - - def setCurrentSchema(schema: String): Unit = { - clientSession.set(ClientSession.builder(clientSession.get).withSchema(schema).build()) - } - -} +case class TrinoContext( + httpClient: OkHttpClient, + clientSession: AtomicReference[ClientSession]) object TrinoContext { def apply(httpClient: OkHttpClient, clientSession: ClientSession): TrinoContext = - new TrinoContext(httpClient, new AtomicReference(clientSession)) + TrinoContext(httpClient, new AtomicReference(clientSession)) } diff --git a/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/TrinoSqlEngine.scala b/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/TrinoSqlEngine.scala index 73ec32a62..d05bc641c 100644 --- a/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/TrinoSqlEngine.scala +++ b/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/TrinoSqlEngine.scala @@ -17,19 +17,74 @@ package org.apache.kyuubi.engine.trino +import java.util.concurrent.CountDownLatch + import org.apache.kyuubi.Logging +import org.apache.kyuubi.Utils.TRINO_ENGINE_SHUTDOWN_PRIORITY +import org.apache.kyuubi.Utils.addShutdownHook import org.apache.kyuubi.config.KyuubiConf +import org.apache.kyuubi.engine.trino.TrinoSqlEngine.countDownLatch +import org.apache.kyuubi.engine.trino.TrinoSqlEngine.currentEngine +import org.apache.kyuubi.ha.HighAvailabilityConf.HA_ZK_CONN_RETRY_POLICY +import org.apache.kyuubi.ha.client.RetryPolicies +import org.apache.kyuubi.service.Serverable import org.apache.kyuubi.util.SignalRegister +case class TrinoSqlEngine() + extends Serverable("TrinoSQLEngine") { + + override val backendService = new TrinoBackendService() + + override val frontendServices = Seq(new TrinoTBinaryFrontendService(this)) + + override def start(): Unit = { + super.start() + // Start engine self-terminating checker after all services are ready and it can be reached by + // all servers in engine spaces. + backendService.sessionManager.startTerminatingChecker(() => { + assert(currentEngine.isDefined) + currentEngine.get.stop() + }) + } + + override protected def stopServer(): Unit = { + countDownLatch.countDown() + } +} + object TrinoSqlEngine extends Logging { + private val countDownLatch = new CountDownLatch(1) val kyuubiConf: KyuubiConf = KyuubiConf() + var currentEngine: Option[TrinoSqlEngine] = None + + def startEngine(): Unit = { + currentEngine = Some(new TrinoSqlEngine()) + currentEngine.foreach { engine => + engine.initialize(kyuubiConf) + engine.start() + addShutdownHook(() => engine.stop(), TRINO_ENGINE_SHUTDOWN_PRIORITY + 1) + } + } + def main(args: Array[String]): Unit = { SignalRegister.registerLogger(logger) - // TODO start engine - warn("Trino engine under development...") - info(kyuubiConf.getAll) + try { + kyuubiConf.setIfMissing(KyuubiConf.FRONTEND_THRIFT_BINARY_BIND_PORT, 0) + kyuubiConf.setIfMissing(HA_ZK_CONN_RETRY_POLICY, RetryPolicies.N_TIME.toString) + + startEngine() + // blocking main thread + countDownLatch.await() + } catch { + case t: Throwable if currentEngine.isDefined => + currentEngine.foreach { engine => + error(t) + engine.stop() + } + case t: Throwable => error("Create Trino Engine Failed", t) + } } } diff --git a/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/TrinoStatement.scala b/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/TrinoStatement.scala index c13ae49bb..c1b2472f7 100644 --- a/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/TrinoStatement.scala +++ b/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/TrinoStatement.scala @@ -27,7 +27,6 @@ import scala.concurrent.ExecutionContext import scala.concurrent.Future import scala.concurrent.duration import scala.concurrent.duration.Duration -import scala.util.control.Breaks._ import com.google.common.base.Verify import io.trino.client.ClientSession @@ -46,7 +45,7 @@ import org.apache.kyuubi.engine.trino.TrinoStatement._ class TrinoStatement(trinoContext: TrinoContext, kyuubiConf: KyuubiConf, sql: String) { private lazy val trino = StatementClientFactory - .newStatementClient(trinoContext.httpClient, trinoContext.getClientSession, sql) + .newStatementClient(trinoContext.httpClient, trinoContext.clientSession.get, sql) private lazy val dataProcessingPoolSize = kyuubiConf.get(DATA_PROCESSING_POOL_SIZE) @@ -55,7 +54,7 @@ class TrinoStatement(trinoContext: TrinoContext, kyuubiConf: KyuubiConf, sql: St def getTrinoClient: StatementClient = trino - def getCurrentDatabase: String = trinoContext.getClientSession.getSchema + def getCurrentDatabase: String = trinoContext.clientSession.get.getSchema def getColumns: List[Column] = { while (trino.isRunning) { @@ -99,50 +98,44 @@ class TrinoStatement(trinoContext: TrinoContext, kyuubiConf: KyuubiConf, sql: St val rowBuffer = new ArrayList[List[Any]](MAX_BUFFERED_ROWS) var bufferStart = System.nanoTime() val result = ArrayBuffer[List[Any]]() - try { - breakable { - while (!dataProcessing.isCompleted) { - val atEnd = drainDetectingEnd(rowQueue, rowBuffer, MAX_BUFFERED_ROWS, END_TOKEN) - if (!atEnd) { - // Flush if needed - if (rowBuffer.size() >= MAX_BUFFERED_ROWS || - Duration.fromNanos(bufferStart).compareTo(MAX_BUFFER_TIME) >= 0) { - result ++= rowBuffer.asScala - rowBuffer.clear() - bufferStart = System.nanoTime() - } - val row = rowQueue.poll(MAX_BUFFER_TIME.toMillis, duration.MILLISECONDS) - row match { - case END_TOKEN => break - case null => - case _ => rowBuffer.add(row) - } - } + var getDataEnd = false + while (!dataProcessing.isCompleted && !getDataEnd) { + val atEnd = drainDetectingEnd(rowQueue, rowBuffer, MAX_BUFFERED_ROWS, END_TOKEN) + if (!atEnd) { + // Flush if needed + if (rowBuffer.size() >= MAX_BUFFERED_ROWS || + Duration.fromNanos(bufferStart).compareTo(MAX_BUFFER_TIME) >= 0) { + result ++= rowBuffer.asScala + rowBuffer.clear() + bufferStart = System.nanoTime() + } + + val row = rowQueue.poll(MAX_BUFFER_TIME.toMillis, duration.MILLISECONDS) + row match { + case END_TOKEN => getDataEnd = true + case null => + case _ => rowBuffer.add(row) } } - if (!rowQueue.isEmpty()) { - drainDetectingEnd(rowQueue, rowBuffer, Integer.MAX_VALUE, END_TOKEN) - } - result ++= rowBuffer.asScala - - val finalStatus = trino.finalStatusInfo() - if (finalStatus.getError() != null) { - val exception = KyuubiSQLException( - s"Query ${finalStatus.getId} failed: ${finalStatus.getError.getMessage}") - throw exception - } - - updateTrinoContext() - } catch { - case e: Exception => - throw KyuubiSQLException(e) } + if (!rowQueue.isEmpty()) { + drainDetectingEnd(rowQueue, rowBuffer, Integer.MAX_VALUE, END_TOKEN) + } + result ++= rowBuffer.asScala + + val finalStatus = trino.finalStatusInfo() + if (finalStatus.getError() != null) { + throw KyuubiSQLException( + s"Query ${finalStatus.getId} failed: ${finalStatus.getError.getMessage}") + } + updateTrinoContext() + result } def updateTrinoContext(): Unit = { - val session = trinoContext.getClientSession + val session = trinoContext.clientSession.get var builder = ClientSession.builder(session) // update catalog and schema diff --git a/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/operation/ExecuteStatement.scala b/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/operation/ExecuteStatement.scala new file mode 100644 index 000000000..e6ee4a1a5 --- /dev/null +++ b/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/operation/ExecuteStatement.scala @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.trino.operation + +import java.util.concurrent.RejectedExecutionException + +import org.apache.kyuubi.KyuubiSQLException +import org.apache.kyuubi.Logging +import org.apache.kyuubi.engine.trino.TrinoStatement +import org.apache.kyuubi.operation.ArrayFetchIterator +import org.apache.kyuubi.operation.IterableFetchIterator +import org.apache.kyuubi.operation.OperationState +import org.apache.kyuubi.operation.OperationType +import org.apache.kyuubi.operation.log.OperationLog +import org.apache.kyuubi.session.Session + +class ExecuteStatement( + session: Session, + override val statement: String, + override val shouldRunAsync: Boolean, + incrementalCollect: Boolean) + extends TrinoOperation(OperationType.EXECUTE_STATEMENT, session) with Logging { + + private val operationLog: OperationLog = OperationLog.createOperationLog(session, getHandle) + override def getOperationLog: Option[OperationLog] = Option(operationLog) + + override protected def beforeRun(): Unit = { + OperationLog.setCurrentOperationLog(operationLog) + setState(OperationState.PENDING) + setHasResultSet(true) + } + + override protected def afterRun(): Unit = { + OperationLog.removeCurrentOperationLog() + } + + override protected def runInternal(): Unit = { + val trinoStatement = TrinoStatement(trinoContext, session.sessionManager.getConf, statement) + trino = trinoStatement.getTrinoClient + if (shouldRunAsync) { + val asyncOperation = new Runnable { + override def run(): Unit = { + OperationLog.setCurrentOperationLog(operationLog) + executeStatement(trinoStatement) + } + } + + try { + val trinoSessionManager = session.sessionManager + val backgroundHandle = trinoSessionManager.submitBackgroundOperation(asyncOperation) + setBackgroundHandle(backgroundHandle) + } catch { + case rejected: RejectedExecutionException => + setState(OperationState.ERROR) + val ke = + KyuubiSQLException("Error submitting query in background, query rejected", rejected) + setOperationException(ke) + throw ke + } + } else { + executeStatement(trinoStatement) + } + } + + private def executeStatement(trinoStatement: TrinoStatement): Unit = { + setState(OperationState.RUNNING) + try { + schema = trinoStatement.getColumns + val resultSet = trinoStatement.execute() + iter = + if (incrementalCollect) { + info("Execute in incremental collect mode") + new IterableFetchIterator(resultSet) + } else { + info("Execute in full collect mode") + new ArrayFetchIterator(resultSet.toArray) + } + setState(OperationState.FINISHED) + } catch { + onError(cancel = true) + } + } +} diff --git a/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/operation/TrinoOperation.scala b/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/operation/TrinoOperation.scala new file mode 100644 index 000000000..298034b41 --- /dev/null +++ b/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/operation/TrinoOperation.scala @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.trino.operation + +import java.io.IOException + +import io.trino.client.Column +import io.trino.client.StatementClient +import org.apache.hive.service.rpc.thrift.TRowSet +import org.apache.hive.service.rpc.thrift.TTableSchema + +import org.apache.kyuubi.KyuubiSQLException +import org.apache.kyuubi.Utils +import org.apache.kyuubi.engine.trino.TrinoContext +import org.apache.kyuubi.engine.trino.schema.RowSet +import org.apache.kyuubi.engine.trino.schema.SchemaHelper +import org.apache.kyuubi.engine.trino.session.TrinoSessionImpl +import org.apache.kyuubi.operation.AbstractOperation +import org.apache.kyuubi.operation.FetchIterator +import org.apache.kyuubi.operation.FetchOrientation.FETCH_FIRST +import org.apache.kyuubi.operation.FetchOrientation.FETCH_NEXT +import org.apache.kyuubi.operation.FetchOrientation.FETCH_PRIOR +import org.apache.kyuubi.operation.FetchOrientation.FetchOrientation +import org.apache.kyuubi.operation.OperationState +import org.apache.kyuubi.operation.OperationState.OperationState +import org.apache.kyuubi.operation.OperationType.OperationType +import org.apache.kyuubi.operation.log.OperationLog +import org.apache.kyuubi.session.Session + +abstract class TrinoOperation(opType: OperationType, session: Session) + extends AbstractOperation(opType, session) { + + protected val trinoContext: TrinoContext = session.asInstanceOf[TrinoSessionImpl].trinoContext + + protected var trino: StatementClient = _ + + protected var schema: List[Column] = _ + + protected var iter: FetchIterator[List[Any]] = _ + + override def getResultSetSchema: TTableSchema = SchemaHelper.toTTableSchema(schema) + + override def getNextRowSet(order: FetchOrientation, rowSetSize: Int): TRowSet = { + validateDefaultFetchOrientation(order) + assertState(OperationState.FINISHED) + setHasResultSet(true) + order match { + case FETCH_NEXT => iter.fetchNext() + case FETCH_PRIOR => iter.fetchPrior(rowSetSize); + case FETCH_FIRST => iter.fetchAbsolute(0); + } + val taken = iter.take(rowSetSize) + val resultRowSet = RowSet.toTRowSet(taken.toList, schema, getProtocolVersion) + resultRowSet.setStartRowOffset(iter.getPosition) + resultRowSet + } + + override protected def beforeRun(): Unit = { + setHasResultSet(true) + setState(OperationState.RUNNING) + } + + override protected def afterRun(): Unit = { + state.synchronized { + if (!isTerminalState(state)) { + setState(OperationState.FINISHED) + } + } + OperationLog.removeCurrentOperationLog() + } + + override def cancel(): Unit = { + cleanup(OperationState.CANCELED) + } + + protected def cleanup(targetState: OperationState): Unit = state.synchronized { + if (!isTerminalState(state)) { + setState(targetState) + Option(getBackgroundHandle).foreach(_.cancel(true)) + } + } + + override def close(): Unit = { + cleanup(OperationState.CLOSED) + try { + trino.close() + getOperationLog.foreach(_.close()) + } catch { + case e: IOException => + error(e.getMessage, e) + } + } + + override def shouldRunAsync: Boolean = false + + protected def onError(cancel: Boolean = false): PartialFunction[Throwable, Unit] = { + // We should use Throwable instead of Exception since `java.lang.NoClassDefFoundError` + // could be thrown. + case e: Throwable => + if (cancel && trino.isRunning) trino.cancelLeafStage() + state.synchronized { + val errMsg = Utils.stringifyException(e) + if (state == OperationState.TIMEOUT) { + val ke = KyuubiSQLException(s"Timeout operating $opType: $errMsg") + setOperationException(ke) + throw ke + } else if (isTerminalState(state)) { + setOperationException(KyuubiSQLException(errMsg)) + warn(s"Ignore exception in terminal state with $statementId: $errMsg") + } else { + error(s"Error operating $opType: $errMsg", e) + val ke = KyuubiSQLException(s"Error operating $opType: $errMsg", e) + setOperationException(ke) + setState(OperationState.ERROR) + throw ke + } + } + } +} diff --git a/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/operation/TrinoOperationManager.scala b/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/operation/TrinoOperationManager.scala new file mode 100644 index 000000000..0f3a12d5b --- /dev/null +++ b/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/operation/TrinoOperationManager.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.trino.operation + +import java.util + +import org.apache.kyuubi.config.KyuubiConf.OPERATION_INCREMENTAL_COLLECT +import org.apache.kyuubi.operation.Operation +import org.apache.kyuubi.operation.OperationManager +import org.apache.kyuubi.session.Session + +class TrinoOperationManager extends OperationManager("TrinoOperationManager") { + + def newExecuteStatementOperation( + session: Session, + statement: String, + confOverlay: Map[String, String], + runAsync: Boolean, + queryTimeout: Long): Operation = { + val incrementalCollect = session.sessionManager.getConf.get(OPERATION_INCREMENTAL_COLLECT) + val operation = new ExecuteStatement(session, statement, runAsync, incrementalCollect) + addOperation(operation) + } + + override def newGetTypeInfoOperation(session: Session): Operation = null + + override def newGetCatalogsOperation(session: Session): Operation = null + + override def newGetSchemasOperation( + session: Session, + catalog: String, + schema: String): Operation = null + + override def newGetTablesOperation( + session: Session, + catalogName: String, + schemaName: String, + tableName: String, + tableTypes: util.List[String]): Operation = null + + override def newGetTableTypesOperation(session: Session): Operation = null + + override def newGetColumnsOperation( + session: Session, + catalogName: String, + schemaName: String, + tableName: String, + columnName: String): Operation = null + + override def newGetFunctionsOperation( + session: Session, + catalogName: String, + schemaName: String, + functionName: String): Operation = null +} diff --git a/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/session/TrinoSessionImpl.scala b/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/session/TrinoSessionImpl.scala new file mode 100644 index 000000000..1b68fc779 --- /dev/null +++ b/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/session/TrinoSessionImpl.scala @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.trino.session + +import java.net.URI +import java.time.ZoneId +import java.util.Collections +import java.util.Locale +import java.util.Optional +import java.util.concurrent.TimeUnit + +import io.airlift.units.Duration +import io.trino.client.ClientSession +import okhttp3.OkHttpClient +import org.apache.hive.service.rpc.thrift.TProtocolVersion + +import org.apache.kyuubi.KyuubiSQLException +import org.apache.kyuubi.Utils.currentUser +import org.apache.kyuubi.config.KyuubiConf +import org.apache.kyuubi.engine.trino.TrinoConf +import org.apache.kyuubi.engine.trino.TrinoContext +import org.apache.kyuubi.session.AbstractSession +import org.apache.kyuubi.session.SessionHandle +import org.apache.kyuubi.session.SessionManager + +class TrinoSessionImpl( + protocol: TProtocolVersion, + user: String, + password: String, + ipAddress: String, + conf: Map[String, String], + sessionManager: SessionManager) + extends AbstractSession(protocol, user, password, ipAddress, conf, sessionManager) { + + var trinoContext: TrinoContext = _ + private var clientSession: ClientSession = _ + + override val handle: SessionHandle = SessionHandle(protocol) + + override def open(): Unit = { + normalizedConf.foreach { + case ("use:database", database) => clientSession = createClientSession(database) + case _ => // do nothing + } + + val httpClient = new OkHttpClient.Builder().build() + + if (clientSession == null) { + clientSession = createClientSession() + } + trinoContext = TrinoContext(httpClient, clientSession) + + super.open() + } + + private def createClientSession(schema: String = null): ClientSession = { + val sessionConf = sessionManager.getConf + val connectionUrl = sessionConf.get(KyuubiConf.ENGINE_TRINO_CONNECTION_URL).getOrElse( + throw KyuubiSQLException("Trino server url can not be null!")) + val catalog = sessionConf.get(KyuubiConf.ENGINE_TRINO_CONNECTION_CATALOG).getOrElse( + throw KyuubiSQLException("Trino default catalog can not be null!")) + val user = sessionConf.getOption("kyuubi.trino.user").getOrElse(currentUser) + val clientRequestTimeout = sessionConf.get(TrinoConf.CLIENT_REQUEST_TIMEOUT) + + new ClientSession( + URI.create(connectionUrl), + user, + Optional.empty(), + "kyuubi", + Optional.empty(), + Collections.emptySet(), + null, + catalog, + schema, + null, + ZoneId.systemDefault(), + Locale.getDefault, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap(), + null, + new Duration(clientRequestTimeout, TimeUnit.MILLISECONDS), + true) + } +} diff --git a/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/session/TrinoSessionManager.scala b/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/session/TrinoSessionManager.scala new file mode 100644 index 000000000..83af04a03 --- /dev/null +++ b/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/session/TrinoSessionManager.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.trino.session + +import org.apache.hive.service.rpc.thrift.TProtocolVersion + +import org.apache.kyuubi.KyuubiSQLException +import org.apache.kyuubi.Utils +import org.apache.kyuubi.config.KyuubiConf +import org.apache.kyuubi.config.KyuubiConf.ENGINE_OPERATION_LOG_DIR_ROOT +import org.apache.kyuubi.config.KyuubiConf.ENGINE_SHARE_LEVEL +import org.apache.kyuubi.engine.ShareLevel +import org.apache.kyuubi.engine.trino.TrinoSqlEngine +import org.apache.kyuubi.engine.trino.operation.TrinoOperationManager +import org.apache.kyuubi.session.SessionHandle +import org.apache.kyuubi.session.SessionManager + +class TrinoSessionManager + extends SessionManager("TrinoSessionManager") { + + val operationManager = new TrinoOperationManager() + + override def initialize(conf: KyuubiConf): Unit = { + val absPath = Utils.getAbsolutePathFromWork(conf.get(ENGINE_OPERATION_LOG_DIR_ROOT)) + _operationLogRoot = Some(absPath.toAbsolutePath.toString) + super.initialize(conf) + } + + override def openSession( + protocol: TProtocolVersion, + user: String, + password: String, + ipAddress: String, + conf: Map[String, String]): SessionHandle = { + info(s"Opening session for $user@$ipAddress") + val sessionImpl = + new TrinoSessionImpl(protocol, user, password, ipAddress, conf, this) + + try { + val handle = sessionImpl.handle + sessionImpl.open() + setSession(handle, sessionImpl) + info(s"$user's trino session with $handle is opened, current opening sessions" + + s" $getOpenSessionCount") + handle + } catch { + case e: Exception => + sessionImpl.close() + throw KyuubiSQLException(e) + } + } + + override def closeSession(sessionHandle: SessionHandle): Unit = { + super.closeSession(sessionHandle) + if (conf.get(ENGINE_SHARE_LEVEL) == ShareLevel.CONNECTION.toString) { + info("Session stopped due to shared level is Connection.") + stopSession() + } + } + + private def stopSession(): Unit = { + TrinoSqlEngine.currentEngine.foreach(_.stop()) + } + + override protected def isServer: Boolean = false +} diff --git a/externals/kyuubi-trino-engine/src/test/scala/org/apache/kyuubi/engine/trino/WithTrinoEngine.scala b/externals/kyuubi-trino-engine/src/test/scala/org/apache/kyuubi/engine/trino/WithTrinoEngine.scala new file mode 100644 index 000000000..14fe3e13e --- /dev/null +++ b/externals/kyuubi-trino-engine/src/test/scala/org/apache/kyuubi/engine/trino/WithTrinoEngine.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.trino + +import org.apache.kyuubi.KyuubiFunSuite +import org.apache.kyuubi.config.KyuubiConf + +trait WithTrinoEngine extends KyuubiFunSuite with WithTrinoContainerServer { + + protected var engine: TrinoSqlEngine = _ + protected var connectionUrl: String = _ + + override val kyuubiConf: KyuubiConf = TrinoSqlEngine.kyuubiConf + + def withKyuubiConf: Map[String, String] + + override def beforeAll(): Unit = { + withContainers { trinoContainer => + val containerConnectionUrl = trinoContainer.jdbcUrl.replace("jdbc:trino", "http") + startTrinoEngine(containerConnectionUrl) + super.beforeAll() + } + } + + def startTrinoEngine(containerConnectionUrl: String): Unit = { + kyuubiConf.set(KyuubiConf.ENGINE_TRINO_CONNECTION_URL, containerConnectionUrl) + + withKyuubiConf.foreach { case (k, v) => + System.setProperty(k, v) + kyuubiConf.set(k, v) + } + + TrinoSqlEngine.startEngine() + engine = TrinoSqlEngine.currentEngine.get + connectionUrl = engine.frontendServices.head.connectionUrl + } + + override def afterAll(): Unit = { + super.afterAll() + stopTrinoEngine() + } + + def stopTrinoEngine(): Unit = { + if (engine != null) { + engine.stop() + engine = null + } + } + + protected def getJdbcUrl: String = s"jdbc:hive2://$connectionUrl/$schema;" +} diff --git a/externals/kyuubi-trino-engine/src/test/scala/org/apache/kyuubi/engine/trino/operation/TrinoOperationSuite.scala b/externals/kyuubi-trino-engine/src/test/scala/org/apache/kyuubi/engine/trino/operation/TrinoOperationSuite.scala new file mode 100644 index 000000000..4e2289f8e --- /dev/null +++ b/externals/kyuubi-trino-engine/src/test/scala/org/apache/kyuubi/engine/trino/operation/TrinoOperationSuite.scala @@ -0,0 +1,216 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.trino.operation + +import scala.collection.JavaConverters._ + +import org.apache.hive.service.rpc.thrift.TCancelOperationReq +import org.apache.hive.service.rpc.thrift.TCloseOperationReq +import org.apache.hive.service.rpc.thrift.TCloseSessionReq +import org.apache.hive.service.rpc.thrift.TExecuteStatementReq +import org.apache.hive.service.rpc.thrift.TFetchOrientation +import org.apache.hive.service.rpc.thrift.TFetchResultsReq +import org.apache.hive.service.rpc.thrift.TGetOperationStatusReq +import org.apache.hive.service.rpc.thrift.TOpenSessionReq +import org.apache.hive.service.rpc.thrift.TOperationState +import org.apache.hive.service.rpc.thrift.TStatusCode + +import org.apache.kyuubi.config.KyuubiConf.ENGINE_TRINO_CONNECTION_CATALOG +import org.apache.kyuubi.engine.trino.WithTrinoEngine +import org.apache.kyuubi.operation.HiveJDBCTestHelper + +class TrinoOperationSuite extends WithTrinoEngine with HiveJDBCTestHelper { + override def withKyuubiConf: Map[String, String] = Map( + ENGINE_TRINO_CONNECTION_CATALOG.key -> "memory") + + // use default schema, do not set to 'default', since withSessionHandle strip suffix '/;' + override protected val schema = "" + + override protected def jdbcUrl: String = getJdbcUrl + + test("execute statement - select decimal") { + withJdbcStatement() { statement => + val resultSet = statement.executeQuery("SELECT DECIMAL '1.2' as col1, DECIMAL '1.23' AS col2") + assert(resultSet.next()) + assert(resultSet.getBigDecimal("col1") === new java.math.BigDecimal("1.2")) + assert(resultSet.getBigDecimal("col2") === new java.math.BigDecimal("1.23")) + val metaData = resultSet.getMetaData + assert(metaData.getColumnType(1) === java.sql.Types.DECIMAL) + assert(metaData.getColumnType(2) === java.sql.Types.DECIMAL) + assert(metaData.getPrecision(1) == 2) + assert(metaData.getPrecision(2) == 3) + assert(metaData.getScale(1) == 1) + assert(metaData.getScale(2) == 2) + } + } + + test("test fetch orientation") { + val sql = "SELECT id FROM (VALUES 0, 1) as t(id)" + + withSessionHandle { (client, handle) => + val req = new TExecuteStatementReq() + req.setSessionHandle(handle) + req.setStatement(sql) + val tExecuteStatementResp = client.ExecuteStatement(req) + val opHandle = tExecuteStatementResp.getOperationHandle + waitForOperationToComplete(client, opHandle) + + // fetch next from before first row + val tFetchResultsReq1 = new TFetchResultsReq(opHandle, TFetchOrientation.FETCH_NEXT, 1) + val tFetchResultsResp1 = client.FetchResults(tFetchResultsReq1) + assert(tFetchResultsResp1.getStatus.getStatusCode === TStatusCode.SUCCESS_STATUS) + val idSeq1 = tFetchResultsResp1.getResults.getColumns.get(0).getI32Val.getValues.asScala.toSeq + assertResult(Seq(0L))(idSeq1) + + // fetch next from first row + val tFetchResultsReq2 = new TFetchResultsReq(opHandle, TFetchOrientation.FETCH_NEXT, 1) + val tFetchResultsResp2 = client.FetchResults(tFetchResultsReq2) + assert(tFetchResultsResp2.getStatus.getStatusCode === TStatusCode.SUCCESS_STATUS) + val idSeq2 = tFetchResultsResp2.getResults.getColumns.get(0).getI32Val.getValues.asScala.toSeq + assertResult(Seq(1L))(idSeq2) + + // fetch prior from second row, expected got first row + val tFetchResultsReq3 = new TFetchResultsReq(opHandle, TFetchOrientation.FETCH_PRIOR, 1) + val tFetchResultsResp3 = client.FetchResults(tFetchResultsReq3) + assert(tFetchResultsResp3.getStatus.getStatusCode === TStatusCode.SUCCESS_STATUS) + val idSeq3 = tFetchResultsResp3.getResults.getColumns.get(0).getI32Val.getValues.asScala.toSeq + assertResult(Seq(0L))(idSeq3) + + // fetch first + val tFetchResultsReq4 = new TFetchResultsReq(opHandle, TFetchOrientation.FETCH_FIRST, 3) + val tFetchResultsResp4 = client.FetchResults(tFetchResultsReq4) + assert(tFetchResultsResp4.getStatus.getStatusCode === TStatusCode.SUCCESS_STATUS) + val idSeq4 = tFetchResultsResp4.getResults.getColumns.get(0).getI32Val.getValues.asScala.toSeq + assertResult(Seq(0L, 1L))(idSeq4) + } + } + + test("get operation status") { + val sql = "select date '2011-11-11' - interval '1' day" + + withSessionHandle { (client, handle) => + val req = new TExecuteStatementReq() + req.setSessionHandle(handle) + req.setStatement(sql) + val tExecuteStatementResp = client.ExecuteStatement(req) + val opHandle = tExecuteStatementResp.getOperationHandle + val tGetOperationStatusReq = new TGetOperationStatusReq() + tGetOperationStatusReq.setOperationHandle(opHandle) + val resp = client.GetOperationStatus(tGetOperationStatusReq) + val status = resp.getStatus + assert(status.getStatusCode === TStatusCode.SUCCESS_STATUS) + assert(resp.getOperationState === TOperationState.FINISHED_STATE) + assert(resp.isHasResultSet) + } + } + + test("basic open | execute | close") { + withThriftClient { client => + val req = new TOpenSessionReq() + req.setUsername("hongdd") + req.setPassword("anonymous") + val tOpenSessionResp = client.OpenSession(req) + + val tExecuteStatementReq = new TExecuteStatementReq() + tExecuteStatementReq.setSessionHandle(tOpenSessionResp.getSessionHandle) + tExecuteStatementReq.setRunAsync(true) + tExecuteStatementReq.setStatement("show session") + val tExecuteStatementResp = client.ExecuteStatement(tExecuteStatementReq) + + val operationHandle = tExecuteStatementResp.getOperationHandle + waitForOperationToComplete(client, operationHandle) + val tFetchResultsReq = new TFetchResultsReq() + tFetchResultsReq.setOperationHandle(operationHandle) + tFetchResultsReq.setFetchType(1) + tFetchResultsReq.setMaxRows(1000) + val tFetchResultsResp = client.FetchResults(tFetchResultsReq) + val logs = tFetchResultsResp.getResults.getColumns.get(0).getStringVal.getValues.asScala + assert(logs.exists(_.contains(classOf[ExecuteStatement].getCanonicalName))) + + tFetchResultsReq.setFetchType(0) + val tFetchResultsResp1 = client.FetchResults(tFetchResultsReq) + val rs = tFetchResultsResp1.getResults.getColumns.get(0).getStringVal.getValues.asScala + assert(rs.contains("aggregation_operator_unspill_memory_limit")) + + val tCloseSessionReq = new TCloseSessionReq() + tCloseSessionReq.setSessionHandle(tOpenSessionResp.getSessionHandle) + val tCloseSessionResp = client.CloseSession(tCloseSessionReq) + assert(tCloseSessionResp.getStatus.getStatusCode === TStatusCode.SUCCESS_STATUS) + } + } + + test("not allow to operate closed session or operation") { + withThriftClient { client => + val req = new TOpenSessionReq() + req.setUsername("hongdd") + req.setPassword("anonymous") + val tOpenSessionResp = client.OpenSession(req) + + val tExecuteStatementReq = new TExecuteStatementReq() + tExecuteStatementReq.setSessionHandle(tOpenSessionResp.getSessionHandle) + tExecuteStatementReq.setStatement("show session") + val tExecuteStatementResp = client.ExecuteStatement(tExecuteStatementReq) + + val tCloseOperationReq = new TCloseOperationReq(tExecuteStatementResp.getOperationHandle) + val tCloseOperationResp = client.CloseOperation(tCloseOperationReq) + assert(tCloseOperationResp.getStatus.getStatusCode === TStatusCode.SUCCESS_STATUS) + + val tFetchResultsReq = new TFetchResultsReq() + tFetchResultsReq.setOperationHandle(tExecuteStatementResp.getOperationHandle) + tFetchResultsReq.setFetchType(0) + tFetchResultsReq.setMaxRows(1000) + val tFetchResultsResp = client.FetchResults(tFetchResultsReq) + assert(tFetchResultsResp.getStatus.getStatusCode === TStatusCode.ERROR_STATUS) + assert(tFetchResultsResp.getStatus.getErrorMessage startsWith "Invalid OperationHandle" + + " [type=EXECUTE_STATEMENT, identifier:") + + val tCloseSessionReq = new TCloseSessionReq() + tCloseSessionReq.setSessionHandle(tOpenSessionResp.getSessionHandle) + val tCloseSessionResp = client.CloseSession(tCloseSessionReq) + assert(tCloseSessionResp.getStatus.getStatusCode === TStatusCode.SUCCESS_STATUS) + val tExecuteStatementResp1 = client.ExecuteStatement(tExecuteStatementReq) + + val status = tExecuteStatementResp1.getStatus + assert(status.getStatusCode === TStatusCode.ERROR_STATUS) + assert(status.getErrorMessage startsWith s"Invalid SessionHandle [") + } + } + + test("cancel operation") { + withThriftClient { client => + val req = new TOpenSessionReq() + req.setUsername("hongdd") + req.setPassword("anonymous") + val tOpenSessionResp = client.OpenSession(req) + + val tExecuteStatementReq = new TExecuteStatementReq() + tExecuteStatementReq.setSessionHandle(tOpenSessionResp.getSessionHandle) + tExecuteStatementReq.setStatement("show session") + val tExecuteStatementResp = client.ExecuteStatement(tExecuteStatementReq) + val tCancelOperationReq = new TCancelOperationReq(tExecuteStatementResp.getOperationHandle) + val tCancelOperationResp = client.CancelOperation(tCancelOperationReq) + assert(tCancelOperationResp.getStatus.getStatusCode === TStatusCode.SUCCESS_STATUS) + val tFetchResultsReq = new TFetchResultsReq() + tFetchResultsReq.setOperationHandle(tExecuteStatementResp.getOperationHandle) + tFetchResultsReq.setFetchType(0) + tFetchResultsReq.setMaxRows(1000) + val tFetchResultsResp = client.FetchResults(tFetchResultsReq) + assert(tFetchResultsResp.getStatus.getStatusCode === TStatusCode.SUCCESS_STATUS) + } + } +} diff --git a/externals/kyuubi-trino-engine/src/test/scala/org/apache/kyuubi/engine/trino/session/SessionSuite.scala b/externals/kyuubi-trino-engine/src/test/scala/org/apache/kyuubi/engine/trino/session/SessionSuite.scala new file mode 100644 index 000000000..f7637efd8 --- /dev/null +++ b/externals/kyuubi-trino-engine/src/test/scala/org/apache/kyuubi/engine/trino/session/SessionSuite.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.trino.session + +import org.apache.kyuubi.config.KyuubiConf.ENGINE_SHARE_LEVEL +import org.apache.kyuubi.config.KyuubiConf.ENGINE_TRINO_CONNECTION_CATALOG +import org.apache.kyuubi.engine.trino.WithTrinoEngine +import org.apache.kyuubi.operation.HiveJDBCTestHelper + +class SessionSuite extends WithTrinoEngine with HiveJDBCTestHelper { + override def withKyuubiConf: Map[String, String] = Map( + ENGINE_TRINO_CONNECTION_CATALOG.key -> "memory", + ENGINE_SHARE_LEVEL.key -> "SERVER") + + override protected val schema = "default" + + override protected def jdbcUrl: String = getJdbcUrl + + test("test session") { + withJdbcStatement() { statement => + statement.executeQuery("create or replace view temp_view as select 1 as id") + val resultSet = statement.executeQuery("select * from temp_view") + assert(resultSet.next()) + } + } + +} diff --git a/kyuubi-common/src/main/scala/org/apache/kyuubi/Utils.scala b/kyuubi-common/src/main/scala/org/apache/kyuubi/Utils.scala index 018c59c02..938219a83 100644 --- a/kyuubi-common/src/main/scala/org/apache/kyuubi/Utils.scala +++ b/kyuubi-common/src/main/scala/org/apache/kyuubi/Utils.scala @@ -203,6 +203,7 @@ object Utils extends Logging { // Hooks need to be invoked before the SparkContext stopped shall use a higher priority. val SPARK_CONTEXT_SHUTDOWN_PRIORITY = 50 val FLINK_ENGINE_SHUTDOWN_PRIORITY = 50 + val TRINO_ENGINE_SHUTDOWN_PRIORITY = 50 /** * Add some operations that you want into ShutdownHook