diff --git a/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/TrinoContext.scala b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/TrinoContext.scala index 8f3131f61..4a0736ddb 100644 --- a/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/TrinoContext.scala +++ b/kyuubi-server/src/main/scala/org/apache/kyuubi/server/trino/api/TrinoContext.scala @@ -18,34 +18,38 @@ package org.apache.kyuubi.server.trino.api import java.io.UnsupportedEncodingException -import java.net.{URLDecoder, URLEncoder} +import java.net.{URI, URLDecoder, URLEncoder} +import java.util import javax.ws.rs.core.{HttpHeaders, Response} import scala.collection.JavaConverters._ +import io.trino.client.{ClientStandardTypes, ClientTypeSignature, Column, QueryError, QueryResults, StatementStats, Warning} import io.trino.client.ProtocolHeaders.TRINO_HEADERS -import io.trino.client.QueryResults +import org.apache.hive.service.rpc.thrift.{TGetResultSetMetadataResp, TRowSet, TTypeId} + +import org.apache.kyuubi.operation.OperationStatus /** * The description and functionality of trino request * and response's context * - * @param user Specifies the session user, must be supplied with every query - * @param timeZone The timezone for query processing + * @param user Specifies the session user, must be supplied with every query + * @param timeZone The timezone for query processing * @param clientCapabilities Exclusive for trino server - * @param source This supplies the name of the software that submitted the query, - * e.g. `trino-jdbc` or `trino-cli` by default - * @param catalog The catalog context for query processing, will be set response - * @param schema The schema context for query processing - * @param language The language to use when processing the query and formatting results, - * formatted as a Java Locale string, e.g., en-US for US English - * @param traceToken Trace token for correlating requests across systems - * @param clientInfo Extra information about the client - * @param clientTags Client tags for selecting resource groups. Example: abc,xyz - * @param preparedStatement `preparedStatement` are kv pairs, where the names - * are names of previously prepared SQL statements, - * and the values are keys that identify the - * executable form of the named prepared statements + * @param source This supplies the name of the software that submitted the query, + * e.g. `trino-jdbc` or `trino-cli` by default + * @param catalog The catalog context for query processing, will be set response + * @param schema The schema context for query processing + * @param language The language to use when processing the query and formatting results, + * formatted as a Java Locale string, e.g., en-US for US English + * @param traceToken Trace token for correlating requests across systems + * @param clientInfo Extra information about the client + * @param clientTags Client tags for selecting resource groups. Example: abc,xyz + * @param preparedStatement `preparedStatement` are kv pairs, where the names + * are names of previously prepared SQL statements, + * and the values are keys that identify the + * executable form of the named prepared statements */ case class TrinoContext( user: String, @@ -63,6 +67,11 @@ case class TrinoContext( object TrinoContext { + private val defaultWarning: util.List[Warning] = new util.ArrayList[Warning]() + private val GENERIC_INTERNAL_ERROR_CODE = 65536 + private val GENERIC_INTERNAL_ERROR_NAME = "GENERIC_INTERNAL_ERROR_NAME" + private val GENERIC_INTERNAL_ERROR_TYPE = "INTERNAL_ERROR" + def apply(headers: HttpHeaders): TrinoContext = { apply(headers.getRequestHeaders.asScala.toMap.map { case (k, v) => (k, v.asScala.toList) @@ -166,4 +175,196 @@ object TrinoContext { throw new AssertionError(e) } + def createQueryResults( + queryId: String, + nextUri: URI, + queryHtmlUri: URI, + queryStatus: OperationStatus, + columns: Option[TGetResultSetMetadataResp] = None, + data: Option[TRowSet] = None): QueryResults = { + + val columnList = columns match { + case Some(value) => convertTColumn(value) + case None => null + } + val rowList = data match { + case Some(value) => convertTRowSet(value) + case None => null + } + + new QueryResults( + queryId, + queryHtmlUri, + nextUri, + nextUri, + columnList, + rowList, + StatementStats.builder.setState(queryStatus.state.name()).setQueued(false) + .setElapsedTimeMillis(0).setQueuedTimeMillis(0).build(), + toQueryError(queryStatus), + defaultWarning, + null, + 0L) + } + + def convertTColumn(columns: TGetResultSetMetadataResp): util.List[Column] = { + columns.getSchema.getColumns.asScala.map(c => { + val tp = c.getTypeDesc.getTypes.get(0).getPrimitiveEntry.getType match { + case TTypeId.BOOLEAN_TYPE => ClientStandardTypes.BOOLEAN + case TTypeId.TINYINT_TYPE => ClientStandardTypes.TINYINT + case TTypeId.SMALLINT_TYPE => ClientStandardTypes.SMALLINT + case TTypeId.INT_TYPE => ClientStandardTypes.INTEGER + case TTypeId.BIGINT_TYPE => ClientStandardTypes.BIGINT + case TTypeId.FLOAT_TYPE => ClientStandardTypes.DOUBLE + case TTypeId.DOUBLE_TYPE => ClientStandardTypes.DOUBLE + case TTypeId.STRING_TYPE => ClientStandardTypes.VARCHAR + case TTypeId.TIMESTAMP_TYPE => ClientStandardTypes.TIMESTAMP + case TTypeId.BINARY_TYPE => ClientStandardTypes.VARBINARY + case TTypeId.DECIMAL_TYPE => ClientStandardTypes.DECIMAL + case TTypeId.DATE_TYPE => ClientStandardTypes.DATE + case TTypeId.VARCHAR_TYPE => ClientStandardTypes.VARCHAR + case TTypeId.CHAR_TYPE => ClientStandardTypes.CHAR + case TTypeId.INTERVAL_YEAR_MONTH_TYPE => ClientStandardTypes.INTERVAL_YEAR_TO_MONTH + case TTypeId.INTERVAL_DAY_TIME_TYPE => ClientStandardTypes.TIME_WITH_TIME_ZONE + case TTypeId.TIMESTAMPLOCALTZ_TYPE => ClientStandardTypes.TIMESTAMP_WITH_TIME_ZONE + case _ => ClientStandardTypes.VARCHAR + } + new Column(c.getColumnName, tp, new ClientTypeSignature(tp)) + }).toList.asJava + } + + def convertTRowSet(rowSet: TRowSet): util.List[util.List[Object]] = { + val dataResult = new util.LinkedList[util.List[Object]] + + if (rowSet.getColumns == null) { + return rowSet.getRows.asScala + .map(t => t.getColVals.asScala.map(v => v.getFieldValue.asInstanceOf[Object]).asJava) + .asJava + } + + rowSet.getColumns.asScala.foreach { + case tColumn if tColumn.isSetBoolVal => + val nulls = util.BitSet.valueOf(tColumn.getBoolVal.getNulls) + if (dataResult.isEmpty) { + (1 to tColumn.getBoolVal.getValuesSize).foreach(_ => + dataResult.add(new util.LinkedList[Object]())) + } + + tColumn.getBoolVal.getValues.asScala.zipWithIndex.foreach { + case (_, rowIdx) if nulls.get(rowIdx) => + dataResult.get(rowIdx).add(null) + case (v, rowIdx) => + dataResult.get(rowIdx).add(v) + } + case tColumn if tColumn.isSetByteVal => + val nulls = util.BitSet.valueOf(tColumn.getByteVal.getNulls) + if (dataResult.isEmpty) { + (1 to tColumn.getByteVal.getValuesSize).foreach(_ => + dataResult.add(new util.LinkedList[Object]())) + } + + tColumn.getByteVal.getValues.asScala.zipWithIndex.foreach { + case (_, rowIdx) if nulls.get(rowIdx) => + dataResult.get(rowIdx).add(null) + case (v, rowIdx) => + dataResult.get(rowIdx).add(v) + } + case tColumn if tColumn.isSetI16Val => + val nulls = util.BitSet.valueOf(tColumn.getI16Val.getNulls) + if (dataResult.isEmpty) { + (1 to tColumn.getI16Val.getValuesSize).foreach(_ => + dataResult.add(new util.LinkedList[Object]())) + } + + tColumn.getI16Val.getValues.asScala.zipWithIndex.foreach { + case (_, rowIdx) if nulls.get(rowIdx) => + dataResult.get(rowIdx).add(null) + case (v, rowIdx) => + dataResult.get(rowIdx).add(v) + } + case tColumn if tColumn.isSetI32Val => + val nulls = util.BitSet.valueOf(tColumn.getI32Val.getNulls) + if (dataResult.isEmpty) { + (1 to tColumn.getI32Val.getValuesSize).foreach(_ => + dataResult.add(new util.LinkedList[Object]())) + } + + tColumn.getI32Val.getValues.asScala.zipWithIndex.foreach { + case (_, rowIdx) if nulls.get(rowIdx) => + dataResult.get(rowIdx).add(null) + case (v, rowIdx) => + dataResult.get(rowIdx).add(v) + } + case tColumn if tColumn.isSetI64Val => + val nulls = util.BitSet.valueOf(tColumn.getI64Val.getNulls) + if (dataResult.isEmpty) { + (1 to tColumn.getI64Val.getValuesSize).foreach(_ => + dataResult.add(new util.LinkedList[Object]())) + } + + tColumn.getI64Val.getValues.asScala.zipWithIndex.foreach { + case (_, rowIdx) if nulls.get(rowIdx) => + dataResult.get(rowIdx).add(null) + case (v, rowIdx) => + dataResult.get(rowIdx).add(v) + } + case tColumn if tColumn.isSetDoubleVal => + val nulls = util.BitSet.valueOf(tColumn.getDoubleVal.getNulls) + if (dataResult.isEmpty) { + (1 to tColumn.getDoubleVal.getValuesSize).foreach(_ => + dataResult.add(new util.LinkedList[Object]())) + } + + tColumn.getDoubleVal.getValues.asScala.zipWithIndex.foreach { + case (_, rowIdx) if nulls.get(rowIdx) => + dataResult.get(rowIdx).add(null) + case (v, rowIdx) => + dataResult.get(rowIdx).add(v) + } + case tColumn if tColumn.isSetBinaryVal => + val nulls = util.BitSet.valueOf(tColumn.getBinaryVal.getNulls) + if (dataResult.isEmpty) { + (1 to tColumn.getBinaryVal.getValuesSize).foreach(_ => + dataResult.add(new util.LinkedList[Object]())) + } + + tColumn.getBinaryVal.getValues.asScala.zipWithIndex.foreach { + case (_, rowIdx) if nulls.get(rowIdx) => + dataResult.get(rowIdx).add(null) + case (v, rowIdx) => + dataResult.get(rowIdx).add(v) + } + case tColumn => + val nulls = util.BitSet.valueOf(tColumn.getStringVal.getNulls) + if (dataResult.isEmpty) { + (1 to tColumn.getStringVal.getValuesSize).foreach(_ => + dataResult.add(new util.LinkedList[Object]())) + } + + tColumn.getStringVal.getValues.asScala.zipWithIndex.foreach { + case (_, rowIdx) if nulls.get(rowIdx) => + dataResult.get(rowIdx).add(null) + case (v, rowIdx) => + dataResult.get(rowIdx).add(v) + } + } + dataResult + } + + def toQueryError(queryStatus: OperationStatus): QueryError = { + val exception = queryStatus.exception + if (exception.isEmpty) { + null + } else { + new QueryError( + exception.get.getMessage, + queryStatus.state.name(), + GENERIC_INTERNAL_ERROR_CODE, + GENERIC_INTERNAL_ERROR_NAME, + GENERIC_INTERNAL_ERROR_TYPE, + null, + null) + } + } + } diff --git a/kyuubi-server/src/test/scala/org/apache/kyuubi/server/trino/api/TrinoContextSuite.scala b/kyuubi-server/src/test/scala/org/apache/kyuubi/server/trino/api/TrinoContextSuite.scala index 67a502288..8d7b2bf2c 100644 --- a/kyuubi-server/src/test/scala/org/apache/kyuubi/server/trino/api/TrinoContextSuite.scala +++ b/kyuubi-server/src/test/scala/org/apache/kyuubi/server/trino/api/TrinoContextSuite.scala @@ -17,13 +17,24 @@ package org.apache.kyuubi.server.trino.api +import java.net.URI import java.time.ZoneId +import javax.ws.rs.core.MediaType + +import scala.collection.JavaConverters._ import io.trino.client.ProtocolHeaders.TRINO_HEADERS +import org.apache.hive.service.rpc.thrift.TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V9 +import org.scalatest.concurrent.PatienceConfiguration.Timeout +import org.scalatest.time.SpanSugar.convertIntToGrainOfTime -import org.apache.kyuubi.KyuubiFunSuite +import org.apache.kyuubi.{KyuubiFunSuite, RestFrontendTestHelper} +import org.apache.kyuubi.events.KyuubiOperationEvent +import org.apache.kyuubi.operation.{FetchOrientation, OperationHandle} +import org.apache.kyuubi.operation.OperationState.{FINISHED, OperationState} + +class TrinoContextSuite extends KyuubiFunSuite with RestFrontendTestHelper { -class TrinoContextSuite extends KyuubiFunSuite { import TrinoContext._ test("create trino request context with header") { @@ -67,4 +78,83 @@ class TrinoContextSuite extends KyuubiFunSuite { assert(actual == expectedTrinoContext) } + test("test convert") { + val opHandle = getOpHandle("select 1") + val opHandleStr = opHandle.identifier.toString + checkOpState(opHandleStr, FINISHED) + + val metadataResp = fe.be.getResultSetMetadata(opHandle) + val tRowSet = fe.be.fetchResults(opHandle, FetchOrientation.FETCH_NEXT, 1000, false) + val status = fe.be.getOperationStatus(opHandle) + + val uri = new URI("sfdsfsdfdsf") + val results = TrinoContext + .createQueryResults("/xdfd/xdf", uri, uri, status, Option(metadataResp), Option(tRowSet)) + + print(results.toString) + assert(results.getColumns.get(0).getType.equals("integer")) + assert(results.getData.asScala.last.get(0) == 1) + } + + test("test convert from table") { + initSql("CREATE DATABASE IF NOT EXISTS INIT_DB") + initSql( + "CREATE TABLE IF NOT EXISTS INIT_DB.test(a int, b double, c String," + + "d BOOLEAN,e DATE,f TIMESTAMP,g ARRAY,h DECIMAL," + + "i MAP) USING PARQUET;") + initSql( + "INSERT INTO INIT_DB.test VALUES (1,2.2,'3',true,current_date()," + + "current_timestamp(),array('1','2'),2.0, map('m','p') )") + + val opHandle = getOpHandle("SELECT * FROM INIT_DB.test") + val opHandleStr = opHandle.identifier.toString + checkOpState(opHandleStr, FINISHED) + + val metadataResp = fe.be.getResultSetMetadata(opHandle) + val tRowSet = fe.be.fetchResults(opHandle, FetchOrientation.FETCH_NEXT, 1000, false) + val status = fe.be.getOperationStatus(opHandle) + + val uri = new URI("sfdsfsdfdsf") + val results = TrinoContext + .createQueryResults("/xdfd/xdf", uri, uri, status, Option(metadataResp), Option(tRowSet)) + + print(results.toString) + assert(results.getColumns.get(0).getType.equals("integer")) + assert(results.getData.asScala.last.get(0) != null) + } + + def getOpHandleStr(statement: String = "show tables"): String = { + getOpHandle(statement).identifier.toString + } + + def getOpHandle(statement: String = "show tables"): OperationHandle = { + val sessionHandle = fe.be.openSession( + HIVE_CLI_SERVICE_PROTOCOL_V9, + "admin", + "123456", + "localhost", + Map("testConfig" -> "testValue")) + + if (statement.nonEmpty) { + fe.be.executeStatement(sessionHandle, statement, Map.empty, runAsync = false, 30000) + } else { + fe.be.getCatalogs(sessionHandle) + } + } + + private def checkOpState(opHandleStr: String, state: OperationState): Unit = { + eventually(Timeout(30.seconds)) { + val response = webTarget.path(s"api/v1/operations/$opHandleStr/event") + .request(MediaType.APPLICATION_JSON_TYPE).get() + assert(response.getStatus === 200) + val operationEvent = response.readEntity(classOf[KyuubiOperationEvent]) + assert(operationEvent.state === state.name()) + } + } + + private def initSql(sql: String): Unit = { + val initOpHandle = getOpHandle(sql) + val initOpHandleStr = initOpHandle.identifier.toString + checkOpState(initOpHandleStr, FINISHED) + } }