[KYUUBI #1821] Add trino ExecuteStatement

<!--
Thanks for sending a pull request!

Here are some tips for you:
  1. If this is your first time, please read our contributor guidelines: https://kyuubi.readthedocs.io/en/latest/community/contributions.html
  2. If the PR is related to an issue in https://github.com/apache/incubator-kyuubi/issues, add '[KYUUBI #XXXX]' in your PR title, e.g., '[KYUUBI #XXXX] Your PR title ...'.
  3. If the PR is unfinished, add '[WIP]' in your PR title, e.g., '[WIP][KYUUBI #XXXX] Your PR title ...'.
-->

### _Why are the changes needed?_
<!--
Please clarify why the changes are needed. For instance,
  1. If you add a feature, you can talk about the use case of it.
  2. If you fix a bug, you can clarify why it is a bug.
-->
Add trino ExecuteStatement

### _How was this patch tested?_
- [X] Add some test cases that check the changes thoroughly including negative and positive cases if possible

- [ ] Add screenshots for manual tests if appropriate

- [X] [Run test](https://kyuubi.apache.org/docs/latest/develop_tools/testing.html#running-tests) locally before make a pull request

Closes #1830 from hddong/add-operation.

Closes #1821

067bde7a [hongdongdong] use flag instead breakable
f4d6cbb9 [hongdongdong] fix
351e2bc9 [hongdongdong] move context to impl
69d7d9b2 [hongdongdong] fix wrong func name
9cb757a9 [hongdongdong] fix
a20f2d0f [hongdongdong] fix time unit
c5072dbf [hongdongdong] [KYUUBI #1821] Add trino ExecuteStatement

Authored-by: hongdongdong <hongdongdong@cmss.chinamobile.com>
Signed-off-by: hongdongdong <hongdongdong@cmss.chinamobile.com>
This commit is contained in:
hongdongdong 2022-02-10 10:07:57 +08:00
parent 33eda2159f
commit b952b7b5d8
15 changed files with 927 additions and 64 deletions

View File

@ -82,6 +82,13 @@
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.apache.kyuubi</groupId>
<artifactId>kyuubi-hive-jdbc-shaded</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
<build>

View File

@ -17,16 +17,13 @@
package org.apache.kyuubi.engine.trino
class TrinoContextSuite extends WithTrinoContainerServer {
import org.apache.kyuubi.engine.trino.session.TrinoSessionManager
import org.apache.kyuubi.service.AbstractBackendService
import org.apache.kyuubi.session.SessionManager
test("set current schema") {
withTrinoContainer { trinoContext =>
val trinoStatement = TrinoStatement(trinoContext, kyuubiConf, "select 1")
assert("tiny" === trinoStatement.getCurrentDatabase)
class TrinoBackendService
extends AbstractBackendService("TrinoBackendService") {
override val sessionManager: SessionManager = new TrinoSessionManager()
trinoContext.setCurrentSchema("sf1")
val trinoStatement2 = TrinoStatement(trinoContext, kyuubiConf, "select 1")
assert("sf1" === trinoStatement2.getCurrentDatabase)
}
}
}

View File

@ -17,6 +17,8 @@
package org.apache.kyuubi.engine.trino
import java.time.Duration
import org.apache.kyuubi.config.ConfigBuilder
import org.apache.kyuubi.config.ConfigEntry
import org.apache.kyuubi.config.KyuubiConf
@ -30,4 +32,11 @@ object TrinoConf {
.version("1.5.0")
.intConf
.createWithDefault(3)
val CLIENT_REQUEST_TIMEOUT: ConfigEntry[Long] =
buildConf("trino.client.request.timeout")
.doc("Timeout for Trino client request to trino cluster")
.version("1.5.0")
.timeConf
.createWithDefault(Duration.ofMinutes(2).toMillis)
}

View File

@ -22,19 +22,11 @@ import java.util.concurrent.atomic.AtomicReference
import io.trino.client.ClientSession
import okhttp3.OkHttpClient
class TrinoContext(
val httpClient: OkHttpClient,
val clientSession: AtomicReference[ClientSession]) {
def getClientSession: ClientSession = clientSession.get
def setCurrentSchema(schema: String): Unit = {
clientSession.set(ClientSession.builder(clientSession.get).withSchema(schema).build())
}
}
case class TrinoContext(
httpClient: OkHttpClient,
clientSession: AtomicReference[ClientSession])
object TrinoContext {
def apply(httpClient: OkHttpClient, clientSession: ClientSession): TrinoContext =
new TrinoContext(httpClient, new AtomicReference(clientSession))
TrinoContext(httpClient, new AtomicReference(clientSession))
}

View File

@ -17,19 +17,74 @@
package org.apache.kyuubi.engine.trino
import java.util.concurrent.CountDownLatch
import org.apache.kyuubi.Logging
import org.apache.kyuubi.Utils.TRINO_ENGINE_SHUTDOWN_PRIORITY
import org.apache.kyuubi.Utils.addShutdownHook
import org.apache.kyuubi.config.KyuubiConf
import org.apache.kyuubi.engine.trino.TrinoSqlEngine.countDownLatch
import org.apache.kyuubi.engine.trino.TrinoSqlEngine.currentEngine
import org.apache.kyuubi.ha.HighAvailabilityConf.HA_ZK_CONN_RETRY_POLICY
import org.apache.kyuubi.ha.client.RetryPolicies
import org.apache.kyuubi.service.Serverable
import org.apache.kyuubi.util.SignalRegister
case class TrinoSqlEngine()
extends Serverable("TrinoSQLEngine") {
override val backendService = new TrinoBackendService()
override val frontendServices = Seq(new TrinoTBinaryFrontendService(this))
override def start(): Unit = {
super.start()
// Start engine self-terminating checker after all services are ready and it can be reached by
// all servers in engine spaces.
backendService.sessionManager.startTerminatingChecker(() => {
assert(currentEngine.isDefined)
currentEngine.get.stop()
})
}
override protected def stopServer(): Unit = {
countDownLatch.countDown()
}
}
object TrinoSqlEngine extends Logging {
private val countDownLatch = new CountDownLatch(1)
val kyuubiConf: KyuubiConf = KyuubiConf()
var currentEngine: Option[TrinoSqlEngine] = None
def startEngine(): Unit = {
currentEngine = Some(new TrinoSqlEngine())
currentEngine.foreach { engine =>
engine.initialize(kyuubiConf)
engine.start()
addShutdownHook(() => engine.stop(), TRINO_ENGINE_SHUTDOWN_PRIORITY + 1)
}
}
def main(args: Array[String]): Unit = {
SignalRegister.registerLogger(logger)
// TODO start engine
warn("Trino engine under development...")
info(kyuubiConf.getAll)
try {
kyuubiConf.setIfMissing(KyuubiConf.FRONTEND_THRIFT_BINARY_BIND_PORT, 0)
kyuubiConf.setIfMissing(HA_ZK_CONN_RETRY_POLICY, RetryPolicies.N_TIME.toString)
startEngine()
// blocking main thread
countDownLatch.await()
} catch {
case t: Throwable if currentEngine.isDefined =>
currentEngine.foreach { engine =>
error(t)
engine.stop()
}
case t: Throwable => error("Create Trino Engine Failed", t)
}
}
}

View File

@ -27,7 +27,6 @@ import scala.concurrent.ExecutionContext
import scala.concurrent.Future
import scala.concurrent.duration
import scala.concurrent.duration.Duration
import scala.util.control.Breaks._
import com.google.common.base.Verify
import io.trino.client.ClientSession
@ -46,7 +45,7 @@ import org.apache.kyuubi.engine.trino.TrinoStatement._
class TrinoStatement(trinoContext: TrinoContext, kyuubiConf: KyuubiConf, sql: String) {
private lazy val trino = StatementClientFactory
.newStatementClient(trinoContext.httpClient, trinoContext.getClientSession, sql)
.newStatementClient(trinoContext.httpClient, trinoContext.clientSession.get, sql)
private lazy val dataProcessingPoolSize = kyuubiConf.get(DATA_PROCESSING_POOL_SIZE)
@ -55,7 +54,7 @@ class TrinoStatement(trinoContext: TrinoContext, kyuubiConf: KyuubiConf, sql: St
def getTrinoClient: StatementClient = trino
def getCurrentDatabase: String = trinoContext.getClientSession.getSchema
def getCurrentDatabase: String = trinoContext.clientSession.get.getSchema
def getColumns: List[Column] = {
while (trino.isRunning) {
@ -99,50 +98,44 @@ class TrinoStatement(trinoContext: TrinoContext, kyuubiConf: KyuubiConf, sql: St
val rowBuffer = new ArrayList[List[Any]](MAX_BUFFERED_ROWS)
var bufferStart = System.nanoTime()
val result = ArrayBuffer[List[Any]]()
try {
breakable {
while (!dataProcessing.isCompleted) {
val atEnd = drainDetectingEnd(rowQueue, rowBuffer, MAX_BUFFERED_ROWS, END_TOKEN)
if (!atEnd) {
// Flush if needed
if (rowBuffer.size() >= MAX_BUFFERED_ROWS ||
Duration.fromNanos(bufferStart).compareTo(MAX_BUFFER_TIME) >= 0) {
result ++= rowBuffer.asScala
rowBuffer.clear()
bufferStart = System.nanoTime()
}
val row = rowQueue.poll(MAX_BUFFER_TIME.toMillis, duration.MILLISECONDS)
row match {
case END_TOKEN => break
case null =>
case _ => rowBuffer.add(row)
}
}
var getDataEnd = false
while (!dataProcessing.isCompleted && !getDataEnd) {
val atEnd = drainDetectingEnd(rowQueue, rowBuffer, MAX_BUFFERED_ROWS, END_TOKEN)
if (!atEnd) {
// Flush if needed
if (rowBuffer.size() >= MAX_BUFFERED_ROWS ||
Duration.fromNanos(bufferStart).compareTo(MAX_BUFFER_TIME) >= 0) {
result ++= rowBuffer.asScala
rowBuffer.clear()
bufferStart = System.nanoTime()
}
val row = rowQueue.poll(MAX_BUFFER_TIME.toMillis, duration.MILLISECONDS)
row match {
case END_TOKEN => getDataEnd = true
case null =>
case _ => rowBuffer.add(row)
}
}
if (!rowQueue.isEmpty()) {
drainDetectingEnd(rowQueue, rowBuffer, Integer.MAX_VALUE, END_TOKEN)
}
result ++= rowBuffer.asScala
val finalStatus = trino.finalStatusInfo()
if (finalStatus.getError() != null) {
val exception = KyuubiSQLException(
s"Query ${finalStatus.getId} failed: ${finalStatus.getError.getMessage}")
throw exception
}
updateTrinoContext()
} catch {
case e: Exception =>
throw KyuubiSQLException(e)
}
if (!rowQueue.isEmpty()) {
drainDetectingEnd(rowQueue, rowBuffer, Integer.MAX_VALUE, END_TOKEN)
}
result ++= rowBuffer.asScala
val finalStatus = trino.finalStatusInfo()
if (finalStatus.getError() != null) {
throw KyuubiSQLException(
s"Query ${finalStatus.getId} failed: ${finalStatus.getError.getMessage}")
}
updateTrinoContext()
result
}
def updateTrinoContext(): Unit = {
val session = trinoContext.getClientSession
val session = trinoContext.clientSession.get
var builder = ClientSession.builder(session)
// update catalog and schema

View File

@ -0,0 +1,98 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kyuubi.engine.trino.operation
import java.util.concurrent.RejectedExecutionException
import org.apache.kyuubi.KyuubiSQLException
import org.apache.kyuubi.Logging
import org.apache.kyuubi.engine.trino.TrinoStatement
import org.apache.kyuubi.operation.ArrayFetchIterator
import org.apache.kyuubi.operation.IterableFetchIterator
import org.apache.kyuubi.operation.OperationState
import org.apache.kyuubi.operation.OperationType
import org.apache.kyuubi.operation.log.OperationLog
import org.apache.kyuubi.session.Session
class ExecuteStatement(
session: Session,
override val statement: String,
override val shouldRunAsync: Boolean,
incrementalCollect: Boolean)
extends TrinoOperation(OperationType.EXECUTE_STATEMENT, session) with Logging {
private val operationLog: OperationLog = OperationLog.createOperationLog(session, getHandle)
override def getOperationLog: Option[OperationLog] = Option(operationLog)
override protected def beforeRun(): Unit = {
OperationLog.setCurrentOperationLog(operationLog)
setState(OperationState.PENDING)
setHasResultSet(true)
}
override protected def afterRun(): Unit = {
OperationLog.removeCurrentOperationLog()
}
override protected def runInternal(): Unit = {
val trinoStatement = TrinoStatement(trinoContext, session.sessionManager.getConf, statement)
trino = trinoStatement.getTrinoClient
if (shouldRunAsync) {
val asyncOperation = new Runnable {
override def run(): Unit = {
OperationLog.setCurrentOperationLog(operationLog)
executeStatement(trinoStatement)
}
}
try {
val trinoSessionManager = session.sessionManager
val backgroundHandle = trinoSessionManager.submitBackgroundOperation(asyncOperation)
setBackgroundHandle(backgroundHandle)
} catch {
case rejected: RejectedExecutionException =>
setState(OperationState.ERROR)
val ke =
KyuubiSQLException("Error submitting query in background, query rejected", rejected)
setOperationException(ke)
throw ke
}
} else {
executeStatement(trinoStatement)
}
}
private def executeStatement(trinoStatement: TrinoStatement): Unit = {
setState(OperationState.RUNNING)
try {
schema = trinoStatement.getColumns
val resultSet = trinoStatement.execute()
iter =
if (incrementalCollect) {
info("Execute in incremental collect mode")
new IterableFetchIterator(resultSet)
} else {
info("Execute in full collect mode")
new ArrayFetchIterator(resultSet.toArray)
}
setState(OperationState.FINISHED)
} catch {
onError(cancel = true)
}
}
}

View File

@ -0,0 +1,134 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kyuubi.engine.trino.operation
import java.io.IOException
import io.trino.client.Column
import io.trino.client.StatementClient
import org.apache.hive.service.rpc.thrift.TRowSet
import org.apache.hive.service.rpc.thrift.TTableSchema
import org.apache.kyuubi.KyuubiSQLException
import org.apache.kyuubi.Utils
import org.apache.kyuubi.engine.trino.TrinoContext
import org.apache.kyuubi.engine.trino.schema.RowSet
import org.apache.kyuubi.engine.trino.schema.SchemaHelper
import org.apache.kyuubi.engine.trino.session.TrinoSessionImpl
import org.apache.kyuubi.operation.AbstractOperation
import org.apache.kyuubi.operation.FetchIterator
import org.apache.kyuubi.operation.FetchOrientation.FETCH_FIRST
import org.apache.kyuubi.operation.FetchOrientation.FETCH_NEXT
import org.apache.kyuubi.operation.FetchOrientation.FETCH_PRIOR
import org.apache.kyuubi.operation.FetchOrientation.FetchOrientation
import org.apache.kyuubi.operation.OperationState
import org.apache.kyuubi.operation.OperationState.OperationState
import org.apache.kyuubi.operation.OperationType.OperationType
import org.apache.kyuubi.operation.log.OperationLog
import org.apache.kyuubi.session.Session
abstract class TrinoOperation(opType: OperationType, session: Session)
extends AbstractOperation(opType, session) {
protected val trinoContext: TrinoContext = session.asInstanceOf[TrinoSessionImpl].trinoContext
protected var trino: StatementClient = _
protected var schema: List[Column] = _
protected var iter: FetchIterator[List[Any]] = _
override def getResultSetSchema: TTableSchema = SchemaHelper.toTTableSchema(schema)
override def getNextRowSet(order: FetchOrientation, rowSetSize: Int): TRowSet = {
validateDefaultFetchOrientation(order)
assertState(OperationState.FINISHED)
setHasResultSet(true)
order match {
case FETCH_NEXT => iter.fetchNext()
case FETCH_PRIOR => iter.fetchPrior(rowSetSize);
case FETCH_FIRST => iter.fetchAbsolute(0);
}
val taken = iter.take(rowSetSize)
val resultRowSet = RowSet.toTRowSet(taken.toList, schema, getProtocolVersion)
resultRowSet.setStartRowOffset(iter.getPosition)
resultRowSet
}
override protected def beforeRun(): Unit = {
setHasResultSet(true)
setState(OperationState.RUNNING)
}
override protected def afterRun(): Unit = {
state.synchronized {
if (!isTerminalState(state)) {
setState(OperationState.FINISHED)
}
}
OperationLog.removeCurrentOperationLog()
}
override def cancel(): Unit = {
cleanup(OperationState.CANCELED)
}
protected def cleanup(targetState: OperationState): Unit = state.synchronized {
if (!isTerminalState(state)) {
setState(targetState)
Option(getBackgroundHandle).foreach(_.cancel(true))
}
}
override def close(): Unit = {
cleanup(OperationState.CLOSED)
try {
trino.close()
getOperationLog.foreach(_.close())
} catch {
case e: IOException =>
error(e.getMessage, e)
}
}
override def shouldRunAsync: Boolean = false
protected def onError(cancel: Boolean = false): PartialFunction[Throwable, Unit] = {
// We should use Throwable instead of Exception since `java.lang.NoClassDefFoundError`
// could be thrown.
case e: Throwable =>
if (cancel && trino.isRunning) trino.cancelLeafStage()
state.synchronized {
val errMsg = Utils.stringifyException(e)
if (state == OperationState.TIMEOUT) {
val ke = KyuubiSQLException(s"Timeout operating $opType: $errMsg")
setOperationException(ke)
throw ke
} else if (isTerminalState(state)) {
setOperationException(KyuubiSQLException(errMsg))
warn(s"Ignore exception in terminal state with $statementId: $errMsg")
} else {
error(s"Error operating $opType: $errMsg", e)
val ke = KyuubiSQLException(s"Error operating $opType: $errMsg", e)
setOperationException(ke)
setState(OperationState.ERROR)
throw ke
}
}
}
}

View File

@ -0,0 +1,70 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kyuubi.engine.trino.operation
import java.util
import org.apache.kyuubi.config.KyuubiConf.OPERATION_INCREMENTAL_COLLECT
import org.apache.kyuubi.operation.Operation
import org.apache.kyuubi.operation.OperationManager
import org.apache.kyuubi.session.Session
class TrinoOperationManager extends OperationManager("TrinoOperationManager") {
def newExecuteStatementOperation(
session: Session,
statement: String,
confOverlay: Map[String, String],
runAsync: Boolean,
queryTimeout: Long): Operation = {
val incrementalCollect = session.sessionManager.getConf.get(OPERATION_INCREMENTAL_COLLECT)
val operation = new ExecuteStatement(session, statement, runAsync, incrementalCollect)
addOperation(operation)
}
override def newGetTypeInfoOperation(session: Session): Operation = null
override def newGetCatalogsOperation(session: Session): Operation = null
override def newGetSchemasOperation(
session: Session,
catalog: String,
schema: String): Operation = null
override def newGetTablesOperation(
session: Session,
catalogName: String,
schemaName: String,
tableName: String,
tableTypes: util.List[String]): Operation = null
override def newGetTableTypesOperation(session: Session): Operation = null
override def newGetColumnsOperation(
session: Session,
catalogName: String,
schemaName: String,
tableName: String,
columnName: String): Operation = null
override def newGetFunctionsOperation(
session: Session,
catalogName: String,
schemaName: String,
functionName: String): Operation = null
}

View File

@ -0,0 +1,102 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kyuubi.engine.trino.session
import java.net.URI
import java.time.ZoneId
import java.util.Collections
import java.util.Locale
import java.util.Optional
import java.util.concurrent.TimeUnit
import io.airlift.units.Duration
import io.trino.client.ClientSession
import okhttp3.OkHttpClient
import org.apache.hive.service.rpc.thrift.TProtocolVersion
import org.apache.kyuubi.KyuubiSQLException
import org.apache.kyuubi.Utils.currentUser
import org.apache.kyuubi.config.KyuubiConf
import org.apache.kyuubi.engine.trino.TrinoConf
import org.apache.kyuubi.engine.trino.TrinoContext
import org.apache.kyuubi.session.AbstractSession
import org.apache.kyuubi.session.SessionHandle
import org.apache.kyuubi.session.SessionManager
class TrinoSessionImpl(
protocol: TProtocolVersion,
user: String,
password: String,
ipAddress: String,
conf: Map[String, String],
sessionManager: SessionManager)
extends AbstractSession(protocol, user, password, ipAddress, conf, sessionManager) {
var trinoContext: TrinoContext = _
private var clientSession: ClientSession = _
override val handle: SessionHandle = SessionHandle(protocol)
override def open(): Unit = {
normalizedConf.foreach {
case ("use:database", database) => clientSession = createClientSession(database)
case _ => // do nothing
}
val httpClient = new OkHttpClient.Builder().build()
if (clientSession == null) {
clientSession = createClientSession()
}
trinoContext = TrinoContext(httpClient, clientSession)
super.open()
}
private def createClientSession(schema: String = null): ClientSession = {
val sessionConf = sessionManager.getConf
val connectionUrl = sessionConf.get(KyuubiConf.ENGINE_TRINO_CONNECTION_URL).getOrElse(
throw KyuubiSQLException("Trino server url can not be null!"))
val catalog = sessionConf.get(KyuubiConf.ENGINE_TRINO_CONNECTION_CATALOG).getOrElse(
throw KyuubiSQLException("Trino default catalog can not be null!"))
val user = sessionConf.getOption("kyuubi.trino.user").getOrElse(currentUser)
val clientRequestTimeout = sessionConf.get(TrinoConf.CLIENT_REQUEST_TIMEOUT)
new ClientSession(
URI.create(connectionUrl),
user,
Optional.empty(),
"kyuubi",
Optional.empty(),
Collections.emptySet(),
null,
catalog,
schema,
null,
ZoneId.systemDefault(),
Locale.getDefault,
Collections.emptyMap(),
Collections.emptyMap(),
Collections.emptyMap(),
Collections.emptyMap(),
Collections.emptyMap(),
null,
new Duration(clientRequestTimeout, TimeUnit.MILLISECONDS),
true)
}
}

View File

@ -0,0 +1,81 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kyuubi.engine.trino.session
import org.apache.hive.service.rpc.thrift.TProtocolVersion
import org.apache.kyuubi.KyuubiSQLException
import org.apache.kyuubi.Utils
import org.apache.kyuubi.config.KyuubiConf
import org.apache.kyuubi.config.KyuubiConf.ENGINE_OPERATION_LOG_DIR_ROOT
import org.apache.kyuubi.config.KyuubiConf.ENGINE_SHARE_LEVEL
import org.apache.kyuubi.engine.ShareLevel
import org.apache.kyuubi.engine.trino.TrinoSqlEngine
import org.apache.kyuubi.engine.trino.operation.TrinoOperationManager
import org.apache.kyuubi.session.SessionHandle
import org.apache.kyuubi.session.SessionManager
class TrinoSessionManager
extends SessionManager("TrinoSessionManager") {
val operationManager = new TrinoOperationManager()
override def initialize(conf: KyuubiConf): Unit = {
val absPath = Utils.getAbsolutePathFromWork(conf.get(ENGINE_OPERATION_LOG_DIR_ROOT))
_operationLogRoot = Some(absPath.toAbsolutePath.toString)
super.initialize(conf)
}
override def openSession(
protocol: TProtocolVersion,
user: String,
password: String,
ipAddress: String,
conf: Map[String, String]): SessionHandle = {
info(s"Opening session for $user@$ipAddress")
val sessionImpl =
new TrinoSessionImpl(protocol, user, password, ipAddress, conf, this)
try {
val handle = sessionImpl.handle
sessionImpl.open()
setSession(handle, sessionImpl)
info(s"$user's trino session with $handle is opened, current opening sessions" +
s" $getOpenSessionCount")
handle
} catch {
case e: Exception =>
sessionImpl.close()
throw KyuubiSQLException(e)
}
}
override def closeSession(sessionHandle: SessionHandle): Unit = {
super.closeSession(sessionHandle)
if (conf.get(ENGINE_SHARE_LEVEL) == ShareLevel.CONNECTION.toString) {
info("Session stopped due to shared level is Connection.")
stopSession()
}
}
private def stopSession(): Unit = {
TrinoSqlEngine.currentEngine.foreach(_.stop())
}
override protected def isServer: Boolean = false
}

View File

@ -0,0 +1,66 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kyuubi.engine.trino
import org.apache.kyuubi.KyuubiFunSuite
import org.apache.kyuubi.config.KyuubiConf
trait WithTrinoEngine extends KyuubiFunSuite with WithTrinoContainerServer {
protected var engine: TrinoSqlEngine = _
protected var connectionUrl: String = _
override val kyuubiConf: KyuubiConf = TrinoSqlEngine.kyuubiConf
def withKyuubiConf: Map[String, String]
override def beforeAll(): Unit = {
withContainers { trinoContainer =>
val containerConnectionUrl = trinoContainer.jdbcUrl.replace("jdbc:trino", "http")
startTrinoEngine(containerConnectionUrl)
super.beforeAll()
}
}
def startTrinoEngine(containerConnectionUrl: String): Unit = {
kyuubiConf.set(KyuubiConf.ENGINE_TRINO_CONNECTION_URL, containerConnectionUrl)
withKyuubiConf.foreach { case (k, v) =>
System.setProperty(k, v)
kyuubiConf.set(k, v)
}
TrinoSqlEngine.startEngine()
engine = TrinoSqlEngine.currentEngine.get
connectionUrl = engine.frontendServices.head.connectionUrl
}
override def afterAll(): Unit = {
super.afterAll()
stopTrinoEngine()
}
def stopTrinoEngine(): Unit = {
if (engine != null) {
engine.stop()
engine = null
}
}
protected def getJdbcUrl: String = s"jdbc:hive2://$connectionUrl/$schema;"
}

View File

@ -0,0 +1,216 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kyuubi.engine.trino.operation
import scala.collection.JavaConverters._
import org.apache.hive.service.rpc.thrift.TCancelOperationReq
import org.apache.hive.service.rpc.thrift.TCloseOperationReq
import org.apache.hive.service.rpc.thrift.TCloseSessionReq
import org.apache.hive.service.rpc.thrift.TExecuteStatementReq
import org.apache.hive.service.rpc.thrift.TFetchOrientation
import org.apache.hive.service.rpc.thrift.TFetchResultsReq
import org.apache.hive.service.rpc.thrift.TGetOperationStatusReq
import org.apache.hive.service.rpc.thrift.TOpenSessionReq
import org.apache.hive.service.rpc.thrift.TOperationState
import org.apache.hive.service.rpc.thrift.TStatusCode
import org.apache.kyuubi.config.KyuubiConf.ENGINE_TRINO_CONNECTION_CATALOG
import org.apache.kyuubi.engine.trino.WithTrinoEngine
import org.apache.kyuubi.operation.HiveJDBCTestHelper
class TrinoOperationSuite extends WithTrinoEngine with HiveJDBCTestHelper {
override def withKyuubiConf: Map[String, String] = Map(
ENGINE_TRINO_CONNECTION_CATALOG.key -> "memory")
// use default schema, do not set to 'default', since withSessionHandle strip suffix '/;'
override protected val schema = ""
override protected def jdbcUrl: String = getJdbcUrl
test("execute statement - select decimal") {
withJdbcStatement() { statement =>
val resultSet = statement.executeQuery("SELECT DECIMAL '1.2' as col1, DECIMAL '1.23' AS col2")
assert(resultSet.next())
assert(resultSet.getBigDecimal("col1") === new java.math.BigDecimal("1.2"))
assert(resultSet.getBigDecimal("col2") === new java.math.BigDecimal("1.23"))
val metaData = resultSet.getMetaData
assert(metaData.getColumnType(1) === java.sql.Types.DECIMAL)
assert(metaData.getColumnType(2) === java.sql.Types.DECIMAL)
assert(metaData.getPrecision(1) == 2)
assert(metaData.getPrecision(2) == 3)
assert(metaData.getScale(1) == 1)
assert(metaData.getScale(2) == 2)
}
}
test("test fetch orientation") {
val sql = "SELECT id FROM (VALUES 0, 1) as t(id)"
withSessionHandle { (client, handle) =>
val req = new TExecuteStatementReq()
req.setSessionHandle(handle)
req.setStatement(sql)
val tExecuteStatementResp = client.ExecuteStatement(req)
val opHandle = tExecuteStatementResp.getOperationHandle
waitForOperationToComplete(client, opHandle)
// fetch next from before first row
val tFetchResultsReq1 = new TFetchResultsReq(opHandle, TFetchOrientation.FETCH_NEXT, 1)
val tFetchResultsResp1 = client.FetchResults(tFetchResultsReq1)
assert(tFetchResultsResp1.getStatus.getStatusCode === TStatusCode.SUCCESS_STATUS)
val idSeq1 = tFetchResultsResp1.getResults.getColumns.get(0).getI32Val.getValues.asScala.toSeq
assertResult(Seq(0L))(idSeq1)
// fetch next from first row
val tFetchResultsReq2 = new TFetchResultsReq(opHandle, TFetchOrientation.FETCH_NEXT, 1)
val tFetchResultsResp2 = client.FetchResults(tFetchResultsReq2)
assert(tFetchResultsResp2.getStatus.getStatusCode === TStatusCode.SUCCESS_STATUS)
val idSeq2 = tFetchResultsResp2.getResults.getColumns.get(0).getI32Val.getValues.asScala.toSeq
assertResult(Seq(1L))(idSeq2)
// fetch prior from second row, expected got first row
val tFetchResultsReq3 = new TFetchResultsReq(opHandle, TFetchOrientation.FETCH_PRIOR, 1)
val tFetchResultsResp3 = client.FetchResults(tFetchResultsReq3)
assert(tFetchResultsResp3.getStatus.getStatusCode === TStatusCode.SUCCESS_STATUS)
val idSeq3 = tFetchResultsResp3.getResults.getColumns.get(0).getI32Val.getValues.asScala.toSeq
assertResult(Seq(0L))(idSeq3)
// fetch first
val tFetchResultsReq4 = new TFetchResultsReq(opHandle, TFetchOrientation.FETCH_FIRST, 3)
val tFetchResultsResp4 = client.FetchResults(tFetchResultsReq4)
assert(tFetchResultsResp4.getStatus.getStatusCode === TStatusCode.SUCCESS_STATUS)
val idSeq4 = tFetchResultsResp4.getResults.getColumns.get(0).getI32Val.getValues.asScala.toSeq
assertResult(Seq(0L, 1L))(idSeq4)
}
}
test("get operation status") {
val sql = "select date '2011-11-11' - interval '1' day"
withSessionHandle { (client, handle) =>
val req = new TExecuteStatementReq()
req.setSessionHandle(handle)
req.setStatement(sql)
val tExecuteStatementResp = client.ExecuteStatement(req)
val opHandle = tExecuteStatementResp.getOperationHandle
val tGetOperationStatusReq = new TGetOperationStatusReq()
tGetOperationStatusReq.setOperationHandle(opHandle)
val resp = client.GetOperationStatus(tGetOperationStatusReq)
val status = resp.getStatus
assert(status.getStatusCode === TStatusCode.SUCCESS_STATUS)
assert(resp.getOperationState === TOperationState.FINISHED_STATE)
assert(resp.isHasResultSet)
}
}
test("basic open | execute | close") {
withThriftClient { client =>
val req = new TOpenSessionReq()
req.setUsername("hongdd")
req.setPassword("anonymous")
val tOpenSessionResp = client.OpenSession(req)
val tExecuteStatementReq = new TExecuteStatementReq()
tExecuteStatementReq.setSessionHandle(tOpenSessionResp.getSessionHandle)
tExecuteStatementReq.setRunAsync(true)
tExecuteStatementReq.setStatement("show session")
val tExecuteStatementResp = client.ExecuteStatement(tExecuteStatementReq)
val operationHandle = tExecuteStatementResp.getOperationHandle
waitForOperationToComplete(client, operationHandle)
val tFetchResultsReq = new TFetchResultsReq()
tFetchResultsReq.setOperationHandle(operationHandle)
tFetchResultsReq.setFetchType(1)
tFetchResultsReq.setMaxRows(1000)
val tFetchResultsResp = client.FetchResults(tFetchResultsReq)
val logs = tFetchResultsResp.getResults.getColumns.get(0).getStringVal.getValues.asScala
assert(logs.exists(_.contains(classOf[ExecuteStatement].getCanonicalName)))
tFetchResultsReq.setFetchType(0)
val tFetchResultsResp1 = client.FetchResults(tFetchResultsReq)
val rs = tFetchResultsResp1.getResults.getColumns.get(0).getStringVal.getValues.asScala
assert(rs.contains("aggregation_operator_unspill_memory_limit"))
val tCloseSessionReq = new TCloseSessionReq()
tCloseSessionReq.setSessionHandle(tOpenSessionResp.getSessionHandle)
val tCloseSessionResp = client.CloseSession(tCloseSessionReq)
assert(tCloseSessionResp.getStatus.getStatusCode === TStatusCode.SUCCESS_STATUS)
}
}
test("not allow to operate closed session or operation") {
withThriftClient { client =>
val req = new TOpenSessionReq()
req.setUsername("hongdd")
req.setPassword("anonymous")
val tOpenSessionResp = client.OpenSession(req)
val tExecuteStatementReq = new TExecuteStatementReq()
tExecuteStatementReq.setSessionHandle(tOpenSessionResp.getSessionHandle)
tExecuteStatementReq.setStatement("show session")
val tExecuteStatementResp = client.ExecuteStatement(tExecuteStatementReq)
val tCloseOperationReq = new TCloseOperationReq(tExecuteStatementResp.getOperationHandle)
val tCloseOperationResp = client.CloseOperation(tCloseOperationReq)
assert(tCloseOperationResp.getStatus.getStatusCode === TStatusCode.SUCCESS_STATUS)
val tFetchResultsReq = new TFetchResultsReq()
tFetchResultsReq.setOperationHandle(tExecuteStatementResp.getOperationHandle)
tFetchResultsReq.setFetchType(0)
tFetchResultsReq.setMaxRows(1000)
val tFetchResultsResp = client.FetchResults(tFetchResultsReq)
assert(tFetchResultsResp.getStatus.getStatusCode === TStatusCode.ERROR_STATUS)
assert(tFetchResultsResp.getStatus.getErrorMessage startsWith "Invalid OperationHandle" +
" [type=EXECUTE_STATEMENT, identifier:")
val tCloseSessionReq = new TCloseSessionReq()
tCloseSessionReq.setSessionHandle(tOpenSessionResp.getSessionHandle)
val tCloseSessionResp = client.CloseSession(tCloseSessionReq)
assert(tCloseSessionResp.getStatus.getStatusCode === TStatusCode.SUCCESS_STATUS)
val tExecuteStatementResp1 = client.ExecuteStatement(tExecuteStatementReq)
val status = tExecuteStatementResp1.getStatus
assert(status.getStatusCode === TStatusCode.ERROR_STATUS)
assert(status.getErrorMessage startsWith s"Invalid SessionHandle [")
}
}
test("cancel operation") {
withThriftClient { client =>
val req = new TOpenSessionReq()
req.setUsername("hongdd")
req.setPassword("anonymous")
val tOpenSessionResp = client.OpenSession(req)
val tExecuteStatementReq = new TExecuteStatementReq()
tExecuteStatementReq.setSessionHandle(tOpenSessionResp.getSessionHandle)
tExecuteStatementReq.setStatement("show session")
val tExecuteStatementResp = client.ExecuteStatement(tExecuteStatementReq)
val tCancelOperationReq = new TCancelOperationReq(tExecuteStatementResp.getOperationHandle)
val tCancelOperationResp = client.CancelOperation(tCancelOperationReq)
assert(tCancelOperationResp.getStatus.getStatusCode === TStatusCode.SUCCESS_STATUS)
val tFetchResultsReq = new TFetchResultsReq()
tFetchResultsReq.setOperationHandle(tExecuteStatementResp.getOperationHandle)
tFetchResultsReq.setFetchType(0)
tFetchResultsReq.setMaxRows(1000)
val tFetchResultsResp = client.FetchResults(tFetchResultsReq)
assert(tFetchResultsResp.getStatus.getStatusCode === TStatusCode.SUCCESS_STATUS)
}
}
}

View File

@ -0,0 +1,42 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kyuubi.engine.trino.session
import org.apache.kyuubi.config.KyuubiConf.ENGINE_SHARE_LEVEL
import org.apache.kyuubi.config.KyuubiConf.ENGINE_TRINO_CONNECTION_CATALOG
import org.apache.kyuubi.engine.trino.WithTrinoEngine
import org.apache.kyuubi.operation.HiveJDBCTestHelper
class SessionSuite extends WithTrinoEngine with HiveJDBCTestHelper {
override def withKyuubiConf: Map[String, String] = Map(
ENGINE_TRINO_CONNECTION_CATALOG.key -> "memory",
ENGINE_SHARE_LEVEL.key -> "SERVER")
override protected val schema = "default"
override protected def jdbcUrl: String = getJdbcUrl
test("test session") {
withJdbcStatement() { statement =>
statement.executeQuery("create or replace view temp_view as select 1 as id")
val resultSet = statement.executeQuery("select * from temp_view")
assert(resultSet.next())
}
}
}

View File

@ -203,6 +203,7 @@ object Utils extends Logging {
// Hooks need to be invoked before the SparkContext stopped shall use a higher priority.
val SPARK_CONTEXT_SHUTDOWN_PRIORITY = 50
val FLINK_ENGINE_SHUTDOWN_PRIORITY = 50
val TRINO_ENGINE_SHUTDOWN_PRIORITY = 50
/**
* Add some operations that you want into ShutdownHook