From 1ff8cc01ea1f9341b3fcedb6356e75ce1da3c4d4 Mon Sep 17 00:00:00 2001 From: yikf Date: Tue, 13 Sep 2022 19:49:22 +0800 Subject: [PATCH] [KYUUBI #3452] Implement GetInfo for Trino engine ### _Why are the changes needed?_ close https://github.com/apache/incubator-kyuubi/issues/3452, this pr aims to implement `GetInfo` for Trino engine ### _How was this patch tested?_ - [ ] Add some test cases that check the changes thoroughly including negative and positive cases if possible - [ ] Add screenshots for manual tests if appropriate - [ ] [Run test](https://kyuubi.apache.org/docs/latest/develop_tools/testing.html#running-tests) locally before make a pull request Closes #3477 from Yikf/trino-getinfo. Closes #3452 ab696a3d [yikf] Trino get info Authored-by: yikf Signed-off-by: Cheng Pan --- .../trino/session/TrinoSessionImpl.scala | 27 ++++++++-- .../trino/operation/TrinoOperationSuite.scala | 52 ++++++++++++++----- 2 files changed, 64 insertions(+), 15 deletions(-) diff --git a/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/session/TrinoSessionImpl.scala b/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/session/TrinoSessionImpl.scala index c774a966f..ac7c246e8 100644 --- a/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/session/TrinoSessionImpl.scala +++ b/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/session/TrinoSessionImpl.scala @@ -27,13 +27,12 @@ import java.util.concurrent.TimeUnit import io.airlift.units.Duration import io.trino.client.ClientSession import okhttp3.OkHttpClient -import org.apache.hive.service.rpc.thrift.TProtocolVersion +import org.apache.hive.service.rpc.thrift.{TGetInfoType, TGetInfoValue, TProtocolVersion} import org.apache.kyuubi.KyuubiSQLException import org.apache.kyuubi.Utils.currentUser import org.apache.kyuubi.config.{KyuubiConf, KyuubiReservedKeys} -import org.apache.kyuubi.engine.trino.TrinoConf -import org.apache.kyuubi.engine.trino.TrinoContext +import org.apache.kyuubi.engine.trino.{TrinoConf, TrinoContext, TrinoStatement} import org.apache.kyuubi.engine.trino.event.TrinoSessionEvent import org.apache.kyuubi.events.EventBus import org.apache.kyuubi.operation.{Operation, OperationHandle} @@ -109,6 +108,28 @@ class TrinoSessionImpl( super.runOperation(operation) } + override def getInfo(infoType: TGetInfoType): TGetInfoValue = withAcquireRelease() { + infoType match { + case TGetInfoType.CLI_SERVER_NAME | TGetInfoType.CLI_DBMS_NAME => + TGetInfoValue.stringValue("Trino") + case TGetInfoType.CLI_DBMS_VER => TGetInfoValue.stringValue(getTrinoServerVersion) + case TGetInfoType.CLI_ODBC_KEYWORDS => TGetInfoValue.stringValue("Unimplemented") + case TGetInfoType.CLI_MAX_COLUMN_NAME_LEN | + TGetInfoType.CLI_MAX_SCHEMA_NAME_LEN | + TGetInfoType.CLI_MAX_TABLE_NAME_LEN => TGetInfoValue.lenValue(0) + case _ => throw KyuubiSQLException(s"Unrecognized GetInfoType value: $infoType") + } + } + + private def getTrinoServerVersion: String = { + val trinoStatement = + TrinoStatement(trinoContext, sessionManager.getConf, "SELECT version()") + val resultSet = trinoStatement.execute() + + assert(resultSet.hasNext) + resultSet.next().head.toString + } + override def close(): Unit = { sessionEvent.endTime = System.currentTimeMillis() EventBus.post(sessionEvent) diff --git a/externals/kyuubi-trino-engine/src/test/scala/org/apache/kyuubi/engine/trino/operation/TrinoOperationSuite.scala b/externals/kyuubi-trino-engine/src/test/scala/org/apache/kyuubi/engine/trino/operation/TrinoOperationSuite.scala index 7b0117f70..a6f125af5 100644 --- a/externals/kyuubi-trino-engine/src/test/scala/org/apache/kyuubi/engine/trino/operation/TrinoOperationSuite.scala +++ b/externals/kyuubi-trino-engine/src/test/scala/org/apache/kyuubi/engine/trino/operation/TrinoOperationSuite.scala @@ -22,21 +22,12 @@ import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.Set import io.trino.client.ClientStandardTypes._ -import org.apache.hive.service.rpc.thrift.TCancelOperationReq -import org.apache.hive.service.rpc.thrift.TCloseOperationReq -import org.apache.hive.service.rpc.thrift.TCloseSessionReq -import org.apache.hive.service.rpc.thrift.TExecuteStatementReq -import org.apache.hive.service.rpc.thrift.TFetchOrientation -import org.apache.hive.service.rpc.thrift.TFetchResultsReq -import org.apache.hive.service.rpc.thrift.TGetOperationStatusReq -import org.apache.hive.service.rpc.thrift.TOpenSessionReq -import org.apache.hive.service.rpc.thrift.TOperationState -import org.apache.hive.service.rpc.thrift.TStatusCode +import org.apache.hive.service.rpc.thrift._ import org.apache.kyuubi.KyuubiSQLException +import org.apache.kyuubi.config.KyuubiConf import org.apache.kyuubi.config.KyuubiConf._ -import org.apache.kyuubi.engine.trino.TrinoQueryTests -import org.apache.kyuubi.engine.trino.WithTrinoEngine +import org.apache.kyuubi.engine.trino.{TrinoQueryTests, TrinoStatement, WithTrinoEngine} import org.apache.kyuubi.operation.meta.ResultSetSchemaConstant._ class TrinoOperationSuite extends WithTrinoEngine with TrinoQueryTests { @@ -768,4 +759,41 @@ class TrinoOperationSuite extends WithTrinoEngine with TrinoQueryTests { } } } + + test("[KYUUBI #3452] Implement GetInfo for Trino engine") { + def getTrinoVersion: String = { + var version: String = "Unknown" + withTrinoContainer { trinoContext => + val trinoStatement = TrinoStatement(trinoContext, kyuubiConf, "SELECT version()") + val schema = trinoStatement.getColumns + val resultSet = trinoStatement.execute() + + assert(schema.size === 1) + assert(schema(0).getName === "_col0") + + assert(resultSet.toIterator.hasNext) + version = resultSet.toIterator.next().head.toString + } + version + } + + withSessionConf(Map(KyuubiConf.SERVER_INFO_PROVIDER.key -> "ENGINE"))()() { + withSessionHandle { (client, handle) => + val req = new TGetInfoReq() + req.setSessionHandle(handle) + req.setInfoType(TGetInfoType.CLI_DBMS_NAME) + assert(client.GetInfo(req).getInfoValue.getStringValue === "Trino") + + val req2 = new TGetInfoReq() + req2.setSessionHandle(handle) + req2.setInfoType(TGetInfoType.CLI_DBMS_VER) + assert(client.GetInfo(req2).getInfoValue.getStringValue === getTrinoVersion) + + val req3 = new TGetInfoReq() + req3.setSessionHandle(handle) + req3.setInfoType(TGetInfoType.CLI_MAX_COLUMN_NAME_LEN) + assert(client.GetInfo(req3).getInfoValue.getLenValue === 0) + } + } + } }