[KYUUBI #3935] Support use Trino client to submit SQL

### _Why are the changes needed?_

Close #3935

### _How was this patch tested?_
- [x] Add some test cases that check the changes thoroughly including negative and positive cases if possible

- [ ] Add screenshots for manual tests if appropriate

- [x] [Run test](https://kyuubi.readthedocs.io/en/master/develop_tools/testing.html#running-tests) locally before make a pull request

Closes #4232 from iodone/kyuubi-3935.

Closes #3935

936ea1f8 [odone] address
e7bd01a1 [odone] support trino client connect kyuubi trino server
9ea8b6af [odone] [WIP] trion request/response implementation

Authored-by: odone <odone.zhang@gmail.com>
Signed-off-by: ulyssesyou <ulyssesyou@apache.org>
This commit is contained in:
odone 2023-02-13 19:28:14 +08:00 committed by ulyssesyou
parent 3b0137ae78
commit 41f08059f0
18 changed files with 661 additions and 134 deletions

View File

@ -69,7 +69,7 @@ jackson-annotations/2.14.2//jackson-annotations-2.14.2.jar
jackson-core/2.14.2//jackson-core-2.14.2.jar
jackson-databind/2.14.2//jackson-databind-2.14.2.jar
jackson-dataformat-yaml/2.14.2//jackson-dataformat-yaml-2.14.2.jar
jackson-datatype-jdk8/2.12.3//jackson-datatype-jdk8-2.12.3.jar
jackson-datatype-jdk8/2.14.2//jackson-datatype-jdk8-2.14.2.jar
jackson-datatype-jsr310/2.14.2//jackson-datatype-jsr310-2.14.2.jar
jackson-jaxrs-base/2.14.2//jackson-jaxrs-base-2.14.2.jar
jackson-jaxrs-json-provider/2.14.2//jackson-jaxrs-json-provider-2.14.2.jar

View File

@ -156,11 +156,14 @@ abstract class AbstractBackendService(name: String)
queryId
}
override def getOperationStatus(operationHandle: OperationHandle): OperationStatus = {
override def getOperationStatus(
operationHandle: OperationHandle,
maxWait: Option[Long]): OperationStatus = {
val operation = sessionManager.operationManager.getOperation(operationHandle)
if (operation.shouldRunAsync) {
try {
operation.getBackgroundHandle.get(timeout, TimeUnit.MILLISECONDS)
val waitTime = maxWait.getOrElse(timeout)
operation.getBackgroundHandle.get(waitTime, TimeUnit.MILLISECONDS)
} catch {
case e: TimeoutException =>
debug(s"$operationHandle: Long polling timed out, ${e.getMessage}")

View File

@ -91,7 +91,9 @@ trait BackendService {
foreignTable: String): OperationHandle
def getQueryId(operationHandle: OperationHandle): String
def getOperationStatus(operationHandle: OperationHandle): OperationStatus
def getOperationStatus(
operationHandle: OperationHandle,
maxWait: Option[Long] = None): OperationStatus
def cancelOperation(operationHandle: OperationHandle): Unit
def closeOperation(operationHandle: OperationHandle): Unit
def getResultSetMetadata(operationHandle: OperationHandle): TGetResultSetMetadataResp

View File

@ -221,6 +221,16 @@
<artifactId>jersey-media-multipart</artifactId>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.datatype</groupId>
<artifactId>jackson-datatype-jdk8</artifactId>
</dependency>
<dependency>
<groupId>com.zaxxer</groupId>
<artifactId>HikariCP</artifactId>

View File

@ -152,9 +152,11 @@ trait BackendServiceMetric extends BackendService {
}
}
abstract override def getOperationStatus(operationHandle: OperationHandle): OperationStatus = {
abstract override def getOperationStatus(
operationHandle: OperationHandle,
maxWait: Option[Long] = None): OperationStatus = {
MetricsSystem.timerTracing(MetricsConstants.BS_GET_OPERATION_STATUS) {
super.getOperationStatus(operationHandle)
super.getOperationStatus(operationHandle, maxWait)
}
}

View File

@ -19,10 +19,9 @@ package org.apache.kyuubi.server.trino.api
import scala.collection.JavaConverters._
import org.apache.hive.service.rpc.thrift.TProtocolVersion
import org.apache.kyuubi.operation.OperationHandle
import org.apache.kyuubi.service.BackendService
import org.apache.kyuubi.session.SessionHandle
import org.apache.kyuubi.sql.parser.trino.KyuubiTrinoFeParser
import org.apache.kyuubi.sql.plan.PassThroughNode
import org.apache.kyuubi.sql.plan.trino.{GetCatalogs, GetColumns, GetSchemas, GetTables, GetTableTypes, GetTypeInfo}
@ -32,17 +31,10 @@ class KyuubiTrinoOperationTranslator(backendService: BackendService) {
def transform(
statement: String,
user: String,
ipAddress: String,
sessionHandle: SessionHandle,
configs: Map[String, String],
runAsync: Boolean,
queryTimeout: Long): OperationHandle = {
val sessionHandle = backendService.openSession(
TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V11,
user,
"",
ipAddress,
configs)
parser.parsePlan(statement) match {
case GetSchemas(catalogName, schemaPattern) =>
backendService.getSchemas(sessionHandle, catalogName, schemaPattern)

View File

@ -0,0 +1,206 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kyuubi.server.trino.api
import java.net.URI
import java.security.SecureRandom
import java.util.Objects.requireNonNull
import java.util.UUID
import java.util.concurrent.atomic.AtomicLong
import javax.ws.rs.WebApplicationException
import javax.ws.rs.core.{Response, UriInfo}
import Slug.Context.{EXECUTING_QUERY, QUEUED_QUERY}
import com.google.common.hash.Hashing
import io.trino.client.QueryResults
import org.apache.hive.service.rpc.thrift.TProtocolVersion
import org.apache.kyuubi.operation.{FetchOrientation, OperationHandle}
import org.apache.kyuubi.operation.OperationState.{FINISHED, INITIALIZED, OperationState, PENDING}
import org.apache.kyuubi.service.BackendService
import org.apache.kyuubi.session.SessionHandle
case class Query(
queryId: QueryId,
context: TrinoContext,
be: BackendService) {
private val QUEUED_QUERY_PATH = "/v1/statement/queued/"
private val EXECUTING_QUERY_PATH = "/v1/statement/executing"
private val slug: Slug = Slug.createNewWithUUID(queryId.getQueryId)
private val lastToken = new AtomicLong
private val defaultMaxRows = 1000
private val defaultFetchOrientation = FetchOrientation.withName("FETCH_NEXT")
def getQueryResults(token: Long, uriInfo: UriInfo, maxWait: Long = 0): QueryResults = {
val status =
be.getOperationStatus(queryId.operationHandle, Some(maxWait))
val nextUri = if (status.exception.isEmpty) {
getNextUri(token + 1, uriInfo, toSlugContext(status.state))
} else null
val queryHtmlUri = uriInfo.getRequestUriBuilder
.replacePath("ui/query.html").replaceQuery(queryId.getQueryId).build()
status.state match {
case FINISHED =>
val metaData = be.getResultSetMetadata(queryId.operationHandle)
val resultSet = be.fetchResults(
queryId.operationHandle,
defaultFetchOrientation,
defaultMaxRows,
false)
TrinoContext.createQueryResults(
queryId.getQueryId,
nextUri,
queryHtmlUri,
status,
Option(metaData),
Option(resultSet))
case _ =>
TrinoContext.createQueryResults(
queryId.getQueryId,
nextUri,
queryHtmlUri,
status)
}
}
def getLastToken: Long = this.lastToken.get()
def getSlug: Slug = this.slug
def cancel: Unit = clear
private def clear = {
be.closeOperation(queryId.operationHandle)
context.session.get("sessionId").foreach { id =>
be.closeSession(SessionHandle.fromUUID(id))
}
}
private def setToken(token: Long): Unit = {
val lastToken = this.lastToken.get
if (token != lastToken && token != lastToken + 1) {
throw new WebApplicationException(Response.Status.GONE)
}
this.lastToken.compareAndSet(lastToken, token)
}
private def getNextUri(token: Long, uriInfo: UriInfo, slugContext: Slug.Context.Context): URI = {
val path = slugContext match {
case QUEUED_QUERY => QUEUED_QUERY_PATH
case EXECUTING_QUERY => EXECUTING_QUERY_PATH
}
uriInfo.getBaseUriBuilder.replacePath(path)
.path(queryId.getQueryId)
.path(slug.makeSlug(slugContext, token))
.path(String.valueOf(token))
.replaceQuery("")
.build()
}
private def toSlugContext(state: OperationState): Slug.Context.Context = {
state match {
case INITIALIZED | PENDING => Slug.Context.QUEUED_QUERY
case _ => Slug.Context.EXECUTING_QUERY
}
}
}
object Query {
def apply(
statement: String,
context: TrinoContext,
translator: KyuubiTrinoOperationTranslator,
backendService: BackendService,
queryTimeout: Long = 0): Query = {
val sessionHandle = createSession(context, backendService)
val operationHandle = translator.transform(
statement,
sessionHandle,
context.session,
true,
queryTimeout)
val newSessionProperties =
context.session + ("sessionId" -> sessionHandle.identifier.toString)
val updatedContext = context.copy(session = newSessionProperties)
Query(QueryId(operationHandle), updatedContext, backendService)
}
def apply(id: String, context: TrinoContext, backendService: BackendService): Query = {
Query(QueryId(id), context, backendService)
}
private def createSession(
context: TrinoContext,
backendService: BackendService): SessionHandle = {
backendService.openSession(
TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V11,
context.user,
"",
context.remoteUserAddress.getOrElse(""),
context.session)
}
}
case class QueryId(operationHandle: OperationHandle) {
def getQueryId: String = operationHandle.identifier.toString
}
object QueryId {
def apply(id: String): QueryId = QueryId(OperationHandle(id))
}
object Slug {
object Context extends Enumeration {
type Context = Value
val QUEUED_QUERY, EXECUTING_QUERY = Value
}
private val RANDOM = new SecureRandom
def createNew: Slug = {
val randomBytes = new Array[Byte](16)
RANDOM.nextBytes(randomBytes)
new Slug(randomBytes)
}
def createNewWithUUID(uuid: String): Slug = {
val uuidBytes = UUID.fromString(uuid).toString.getBytes("UTF-8")
new Slug(uuidBytes)
}
}
case class Slug(slugKey: Array[Byte]) {
val hmac = Hashing.hmacSha1(requireNonNull(slugKey, "slugKey is null"))
def makeSlug(context: Slug.Context.Context, token: Long): String = {
"y" + hmac.newHasher.putInt(context.id).putLong(token).hash.toString
}
def isValid(context: Slug.Context.Context, slug: String, token: Long): Boolean =
makeSlug(context, token) == slug
}

View File

@ -28,6 +28,7 @@ import io.trino.client.{ClientStandardTypes, ClientTypeSignature, Column, QueryE
import io.trino.client.ProtocolHeaders.TRINO_HEADERS
import org.apache.hive.service.rpc.thrift.{TGetResultSetMetadataResp, TRowSet, TTypeId}
import org.apache.kyuubi.operation.OperationState.FINISHED
import org.apache.kyuubi.operation.OperationStatus
/**
@ -58,6 +59,7 @@ case class TrinoContext(
source: Option[String] = None,
catalog: Option[String] = None,
schema: Option[String] = None,
remoteUserAddress: Option[String] = None,
language: Option[String] = None,
traceToken: Option[String] = None,
clientInfo: Option[String] = None,
@ -72,10 +74,11 @@ object TrinoContext {
private val GENERIC_INTERNAL_ERROR_NAME = "GENERIC_INTERNAL_ERROR_NAME"
private val GENERIC_INTERNAL_ERROR_TYPE = "INTERNAL_ERROR"
def apply(headers: HttpHeaders): TrinoContext = {
apply(headers.getRequestHeaders.asScala.toMap.map {
def apply(headers: HttpHeaders, remoteAddress: Option[String]): TrinoContext = {
val context = apply(headers.getRequestHeaders.asScala.toMap.map {
case (k, v) => (k, v.asScala.toList)
})
context.copy(remoteUserAddress = remoteAddress)
}
def apply(headers: Map[String, List[String]]): TrinoContext = {
@ -134,7 +137,6 @@ object TrinoContext {
}
}
// TODO: Building response with TrinoContext and other information
def buildTrinoResponse(qr: QueryResults, trinoContext: TrinoContext): Response = {
val responseBuilder = Response.ok(qr)
@ -156,8 +158,6 @@ object TrinoContext {
responseBuilder.header(TRINO_HEADERS.responseDeallocatedPrepare, urlEncode(v))
}
responseBuilder.header(TRINO_HEADERS.responseClearSession, s"responseClearSession")
responseBuilder.header(TRINO_HEADERS.responseClearTransactionId, "false")
responseBuilder.build()
}
@ -192,11 +192,16 @@ object TrinoContext {
case None => null
}
val updatedNextUri = queryStatus.state match {
case FINISHED if rowList == null || rowList.isEmpty || rowList.get(0).isEmpty => null
case _ => nextUri
}
new QueryResults(
queryId,
queryHtmlUri,
nextUri,
nextUri,
updatedNextUri,
columnList,
rowList,
StatementStats.builder.setState(queryStatus.state.name()).setQueued(false)

View File

@ -19,11 +19,14 @@ package org.apache.kyuubi.server.trino.api
import javax.ws.rs.ext.ContextResolver
import com.fasterxml.jackson.databind.ObjectMapper
import com.fasterxml.jackson.module.scala.DefaultScalaModule
import com.fasterxml.jackson.databind.{DeserializationFeature, ObjectMapper}
import com.fasterxml.jackson.datatype.jdk8.Jdk8Module
class KyuubiScalaObjectMapper extends ContextResolver[ObjectMapper] {
private val mapper = new ObjectMapper().registerModule(DefaultScalaModule)
class TrinoScalaObjectMapper extends ContextResolver[ObjectMapper] {
private lazy val mapper = new ObjectMapper()
.disable(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES)
.registerModule(new Jdk8Module)
override def getContext(aClass: Class[_]): ObjectMapper = mapper
}

View File

@ -21,6 +21,6 @@ import org.glassfish.jersey.server.ResourceConfig
class TrinoServerConfig extends ResourceConfig {
packages("org.apache.kyuubi.server.trino.api.v1")
register(classOf[KyuubiScalaObjectMapper])
register(classOf[TrinoScalaObjectMapper])
register(classOf[RestExceptionMapper])
}

View File

@ -18,16 +18,24 @@
package org.apache.kyuubi.server.trino.api.v1
import javax.ws.rs._
import javax.ws.rs.core.{Context, HttpHeaders, MediaType}
import javax.ws.rs.core.{Context, HttpHeaders, MediaType, Response, UriInfo}
import javax.ws.rs.core.MediaType.TEXT_PLAIN_TYPE
import javax.ws.rs.core.Response.Status.{BAD_REQUEST, NOT_FOUND}
import scala.util.Try
import scala.util.control.NonFatal
import io.airlift.units.Duration
import io.swagger.v3.oas.annotations.media.{Content, Schema}
import io.swagger.v3.oas.annotations.responses.ApiResponse
import io.swagger.v3.oas.annotations.tags.Tag
import io.trino.client.QueryResults
import org.apache.kyuubi.Logging
import org.apache.kyuubi.server.trino.api.{ApiRequestContext, KyuubiTrinoOperationTranslator}
import org.apache.kyuubi.server.trino.api.{ApiRequestContext, KyuubiTrinoOperationTranslator, Query, QueryId, Slug, TrinoContext}
import org.apache.kyuubi.server.trino.api.Slug.Context.{EXECUTING_QUERY, QUEUED_QUERY}
import org.apache.kyuubi.server.trino.api.v1.dto.Ok
import org.apache.kyuubi.service.BackendService
@Tag(name = "Statement")
@Produces(Array(MediaType.APPLICATION_JSON))
@ -50,11 +58,32 @@ private[v1] class StatementResource extends ApiRequestContext with Logging {
schema = new Schema(implementation = classOf[QueryResults]))),
description =
"Create a query")
@GET
@POST
@Path("/")
@Consumes(Array(MediaType.TEXT_PLAIN))
def query(statement: String, @Context headers: HttpHeaders): QueryResults = {
throw new UnsupportedOperationException
def query(
statement: String,
@Context headers: HttpHeaders,
@Context uriInfo: UriInfo): Response = {
if (statement == null || statement.isEmpty) {
throw badRequest(BAD_REQUEST, "SQL statement is empty")
}
val remoteAddr = Option(httpRequest.getRemoteAddr)
val trinoContext = TrinoContext(headers, remoteAddr)
try {
val query = Query(statement, trinoContext, translator, fe.be)
val qr = query.getQueryResults(query.getLastToken, uriInfo)
TrinoContext.buildTrinoResponse(qr, query.context)
} catch {
case e: Exception =>
val errorMsg =
s"Error submitting sql"
e.printStackTrace()
error(errorMsg, e)
throw badRequest(BAD_REQUEST, errorMsg)
}
}
@ApiResponse(
@ -65,11 +94,31 @@ private[v1] class StatementResource extends ApiRequestContext with Logging {
@GET
@Path("/queued/{queryId}/{slug}/{token}")
def getQueuedStatementStatus(
@Context headers: HttpHeaders,
@PathParam("queryId") queryId: String,
@PathParam("slug") slug: String,
@PathParam("token") token: Long): QueryResults = {
throw new UnsupportedOperationException
@PathParam("token") token: Long,
@QueryParam("maxWait") maxWait: Duration,
@Context headers: HttpHeaders,
@Context uriInfo: UriInfo): Response = {
val remoteAddr = Option(httpRequest.getRemoteAddr)
val trinoContext = TrinoContext(headers, remoteAddr)
val waitTime = if (maxWait == null) 0 else maxWait.toMillis
getQuery(fe.be, trinoContext, QueryId(queryId), slug, token, QUEUED_QUERY)
.flatMap(query =>
Try(TrinoContext.buildTrinoResponse(
query.getQueryResults(
token,
uriInfo,
waitTime),
query.context)))
.recover {
case NonFatal(e) =>
val errorMsg =
s"Error executing for query id $queryId"
error(errorMsg, e)
throw badRequest(NOT_FOUND, "Query not found")
}.get
}
@ApiResponse(
@ -80,11 +129,28 @@ private[v1] class StatementResource extends ApiRequestContext with Logging {
@GET
@Path("/executing/{queryId}/{slug}/{token}")
def getExecutingStatementStatus(
@Context headers: HttpHeaders,
@PathParam("queryId") queryId: String,
@PathParam("slug") slug: String,
@PathParam("token") token: Long): QueryResults = {
throw new UnsupportedOperationException
@PathParam("token") token: Long,
@QueryParam("maxWait") maxWait: Duration,
@Context headers: HttpHeaders,
@Context uriInfo: UriInfo): Response = {
val remoteAddr = Option(httpRequest.getRemoteAddr)
val trinoContext = TrinoContext(headers, remoteAddr)
val waitTime = if (maxWait == null) 0 else maxWait.toMillis
getQuery(fe.be, trinoContext, QueryId(queryId), slug, token, EXECUTING_QUERY)
.flatMap(query =>
Try(TrinoContext.buildTrinoResponse(
query.getQueryResults(token, uriInfo, waitTime),
query.context)))
.recover {
case NonFatal(e) =>
val errorMsg =
s"Error executing for query id $queryId"
error(errorMsg, e)
throw badRequest(NOT_FOUND, "Query not found")
}.get
}
@ApiResponse(
@ -95,11 +161,23 @@ private[v1] class StatementResource extends ApiRequestContext with Logging {
@DELETE
@Path("/queued/{queryId}/{slug}/{token}")
def cancelQueuedStatement(
@Context headers: HttpHeaders,
@PathParam("queryId") queryId: String,
@PathParam("slug") slug: String,
@PathParam("token") token: Long): QueryResults = {
throw new UnsupportedOperationException
@PathParam("token") token: Long,
@Context headers: HttpHeaders): Response = {
val remoteAddr = Option(httpRequest.getRemoteAddr)
val trinoContext = TrinoContext(headers, remoteAddr)
getQuery(fe.be, trinoContext, QueryId(queryId), slug, token, QUEUED_QUERY)
.flatMap(query => Try(query.cancel))
.recover {
case NonFatal(e) =>
val errorMsg =
s"Error executing for query id $queryId"
error(errorMsg, e)
throw badRequest(NOT_FOUND, "Query not found")
}.get
Response.noContent.build
}
@ApiResponse(
@ -110,11 +188,44 @@ private[v1] class StatementResource extends ApiRequestContext with Logging {
@DELETE
@Path("/executing/{queryId}/{slug}/{token}")
def cancelExecutingStatementStatus(
@Context headers: HttpHeaders,
@PathParam("queryId") queryId: String,
@PathParam("slug") slug: String,
@PathParam("token") token: Long): QueryResults = {
throw new UnsupportedOperationException
@PathParam("token") token: Long,
@Context headers: HttpHeaders): Response = {
val remoteAddr = Option(httpRequest.getRemoteAddr)
val trinoContext = TrinoContext(headers, remoteAddr)
getQuery(fe.be, trinoContext, QueryId(queryId), slug, token, EXECUTING_QUERY)
.flatMap(query => Try(query.cancel))
.recover {
case NonFatal(e) =>
val errorMsg =
s"Error executing for query id $queryId"
error(errorMsg, e)
throw badRequest(NOT_FOUND, "Query not found")
}.get
Response.noContent.build
}
private def getQuery(
be: BackendService,
context: TrinoContext,
queryId: QueryId,
slug: String,
token: Long,
slugContext: Slug.Context.Context): Try[Query] = {
Try(be.sessionManager.operationManager.getOperation(queryId.operationHandle)).map { _ =>
Query(queryId, context, be)
}.filter(_.getSlug.isValid(slugContext, slug, token))
}
private def badRequest(status: Response.Status, message: String) =
new WebApplicationException(
Response.status(status)
.`type`(TEXT_PLAIN_TYPE)
.entity(message)
.build)
}

View File

@ -37,7 +37,7 @@ import org.apache.kyuubi.service.AbstractFrontendService
object RestFrontendTestHelper {
private class RestApiBaseSuite extends JerseyTest {
class RestApiBaseSuite extends JerseyTest {
override def configure: Application = new ResourceConfig(getClass)
.register(classOf[MultiPartFeature])
@ -58,7 +58,7 @@ trait RestFrontendTestHelper extends WithKyuubiServer {
override protected val frontendProtocols: Seq[FrontendProtocol] =
FrontendProtocols.REST :: Nil
private val restApiBaseSuite = new RestApiBaseSuite
protected val restApiBaseSuite: JerseyTest = new RestApiBaseSuite
override def beforeAll(): Unit = {
super.beforeAll()

View File

@ -1,80 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kyuubi
import java.net.URI
import java.time.ZoneId
import java.util.{Locale, Optional}
import java.util.concurrent.TimeUnit
import scala.collection.JavaConverters._
import io.airlift.units.Duration
import io.trino.client.{ClientSelectedRole, ClientSession, StatementClient, StatementClientFactory}
import okhttp3.OkHttpClient
trait TrinoClientTestHelper extends RestFrontendTestHelper {
override def afterAll(): Unit = {
super.afterAll()
}
private val httpClient = new OkHttpClient.Builder().build()
protected val clientSession = createClientSession(baseUri: URI)
def getTrinoStatementClient(sql: String): StatementClient = {
StatementClientFactory.newStatementClient(httpClient, clientSession, sql)
}
def createClientSession(connectUrl: URI): ClientSession = {
new ClientSession(
connectUrl,
"kyuubi_test",
Optional.of("test_user"),
"kyuubi",
Optional.of("test_token_tracing"),
Set[String]().asJava,
"test_client_info",
"test_catalog",
"test_schema",
"test_path",
ZoneId.systemDefault(),
Locale.getDefault,
Map[String, String](
"test_resource_key0" -> "test_resource_value0",
"test_resource_key1" -> "test_resource_value1").asJava,
Map[String, String](
"test_property_key0" -> "test_property_value0",
"test_property_key1" -> "test_propert_value1").asJava,
Map[String, String](
"test_statement_key0" -> "select 1",
"test_statement_key1" -> "select 2").asJava,
Map[String, ClientSelectedRole](
"test_role_key0" -> ClientSelectedRole.valueOf("ROLE"),
"test_role_key2" -> ClientSelectedRole.valueOf("ALL")).asJava,
Map[String, String](
"test_credentials_key0" -> "test_credentials_value0",
"test_credentials_key1" -> "test_credentials_value1").asJava,
"test_transaction_id",
new Duration(2, TimeUnit.MINUTES),
true)
}
}

View File

@ -0,0 +1,40 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kyuubi
import org.glassfish.jersey.client.ClientConfig
import org.glassfish.jersey.test.JerseyTest
import org.apache.kyuubi.config.KyuubiConf.FrontendProtocols
import org.apache.kyuubi.config.KyuubiConf.FrontendProtocols.FrontendProtocol
import org.apache.kyuubi.server.trino.api.TrinoScalaObjectMapper
trait TrinoRestFrontendTestHelper extends RestFrontendTestHelper {
private class TrinoRestBaseSuite extends RestFrontendTestHelper.RestApiBaseSuite {
override def configureClient(config: ClientConfig): Unit = {
config.register(classOf[TrinoScalaObjectMapper])
}
}
override protected val frontendProtocols: Seq[FrontendProtocol] =
FrontendProtocols.TRINO :: Nil
override protected val restApiBaseSuite: JerseyTest = new TrinoRestBaseSuite
}

View File

@ -0,0 +1,145 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kyuubi.server.trino.api
import java.net.URI
import java.time.ZoneId
import java.util.{Collections, Locale, Optional}
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicReference
import scala.annotation.tailrec
import scala.collection.JavaConverters._
import com.google.common.base.Verify
import io.airlift.units.Duration
import io.trino.client.{ClientSession, StatementClient, StatementClientFactory}
import okhttp3.OkHttpClient
import org.apache.kyuubi.{KyuubiFunSuite, KyuubiSQLException, TrinoRestFrontendTestHelper}
class TrinoClientApiSuite extends KyuubiFunSuite with TrinoRestFrontendTestHelper {
private val httpClient =
new OkHttpClient.Builder()
.readTimeout(5, TimeUnit.MINUTES)
.build()
private lazy val clientSession =
new AtomicReference[ClientSession](createTestClientSession(baseUri))
test("submit query with trino client api") {
val trino = getTrinoStatementClient("select 1")
val result = execute(trino)
val sessionId = trino.getSetSessionProperties.asScala.get("sessionId")
assert(result == List(List(1)))
updateClientSession(trino)
val trino1 = getTrinoStatementClient("select 2")
val result1 = execute(trino1)
val sessionId1 = trino1.getSetSessionProperties.asScala.get("sessionId")
assert(result1 == List(List(2)))
assert(sessionId != sessionId1)
trino.close()
}
private def updateClientSession(trino: StatementClient): Unit = {
val session = clientSession.get
var builder = ClientSession.builder(session)
// update catalog and schema
if (trino.getSetCatalog.isPresent || trino.getSetSchema.isPresent) {
builder = builder
.withCatalog(trino.getSetCatalog.orElse(session.getCatalog))
.withSchema(trino.getSetSchema.orElse(session.getSchema))
}
// update path if present
if (trino.getSetPath.isPresent) {
builder = builder.withPath(trino.getSetPath.get)
}
// update session properties if present
if (!trino.getSetSessionProperties.isEmpty || !trino.getResetSessionProperties.isEmpty) {
val properties = session.getProperties.asScala.clone()
properties ++= trino.getSetSessionProperties.asScala
properties --= trino.getResetSessionProperties.asScala
builder = builder.withProperties(properties.asJava)
}
clientSession.set(builder.build())
}
private def execute(trino: StatementClient): List[List[Any]] = {
@tailrec
def getData(trino: StatementClient): (Boolean, List[List[Any]]) = {
if (trino.isRunning) {
val data = trino.currentData().getData()
trino.advance()
if (data != null) {
(true, data.asScala.toList.map(_.asScala.toList))
} else {
getData(trino)
}
} else {
Verify.verify(trino.isFinished)
val finalStatus = trino.finalStatusInfo()
if (finalStatus.getError() != null) {
throw KyuubiSQLException(
s"Query ${finalStatus.getId} failed: ${finalStatus.getError.getMessage}")
}
(false, List[List[Any]]())
}
}
Iterator.continually(getData(trino)).takeWhile(_._1).flatMap(_._2).toList
}
private def getTrinoStatementClient(sql: String): StatementClient = {
StatementClientFactory.newStatementClient(httpClient, clientSession.get, sql)
}
private def createTestClientSession(connectUrl: URI): ClientSession = {
new ClientSession(
connectUrl,
"kyuubi_test",
Optional.of("test_user"),
"kyuubi",
Optional.of("test_token_tracing"),
Set[String]().asJava,
"test_client_info",
"test_catalog",
"test_schema",
null,
ZoneId.systemDefault(),
Locale.getDefault,
Collections.emptyMap(),
Map[String, String](
"test_property_key0" -> "test_property_value0",
"test_property_key1" -> "test_propert_value1").asJava,
Map[String, String](
"test_statement_key0" -> "select 1",
"test_statement_key1" -> "select 2").asJava,
Collections.emptyMap(),
Collections.emptyMap(),
null,
new Duration(2, TimeUnit.MINUTES),
true)
}
}

View File

@ -85,7 +85,7 @@ class TrinoContextSuite extends KyuubiFunSuite with RestFrontendTestHelper {
val metadataResp = fe.be.getResultSetMetadata(opHandle)
val tRowSet = fe.be.fetchResults(opHandle, FetchOrientation.FETCH_NEXT, 1000, false)
val status = fe.be.getOperationStatus(opHandle)
val status = fe.be.getOperationStatus(opHandle, Some(0))
val uri = new URI("sfdsfsdfdsf")
val results = TrinoContext
@ -112,7 +112,7 @@ class TrinoContextSuite extends KyuubiFunSuite with RestFrontendTestHelper {
val metadataResp = fe.be.getResultSetMetadata(opHandle)
val tRowSet = fe.be.fetchResults(opHandle, FetchOrientation.FETCH_NEXT, 1000, false)
val status = fe.be.getOperationStatus(opHandle)
val status = fe.be.getOperationStatus(opHandle, Some(0))
val uri = new URI("sfdsfsdfdsf")
val results = TrinoContext

View File

@ -17,15 +17,27 @@
package org.apache.kyuubi.server.trino.api.v1
import org.apache.kyuubi.{KyuubiFunSuite, RestFrontendTestHelper}
import org.apache.kyuubi.config.KyuubiConf.FrontendProtocols
import org.apache.kyuubi.config.KyuubiConf.FrontendProtocols.FrontendProtocol
import javax.ws.rs.client.Entity
import javax.ws.rs.core.{MediaType, Response}
import scala.collection.JavaConverters._
import io.trino.client.{QueryError, QueryResults}
import io.trino.client.ProtocolHeaders.TRINO_HEADERS
import org.apache.kyuubi.{KyuubiFunSuite, KyuubiSQLException, TrinoRestFrontendTestHelper}
import org.apache.kyuubi.operation.{OperationHandle, OperationState}
import org.apache.kyuubi.server.trino.api.TrinoContext
import org.apache.kyuubi.server.trino.api.v1.dto.Ok
import org.apache.kyuubi.session.SessionHandle
class StatementResourceSuite extends KyuubiFunSuite with RestFrontendTestHelper {
class StatementResourceSuite extends KyuubiFunSuite with TrinoRestFrontendTestHelper {
override protected val frontendProtocols: Seq[FrontendProtocol] =
FrontendProtocols.TRINO :: Nil
case class TrinoResponse(
response: Option[Response] = None,
queryError: Option[QueryError] = None,
data: List[List[Any]] = List[List[Any]](),
isEnd: Boolean = false)
test("statement test") {
val response = webTarget.path("v1/statement/test").request().get()
@ -33,4 +45,74 @@ class StatementResourceSuite extends KyuubiFunSuite with RestFrontendTestHelper
assert(result == new Ok("trino server is running"))
}
test("statement submit for query error") {
val response = webTarget.path("v1/statement")
.request().post(Entity.entity("select a", MediaType.TEXT_PLAIN_TYPE))
val trinoResponseIter = Iterator.iterate(TrinoResponse(response = Option(response)))(getData)
val isErr = trinoResponseIter.takeWhile(_.isEnd == false).exists { t =>
t.queryError != None && t.response == None
}
assert(isErr == true)
}
test("statement submit and get result") {
val response = webTarget.path("v1/statement")
.request().post(Entity.entity("select 1", MediaType.TEXT_PLAIN_TYPE))
val trinoResponseIter = Iterator.iterate(TrinoResponse(response = Option(response)))(getData)
val dataSet = trinoResponseIter
.takeWhile(_.isEnd == false)
.map(_.data)
.flatten.toList
assert(dataSet == List(List(1)))
}
test("query cancel") {
val response = webTarget.path("v1/statement")
.request().post(Entity.entity("select 1", MediaType.TEXT_PLAIN_TYPE))
val qr = response.readEntity(classOf[QueryResults])
val sessionManager = fe.be.sessionManager
val sessionHandle =
response.getStringHeaders.get(TRINO_HEADERS.responseSetSession).asScala
.map(_.split("="))
.find {
case Array("sessionId", _) => true
}
.map {
case Array(_, value) => SessionHandle.fromUUID(TrinoContext.urlDecode(value))
}.get
sessionManager.getSession(sessionHandle)
val operationHandle = OperationHandle(qr.getId)
val operation = sessionManager.operationManager.getOperation(operationHandle)
assert(response.getStatus == 200)
val path = qr.getNextUri.getPath
val nextResponse = webTarget.path(path).request().header(
TRINO_HEADERS.requestSession(),
s"sessionId=${TrinoContext.urlEncode(sessionHandle.identifier.toString)}").delete()
assert(nextResponse.getStatus == 204)
assert(operation.getStatus.state == OperationState.CLOSED)
val exception = intercept[KyuubiSQLException](sessionManager.getSession(sessionHandle))
assert(exception.getMessage === s"Invalid $sessionHandle")
}
private def getData(current: TrinoResponse): TrinoResponse = {
current.response.map { response =>
assert(response.getStatus == 200)
val qr = response.readEntity(classOf[QueryResults])
val nextData = Option(qr.getData)
.map(_.asScala.toList.map(_.asScala.toList))
.getOrElse(List[List[Any]]())
val nextResponse = Option(qr.getNextUri).map {
uri =>
val path = uri.getPath
val headers = response.getHeaders
webTarget.path(path).request().headers(headers).get()
}
TrinoResponse(nextResponse, Option(qr.getError), nextData)
}.getOrElse(TrinoResponse(isEnd = true))
}
}

View File

@ -778,6 +778,12 @@
<version>${jackson.version}</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.datatype</groupId>
<artifactId>jackson-datatype-jdk8</artifactId>
<version>${jackson.version}</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.jaxrs</groupId>
<artifactId>jackson-jaxrs-base</artifactId>