[KYUUBI #3934] Compatiable with Trino rest dto
### _Why are the changes needed?_ close #3934 ### _How was this patch tested?_ - [x] Add some test cases that check the changes thoroughly including negative and positive cases if possible - [ ] Add screenshots for manual tests if appropriate - [x] [Run test](https://kyuubi.readthedocs.io/en/master/develop_tools/testing.html#running-tests) locally before make a pull request Closes #4182 from yehere/kyuubi-3934. Closes #3934 ced64e2c [yehere] [KYUUBI #3934] Add more result types support 3ef85230 [yehere] [KYUUBI #3934] Optimization for code review 69e1f442 [yehere] [KYUUBI #3934] Merge the test class to TrinoContextSuite 4f0a0152 [yehere] [KYUUBI #3934] Merge the class to TrinoContext 7c9473f6 [yehere] [KYUUBI #3934] Format style, with Copyright Profiles 2023f3ce [yehere] [KYUUBI #3934] Format and add test case a2243b46 [yehere] [KYUUBI #3934] Compatiable with Trino rest dto Authored-by: yehere <867171931@qq.com> Signed-off-by: ulyssesyou <ulyssesyou@apache.org>
This commit is contained in:
parent
f62a7ac587
commit
1b92d80678
@ -18,34 +18,38 @@
|
||||
package org.apache.kyuubi.server.trino.api
|
||||
|
||||
import java.io.UnsupportedEncodingException
|
||||
import java.net.{URLDecoder, URLEncoder}
|
||||
import java.net.{URI, URLDecoder, URLEncoder}
|
||||
import java.util
|
||||
import javax.ws.rs.core.{HttpHeaders, Response}
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import io.trino.client.{ClientStandardTypes, ClientTypeSignature, Column, QueryError, QueryResults, StatementStats, Warning}
|
||||
import io.trino.client.ProtocolHeaders.TRINO_HEADERS
|
||||
import io.trino.client.QueryResults
|
||||
import org.apache.hive.service.rpc.thrift.{TGetResultSetMetadataResp, TRowSet, TTypeId}
|
||||
|
||||
import org.apache.kyuubi.operation.OperationStatus
|
||||
|
||||
/**
|
||||
* The description and functionality of trino request
|
||||
* and response's context
|
||||
*
|
||||
* @param user Specifies the session user, must be supplied with every query
|
||||
* @param timeZone The timezone for query processing
|
||||
* @param user Specifies the session user, must be supplied with every query
|
||||
* @param timeZone The timezone for query processing
|
||||
* @param clientCapabilities Exclusive for trino server
|
||||
* @param source This supplies the name of the software that submitted the query,
|
||||
* e.g. `trino-jdbc` or `trino-cli` by default
|
||||
* @param catalog The catalog context for query processing, will be set response
|
||||
* @param schema The schema context for query processing
|
||||
* @param language The language to use when processing the query and formatting results,
|
||||
* formatted as a Java Locale string, e.g., en-US for US English
|
||||
* @param traceToken Trace token for correlating requests across systems
|
||||
* @param clientInfo Extra information about the client
|
||||
* @param clientTags Client tags for selecting resource groups. Example: abc,xyz
|
||||
* @param preparedStatement `preparedStatement` are kv pairs, where the names
|
||||
* are names of previously prepared SQL statements,
|
||||
* and the values are keys that identify the
|
||||
* executable form of the named prepared statements
|
||||
* @param source This supplies the name of the software that submitted the query,
|
||||
* e.g. `trino-jdbc` or `trino-cli` by default
|
||||
* @param catalog The catalog context for query processing, will be set response
|
||||
* @param schema The schema context for query processing
|
||||
* @param language The language to use when processing the query and formatting results,
|
||||
* formatted as a Java Locale string, e.g., en-US for US English
|
||||
* @param traceToken Trace token for correlating requests across systems
|
||||
* @param clientInfo Extra information about the client
|
||||
* @param clientTags Client tags for selecting resource groups. Example: abc,xyz
|
||||
* @param preparedStatement `preparedStatement` are kv pairs, where the names
|
||||
* are names of previously prepared SQL statements,
|
||||
* and the values are keys that identify the
|
||||
* executable form of the named prepared statements
|
||||
*/
|
||||
case class TrinoContext(
|
||||
user: String,
|
||||
@ -63,6 +67,11 @@ case class TrinoContext(
|
||||
|
||||
object TrinoContext {
|
||||
|
||||
private val defaultWarning: util.List[Warning] = new util.ArrayList[Warning]()
|
||||
private val GENERIC_INTERNAL_ERROR_CODE = 65536
|
||||
private val GENERIC_INTERNAL_ERROR_NAME = "GENERIC_INTERNAL_ERROR_NAME"
|
||||
private val GENERIC_INTERNAL_ERROR_TYPE = "INTERNAL_ERROR"
|
||||
|
||||
def apply(headers: HttpHeaders): TrinoContext = {
|
||||
apply(headers.getRequestHeaders.asScala.toMap.map {
|
||||
case (k, v) => (k, v.asScala.toList)
|
||||
@ -166,4 +175,196 @@ object TrinoContext {
|
||||
throw new AssertionError(e)
|
||||
}
|
||||
|
||||
def createQueryResults(
|
||||
queryId: String,
|
||||
nextUri: URI,
|
||||
queryHtmlUri: URI,
|
||||
queryStatus: OperationStatus,
|
||||
columns: Option[TGetResultSetMetadataResp] = None,
|
||||
data: Option[TRowSet] = None): QueryResults = {
|
||||
|
||||
val columnList = columns match {
|
||||
case Some(value) => convertTColumn(value)
|
||||
case None => null
|
||||
}
|
||||
val rowList = data match {
|
||||
case Some(value) => convertTRowSet(value)
|
||||
case None => null
|
||||
}
|
||||
|
||||
new QueryResults(
|
||||
queryId,
|
||||
queryHtmlUri,
|
||||
nextUri,
|
||||
nextUri,
|
||||
columnList,
|
||||
rowList,
|
||||
StatementStats.builder.setState(queryStatus.state.name()).setQueued(false)
|
||||
.setElapsedTimeMillis(0).setQueuedTimeMillis(0).build(),
|
||||
toQueryError(queryStatus),
|
||||
defaultWarning,
|
||||
null,
|
||||
0L)
|
||||
}
|
||||
|
||||
def convertTColumn(columns: TGetResultSetMetadataResp): util.List[Column] = {
|
||||
columns.getSchema.getColumns.asScala.map(c => {
|
||||
val tp = c.getTypeDesc.getTypes.get(0).getPrimitiveEntry.getType match {
|
||||
case TTypeId.BOOLEAN_TYPE => ClientStandardTypes.BOOLEAN
|
||||
case TTypeId.TINYINT_TYPE => ClientStandardTypes.TINYINT
|
||||
case TTypeId.SMALLINT_TYPE => ClientStandardTypes.SMALLINT
|
||||
case TTypeId.INT_TYPE => ClientStandardTypes.INTEGER
|
||||
case TTypeId.BIGINT_TYPE => ClientStandardTypes.BIGINT
|
||||
case TTypeId.FLOAT_TYPE => ClientStandardTypes.DOUBLE
|
||||
case TTypeId.DOUBLE_TYPE => ClientStandardTypes.DOUBLE
|
||||
case TTypeId.STRING_TYPE => ClientStandardTypes.VARCHAR
|
||||
case TTypeId.TIMESTAMP_TYPE => ClientStandardTypes.TIMESTAMP
|
||||
case TTypeId.BINARY_TYPE => ClientStandardTypes.VARBINARY
|
||||
case TTypeId.DECIMAL_TYPE => ClientStandardTypes.DECIMAL
|
||||
case TTypeId.DATE_TYPE => ClientStandardTypes.DATE
|
||||
case TTypeId.VARCHAR_TYPE => ClientStandardTypes.VARCHAR
|
||||
case TTypeId.CHAR_TYPE => ClientStandardTypes.CHAR
|
||||
case TTypeId.INTERVAL_YEAR_MONTH_TYPE => ClientStandardTypes.INTERVAL_YEAR_TO_MONTH
|
||||
case TTypeId.INTERVAL_DAY_TIME_TYPE => ClientStandardTypes.TIME_WITH_TIME_ZONE
|
||||
case TTypeId.TIMESTAMPLOCALTZ_TYPE => ClientStandardTypes.TIMESTAMP_WITH_TIME_ZONE
|
||||
case _ => ClientStandardTypes.VARCHAR
|
||||
}
|
||||
new Column(c.getColumnName, tp, new ClientTypeSignature(tp))
|
||||
}).toList.asJava
|
||||
}
|
||||
|
||||
def convertTRowSet(rowSet: TRowSet): util.List[util.List[Object]] = {
|
||||
val dataResult = new util.LinkedList[util.List[Object]]
|
||||
|
||||
if (rowSet.getColumns == null) {
|
||||
return rowSet.getRows.asScala
|
||||
.map(t => t.getColVals.asScala.map(v => v.getFieldValue.asInstanceOf[Object]).asJava)
|
||||
.asJava
|
||||
}
|
||||
|
||||
rowSet.getColumns.asScala.foreach {
|
||||
case tColumn if tColumn.isSetBoolVal =>
|
||||
val nulls = util.BitSet.valueOf(tColumn.getBoolVal.getNulls)
|
||||
if (dataResult.isEmpty) {
|
||||
(1 to tColumn.getBoolVal.getValuesSize).foreach(_ =>
|
||||
dataResult.add(new util.LinkedList[Object]()))
|
||||
}
|
||||
|
||||
tColumn.getBoolVal.getValues.asScala.zipWithIndex.foreach {
|
||||
case (_, rowIdx) if nulls.get(rowIdx) =>
|
||||
dataResult.get(rowIdx).add(null)
|
||||
case (v, rowIdx) =>
|
||||
dataResult.get(rowIdx).add(v)
|
||||
}
|
||||
case tColumn if tColumn.isSetByteVal =>
|
||||
val nulls = util.BitSet.valueOf(tColumn.getByteVal.getNulls)
|
||||
if (dataResult.isEmpty) {
|
||||
(1 to tColumn.getByteVal.getValuesSize).foreach(_ =>
|
||||
dataResult.add(new util.LinkedList[Object]()))
|
||||
}
|
||||
|
||||
tColumn.getByteVal.getValues.asScala.zipWithIndex.foreach {
|
||||
case (_, rowIdx) if nulls.get(rowIdx) =>
|
||||
dataResult.get(rowIdx).add(null)
|
||||
case (v, rowIdx) =>
|
||||
dataResult.get(rowIdx).add(v)
|
||||
}
|
||||
case tColumn if tColumn.isSetI16Val =>
|
||||
val nulls = util.BitSet.valueOf(tColumn.getI16Val.getNulls)
|
||||
if (dataResult.isEmpty) {
|
||||
(1 to tColumn.getI16Val.getValuesSize).foreach(_ =>
|
||||
dataResult.add(new util.LinkedList[Object]()))
|
||||
}
|
||||
|
||||
tColumn.getI16Val.getValues.asScala.zipWithIndex.foreach {
|
||||
case (_, rowIdx) if nulls.get(rowIdx) =>
|
||||
dataResult.get(rowIdx).add(null)
|
||||
case (v, rowIdx) =>
|
||||
dataResult.get(rowIdx).add(v)
|
||||
}
|
||||
case tColumn if tColumn.isSetI32Val =>
|
||||
val nulls = util.BitSet.valueOf(tColumn.getI32Val.getNulls)
|
||||
if (dataResult.isEmpty) {
|
||||
(1 to tColumn.getI32Val.getValuesSize).foreach(_ =>
|
||||
dataResult.add(new util.LinkedList[Object]()))
|
||||
}
|
||||
|
||||
tColumn.getI32Val.getValues.asScala.zipWithIndex.foreach {
|
||||
case (_, rowIdx) if nulls.get(rowIdx) =>
|
||||
dataResult.get(rowIdx).add(null)
|
||||
case (v, rowIdx) =>
|
||||
dataResult.get(rowIdx).add(v)
|
||||
}
|
||||
case tColumn if tColumn.isSetI64Val =>
|
||||
val nulls = util.BitSet.valueOf(tColumn.getI64Val.getNulls)
|
||||
if (dataResult.isEmpty) {
|
||||
(1 to tColumn.getI64Val.getValuesSize).foreach(_ =>
|
||||
dataResult.add(new util.LinkedList[Object]()))
|
||||
}
|
||||
|
||||
tColumn.getI64Val.getValues.asScala.zipWithIndex.foreach {
|
||||
case (_, rowIdx) if nulls.get(rowIdx) =>
|
||||
dataResult.get(rowIdx).add(null)
|
||||
case (v, rowIdx) =>
|
||||
dataResult.get(rowIdx).add(v)
|
||||
}
|
||||
case tColumn if tColumn.isSetDoubleVal =>
|
||||
val nulls = util.BitSet.valueOf(tColumn.getDoubleVal.getNulls)
|
||||
if (dataResult.isEmpty) {
|
||||
(1 to tColumn.getDoubleVal.getValuesSize).foreach(_ =>
|
||||
dataResult.add(new util.LinkedList[Object]()))
|
||||
}
|
||||
|
||||
tColumn.getDoubleVal.getValues.asScala.zipWithIndex.foreach {
|
||||
case (_, rowIdx) if nulls.get(rowIdx) =>
|
||||
dataResult.get(rowIdx).add(null)
|
||||
case (v, rowIdx) =>
|
||||
dataResult.get(rowIdx).add(v)
|
||||
}
|
||||
case tColumn if tColumn.isSetBinaryVal =>
|
||||
val nulls = util.BitSet.valueOf(tColumn.getBinaryVal.getNulls)
|
||||
if (dataResult.isEmpty) {
|
||||
(1 to tColumn.getBinaryVal.getValuesSize).foreach(_ =>
|
||||
dataResult.add(new util.LinkedList[Object]()))
|
||||
}
|
||||
|
||||
tColumn.getBinaryVal.getValues.asScala.zipWithIndex.foreach {
|
||||
case (_, rowIdx) if nulls.get(rowIdx) =>
|
||||
dataResult.get(rowIdx).add(null)
|
||||
case (v, rowIdx) =>
|
||||
dataResult.get(rowIdx).add(v)
|
||||
}
|
||||
case tColumn =>
|
||||
val nulls = util.BitSet.valueOf(tColumn.getStringVal.getNulls)
|
||||
if (dataResult.isEmpty) {
|
||||
(1 to tColumn.getStringVal.getValuesSize).foreach(_ =>
|
||||
dataResult.add(new util.LinkedList[Object]()))
|
||||
}
|
||||
|
||||
tColumn.getStringVal.getValues.asScala.zipWithIndex.foreach {
|
||||
case (_, rowIdx) if nulls.get(rowIdx) =>
|
||||
dataResult.get(rowIdx).add(null)
|
||||
case (v, rowIdx) =>
|
||||
dataResult.get(rowIdx).add(v)
|
||||
}
|
||||
}
|
||||
dataResult
|
||||
}
|
||||
|
||||
def toQueryError(queryStatus: OperationStatus): QueryError = {
|
||||
val exception = queryStatus.exception
|
||||
if (exception.isEmpty) {
|
||||
null
|
||||
} else {
|
||||
new QueryError(
|
||||
exception.get.getMessage,
|
||||
queryStatus.state.name(),
|
||||
GENERIC_INTERNAL_ERROR_CODE,
|
||||
GENERIC_INTERNAL_ERROR_NAME,
|
||||
GENERIC_INTERNAL_ERROR_TYPE,
|
||||
null,
|
||||
null)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@ -17,13 +17,24 @@
|
||||
|
||||
package org.apache.kyuubi.server.trino.api
|
||||
|
||||
import java.net.URI
|
||||
import java.time.ZoneId
|
||||
import javax.ws.rs.core.MediaType
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import io.trino.client.ProtocolHeaders.TRINO_HEADERS
|
||||
import org.apache.hive.service.rpc.thrift.TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V9
|
||||
import org.scalatest.concurrent.PatienceConfiguration.Timeout
|
||||
import org.scalatest.time.SpanSugar.convertIntToGrainOfTime
|
||||
|
||||
import org.apache.kyuubi.KyuubiFunSuite
|
||||
import org.apache.kyuubi.{KyuubiFunSuite, RestFrontendTestHelper}
|
||||
import org.apache.kyuubi.events.KyuubiOperationEvent
|
||||
import org.apache.kyuubi.operation.{FetchOrientation, OperationHandle}
|
||||
import org.apache.kyuubi.operation.OperationState.{FINISHED, OperationState}
|
||||
|
||||
class TrinoContextSuite extends KyuubiFunSuite with RestFrontendTestHelper {
|
||||
|
||||
class TrinoContextSuite extends KyuubiFunSuite {
|
||||
import TrinoContext._
|
||||
|
||||
test("create trino request context with header") {
|
||||
@ -67,4 +78,83 @@ class TrinoContextSuite extends KyuubiFunSuite {
|
||||
assert(actual == expectedTrinoContext)
|
||||
}
|
||||
|
||||
test("test convert") {
|
||||
val opHandle = getOpHandle("select 1")
|
||||
val opHandleStr = opHandle.identifier.toString
|
||||
checkOpState(opHandleStr, FINISHED)
|
||||
|
||||
val metadataResp = fe.be.getResultSetMetadata(opHandle)
|
||||
val tRowSet = fe.be.fetchResults(opHandle, FetchOrientation.FETCH_NEXT, 1000, false)
|
||||
val status = fe.be.getOperationStatus(opHandle)
|
||||
|
||||
val uri = new URI("sfdsfsdfdsf")
|
||||
val results = TrinoContext
|
||||
.createQueryResults("/xdfd/xdf", uri, uri, status, Option(metadataResp), Option(tRowSet))
|
||||
|
||||
print(results.toString)
|
||||
assert(results.getColumns.get(0).getType.equals("integer"))
|
||||
assert(results.getData.asScala.last.get(0) == 1)
|
||||
}
|
||||
|
||||
test("test convert from table") {
|
||||
initSql("CREATE DATABASE IF NOT EXISTS INIT_DB")
|
||||
initSql(
|
||||
"CREATE TABLE IF NOT EXISTS INIT_DB.test(a int, b double, c String," +
|
||||
"d BOOLEAN,e DATE,f TIMESTAMP,g ARRAY<String>,h DECIMAL," +
|
||||
"i MAP<String,String>) USING PARQUET;")
|
||||
initSql(
|
||||
"INSERT INTO INIT_DB.test VALUES (1,2.2,'3',true,current_date()," +
|
||||
"current_timestamp(),array('1','2'),2.0, map('m','p') )")
|
||||
|
||||
val opHandle = getOpHandle("SELECT * FROM INIT_DB.test")
|
||||
val opHandleStr = opHandle.identifier.toString
|
||||
checkOpState(opHandleStr, FINISHED)
|
||||
|
||||
val metadataResp = fe.be.getResultSetMetadata(opHandle)
|
||||
val tRowSet = fe.be.fetchResults(opHandle, FetchOrientation.FETCH_NEXT, 1000, false)
|
||||
val status = fe.be.getOperationStatus(opHandle)
|
||||
|
||||
val uri = new URI("sfdsfsdfdsf")
|
||||
val results = TrinoContext
|
||||
.createQueryResults("/xdfd/xdf", uri, uri, status, Option(metadataResp), Option(tRowSet))
|
||||
|
||||
print(results.toString)
|
||||
assert(results.getColumns.get(0).getType.equals("integer"))
|
||||
assert(results.getData.asScala.last.get(0) != null)
|
||||
}
|
||||
|
||||
def getOpHandleStr(statement: String = "show tables"): String = {
|
||||
getOpHandle(statement).identifier.toString
|
||||
}
|
||||
|
||||
def getOpHandle(statement: String = "show tables"): OperationHandle = {
|
||||
val sessionHandle = fe.be.openSession(
|
||||
HIVE_CLI_SERVICE_PROTOCOL_V9,
|
||||
"admin",
|
||||
"123456",
|
||||
"localhost",
|
||||
Map("testConfig" -> "testValue"))
|
||||
|
||||
if (statement.nonEmpty) {
|
||||
fe.be.executeStatement(sessionHandle, statement, Map.empty, runAsync = false, 30000)
|
||||
} else {
|
||||
fe.be.getCatalogs(sessionHandle)
|
||||
}
|
||||
}
|
||||
|
||||
private def checkOpState(opHandleStr: String, state: OperationState): Unit = {
|
||||
eventually(Timeout(30.seconds)) {
|
||||
val response = webTarget.path(s"api/v1/operations/$opHandleStr/event")
|
||||
.request(MediaType.APPLICATION_JSON_TYPE).get()
|
||||
assert(response.getStatus === 200)
|
||||
val operationEvent = response.readEntity(classOf[KyuubiOperationEvent])
|
||||
assert(operationEvent.state === state.name())
|
||||
}
|
||||
}
|
||||
|
||||
private def initSql(sql: String): Unit = {
|
||||
val initOpHandle = getOpHandle(sql)
|
||||
val initOpHandleStr = initOpHandle.identifier.toString
|
||||
checkOpState(initOpHandleStr, FINISHED)
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user