diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/GetCatalogs.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/GetCatalogs.scala index eec959e08..9d1414715 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/GetCatalogs.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/GetCatalogs.scala @@ -20,7 +20,7 @@ package org.apache.kyuubi.engine.spark.operation import org.apache.spark.sql.SparkSession import org.apache.spark.sql.types.StructType -import org.apache.kyuubi.engine.spark.shim.SparkShim +import org.apache.kyuubi.engine.spark.shim.SparkCatalogShim import org.apache.kyuubi.operation.OperationType import org.apache.kyuubi.operation.meta.ResultSetSchemaConstant.TABLE_CAT import org.apache.kyuubi.session.Session @@ -35,7 +35,7 @@ class GetCatalogs(spark: SparkSession, session: Session) override protected def runInternal(): Unit = { try { - iter = SparkShim().getCatalogs(spark).toIterator + iter = SparkCatalogShim().getCatalogs(spark).toIterator } catch onError() } } diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/GetColumns.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/GetColumns.scala index c2787609d..37cc2cae3 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/GetColumns.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/GetColumns.scala @@ -17,14 +17,10 @@ package org.apache.kyuubi.engine.spark.operation -import java.util.regex.Pattern - -import scala.collection.mutable.ArrayBuffer - -import org.apache.commons.lang3.StringUtils -import org.apache.spark.sql.{Row, SparkSession} -import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, CalendarIntervalType, DataType, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, NullType, NumericType, ShortType, StringType, StructField, StructType, TimestampType} +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.types._ +import org.apache.kyuubi.engine.spark.shim.SparkCatalogShim import org.apache.kyuubi.operation.OperationType import org.apache.kyuubi.operation.meta.ResultSetSchemaConstant._ import org.apache.kyuubi.session.Session @@ -46,95 +42,6 @@ class GetColumns( s" columnPattern: $columnName]" } - private def toJavaSQLType(typ: DataType): Int = typ match { - case NullType => java.sql.Types.NULL - case BooleanType => java.sql.Types.BOOLEAN - case ByteType => java.sql.Types.TINYINT - case ShortType => java.sql.Types.SMALLINT - case IntegerType => java.sql.Types.INTEGER - case LongType => java.sql.Types.BIGINT - case FloatType => java.sql.Types.FLOAT - case DoubleType => java.sql.Types.DOUBLE - case StringType => java.sql.Types.VARCHAR - case _: DecimalType => java.sql.Types.DECIMAL - case DateType => java.sql.Types.DATE - case TimestampType => java.sql.Types.TIMESTAMP - case BinaryType => java.sql.Types.BINARY - case _: ArrayType => java.sql.Types.ARRAY - case _: MapType => java.sql.Types.JAVA_OBJECT - case _: StructType => java.sql.Types.STRUCT - case _ => java.sql.Types.OTHER - } - - /** - * For boolean, numeric and datetime types, it returns the default size of its catalyst type - * For struct type, when its elements are fixed-size, the summation of all element sizes will be - * returned. - * For array, map, string, and binaries, the column size is variable, return null as unknown. - */ - private def getColumnSize(typ: DataType): Option[Int] = typ match { - case dt @ (BooleanType | _: NumericType | DateType | TimestampType | - CalendarIntervalType | NullType) => - Some(dt.defaultSize) - case StructType(fields) => - val sizeArr = fields.map(f => getColumnSize(f.dataType)) - if (sizeArr.contains(None)) { - None - } else { - Some(sizeArr.map(_.get).sum) - } - case _ => None - } - - /** - * The number of fractional digits for this type. - * Null is returned for data types where this is not applicable. - * For boolean and integrals, the decimal digits is 0 - * For floating types, we follow the IEEE Standard for Floating-Point Arithmetic (IEEE 754) - * For timestamp values, we support microseconds - * For decimals, it returns the scale - */ - private def getDecimalDigits(typ: DataType): Option[Int] = typ match { - case BooleanType | _: IntegerType => Some(0) - case FloatType => Some(7) - case DoubleType => Some(15) - case d: DecimalType => Some(d.scale) - case TimestampType => Some(6) - case _ => None - } - - private def getNumPrecRadix(typ: DataType): Option[Int] = typ match { - case _: NumericType => Some(10) - case _ => None - } - - private def toRow(db: String, table: String, col: StructField, pos: Int): Row = { - Row( - null, // TABLE_CAT - db, // TABLE_SCHEM - table, // TABLE_NAME - col.name, // COLUMN_NAME - toJavaSQLType(col.dataType), // DATA_TYPE - col.dataType.sql, // TYPE_NAME - getColumnSize(col.dataType).orNull, // COLUMN_SIZE - null, // BUFFER_LENGTH - getDecimalDigits(col.dataType).orNull, // DECIMAL_DIGITS - getNumPrecRadix(col.dataType).orNull, // NUM_PREC_RADIX - if (col.nullable) 1 else 0, // NULLABLE - col.getComment().getOrElse(""), // REMARKS - null, // COLUMN_DEF - null, // SQL_DATA_TYPE - null, // SQL_DATETIME_SUB - null, // CHAR_OCTET_LENGTH - pos, // ORDINAL_POSITION - "YES", // IS_NULLABLE - null, // SCOPE_CATALOG - null, // SCOPE_SCHEMA - null, // SCOPE_TABLE - null, // SOURCE_DATA_TYPE - "NO" // IS_AUTO_INCREMENT - ) - } override protected def resultSchema: StructType = { new StructType() .add(TABLE_CAT, "string", nullable = true, "Catalog name. NULL if not applicable") @@ -178,45 +85,12 @@ class GetColumns( override protected def runInternal(): Unit = { try { - val catalog = spark.sessionState.catalog - val schemaPattern = convertSchemaPattern(schemaName) - val tablePattern = convertIdentifierPattern(tableName, datanucleusFormat = true) - val columnPattern = - Pattern.compile(convertIdentifierPattern(columnName, datanucleusFormat = false)) - val tables: Seq[Row] = catalog.listDatabases(schemaPattern).flatMap { db => - val identifiers = - catalog.listTables(db, tablePattern, includeLocalTempViews = false) - catalog.getTablesByName(identifiers).flatMap { t => - t.schema.zipWithIndex - .filter { f => columnPattern.matcher(f._1.name).matches() } - .map { case (f, i) => toRow(t.database, t.identifier.table, f, i) - } - } - } - - val gviews = new ArrayBuffer[Row]() - val globalTmpDb = catalog.globalTempViewManager.database - if (StringUtils.isEmpty(schemaName) || schemaName == "*" - || Pattern.compile(convertSchemaPattern(schemaName, false)) - .matcher(globalTmpDb).matches()) { - catalog.globalTempViewManager.listViewNames(tablePattern).foreach { v => - catalog.globalTempViewManager.get(v).foreach { plan => - plan.schema.zipWithIndex - .filter { f => columnPattern.matcher(f._1.name).matches() } - .foreach { case (f, i) => gviews += toRow(globalTmpDb, v, f, i) } - } - } - } - - val views: Seq[Row] = catalog.listLocalTempViews(tablePattern) - .map(v => (v, catalog.getTempView(v.table).get)) - .flatMap { case (v, plan) => - plan.schema.zipWithIndex - .filter(f => columnPattern.matcher(f._1.name).matches()) - .map { case (f, i) => toRow(null, v.table, f, i) } - } - - iter = (tables ++ gviews ++ views).toList.iterator + val schemaPattern = toJavaRegex(schemaName) + val tablePattern = toJavaRegex(tableName) + val columnPattern = toJavaRegex(columnName) + iter = SparkCatalogShim() + .getColumns(spark, catalogName, schemaPattern, tablePattern, columnPattern) + .toList.iterator } catch { onError() } diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/GetFunctions.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/GetFunctions.scala index a65afd07d..c821f41fc 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/GetFunctions.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/GetFunctions.scala @@ -55,8 +55,8 @@ class GetFunctions( override protected def runInternal(): Unit = { try { - val schemaPattern = convertSchemaPattern(schemaName) - val functionPattern = convertIdentifierPattern(functionName, datanucleusFormat = false) + val schemaPattern = toJavaRegex(schemaName) + val functionPattern = toJavaRegex(functionName) val catalog = spark.sessionState.catalog val a: Seq[Row] = catalog.listDatabases(schemaPattern).flatMap { db => catalog.listFunctions(db, functionPattern).map { case (f, _) => diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/GetSchemas.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/GetSchemas.scala index e6e19a45d..7b2467969 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/GetSchemas.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/GetSchemas.scala @@ -20,7 +20,7 @@ package org.apache.kyuubi.engine.spark.operation import org.apache.spark.sql.SparkSession import org.apache.spark.sql.types.StructType -import org.apache.kyuubi.engine.spark.shim.SparkShim +import org.apache.kyuubi.engine.spark.shim.SparkCatalogShim import org.apache.kyuubi.operation.OperationType import org.apache.kyuubi.operation.meta.ResultSetSchemaConstant._ import org.apache.kyuubi.session.Session @@ -40,8 +40,8 @@ class GetSchemas(spark: SparkSession, session: Session, catalogName: String, sch override protected def runInternal(): Unit = { try { - val schemaPattern = convertSchemaPattern(schema, datanucleusFormat = false) - val rows = SparkShim().getSchemas(spark, catalogName, schemaPattern) + val schemaPattern = toJavaRegex(schema) + val rows = SparkCatalogShim().getSchemas(spark, catalogName, schemaPattern) iter = rows.toList.toIterator } catch onError() } diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/GetTableTypes.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/GetTableTypes.scala index 6d746bdb7..9170234aa 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/GetTableTypes.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/GetTableTypes.scala @@ -20,7 +20,7 @@ package org.apache.kyuubi.engine.spark.operation import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.types.StructType -import org.apache.kyuubi.engine.spark.shim.SparkShim +import org.apache.kyuubi.engine.spark.shim.SparkCatalogShim import org.apache.kyuubi.operation.OperationType import org.apache.kyuubi.operation.meta.ResultSetSchemaConstant._ import org.apache.kyuubi.session.Session @@ -33,6 +33,6 @@ class GetTableTypes(spark: SparkSession, session: Session) } override protected def runInternal(): Unit = { - iter = SparkShim.sparkTableTypes.map(Row(_)).toList.iterator + iter = SparkCatalogShim.sparkTableTypes.map(Row(_)).toList.iterator } } diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/GetTables.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/GetTables.scala index b9dc7f6ea..4d252f38a 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/GetTables.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/GetTables.scala @@ -18,10 +18,9 @@ package org.apache.kyuubi.engine.spark.operation import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.catalog.CatalogTableType import org.apache.spark.sql.types.StructType -import org.apache.kyuubi.engine.spark.shim.SparkShim +import org.apache.kyuubi.engine.spark.shim.SparkCatalogShim import org.apache.kyuubi.operation.OperationType import org.apache.kyuubi.operation.meta.ResultSetSchemaConstant._ import org.apache.kyuubi.session.Session @@ -61,9 +60,9 @@ class GetTables( override protected def runInternal(): Unit = { try { - val schemaPattern = convertSchemaPattern(schema, datanucleusFormat = false) - val tablePattern = convertIdentifierPattern(tableName, datanucleusFormat = true) - val sparkShim = SparkShim() + val schemaPattern = toJavaRegex(schema) + val tablePattern = toJavaRegex(tableName) + val sparkShim = SparkCatalogShim() val catalogTablesAndViews = sparkShim.getCatalogTablesOrViews(spark, catalog, schemaPattern, tablePattern, tableTypes) diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/SparkOperation.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/SparkOperation.scala index 19078de7b..dade5f8bf 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/SparkOperation.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/SparkOperation.scala @@ -17,6 +17,8 @@ package org.apache.kyuubi.engine.spark.operation +import java.util.regex.Pattern + import org.apache.commons.lang3.StringUtils import org.apache.hive.service.rpc.thrift.{TRowSet, TTableSchema} import org.apache.spark.sql.{Row, SparkSession} @@ -49,38 +51,33 @@ abstract class SparkOperation(spark: SparkSession, opType: OperationType, sessio } } - private def convertPattern(pattern: String, datanucleusFormat: Boolean): String = { - val wStr = if (datanucleusFormat) "*" else ".*" - pattern - .replaceAll("([^\\\\])%", "$1" + wStr) - .replaceAll("\\\\%", "%") - .replaceAll("^%", wStr) - .replaceAll("([^\\\\])_", "$1.") - .replaceAll("\\\\_", "_") - .replaceAll("^_", ".") - } - /** - * Convert wildcards and escape sequence of schema pattern from JDBC format to datanucleous/regex - * The schema pattern treats empty string also as wildcard + * convert SQL 'like' pattern to a Java regular expression. + * + * Underscores (_) are converted to '.' and percent signs (%) are converted to '.*'. + * + * @param input the SQL pattern to convert + * @return the equivalent Java regular expression of the pattern */ - protected def convertSchemaPattern(pattern: String, datanucleusFormat: Boolean = true): String = { - if (StringUtils.isEmpty(pattern) || pattern == "*") { - convertPattern("%", datanucleusFormat) + def toJavaRegex(input: String): String = { + val res = if (StringUtils.isEmpty(input) || input == "*") { + "%" } else { - convertPattern(pattern, datanucleusFormat) + input } - } + val in = res.toIterator + val out = new StringBuilder() - /** - * Convert wildcards and escape sequence from JDBC format to datanucleous/regex - */ - protected def convertIdentifierPattern(pattern: String, datanucleusFormat: Boolean): String = { - if (pattern == null) { - convertPattern("%", datanucleusFormat) - } else { - convertPattern(pattern, datanucleusFormat) + while (in.hasNext) { + in.next match { + case c if c == '\\' && in.hasNext => Pattern.quote(Character.toString(in.next())) + case c if c == '\\' && !in.hasNext => Pattern.quote(Character.toString(c)) + case '_' => out ++= "." + case '%' => out ++= ".*" + case c => out ++= Character.toString(c) + } } + out.result() } protected def onError(cancel: Boolean = false): PartialFunction[Throwable, Unit] = { diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/SparkSQLOperationManager.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/SparkSQLOperationManager.scala index 1b4e0d8e3..38fdc23fb 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/SparkSQLOperationManager.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/SparkSQLOperationManager.scala @@ -24,7 +24,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.sql.SparkSession import org.apache.kyuubi.KyuubiSQLException -import org.apache.kyuubi.engine.spark.shim.SparkShim +import org.apache.kyuubi.engine.spark.shim.SparkCatalogShim import org.apache.kyuubi.operation.{Operation, OperationManager} import org.apache.kyuubi.session.{Session, SessionHandle} @@ -91,7 +91,7 @@ class SparkSQLOperationManager private (name: String) extends OperationManager(n tableTypes: java.util.List[String]): Operation = { val spark = getSparkSession(session.handle) val tTypes = if (tableTypes == null || tableTypes.isEmpty) { - SparkShim.sparkTableTypes + SparkCatalogShim.sparkTableTypes } else { tableTypes.asScala.toSet } diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/shim/Shim_v2_4.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/shim/CatalogShim_v2_4.scala similarity index 53% rename from externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/shim/Shim_v2_4.scala rename to externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/shim/CatalogShim_v2_4.scala index 260d989e8..b443f2600 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/shim/Shim_v2_4.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/shim/CatalogShim_v2_4.scala @@ -17,15 +17,16 @@ package org.apache.kyuubi.engine.spark.shim +import java.util.regex.Pattern + import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.connector.catalog.CatalogPlugin -class Shim_v2_4 extends SparkShim { +class CatalogShim_v2_4 extends SparkCatalogShim { - override def getCatalogs(spark: SparkSession): Seq[Row] = Seq(Row("")) - - override protected def getCatalog(spark: SparkSession, catalog: String): CatalogPlugin = null + override def getCatalogs(spark: SparkSession): Seq[Row] = { + Seq(Row(SparkCatalogShim.SESSION_CATALOG)) + } override protected def catalogExists(spark: SparkSession, catalog: String): Boolean = false @@ -86,4 +87,74 @@ class Shim_v2_4 extends SparkShim { spark.sessionState.catalog.listLocalTempViews(tablePattern) } } + + override def getColumns( + spark: SparkSession, + catalogName: String, + schemaPattern: String, + tablePattern: String, + columnPattern: String): Seq[Row] = { + + val cp = columnPattern.r.pattern + val byCatalog = getColumnsByCatalog(spark, catalogName, schemaPattern, tablePattern, cp) + val byGlobalTmpDB = getColumnsByGlobalTempViewManager(spark, schemaPattern, tablePattern, cp) + val byLocalTmp = getColumnsByLocalTempViews(spark, tablePattern, cp) + + byCatalog ++ byGlobalTmpDB ++ byLocalTmp + } + + protected def getColumnsByCatalog( + spark: SparkSession, + catalogName: String, + schemaPattern: String, + tablePattern: String, + columnPattern: Pattern): Seq[Row] = { + val catalog = spark.sessionState.catalog + + val databases = catalog.listDatabases(schemaPattern) + + databases.flatMap { db => + val identifiers = catalog.listTables(db, tablePattern, includeLocalTempViews = true) + catalog.getTablesByName(identifiers).flatMap { t => + t.schema.zipWithIndex.filter(f => columnPattern.matcher(f._1.name).matches()) + .map { case (f, i) => toColumnResult(catalogName, t.database, t.identifier.table, f, i) } + } + } + } + + protected def getColumnsByGlobalTempViewManager( + spark: SparkSession, + schemaPattern: String, + tablePattern: String, + columnPattern: Pattern): Seq[Row] = { + val catalog = spark.sessionState.catalog + + getGlobalTempViewManager(spark, schemaPattern).flatMap { globalTmpDb => + catalog.globalTempViewManager.listViewNames(tablePattern).flatMap { v => + catalog.globalTempViewManager.get(v).map { plan => + plan.schema.zipWithIndex.filter(f => columnPattern.matcher(f._1.name).matches()) + .map { case (f, i) => + toColumnResult(SparkCatalogShim.SESSION_CATALOG, globalTmpDb, v, f, i) + } + } + }.flatten + } + } + + protected def getColumnsByLocalTempViews( + spark: SparkSession, + tablePattern: String, + columnPattern: Pattern): Seq[Row] = { + val catalog = spark.sessionState.catalog + + catalog.listLocalTempViews(tablePattern) + .map(v => (v, catalog.getTempView(v.table).get)) + .flatMap { case (v, plan) => + plan.schema.zipWithIndex + .filter(f => columnPattern.matcher(f._1.name).matches()) + .map { case (f, i) => + toColumnResult(SparkCatalogShim.SESSION_CATALOG, null, v.table, f, i) + } + } + } } diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/shim/Shim_v3_0.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/shim/CatalogShim_v3_0.scala similarity index 73% rename from externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/shim/Shim_v3_0.scala rename to externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/shim/CatalogShim_v3_0.scala index 8fc9a213e..d759af6ef 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/shim/Shim_v3_0.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/shim/CatalogShim_v3_0.scala @@ -17,10 +17,14 @@ package org.apache.kyuubi.engine.spark.shim +import java.util.regex.Pattern + import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.connector.catalog.{CatalogExtension, CatalogPlugin, SupportsNamespaces, TableCatalog} -class Shim_v3_0 extends Shim_v2_4 { +import org.apache.kyuubi.engine.spark.shim.SparkCatalogShim.SESSION_CATALOG + +class CatalogShim_v3_0 extends CatalogShim_v2_4 { override def getCatalogs(spark: SparkSession): Seq[Row] = { @@ -37,9 +41,9 @@ class Shim_v3_0 extends Shim_v2_4 { (catalogs.keys ++: defaults).distinct.map(Row(_)) } - override def getCatalog(spark: SparkSession, catalogName: String): CatalogPlugin = { + private def getCatalog(spark: SparkSession, catalogName: String): CatalogPlugin = { val catalogManager = spark.sessionState.catalogManager - if (catalogName == null) { + if (catalogName == null || catalogName.isEmpty) { catalogManager.currentCatalog } else { catalogManager.catalog(catalogName) @@ -124,17 +128,17 @@ class Shim_v3_0 extends Shim_v2_4 { tablePattern: String, tableTypes: Set[String]): Seq[Row] = { val catalog = getCatalog(spark, catalogName) - val schemas = listNamespacesWithPattern(catalog, schemaPattern) + val namespaces = listNamespacesWithPattern(catalog, schemaPattern) catalog match { - case catalog if catalog.name() == SESSION_CATALOG => + case builtin if builtin.name() == SESSION_CATALOG => super.getCatalogTablesOrViews( spark, SESSION_CATALOG, schemaPattern, tablePattern, tableTypes) - case ce: CatalogExtension => - super.getCatalogTablesOrViews(spark, ce.name(), schemaPattern, tablePattern, tableTypes) case tc: TableCatalog => - schemas.flatMap { ns => - tc.listTables(ns) - }.map { ident => + val tp = tablePattern.r.pattern + val identifiers = namespaces.flatMap { ns => + tc.listTables(ns).filter(i => tp.matcher(quoteIfNeeded(i.name())).matches()) + } + identifiers.map { ident => val table = tc.loadTable(ident) // TODO: restore view type for session catalog val comment = table.properties().getOrDefault(TableCatalog.PROP_COMMENT, "") @@ -145,4 +149,34 @@ class Shim_v3_0 extends Shim_v2_4 { case _ => Seq.empty[Row] } } + + override protected def getColumnsByCatalog( + spark: SparkSession, + catalogName: String, + schemaPattern: String, + tablePattern: String, + columnPattern: Pattern): Seq[Row] = { + val catalog = getCatalog(spark, catalogName) + + catalog match { + case builtin if builtin.name() == SESSION_CATALOG => + super.getColumnsByCatalog( + spark, SESSION_CATALOG, schemaPattern, tablePattern, columnPattern) + + case tc: TableCatalog => + val namespaces = listNamespacesWithPattern(catalog, schemaPattern) + val tp = tablePattern.r.pattern + val identifiers = namespaces.flatMap { ns => + tc.listTables(ns).filter(i => tp.matcher(quoteIfNeeded(i.name())).matches()) + } + identifiers.flatMap { ident => + val table = tc.loadTable(ident) + val namespace = ident.namespace().map(quoteIfNeeded).mkString(".") + val tableName = quoteIfNeeded(ident.name()) + + table.schema.zipWithIndex.filter(f => columnPattern.matcher(f._1.name).matches()) + .map { case (f, i) => toColumnResult(tc.name(), namespace, tableName, f, i) } + } + } + } } diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/shim/SparkShim.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/shim/SparkCatalogShim.scala similarity index 62% rename from externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/shim/SparkShim.scala rename to externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/shim/SparkCatalogShim.scala index ee6c44ee0..6919eb773 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/shim/SparkShim.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/shim/SparkCatalogShim.scala @@ -19,14 +19,15 @@ package org.apache.kyuubi.engine.spark.shim import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.connector.catalog.CatalogPlugin +import org.apache.spark.sql.types.StructField import org.apache.kyuubi.{Logging, Utils} +import org.apache.kyuubi.schema.SchemaHelper /** * A shim that defines the interface interact with Spark's catalogs */ -trait SparkShim extends Logging { +trait SparkCatalogShim extends Logging { ///////////////////////////////////////////////////////////////////////////////////////////////// // Catalog // @@ -37,8 +38,6 @@ trait SparkShim extends Logging { */ def getCatalogs(spark: SparkSession): Seq[Row] - protected def getCatalog(spark: SparkSession, catalog: String): CatalogPlugin - protected def catalogExists(spark: SparkSession, catalog: String): Boolean ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -74,6 +73,47 @@ trait SparkShim extends Logging { schemaPattern: String, tablePattern: String): Seq[TableIdentifier] + + ///////////////////////////////////////////////////////////////////////////////////////////////// + // Columns // + ///////////////////////////////////////////////////////////////////////////////////////////////// + + def getColumns( + spark: SparkSession, + catalogName: String, + schemaPattern: String, + tablePattern: String, + columnPattern: String): Seq[Row] + + protected def toColumnResult( + catalog: String, db: String, table: String, col: StructField, pos: Int): Row = { + Row( + catalog, // TABLE_CAT + db, // TABLE_SCHEM + table, // TABLE_NAME + col.name, // COLUMN_NAME + SchemaHelper.toJavaSQLType(col.dataType), // DATA_TYPE + col.dataType.sql, // TYPE_NAME + SchemaHelper.getColumnSize(col.dataType).orNull, // COLUMN_SIZE + null, // BUFFER_LENGTH + SchemaHelper.getDecimalDigits(col.dataType).orNull, // DECIMAL_DIGITS + SchemaHelper.getNumPrecRadix(col.dataType).orNull, // NUM_PREC_RADIX + if (col.nullable) 1 else 0, // NULLABLE + col.getComment().getOrElse(""), // REMARKS + null, // COLUMN_DEF + null, // SQL_DATA_TYPE + null, // SQL_DATETIME_SUB + null, // CHAR_OCTET_LENGTH + pos, // ORDINAL_POSITION + "YES", // IS_NULLABLE + null, // SCOPE_CATALOG + null, // SCOPE_SCHEMA + null, // SCOPE_TABLE + null, // SOURCE_DATA_TYPE + "NO" // IS_AUTO_INCREMENT + ) + } + ///////////////////////////////////////////////////////////////////////////////////////////////// // Miscellaneous // ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -110,19 +150,20 @@ trait SparkShim extends Logging { tableTypes.exists(typ.equalsIgnoreCase) } - protected val SESSION_CATALOG: String = "spark_catalog" } -object SparkShim { - def apply(): SparkShim = { +object SparkCatalogShim { + def apply(): SparkCatalogShim = { val runtimeSparkVer = org.apache.spark.SPARK_VERSION val (major, minor) = Utils.majorMinorVersion(runtimeSparkVer) (major, minor) match { - case (3, _) => new Shim_v3_0 - case (2, _) => new Shim_v2_4 + case (3, _) => new CatalogShim_v3_0 + case (2, _) => new CatalogShim_v2_4 case _ => throw new IllegalArgumentException(s"Not Support spark version $runtimeSparkVer") } } + val SESSION_CATALOG: String = "spark_catalog" + val sparkTableTypes = Set("VIEW", "TABLE") } diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/schema/SchemaHelper.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/schema/SchemaHelper.scala index 8ff5bb78a..13a480273 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/schema/SchemaHelper.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/schema/SchemaHelper.scala @@ -85,4 +85,68 @@ object SchemaHelper { } tTableSchema } + + def toJavaSQLType(sparkType: DataType): Int = sparkType match { + case NullType => java.sql.Types.NULL + case BooleanType => java.sql.Types.BOOLEAN + case ByteType => java.sql.Types.TINYINT + case ShortType => java.sql.Types.SMALLINT + case IntegerType => java.sql.Types.INTEGER + case LongType => java.sql.Types.BIGINT + case FloatType => java.sql.Types.FLOAT + case DoubleType => java.sql.Types.DOUBLE + case StringType => java.sql.Types.VARCHAR + case _: DecimalType => java.sql.Types.DECIMAL + case DateType => java.sql.Types.DATE + case TimestampType => java.sql.Types.TIMESTAMP + case BinaryType => java.sql.Types.BINARY + case _: ArrayType => java.sql.Types.ARRAY + case _: MapType => java.sql.Types.JAVA_OBJECT + case _: StructType => java.sql.Types.STRUCT + case _ => java.sql.Types.OTHER + } + + + /** + * For boolean, numeric and datetime types, it returns the default size of its catalyst type + * For struct type, when its elements are fixed-size, the summation of all element sizes will be + * returned. + * For array, map, string, and binaries, the column size is variable, return null as unknown. + */ + def getColumnSize(sparkType: DataType): Option[Int] = sparkType match { + case dt @ (BooleanType | _: NumericType | DateType | TimestampType | + CalendarIntervalType | NullType) => + Some(dt.defaultSize) + case StructType(fields) => + val sizeArr = fields.map(f => getColumnSize(f.dataType)) + if (sizeArr.contains(None)) { + None + } else { + Some(sizeArr.map(_.get).sum) + } + case _ => None + } + + + /** + * The number of fractional digits for this type. + * Null is returned for data types where this is not applicable. + * For boolean and integrals, the decimal digits is 0 + * For floating types, we follow the IEEE Standard for Floating-Point Arithmetic (IEEE 754) + * For timestamp values, we support microseconds + * For decimals, it returns the scale + */ + def getDecimalDigits(sparkType: DataType): Option[Int] = sparkType match { + case BooleanType | _: IntegerType => Some(0) + case FloatType => Some(7) + case DoubleType => Some(15) + case d: DecimalType => Some(d.scale) + case TimestampType => Some(6) + case _ => None + } + + def getNumPrecRadix(typ: DataType): Option[Int] = typ match { + case _: NumericType => Some(10) + case _ => None + } } diff --git a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkOperationSuite.scala b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkOperationSuite.scala index d745a27ba..47b904dcc 100644 --- a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkOperationSuite.scala +++ b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/engine/spark/operation/SparkOperationSuite.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.types._ import org.apache.kyuubi.Utils import org.apache.kyuubi.engine.spark.WithSparkSQLEngine -import org.apache.kyuubi.engine.spark.shim.SparkShim +import org.apache.kyuubi.engine.spark.shim.SparkCatalogShim import org.apache.kyuubi.operation.JDBCTests import org.apache.kyuubi.operation.meta.ResultSetSchemaConstant._ @@ -44,7 +44,7 @@ class SparkOperationSuite extends WithSparkSQLEngine with JDBCTests { withJdbcStatement() { statement => val meta = statement.getConnection.getMetaData val types = meta.getTableTypes - val expected = SparkShim.sparkTableTypes.toIterator + val expected = SparkCatalogShim.sparkTableTypes.toIterator while (types.next()) { assert(types.getString(TABLE_TYPE) === expected.next()) } @@ -98,7 +98,7 @@ class SparkOperationSuite extends WithSparkSQLEngine with JDBCTests { var pos = 0 while (rowSet.next()) { - assert(rowSet.getString(TABLE_CAT) === null) + assert(rowSet.getString(TABLE_CAT) === SparkCatalogShim.SESSION_CATALOG) assert(rowSet.getString(TABLE_SCHEM) === dftSchema) assert(rowSet.getString(TABLE_NAME) === tableName) assert(rowSet.getString(COLUMN_NAME) === schema(pos).name) @@ -142,9 +142,6 @@ class SparkOperationSuite extends WithSparkSQLEngine with JDBCTests { val rowSet = metaData.getColumns(null, "*", "not_exist", "not_exist") assert(!rowSet.next()) - - val e1 = intercept[HiveSQLException](metaData.getColumns(null, null, null, "*")) - assert(e1.getCause.getMessage contains "Dangling meta character '*' near index 0\n*\n^") } } @@ -157,7 +154,7 @@ class SparkOperationSuite extends WithSparkSQLEngine with JDBCTests { val data = statement.getConnection.getMetaData val rowSet = data.getColumns("", "global_temp", viewName, null) while (rowSet.next()) { - assert(rowSet.getString(TABLE_CAT) === null) + assert(rowSet.getString(TABLE_CAT) === SparkCatalogShim.SESSION_CATALOG) assert(rowSet.getString(TABLE_SCHEM) === "global_temp") assert(rowSet.getString(TABLE_NAME) === viewName) assert(rowSet.getString(COLUMN_NAME) === "i") @@ -184,20 +181,20 @@ class SparkOperationSuite extends WithSparkSQLEngine with JDBCTests { val data = statement.getConnection.getMetaData val rowSet = data.getColumns("", "global_temp", viewName, "n") while (rowSet.next()) { - assert(rowSet.getString("TABLE_CAT") === null) - assert(rowSet.getString("TABLE_SCHEM") === "global_temp") - assert(rowSet.getString("TABLE_NAME") === viewName) - assert(rowSet.getString("COLUMN_NAME") === "n") - assert(rowSet.getInt("DATA_TYPE") === java.sql.Types.NULL) - assert(rowSet.getString("TYPE_NAME").equalsIgnoreCase(NullType.sql)) - assert(rowSet.getInt("COLUMN_SIZE") === 1) - assert(rowSet.getInt("DECIMAL_DIGITS") === 0) - assert(rowSet.getInt("NUM_PREC_RADIX") === 0) - assert(rowSet.getInt("NULLABLE") === 1) - assert(rowSet.getString("REMARKS") === "") - assert(rowSet.getInt("ORDINAL_POSITION") === 0) - assert(rowSet.getString("IS_NULLABLE") === "YES") - assert(rowSet.getString("IS_AUTO_INCREMENT") === "NO") + assert(rowSet.getString(TABLE_CAT) === SparkCatalogShim.SESSION_CATALOG) + assert(rowSet.getString(TABLE_SCHEM) === "global_temp") + assert(rowSet.getString(TABLE_NAME) === viewName) + assert(rowSet.getString(COLUMN_NAME) === "n") + assert(rowSet.getInt(DATA_TYPE) === java.sql.Types.NULL) + assert(rowSet.getString(TYPE_NAME).equalsIgnoreCase(NullType.sql)) + assert(rowSet.getInt(COLUMN_SIZE) === 1) + assert(rowSet.getInt(DECIMAL_DIGITS) === 0) + assert(rowSet.getInt(NUM_PREC_RADIX) === 0) + assert(rowSet.getInt(NULLABLE) === 1) + assert(rowSet.getString(REMARKS) === "") + assert(rowSet.getInt(ORDINAL_POSITION) === 0) + assert(rowSet.getString(IS_NULLABLE) === "YES") + assert(rowSet.getString(IS_AUTO_INCREMENT) === "NO") } } } diff --git a/kyuubi-common/src/test/scala/org/apache/kyuubi/operation/BasicIcebergJDBCTests.scala b/kyuubi-common/src/test/scala/org/apache/kyuubi/operation/BasicIcebergJDBCTests.scala index 027f00cdb..602dbad6f 100644 --- a/kyuubi-common/src/test/scala/org/apache/kyuubi/operation/BasicIcebergJDBCTests.scala +++ b/kyuubi-common/src/test/scala/org/apache/kyuubi/operation/BasicIcebergJDBCTests.scala @@ -152,4 +152,55 @@ trait BasicIcebergJDBCTests extends JDBCTestUtils { } } + + test("get columns operation") { + val dataTypes = Seq("boolean", "int", "bigint", + "float", "double", "decimal(38,20)", "decimal(10,2)", + "string", "array", "array", "map", + "date", "timestamp", "struct<`X`: bigint, `Y`: double>", "binary", "struct<`X`: string>") + val cols = dataTypes.zipWithIndex.map { case (dt, idx) => s"c$idx" -> dt } + val (colNames, _) = cols.unzip + + val tableName = "iceberg_get_col_operation" + + val ddl = + s""" + |CREATE TABLE IF NOT EXISTS $catalog.$dftSchema.$tableName ( + | ${cols.map { case (cn, dt) => cn + " " + dt }.mkString(",\n")} + |) + |USING iceberg""".stripMargin + + + withJdbcStatement(tableName) { statement => + statement.execute(ddl) + + val metaData = statement.getConnection.getMetaData + + Seq("%", null, ".*", "c.*") foreach { columnPattern => + val rowSet = metaData.getColumns(catalog, dftSchema, tableName, columnPattern) + + import java.sql.Types._ + val expectedJavaTypes = Seq(BOOLEAN, INTEGER, BIGINT, FLOAT, DOUBLE, + DECIMAL, DECIMAL, VARCHAR, ARRAY, ARRAY, JAVA_OBJECT, DATE, TIMESTAMP, STRUCT, BINARY, + STRUCT) + + var pos = 0 + + while (rowSet.next()) { + assert(rowSet.getString(TABLE_CAT) === catalog) + assert(rowSet.getString(TABLE_SCHEM) === dftSchema) + assert(rowSet.getString(TABLE_NAME) === tableName) + assert(rowSet.getString(COLUMN_NAME) === colNames(pos)) + assert(rowSet.getInt(DATA_TYPE) === expectedJavaTypes(pos)) + assert(rowSet.getString(TYPE_NAME) equalsIgnoreCase dataTypes(pos)) + pos += 1 + } + + assert(pos === dataTypes.size, "all columns should have been verified") + } + + val rowSet = metaData.getColumns(catalog, "*", "not_exist", "not_exist") + assert(!rowSet.next()) + } + } }