diff --git a/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/operation/GetSchemas.scala b/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/operation/GetSchemas.scala new file mode 100644 index 000000000..b4547ec83 --- /dev/null +++ b/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/operation/GetSchemas.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.engine.trino.operation + +import scala.collection.mutable.ArrayBuffer + +import org.apache.commons.lang3.StringUtils + +import org.apache.kyuubi.engine.trino.TrinoStatement +import org.apache.kyuubi.operation.IterableFetchIterator +import org.apache.kyuubi.operation.OperationType +import org.apache.kyuubi.operation.meta.ResultSetSchemaConstant.TABLE_CATALOG +import org.apache.kyuubi.operation.meta.ResultSetSchemaConstant.TABLE_SCHEM +import org.apache.kyuubi.session.Session + +class GetSchemas(session: Session, catalogName: String, schemaPattern: String) + extends TrinoOperation(OperationType.GET_SCHEMAS, session) { + + private val SEARCH_STRING_ESCAPE: String = "\\" + + override protected def runInternal(): Unit = { + val query = new StringBuilder("SELECT TABLE_SCHEM, TABLE_CATALOG FROM system.jdbc.schemas") + + val filters = ArrayBuffer[String]() + if (StringUtils.isNotEmpty(catalogName)) { + filters += s"$TABLE_CATALOG = '$catalogName'" + } + if (StringUtils.isNotEmpty(schemaPattern)) { + filters += s"$TABLE_SCHEM LIKE '$schemaPattern' ESCAPE '$SEARCH_STRING_ESCAPE'" + } + + if (filters.nonEmpty) { + query.append(" WHERE ") + query.append(filters.mkString(" AND ")) + } + + try { + val trinoStatement = TrinoStatement( + trinoContext, + session.sessionManager.getConf, + query.toString) + schema = trinoStatement.getColumns + val resultSet = trinoStatement.execute() + iter = new IterableFetchIterator(resultSet) + } catch onError() + } +} diff --git a/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/operation/TrinoOperationManager.scala b/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/operation/TrinoOperationManager.scala index 8f1505cbb..939fbe3ec 100644 --- a/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/operation/TrinoOperationManager.scala +++ b/externals/kyuubi-trino-engine/src/main/scala/org/apache/kyuubi/engine/trino/operation/TrinoOperationManager.scala @@ -47,7 +47,10 @@ class TrinoOperationManager extends OperationManager("TrinoOperationManager") { override def newGetSchemasOperation( session: Session, catalog: String, - schema: String): Operation = null + schema: String): Operation = { + val op = new GetSchemas(session, catalog, schema) + addOperation(op) + } override def newGetTablesOperation( session: Session, diff --git a/externals/kyuubi-trino-engine/src/test/scala/org/apache/kyuubi/engine/trino/operation/TrinoOperationSuite.scala b/externals/kyuubi-trino-engine/src/test/scala/org/apache/kyuubi/engine/trino/operation/TrinoOperationSuite.scala index 7020a8202..ac56352da 100644 --- a/externals/kyuubi-trino-engine/src/test/scala/org/apache/kyuubi/engine/trino/operation/TrinoOperationSuite.scala +++ b/externals/kyuubi-trino-engine/src/test/scala/org/apache/kyuubi/engine/trino/operation/TrinoOperationSuite.scala @@ -34,7 +34,7 @@ import org.apache.hive.service.rpc.thrift.TStatusCode import org.apache.kyuubi.config.KyuubiConf.ENGINE_TRINO_CONNECTION_CATALOG import org.apache.kyuubi.engine.trino.WithTrinoEngine import org.apache.kyuubi.operation.HiveJDBCTestHelper -import org.apache.kyuubi.operation.meta.ResultSetSchemaConstant.{TABLE_CAT, TABLE_TYPE} +import org.apache.kyuubi.operation.meta.ResultSetSchemaConstant._ class TrinoOperationSuite extends WithTrinoEngine with HiveJDBCTestHelper { override def withKyuubiConf: Map[String, String] = Map( @@ -71,6 +71,72 @@ class TrinoOperationSuite extends WithTrinoEngine with HiveJDBCTestHelper { } } + test("trino - get schemas") { + case class SchemaWithCatalog(catalog: String, schema: String) + + withJdbcStatement() { statement => + statement.execute("CREATE SCHEMA IF NOT EXISTS memory.test_escape_1") + statement.execute("CREATE SCHEMA IF NOT EXISTS memory.test2escape_1") + statement.execute("CREATE SCHEMA IF NOT EXISTS memory.test_escape11") + + val meta = statement.getConnection.getMetaData + val resultSetBuffer = ArrayBuffer[SchemaWithCatalog]() + + val schemas1 = meta.getSchemas(null, null) + while (schemas1.next()) { + resultSetBuffer += + SchemaWithCatalog(schemas1.getString(TABLE_CATALOG), schemas1.getString(TABLE_SCHEM)) + } + assert(resultSetBuffer.contains(SchemaWithCatalog("memory", "information_schema"))) + assert(resultSetBuffer.contains(SchemaWithCatalog("system", "information_schema"))) + + val schemas2 = meta.getSchemas("memory", null) + resultSetBuffer.clear() + while (schemas2.next()) { + resultSetBuffer += + SchemaWithCatalog(schemas2.getString(TABLE_CATALOG), schemas2.getString(TABLE_SCHEM)) + } + assert(resultSetBuffer.contains(SchemaWithCatalog("memory", "default"))) + assert(resultSetBuffer.contains(SchemaWithCatalog("memory", "information_schema"))) + assert(!resultSetBuffer.exists(f => f.catalog == "system")) + + val schemas3 = meta.getSchemas(null, "sf_") + resultSetBuffer.clear() + while (schemas3.next()) { + resultSetBuffer += + SchemaWithCatalog(schemas3.getString(TABLE_CATALOG), schemas3.getString(TABLE_SCHEM)) + } + assert(resultSetBuffer.contains(SchemaWithCatalog("tpcds", "sf1"))) + assert(!resultSetBuffer.contains(SchemaWithCatalog("tpcds", "sf10"))) + + val schemas4 = meta.getSchemas(null, "sf%") + resultSetBuffer.clear() + while (schemas4.next()) { + resultSetBuffer += + SchemaWithCatalog(schemas4.getString(TABLE_CATALOG), schemas4.getString(TABLE_SCHEM)) + } + assert(resultSetBuffer.contains(SchemaWithCatalog("tpcds", "sf1"))) + assert(resultSetBuffer.contains(SchemaWithCatalog("tpcds", "sf10"))) + assert(resultSetBuffer.contains(SchemaWithCatalog("tpcds", "sf100"))) + assert(resultSetBuffer.contains(SchemaWithCatalog("tpcds", "sf1000"))) + + // test escape the second '_' + val schemas5 = meta.getSchemas("memory", "test_escape\\_1") + resultSetBuffer.clear() + while (schemas5.next()) { + resultSetBuffer += + SchemaWithCatalog(schemas5.getString(TABLE_CATALOG), schemas5.getString(TABLE_SCHEM)) + } + assert(resultSetBuffer.contains(SchemaWithCatalog("memory", "test_escape_1"))) + assert(resultSetBuffer.contains(SchemaWithCatalog("memory", "test2escape_1"))) + assert(!resultSetBuffer.contains(SchemaWithCatalog("memory", "test_escape11"))) + + statement.execute("DROP SCHEMA memory.test_escape_1") + statement.execute("DROP SCHEMA memory.test2escape_1") + statement.execute("DROP SCHEMA memory.test_escape11") + } + } + test("execute statement - select decimal") { withJdbcStatement() { statement => val resultSet = statement.executeQuery("SELECT DECIMAL '1.2' as col1, DECIMAL '1.23' AS col2")