[KYUUBI #3934] Compatiable with Trino rest dto

### _Why are the changes needed?_

close #3934

### _How was this patch tested?_
- [x] Add some test cases that check the changes thoroughly including negative and positive cases if possible

- [ ] Add screenshots for manual tests if appropriate

- [x] [Run test](https://kyuubi.readthedocs.io/en/master/develop_tools/testing.html#running-tests) locally before make a pull request

Closes #4182 from yehere/kyuubi-3934.

Closes #3934

ced64e2c [yehere] [KYUUBI #3934] Add more result types support
3ef85230 [yehere] [KYUUBI #3934] Optimization for code review
69e1f442 [yehere] [KYUUBI #3934] Merge the test class to TrinoContextSuite
4f0a0152 [yehere] [KYUUBI #3934] Merge the class to TrinoContext
7c9473f6 [yehere] [KYUUBI #3934] Format style, with Copyright  Profiles
2023f3ce [yehere] [KYUUBI #3934] Format and add test case
a2243b46 [yehere] [KYUUBI #3934] Compatiable with Trino rest dto

Authored-by: yehere <867171931@qq.com>
Signed-off-by: ulyssesyou <ulyssesyou@apache.org>
This commit is contained in:
yehere 2023-02-07 11:08:48 +08:00 committed by ulyssesyou
parent f62a7ac587
commit 1b92d80678
2 changed files with 310 additions and 19 deletions

View File

@ -18,34 +18,38 @@
package org.apache.kyuubi.server.trino.api
import java.io.UnsupportedEncodingException
import java.net.{URLDecoder, URLEncoder}
import java.net.{URI, URLDecoder, URLEncoder}
import java.util
import javax.ws.rs.core.{HttpHeaders, Response}
import scala.collection.JavaConverters._
import io.trino.client.{ClientStandardTypes, ClientTypeSignature, Column, QueryError, QueryResults, StatementStats, Warning}
import io.trino.client.ProtocolHeaders.TRINO_HEADERS
import io.trino.client.QueryResults
import org.apache.hive.service.rpc.thrift.{TGetResultSetMetadataResp, TRowSet, TTypeId}
import org.apache.kyuubi.operation.OperationStatus
/**
* The description and functionality of trino request
* and response's context
*
* @param user Specifies the session user, must be supplied with every query
* @param timeZone The timezone for query processing
* @param user Specifies the session user, must be supplied with every query
* @param timeZone The timezone for query processing
* @param clientCapabilities Exclusive for trino server
* @param source This supplies the name of the software that submitted the query,
* e.g. `trino-jdbc` or `trino-cli` by default
* @param catalog The catalog context for query processing, will be set response
* @param schema The schema context for query processing
* @param language The language to use when processing the query and formatting results,
* formatted as a Java Locale string, e.g., en-US for US English
* @param traceToken Trace token for correlating requests across systems
* @param clientInfo Extra information about the client
* @param clientTags Client tags for selecting resource groups. Example: abc,xyz
* @param preparedStatement `preparedStatement` are kv pairs, where the names
* are names of previously prepared SQL statements,
* and the values are keys that identify the
* executable form of the named prepared statements
* @param source This supplies the name of the software that submitted the query,
* e.g. `trino-jdbc` or `trino-cli` by default
* @param catalog The catalog context for query processing, will be set response
* @param schema The schema context for query processing
* @param language The language to use when processing the query and formatting results,
* formatted as a Java Locale string, e.g., en-US for US English
* @param traceToken Trace token for correlating requests across systems
* @param clientInfo Extra information about the client
* @param clientTags Client tags for selecting resource groups. Example: abc,xyz
* @param preparedStatement `preparedStatement` are kv pairs, where the names
* are names of previously prepared SQL statements,
* and the values are keys that identify the
* executable form of the named prepared statements
*/
case class TrinoContext(
user: String,
@ -63,6 +67,11 @@ case class TrinoContext(
object TrinoContext {
private val defaultWarning: util.List[Warning] = new util.ArrayList[Warning]()
private val GENERIC_INTERNAL_ERROR_CODE = 65536
private val GENERIC_INTERNAL_ERROR_NAME = "GENERIC_INTERNAL_ERROR_NAME"
private val GENERIC_INTERNAL_ERROR_TYPE = "INTERNAL_ERROR"
def apply(headers: HttpHeaders): TrinoContext = {
apply(headers.getRequestHeaders.asScala.toMap.map {
case (k, v) => (k, v.asScala.toList)
@ -166,4 +175,196 @@ object TrinoContext {
throw new AssertionError(e)
}
def createQueryResults(
queryId: String,
nextUri: URI,
queryHtmlUri: URI,
queryStatus: OperationStatus,
columns: Option[TGetResultSetMetadataResp] = None,
data: Option[TRowSet] = None): QueryResults = {
val columnList = columns match {
case Some(value) => convertTColumn(value)
case None => null
}
val rowList = data match {
case Some(value) => convertTRowSet(value)
case None => null
}
new QueryResults(
queryId,
queryHtmlUri,
nextUri,
nextUri,
columnList,
rowList,
StatementStats.builder.setState(queryStatus.state.name()).setQueued(false)
.setElapsedTimeMillis(0).setQueuedTimeMillis(0).build(),
toQueryError(queryStatus),
defaultWarning,
null,
0L)
}
def convertTColumn(columns: TGetResultSetMetadataResp): util.List[Column] = {
columns.getSchema.getColumns.asScala.map(c => {
val tp = c.getTypeDesc.getTypes.get(0).getPrimitiveEntry.getType match {
case TTypeId.BOOLEAN_TYPE => ClientStandardTypes.BOOLEAN
case TTypeId.TINYINT_TYPE => ClientStandardTypes.TINYINT
case TTypeId.SMALLINT_TYPE => ClientStandardTypes.SMALLINT
case TTypeId.INT_TYPE => ClientStandardTypes.INTEGER
case TTypeId.BIGINT_TYPE => ClientStandardTypes.BIGINT
case TTypeId.FLOAT_TYPE => ClientStandardTypes.DOUBLE
case TTypeId.DOUBLE_TYPE => ClientStandardTypes.DOUBLE
case TTypeId.STRING_TYPE => ClientStandardTypes.VARCHAR
case TTypeId.TIMESTAMP_TYPE => ClientStandardTypes.TIMESTAMP
case TTypeId.BINARY_TYPE => ClientStandardTypes.VARBINARY
case TTypeId.DECIMAL_TYPE => ClientStandardTypes.DECIMAL
case TTypeId.DATE_TYPE => ClientStandardTypes.DATE
case TTypeId.VARCHAR_TYPE => ClientStandardTypes.VARCHAR
case TTypeId.CHAR_TYPE => ClientStandardTypes.CHAR
case TTypeId.INTERVAL_YEAR_MONTH_TYPE => ClientStandardTypes.INTERVAL_YEAR_TO_MONTH
case TTypeId.INTERVAL_DAY_TIME_TYPE => ClientStandardTypes.TIME_WITH_TIME_ZONE
case TTypeId.TIMESTAMPLOCALTZ_TYPE => ClientStandardTypes.TIMESTAMP_WITH_TIME_ZONE
case _ => ClientStandardTypes.VARCHAR
}
new Column(c.getColumnName, tp, new ClientTypeSignature(tp))
}).toList.asJava
}
def convertTRowSet(rowSet: TRowSet): util.List[util.List[Object]] = {
val dataResult = new util.LinkedList[util.List[Object]]
if (rowSet.getColumns == null) {
return rowSet.getRows.asScala
.map(t => t.getColVals.asScala.map(v => v.getFieldValue.asInstanceOf[Object]).asJava)
.asJava
}
rowSet.getColumns.asScala.foreach {
case tColumn if tColumn.isSetBoolVal =>
val nulls = util.BitSet.valueOf(tColumn.getBoolVal.getNulls)
if (dataResult.isEmpty) {
(1 to tColumn.getBoolVal.getValuesSize).foreach(_ =>
dataResult.add(new util.LinkedList[Object]()))
}
tColumn.getBoolVal.getValues.asScala.zipWithIndex.foreach {
case (_, rowIdx) if nulls.get(rowIdx) =>
dataResult.get(rowIdx).add(null)
case (v, rowIdx) =>
dataResult.get(rowIdx).add(v)
}
case tColumn if tColumn.isSetByteVal =>
val nulls = util.BitSet.valueOf(tColumn.getByteVal.getNulls)
if (dataResult.isEmpty) {
(1 to tColumn.getByteVal.getValuesSize).foreach(_ =>
dataResult.add(new util.LinkedList[Object]()))
}
tColumn.getByteVal.getValues.asScala.zipWithIndex.foreach {
case (_, rowIdx) if nulls.get(rowIdx) =>
dataResult.get(rowIdx).add(null)
case (v, rowIdx) =>
dataResult.get(rowIdx).add(v)
}
case tColumn if tColumn.isSetI16Val =>
val nulls = util.BitSet.valueOf(tColumn.getI16Val.getNulls)
if (dataResult.isEmpty) {
(1 to tColumn.getI16Val.getValuesSize).foreach(_ =>
dataResult.add(new util.LinkedList[Object]()))
}
tColumn.getI16Val.getValues.asScala.zipWithIndex.foreach {
case (_, rowIdx) if nulls.get(rowIdx) =>
dataResult.get(rowIdx).add(null)
case (v, rowIdx) =>
dataResult.get(rowIdx).add(v)
}
case tColumn if tColumn.isSetI32Val =>
val nulls = util.BitSet.valueOf(tColumn.getI32Val.getNulls)
if (dataResult.isEmpty) {
(1 to tColumn.getI32Val.getValuesSize).foreach(_ =>
dataResult.add(new util.LinkedList[Object]()))
}
tColumn.getI32Val.getValues.asScala.zipWithIndex.foreach {
case (_, rowIdx) if nulls.get(rowIdx) =>
dataResult.get(rowIdx).add(null)
case (v, rowIdx) =>
dataResult.get(rowIdx).add(v)
}
case tColumn if tColumn.isSetI64Val =>
val nulls = util.BitSet.valueOf(tColumn.getI64Val.getNulls)
if (dataResult.isEmpty) {
(1 to tColumn.getI64Val.getValuesSize).foreach(_ =>
dataResult.add(new util.LinkedList[Object]()))
}
tColumn.getI64Val.getValues.asScala.zipWithIndex.foreach {
case (_, rowIdx) if nulls.get(rowIdx) =>
dataResult.get(rowIdx).add(null)
case (v, rowIdx) =>
dataResult.get(rowIdx).add(v)
}
case tColumn if tColumn.isSetDoubleVal =>
val nulls = util.BitSet.valueOf(tColumn.getDoubleVal.getNulls)
if (dataResult.isEmpty) {
(1 to tColumn.getDoubleVal.getValuesSize).foreach(_ =>
dataResult.add(new util.LinkedList[Object]()))
}
tColumn.getDoubleVal.getValues.asScala.zipWithIndex.foreach {
case (_, rowIdx) if nulls.get(rowIdx) =>
dataResult.get(rowIdx).add(null)
case (v, rowIdx) =>
dataResult.get(rowIdx).add(v)
}
case tColumn if tColumn.isSetBinaryVal =>
val nulls = util.BitSet.valueOf(tColumn.getBinaryVal.getNulls)
if (dataResult.isEmpty) {
(1 to tColumn.getBinaryVal.getValuesSize).foreach(_ =>
dataResult.add(new util.LinkedList[Object]()))
}
tColumn.getBinaryVal.getValues.asScala.zipWithIndex.foreach {
case (_, rowIdx) if nulls.get(rowIdx) =>
dataResult.get(rowIdx).add(null)
case (v, rowIdx) =>
dataResult.get(rowIdx).add(v)
}
case tColumn =>
val nulls = util.BitSet.valueOf(tColumn.getStringVal.getNulls)
if (dataResult.isEmpty) {
(1 to tColumn.getStringVal.getValuesSize).foreach(_ =>
dataResult.add(new util.LinkedList[Object]()))
}
tColumn.getStringVal.getValues.asScala.zipWithIndex.foreach {
case (_, rowIdx) if nulls.get(rowIdx) =>
dataResult.get(rowIdx).add(null)
case (v, rowIdx) =>
dataResult.get(rowIdx).add(v)
}
}
dataResult
}
def toQueryError(queryStatus: OperationStatus): QueryError = {
val exception = queryStatus.exception
if (exception.isEmpty) {
null
} else {
new QueryError(
exception.get.getMessage,
queryStatus.state.name(),
GENERIC_INTERNAL_ERROR_CODE,
GENERIC_INTERNAL_ERROR_NAME,
GENERIC_INTERNAL_ERROR_TYPE,
null,
null)
}
}
}

View File

@ -17,13 +17,24 @@
package org.apache.kyuubi.server.trino.api
import java.net.URI
import java.time.ZoneId
import javax.ws.rs.core.MediaType
import scala.collection.JavaConverters._
import io.trino.client.ProtocolHeaders.TRINO_HEADERS
import org.apache.hive.service.rpc.thrift.TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V9
import org.scalatest.concurrent.PatienceConfiguration.Timeout
import org.scalatest.time.SpanSugar.convertIntToGrainOfTime
import org.apache.kyuubi.KyuubiFunSuite
import org.apache.kyuubi.{KyuubiFunSuite, RestFrontendTestHelper}
import org.apache.kyuubi.events.KyuubiOperationEvent
import org.apache.kyuubi.operation.{FetchOrientation, OperationHandle}
import org.apache.kyuubi.operation.OperationState.{FINISHED, OperationState}
class TrinoContextSuite extends KyuubiFunSuite with RestFrontendTestHelper {
class TrinoContextSuite extends KyuubiFunSuite {
import TrinoContext._
test("create trino request context with header") {
@ -67,4 +78,83 @@ class TrinoContextSuite extends KyuubiFunSuite {
assert(actual == expectedTrinoContext)
}
test("test convert") {
val opHandle = getOpHandle("select 1")
val opHandleStr = opHandle.identifier.toString
checkOpState(opHandleStr, FINISHED)
val metadataResp = fe.be.getResultSetMetadata(opHandle)
val tRowSet = fe.be.fetchResults(opHandle, FetchOrientation.FETCH_NEXT, 1000, false)
val status = fe.be.getOperationStatus(opHandle)
val uri = new URI("sfdsfsdfdsf")
val results = TrinoContext
.createQueryResults("/xdfd/xdf", uri, uri, status, Option(metadataResp), Option(tRowSet))
print(results.toString)
assert(results.getColumns.get(0).getType.equals("integer"))
assert(results.getData.asScala.last.get(0) == 1)
}
test("test convert from table") {
initSql("CREATE DATABASE IF NOT EXISTS INIT_DB")
initSql(
"CREATE TABLE IF NOT EXISTS INIT_DB.test(a int, b double, c String," +
"d BOOLEAN,e DATE,f TIMESTAMP,g ARRAY<String>,h DECIMAL," +
"i MAP<String,String>) USING PARQUET;")
initSql(
"INSERT INTO INIT_DB.test VALUES (1,2.2,'3',true,current_date()," +
"current_timestamp(),array('1','2'),2.0, map('m','p') )")
val opHandle = getOpHandle("SELECT * FROM INIT_DB.test")
val opHandleStr = opHandle.identifier.toString
checkOpState(opHandleStr, FINISHED)
val metadataResp = fe.be.getResultSetMetadata(opHandle)
val tRowSet = fe.be.fetchResults(opHandle, FetchOrientation.FETCH_NEXT, 1000, false)
val status = fe.be.getOperationStatus(opHandle)
val uri = new URI("sfdsfsdfdsf")
val results = TrinoContext
.createQueryResults("/xdfd/xdf", uri, uri, status, Option(metadataResp), Option(tRowSet))
print(results.toString)
assert(results.getColumns.get(0).getType.equals("integer"))
assert(results.getData.asScala.last.get(0) != null)
}
def getOpHandleStr(statement: String = "show tables"): String = {
getOpHandle(statement).identifier.toString
}
def getOpHandle(statement: String = "show tables"): OperationHandle = {
val sessionHandle = fe.be.openSession(
HIVE_CLI_SERVICE_PROTOCOL_V9,
"admin",
"123456",
"localhost",
Map("testConfig" -> "testValue"))
if (statement.nonEmpty) {
fe.be.executeStatement(sessionHandle, statement, Map.empty, runAsync = false, 30000)
} else {
fe.be.getCatalogs(sessionHandle)
}
}
private def checkOpState(opHandleStr: String, state: OperationState): Unit = {
eventually(Timeout(30.seconds)) {
val response = webTarget.path(s"api/v1/operations/$opHandleStr/event")
.request(MediaType.APPLICATION_JSON_TYPE).get()
assert(response.getStatus === 200)
val operationEvent = response.readEntity(classOf[KyuubiOperationEvent])
assert(operationEvent.state === state.name())
}
}
private def initSql(sql: String): Unit = {
val initOpHandle = getOpHandle(sql)
val initOpHandleStr = initOpHandle.identifier.toString
checkOpState(initOpHandleStr, FINISHED)
}
}