[KYUUBI #310] GetColumns supports DSv2 and keeps its backward compatibility

![yaooqinn](https://badgen.net/badge/Hello/yaooqinn/green) [![Closes #358](https://badgen.net/badge/Preview/Closes%20%23358/blue)](https://github.com/yaooqinn/kyuubi/pull/358) ![351](https://badgen.net/badge/%2B/351/red) ![223](https://badgen.net/badge/-/223/green) ![3](https://badgen.net/badge/commits/3/yellow) [&#10088;?&#10089;](https://pullrequestbadge.com/?utm_medium=github&utm_source=yaooqinn&utm_campaign=badge_info)<!-- PR-BADGE: PLEASE DO NOT REMOVE THIS COMMENT -->

<!--
Thanks for sending a pull request!

Here are some tips for you:
  1. If this is your first time, please read our contributor guidelines: https://kyuubi.readthedocs.io/en/latest/community/contributions.html
  2. If the PR is related to an issue in https://github.com/yaooqinn/kyuubi/issues, add '[KYUUBI #XXXX]' in your PR title, e.g., '[KYUUBI #XXXX] Your PR title ...'.
  3. If the PR is unfinished, add '[WIP]' in your PR title, e.g., '[WIP][KYUUBI #XXXX] Your PR title ...'.
-->
close #310
### _Why are the changes needed?_
<!--
Please clarify why the changes are needed. For instance,
  1. If you add a feature, you can talk about the use case of it.
  2. If you fix a bug, you can clarify why it is a bug.
-->

### _How was this patch tested?_
- [x] Add some test cases that check the changes thoroughly including negative and positive cases if possible

- [ ] Add screenshots for manual tests if appropriate

- [x] [Run test](https://kyuubi.readthedocs.io/en/latest/tools/testing.html#running-tests) locally before make a pull request

Closes #358 from yaooqinn/310.

8cb30a4 [Kent Yao] sql wildcards to java regex
34d2c3a [Kent Yao] Merge branch 'master' into 310
d332be5 [Kent Yao] [KYUUBI #310] GetColumns supports DSv2 and keeps its backward compatibility

Authored-by: Kent Yao <yao@apache.org>
Signed-off-by: Kent Yao <yao@apache.org>
This commit is contained in:
Kent Yao 2021-02-23 20:43:02 +08:00
parent 54c306b967
commit ee50890b00
No known key found for this signature in database
GPG Key ID: F7051850A0AF904D
14 changed files with 350 additions and 222 deletions

View File

@ -20,7 +20,7 @@ package org.apache.kyuubi.engine.spark.operation
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types.StructType
import org.apache.kyuubi.engine.spark.shim.SparkShim
import org.apache.kyuubi.engine.spark.shim.SparkCatalogShim
import org.apache.kyuubi.operation.OperationType
import org.apache.kyuubi.operation.meta.ResultSetSchemaConstant.TABLE_CAT
import org.apache.kyuubi.session.Session
@ -35,7 +35,7 @@ class GetCatalogs(spark: SparkSession, session: Session)
override protected def runInternal(): Unit = {
try {
iter = SparkShim().getCatalogs(spark).toIterator
iter = SparkCatalogShim().getCatalogs(spark).toIterator
} catch onError()
}
}

View File

@ -17,14 +17,10 @@
package org.apache.kyuubi.engine.spark.operation
import java.util.regex.Pattern
import scala.collection.mutable.ArrayBuffer
import org.apache.commons.lang3.StringUtils
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, CalendarIntervalType, DataType, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, NullType, NumericType, ShortType, StringType, StructField, StructType, TimestampType}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types._
import org.apache.kyuubi.engine.spark.shim.SparkCatalogShim
import org.apache.kyuubi.operation.OperationType
import org.apache.kyuubi.operation.meta.ResultSetSchemaConstant._
import org.apache.kyuubi.session.Session
@ -46,95 +42,6 @@ class GetColumns(
s" columnPattern: $columnName]"
}
private def toJavaSQLType(typ: DataType): Int = typ match {
case NullType => java.sql.Types.NULL
case BooleanType => java.sql.Types.BOOLEAN
case ByteType => java.sql.Types.TINYINT
case ShortType => java.sql.Types.SMALLINT
case IntegerType => java.sql.Types.INTEGER
case LongType => java.sql.Types.BIGINT
case FloatType => java.sql.Types.FLOAT
case DoubleType => java.sql.Types.DOUBLE
case StringType => java.sql.Types.VARCHAR
case _: DecimalType => java.sql.Types.DECIMAL
case DateType => java.sql.Types.DATE
case TimestampType => java.sql.Types.TIMESTAMP
case BinaryType => java.sql.Types.BINARY
case _: ArrayType => java.sql.Types.ARRAY
case _: MapType => java.sql.Types.JAVA_OBJECT
case _: StructType => java.sql.Types.STRUCT
case _ => java.sql.Types.OTHER
}
/**
* For boolean, numeric and datetime types, it returns the default size of its catalyst type
* For struct type, when its elements are fixed-size, the summation of all element sizes will be
* returned.
* For array, map, string, and binaries, the column size is variable, return null as unknown.
*/
private def getColumnSize(typ: DataType): Option[Int] = typ match {
case dt @ (BooleanType | _: NumericType | DateType | TimestampType |
CalendarIntervalType | NullType) =>
Some(dt.defaultSize)
case StructType(fields) =>
val sizeArr = fields.map(f => getColumnSize(f.dataType))
if (sizeArr.contains(None)) {
None
} else {
Some(sizeArr.map(_.get).sum)
}
case _ => None
}
/**
* The number of fractional digits for this type.
* Null is returned for data types where this is not applicable.
* For boolean and integrals, the decimal digits is 0
* For floating types, we follow the IEEE Standard for Floating-Point Arithmetic (IEEE 754)
* For timestamp values, we support microseconds
* For decimals, it returns the scale
*/
private def getDecimalDigits(typ: DataType): Option[Int] = typ match {
case BooleanType | _: IntegerType => Some(0)
case FloatType => Some(7)
case DoubleType => Some(15)
case d: DecimalType => Some(d.scale)
case TimestampType => Some(6)
case _ => None
}
private def getNumPrecRadix(typ: DataType): Option[Int] = typ match {
case _: NumericType => Some(10)
case _ => None
}
private def toRow(db: String, table: String, col: StructField, pos: Int): Row = {
Row(
null, // TABLE_CAT
db, // TABLE_SCHEM
table, // TABLE_NAME
col.name, // COLUMN_NAME
toJavaSQLType(col.dataType), // DATA_TYPE
col.dataType.sql, // TYPE_NAME
getColumnSize(col.dataType).orNull, // COLUMN_SIZE
null, // BUFFER_LENGTH
getDecimalDigits(col.dataType).orNull, // DECIMAL_DIGITS
getNumPrecRadix(col.dataType).orNull, // NUM_PREC_RADIX
if (col.nullable) 1 else 0, // NULLABLE
col.getComment().getOrElse(""), // REMARKS
null, // COLUMN_DEF
null, // SQL_DATA_TYPE
null, // SQL_DATETIME_SUB
null, // CHAR_OCTET_LENGTH
pos, // ORDINAL_POSITION
"YES", // IS_NULLABLE
null, // SCOPE_CATALOG
null, // SCOPE_SCHEMA
null, // SCOPE_TABLE
null, // SOURCE_DATA_TYPE
"NO" // IS_AUTO_INCREMENT
)
}
override protected def resultSchema: StructType = {
new StructType()
.add(TABLE_CAT, "string", nullable = true, "Catalog name. NULL if not applicable")
@ -178,45 +85,12 @@ class GetColumns(
override protected def runInternal(): Unit = {
try {
val catalog = spark.sessionState.catalog
val schemaPattern = convertSchemaPattern(schemaName)
val tablePattern = convertIdentifierPattern(tableName, datanucleusFormat = true)
val columnPattern =
Pattern.compile(convertIdentifierPattern(columnName, datanucleusFormat = false))
val tables: Seq[Row] = catalog.listDatabases(schemaPattern).flatMap { db =>
val identifiers =
catalog.listTables(db, tablePattern, includeLocalTempViews = false)
catalog.getTablesByName(identifiers).flatMap { t =>
t.schema.zipWithIndex
.filter { f => columnPattern.matcher(f._1.name).matches() }
.map { case (f, i) => toRow(t.database, t.identifier.table, f, i)
}
}
}
val gviews = new ArrayBuffer[Row]()
val globalTmpDb = catalog.globalTempViewManager.database
if (StringUtils.isEmpty(schemaName) || schemaName == "*"
|| Pattern.compile(convertSchemaPattern(schemaName, false))
.matcher(globalTmpDb).matches()) {
catalog.globalTempViewManager.listViewNames(tablePattern).foreach { v =>
catalog.globalTempViewManager.get(v).foreach { plan =>
plan.schema.zipWithIndex
.filter { f => columnPattern.matcher(f._1.name).matches() }
.foreach { case (f, i) => gviews += toRow(globalTmpDb, v, f, i) }
}
}
}
val views: Seq[Row] = catalog.listLocalTempViews(tablePattern)
.map(v => (v, catalog.getTempView(v.table).get))
.flatMap { case (v, plan) =>
plan.schema.zipWithIndex
.filter(f => columnPattern.matcher(f._1.name).matches())
.map { case (f, i) => toRow(null, v.table, f, i) }
}
iter = (tables ++ gviews ++ views).toList.iterator
val schemaPattern = toJavaRegex(schemaName)
val tablePattern = toJavaRegex(tableName)
val columnPattern = toJavaRegex(columnName)
iter = SparkCatalogShim()
.getColumns(spark, catalogName, schemaPattern, tablePattern, columnPattern)
.toList.iterator
} catch {
onError()
}

View File

@ -55,8 +55,8 @@ class GetFunctions(
override protected def runInternal(): Unit = {
try {
val schemaPattern = convertSchemaPattern(schemaName)
val functionPattern = convertIdentifierPattern(functionName, datanucleusFormat = false)
val schemaPattern = toJavaRegex(schemaName)
val functionPattern = toJavaRegex(functionName)
val catalog = spark.sessionState.catalog
val a: Seq[Row] = catalog.listDatabases(schemaPattern).flatMap { db =>
catalog.listFunctions(db, functionPattern).map { case (f, _) =>

View File

@ -20,7 +20,7 @@ package org.apache.kyuubi.engine.spark.operation
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types.StructType
import org.apache.kyuubi.engine.spark.shim.SparkShim
import org.apache.kyuubi.engine.spark.shim.SparkCatalogShim
import org.apache.kyuubi.operation.OperationType
import org.apache.kyuubi.operation.meta.ResultSetSchemaConstant._
import org.apache.kyuubi.session.Session
@ -40,8 +40,8 @@ class GetSchemas(spark: SparkSession, session: Session, catalogName: String, sch
override protected def runInternal(): Unit = {
try {
val schemaPattern = convertSchemaPattern(schema, datanucleusFormat = false)
val rows = SparkShim().getSchemas(spark, catalogName, schemaPattern)
val schemaPattern = toJavaRegex(schema)
val rows = SparkCatalogShim().getSchemas(spark, catalogName, schemaPattern)
iter = rows.toList.toIterator
} catch onError()
}

View File

@ -20,7 +20,7 @@ package org.apache.kyuubi.engine.spark.operation
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.types.StructType
import org.apache.kyuubi.engine.spark.shim.SparkShim
import org.apache.kyuubi.engine.spark.shim.SparkCatalogShim
import org.apache.kyuubi.operation.OperationType
import org.apache.kyuubi.operation.meta.ResultSetSchemaConstant._
import org.apache.kyuubi.session.Session
@ -33,6 +33,6 @@ class GetTableTypes(spark: SparkSession, session: Session)
}
override protected def runInternal(): Unit = {
iter = SparkShim.sparkTableTypes.map(Row(_)).toList.iterator
iter = SparkCatalogShim.sparkTableTypes.map(Row(_)).toList.iterator
}
}

View File

@ -18,10 +18,9 @@
package org.apache.kyuubi.engine.spark.operation
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.catalog.CatalogTableType
import org.apache.spark.sql.types.StructType
import org.apache.kyuubi.engine.spark.shim.SparkShim
import org.apache.kyuubi.engine.spark.shim.SparkCatalogShim
import org.apache.kyuubi.operation.OperationType
import org.apache.kyuubi.operation.meta.ResultSetSchemaConstant._
import org.apache.kyuubi.session.Session
@ -61,9 +60,9 @@ class GetTables(
override protected def runInternal(): Unit = {
try {
val schemaPattern = convertSchemaPattern(schema, datanucleusFormat = false)
val tablePattern = convertIdentifierPattern(tableName, datanucleusFormat = true)
val sparkShim = SparkShim()
val schemaPattern = toJavaRegex(schema)
val tablePattern = toJavaRegex(tableName)
val sparkShim = SparkCatalogShim()
val catalogTablesAndViews =
sparkShim.getCatalogTablesOrViews(spark, catalog, schemaPattern, tablePattern, tableTypes)

View File

@ -17,6 +17,8 @@
package org.apache.kyuubi.engine.spark.operation
import java.util.regex.Pattern
import org.apache.commons.lang3.StringUtils
import org.apache.hive.service.rpc.thrift.{TRowSet, TTableSchema}
import org.apache.spark.sql.{Row, SparkSession}
@ -49,38 +51,33 @@ abstract class SparkOperation(spark: SparkSession, opType: OperationType, sessio
}
}
private def convertPattern(pattern: String, datanucleusFormat: Boolean): String = {
val wStr = if (datanucleusFormat) "*" else ".*"
pattern
.replaceAll("([^\\\\])%", "$1" + wStr)
.replaceAll("\\\\%", "%")
.replaceAll("^%", wStr)
.replaceAll("([^\\\\])_", "$1.")
.replaceAll("\\\\_", "_")
.replaceAll("^_", ".")
}
/**
* Convert wildcards and escape sequence of schema pattern from JDBC format to datanucleous/regex
* The schema pattern treats empty string also as wildcard
* convert SQL 'like' pattern to a Java regular expression.
*
* Underscores (_) are converted to '.' and percent signs (%) are converted to '.*'.
*
* @param input the SQL pattern to convert
* @return the equivalent Java regular expression of the pattern
*/
protected def convertSchemaPattern(pattern: String, datanucleusFormat: Boolean = true): String = {
if (StringUtils.isEmpty(pattern) || pattern == "*") {
convertPattern("%", datanucleusFormat)
def toJavaRegex(input: String): String = {
val res = if (StringUtils.isEmpty(input) || input == "*") {
"%"
} else {
convertPattern(pattern, datanucleusFormat)
input
}
}
val in = res.toIterator
val out = new StringBuilder()
/**
* Convert wildcards and escape sequence from JDBC format to datanucleous/regex
*/
protected def convertIdentifierPattern(pattern: String, datanucleusFormat: Boolean): String = {
if (pattern == null) {
convertPattern("%", datanucleusFormat)
} else {
convertPattern(pattern, datanucleusFormat)
while (in.hasNext) {
in.next match {
case c if c == '\\' && in.hasNext => Pattern.quote(Character.toString(in.next()))
case c if c == '\\' && !in.hasNext => Pattern.quote(Character.toString(c))
case '_' => out ++= "."
case '%' => out ++= ".*"
case c => out ++= Character.toString(c)
}
}
out.result()
}
protected def onError(cancel: Boolean = false): PartialFunction[Throwable, Unit] = {

View File

@ -24,7 +24,7 @@ import scala.collection.JavaConverters._
import org.apache.spark.sql.SparkSession
import org.apache.kyuubi.KyuubiSQLException
import org.apache.kyuubi.engine.spark.shim.SparkShim
import org.apache.kyuubi.engine.spark.shim.SparkCatalogShim
import org.apache.kyuubi.operation.{Operation, OperationManager}
import org.apache.kyuubi.session.{Session, SessionHandle}
@ -91,7 +91,7 @@ class SparkSQLOperationManager private (name: String) extends OperationManager(n
tableTypes: java.util.List[String]): Operation = {
val spark = getSparkSession(session.handle)
val tTypes = if (tableTypes == null || tableTypes.isEmpty) {
SparkShim.sparkTableTypes
SparkCatalogShim.sparkTableTypes
} else {
tableTypes.asScala.toSet
}

View File

@ -17,15 +17,16 @@
package org.apache.kyuubi.engine.spark.shim
import java.util.regex.Pattern
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.connector.catalog.CatalogPlugin
class Shim_v2_4 extends SparkShim {
class CatalogShim_v2_4 extends SparkCatalogShim {
override def getCatalogs(spark: SparkSession): Seq[Row] = Seq(Row(""))
override protected def getCatalog(spark: SparkSession, catalog: String): CatalogPlugin = null
override def getCatalogs(spark: SparkSession): Seq[Row] = {
Seq(Row(SparkCatalogShim.SESSION_CATALOG))
}
override protected def catalogExists(spark: SparkSession, catalog: String): Boolean = false
@ -86,4 +87,74 @@ class Shim_v2_4 extends SparkShim {
spark.sessionState.catalog.listLocalTempViews(tablePattern)
}
}
override def getColumns(
spark: SparkSession,
catalogName: String,
schemaPattern: String,
tablePattern: String,
columnPattern: String): Seq[Row] = {
val cp = columnPattern.r.pattern
val byCatalog = getColumnsByCatalog(spark, catalogName, schemaPattern, tablePattern, cp)
val byGlobalTmpDB = getColumnsByGlobalTempViewManager(spark, schemaPattern, tablePattern, cp)
val byLocalTmp = getColumnsByLocalTempViews(spark, tablePattern, cp)
byCatalog ++ byGlobalTmpDB ++ byLocalTmp
}
protected def getColumnsByCatalog(
spark: SparkSession,
catalogName: String,
schemaPattern: String,
tablePattern: String,
columnPattern: Pattern): Seq[Row] = {
val catalog = spark.sessionState.catalog
val databases = catalog.listDatabases(schemaPattern)
databases.flatMap { db =>
val identifiers = catalog.listTables(db, tablePattern, includeLocalTempViews = true)
catalog.getTablesByName(identifiers).flatMap { t =>
t.schema.zipWithIndex.filter(f => columnPattern.matcher(f._1.name).matches())
.map { case (f, i) => toColumnResult(catalogName, t.database, t.identifier.table, f, i) }
}
}
}
protected def getColumnsByGlobalTempViewManager(
spark: SparkSession,
schemaPattern: String,
tablePattern: String,
columnPattern: Pattern): Seq[Row] = {
val catalog = spark.sessionState.catalog
getGlobalTempViewManager(spark, schemaPattern).flatMap { globalTmpDb =>
catalog.globalTempViewManager.listViewNames(tablePattern).flatMap { v =>
catalog.globalTempViewManager.get(v).map { plan =>
plan.schema.zipWithIndex.filter(f => columnPattern.matcher(f._1.name).matches())
.map { case (f, i) =>
toColumnResult(SparkCatalogShim.SESSION_CATALOG, globalTmpDb, v, f, i)
}
}
}.flatten
}
}
protected def getColumnsByLocalTempViews(
spark: SparkSession,
tablePattern: String,
columnPattern: Pattern): Seq[Row] = {
val catalog = spark.sessionState.catalog
catalog.listLocalTempViews(tablePattern)
.map(v => (v, catalog.getTempView(v.table).get))
.flatMap { case (v, plan) =>
plan.schema.zipWithIndex
.filter(f => columnPattern.matcher(f._1.name).matches())
.map { case (f, i) =>
toColumnResult(SparkCatalogShim.SESSION_CATALOG, null, v.table, f, i)
}
}
}
}

View File

@ -17,10 +17,14 @@
package org.apache.kyuubi.engine.spark.shim
import java.util.regex.Pattern
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.connector.catalog.{CatalogExtension, CatalogPlugin, SupportsNamespaces, TableCatalog}
class Shim_v3_0 extends Shim_v2_4 {
import org.apache.kyuubi.engine.spark.shim.SparkCatalogShim.SESSION_CATALOG
class CatalogShim_v3_0 extends CatalogShim_v2_4 {
override def getCatalogs(spark: SparkSession): Seq[Row] = {
@ -37,9 +41,9 @@ class Shim_v3_0 extends Shim_v2_4 {
(catalogs.keys ++: defaults).distinct.map(Row(_))
}
override def getCatalog(spark: SparkSession, catalogName: String): CatalogPlugin = {
private def getCatalog(spark: SparkSession, catalogName: String): CatalogPlugin = {
val catalogManager = spark.sessionState.catalogManager
if (catalogName == null) {
if (catalogName == null || catalogName.isEmpty) {
catalogManager.currentCatalog
} else {
catalogManager.catalog(catalogName)
@ -124,17 +128,17 @@ class Shim_v3_0 extends Shim_v2_4 {
tablePattern: String,
tableTypes: Set[String]): Seq[Row] = {
val catalog = getCatalog(spark, catalogName)
val schemas = listNamespacesWithPattern(catalog, schemaPattern)
val namespaces = listNamespacesWithPattern(catalog, schemaPattern)
catalog match {
case catalog if catalog.name() == SESSION_CATALOG =>
case builtin if builtin.name() == SESSION_CATALOG =>
super.getCatalogTablesOrViews(
spark, SESSION_CATALOG, schemaPattern, tablePattern, tableTypes)
case ce: CatalogExtension =>
super.getCatalogTablesOrViews(spark, ce.name(), schemaPattern, tablePattern, tableTypes)
case tc: TableCatalog =>
schemas.flatMap { ns =>
tc.listTables(ns)
}.map { ident =>
val tp = tablePattern.r.pattern
val identifiers = namespaces.flatMap { ns =>
tc.listTables(ns).filter(i => tp.matcher(quoteIfNeeded(i.name())).matches())
}
identifiers.map { ident =>
val table = tc.loadTable(ident)
// TODO: restore view type for session catalog
val comment = table.properties().getOrDefault(TableCatalog.PROP_COMMENT, "")
@ -145,4 +149,34 @@ class Shim_v3_0 extends Shim_v2_4 {
case _ => Seq.empty[Row]
}
}
override protected def getColumnsByCatalog(
spark: SparkSession,
catalogName: String,
schemaPattern: String,
tablePattern: String,
columnPattern: Pattern): Seq[Row] = {
val catalog = getCatalog(spark, catalogName)
catalog match {
case builtin if builtin.name() == SESSION_CATALOG =>
super.getColumnsByCatalog(
spark, SESSION_CATALOG, schemaPattern, tablePattern, columnPattern)
case tc: TableCatalog =>
val namespaces = listNamespacesWithPattern(catalog, schemaPattern)
val tp = tablePattern.r.pattern
val identifiers = namespaces.flatMap { ns =>
tc.listTables(ns).filter(i => tp.matcher(quoteIfNeeded(i.name())).matches())
}
identifiers.flatMap { ident =>
val table = tc.loadTable(ident)
val namespace = ident.namespace().map(quoteIfNeeded).mkString(".")
val tableName = quoteIfNeeded(ident.name())
table.schema.zipWithIndex.filter(f => columnPattern.matcher(f._1.name).matches())
.map { case (f, i) => toColumnResult(tc.name(), namespace, tableName, f, i) }
}
}
}
}

View File

@ -19,14 +19,15 @@ package org.apache.kyuubi.engine.spark.shim
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.connector.catalog.CatalogPlugin
import org.apache.spark.sql.types.StructField
import org.apache.kyuubi.{Logging, Utils}
import org.apache.kyuubi.schema.SchemaHelper
/**
* A shim that defines the interface interact with Spark's catalogs
*/
trait SparkShim extends Logging {
trait SparkCatalogShim extends Logging {
/////////////////////////////////////////////////////////////////////////////////////////////////
// Catalog //
@ -37,8 +38,6 @@ trait SparkShim extends Logging {
*/
def getCatalogs(spark: SparkSession): Seq[Row]
protected def getCatalog(spark: SparkSession, catalog: String): CatalogPlugin
protected def catalogExists(spark: SparkSession, catalog: String): Boolean
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -74,6 +73,47 @@ trait SparkShim extends Logging {
schemaPattern: String,
tablePattern: String): Seq[TableIdentifier]
/////////////////////////////////////////////////////////////////////////////////////////////////
// Columns //
/////////////////////////////////////////////////////////////////////////////////////////////////
def getColumns(
spark: SparkSession,
catalogName: String,
schemaPattern: String,
tablePattern: String,
columnPattern: String): Seq[Row]
protected def toColumnResult(
catalog: String, db: String, table: String, col: StructField, pos: Int): Row = {
Row(
catalog, // TABLE_CAT
db, // TABLE_SCHEM
table, // TABLE_NAME
col.name, // COLUMN_NAME
SchemaHelper.toJavaSQLType(col.dataType), // DATA_TYPE
col.dataType.sql, // TYPE_NAME
SchemaHelper.getColumnSize(col.dataType).orNull, // COLUMN_SIZE
null, // BUFFER_LENGTH
SchemaHelper.getDecimalDigits(col.dataType).orNull, // DECIMAL_DIGITS
SchemaHelper.getNumPrecRadix(col.dataType).orNull, // NUM_PREC_RADIX
if (col.nullable) 1 else 0, // NULLABLE
col.getComment().getOrElse(""), // REMARKS
null, // COLUMN_DEF
null, // SQL_DATA_TYPE
null, // SQL_DATETIME_SUB
null, // CHAR_OCTET_LENGTH
pos, // ORDINAL_POSITION
"YES", // IS_NULLABLE
null, // SCOPE_CATALOG
null, // SCOPE_SCHEMA
null, // SCOPE_TABLE
null, // SOURCE_DATA_TYPE
"NO" // IS_AUTO_INCREMENT
)
}
/////////////////////////////////////////////////////////////////////////////////////////////////
// Miscellaneous //
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -110,19 +150,20 @@ trait SparkShim extends Logging {
tableTypes.exists(typ.equalsIgnoreCase)
}
protected val SESSION_CATALOG: String = "spark_catalog"
}
object SparkShim {
def apply(): SparkShim = {
object SparkCatalogShim {
def apply(): SparkCatalogShim = {
val runtimeSparkVer = org.apache.spark.SPARK_VERSION
val (major, minor) = Utils.majorMinorVersion(runtimeSparkVer)
(major, minor) match {
case (3, _) => new Shim_v3_0
case (2, _) => new Shim_v2_4
case (3, _) => new CatalogShim_v3_0
case (2, _) => new CatalogShim_v2_4
case _ => throw new IllegalArgumentException(s"Not Support spark version $runtimeSparkVer")
}
}
val SESSION_CATALOG: String = "spark_catalog"
val sparkTableTypes = Set("VIEW", "TABLE")
}

View File

@ -85,4 +85,68 @@ object SchemaHelper {
}
tTableSchema
}
def toJavaSQLType(sparkType: DataType): Int = sparkType match {
case NullType => java.sql.Types.NULL
case BooleanType => java.sql.Types.BOOLEAN
case ByteType => java.sql.Types.TINYINT
case ShortType => java.sql.Types.SMALLINT
case IntegerType => java.sql.Types.INTEGER
case LongType => java.sql.Types.BIGINT
case FloatType => java.sql.Types.FLOAT
case DoubleType => java.sql.Types.DOUBLE
case StringType => java.sql.Types.VARCHAR
case _: DecimalType => java.sql.Types.DECIMAL
case DateType => java.sql.Types.DATE
case TimestampType => java.sql.Types.TIMESTAMP
case BinaryType => java.sql.Types.BINARY
case _: ArrayType => java.sql.Types.ARRAY
case _: MapType => java.sql.Types.JAVA_OBJECT
case _: StructType => java.sql.Types.STRUCT
case _ => java.sql.Types.OTHER
}
/**
* For boolean, numeric and datetime types, it returns the default size of its catalyst type
* For struct type, when its elements are fixed-size, the summation of all element sizes will be
* returned.
* For array, map, string, and binaries, the column size is variable, return null as unknown.
*/
def getColumnSize(sparkType: DataType): Option[Int] = sparkType match {
case dt @ (BooleanType | _: NumericType | DateType | TimestampType |
CalendarIntervalType | NullType) =>
Some(dt.defaultSize)
case StructType(fields) =>
val sizeArr = fields.map(f => getColumnSize(f.dataType))
if (sizeArr.contains(None)) {
None
} else {
Some(sizeArr.map(_.get).sum)
}
case _ => None
}
/**
* The number of fractional digits for this type.
* Null is returned for data types where this is not applicable.
* For boolean and integrals, the decimal digits is 0
* For floating types, we follow the IEEE Standard for Floating-Point Arithmetic (IEEE 754)
* For timestamp values, we support microseconds
* For decimals, it returns the scale
*/
def getDecimalDigits(sparkType: DataType): Option[Int] = sparkType match {
case BooleanType | _: IntegerType => Some(0)
case FloatType => Some(7)
case DoubleType => Some(15)
case d: DecimalType => Some(d.scale)
case TimestampType => Some(6)
case _ => None
}
def getNumPrecRadix(typ: DataType): Option[Int] = typ match {
case _: NumericType => Some(10)
case _ => None
}
}

View File

@ -32,7 +32,7 @@ import org.apache.spark.sql.types._
import org.apache.kyuubi.Utils
import org.apache.kyuubi.engine.spark.WithSparkSQLEngine
import org.apache.kyuubi.engine.spark.shim.SparkShim
import org.apache.kyuubi.engine.spark.shim.SparkCatalogShim
import org.apache.kyuubi.operation.JDBCTests
import org.apache.kyuubi.operation.meta.ResultSetSchemaConstant._
@ -44,7 +44,7 @@ class SparkOperationSuite extends WithSparkSQLEngine with JDBCTests {
withJdbcStatement() { statement =>
val meta = statement.getConnection.getMetaData
val types = meta.getTableTypes
val expected = SparkShim.sparkTableTypes.toIterator
val expected = SparkCatalogShim.sparkTableTypes.toIterator
while (types.next()) {
assert(types.getString(TABLE_TYPE) === expected.next())
}
@ -98,7 +98,7 @@ class SparkOperationSuite extends WithSparkSQLEngine with JDBCTests {
var pos = 0
while (rowSet.next()) {
assert(rowSet.getString(TABLE_CAT) === null)
assert(rowSet.getString(TABLE_CAT) === SparkCatalogShim.SESSION_CATALOG)
assert(rowSet.getString(TABLE_SCHEM) === dftSchema)
assert(rowSet.getString(TABLE_NAME) === tableName)
assert(rowSet.getString(COLUMN_NAME) === schema(pos).name)
@ -142,9 +142,6 @@ class SparkOperationSuite extends WithSparkSQLEngine with JDBCTests {
val rowSet = metaData.getColumns(null, "*", "not_exist", "not_exist")
assert(!rowSet.next())
val e1 = intercept[HiveSQLException](metaData.getColumns(null, null, null, "*"))
assert(e1.getCause.getMessage contains "Dangling meta character '*' near index 0\n*\n^")
}
}
@ -157,7 +154,7 @@ class SparkOperationSuite extends WithSparkSQLEngine with JDBCTests {
val data = statement.getConnection.getMetaData
val rowSet = data.getColumns("", "global_temp", viewName, null)
while (rowSet.next()) {
assert(rowSet.getString(TABLE_CAT) === null)
assert(rowSet.getString(TABLE_CAT) === SparkCatalogShim.SESSION_CATALOG)
assert(rowSet.getString(TABLE_SCHEM) === "global_temp")
assert(rowSet.getString(TABLE_NAME) === viewName)
assert(rowSet.getString(COLUMN_NAME) === "i")
@ -184,20 +181,20 @@ class SparkOperationSuite extends WithSparkSQLEngine with JDBCTests {
val data = statement.getConnection.getMetaData
val rowSet = data.getColumns("", "global_temp", viewName, "n")
while (rowSet.next()) {
assert(rowSet.getString("TABLE_CAT") === null)
assert(rowSet.getString("TABLE_SCHEM") === "global_temp")
assert(rowSet.getString("TABLE_NAME") === viewName)
assert(rowSet.getString("COLUMN_NAME") === "n")
assert(rowSet.getInt("DATA_TYPE") === java.sql.Types.NULL)
assert(rowSet.getString("TYPE_NAME").equalsIgnoreCase(NullType.sql))
assert(rowSet.getInt("COLUMN_SIZE") === 1)
assert(rowSet.getInt("DECIMAL_DIGITS") === 0)
assert(rowSet.getInt("NUM_PREC_RADIX") === 0)
assert(rowSet.getInt("NULLABLE") === 1)
assert(rowSet.getString("REMARKS") === "")
assert(rowSet.getInt("ORDINAL_POSITION") === 0)
assert(rowSet.getString("IS_NULLABLE") === "YES")
assert(rowSet.getString("IS_AUTO_INCREMENT") === "NO")
assert(rowSet.getString(TABLE_CAT) === SparkCatalogShim.SESSION_CATALOG)
assert(rowSet.getString(TABLE_SCHEM) === "global_temp")
assert(rowSet.getString(TABLE_NAME) === viewName)
assert(rowSet.getString(COLUMN_NAME) === "n")
assert(rowSet.getInt(DATA_TYPE) === java.sql.Types.NULL)
assert(rowSet.getString(TYPE_NAME).equalsIgnoreCase(NullType.sql))
assert(rowSet.getInt(COLUMN_SIZE) === 1)
assert(rowSet.getInt(DECIMAL_DIGITS) === 0)
assert(rowSet.getInt(NUM_PREC_RADIX) === 0)
assert(rowSet.getInt(NULLABLE) === 1)
assert(rowSet.getString(REMARKS) === "")
assert(rowSet.getInt(ORDINAL_POSITION) === 0)
assert(rowSet.getString(IS_NULLABLE) === "YES")
assert(rowSet.getString(IS_AUTO_INCREMENT) === "NO")
}
}
}

View File

@ -152,4 +152,55 @@ trait BasicIcebergJDBCTests extends JDBCTestUtils {
}
}
test("get columns operation") {
val dataTypes = Seq("boolean", "int", "bigint",
"float", "double", "decimal(38,20)", "decimal(10,2)",
"string", "array<bigint>", "array<string>", "map<int, bigint>",
"date", "timestamp", "struct<`X`: bigint, `Y`: double>", "binary", "struct<`X`: string>")
val cols = dataTypes.zipWithIndex.map { case (dt, idx) => s"c$idx" -> dt }
val (colNames, _) = cols.unzip
val tableName = "iceberg_get_col_operation"
val ddl =
s"""
|CREATE TABLE IF NOT EXISTS $catalog.$dftSchema.$tableName (
| ${cols.map { case (cn, dt) => cn + " " + dt }.mkString(",\n")}
|)
|USING iceberg""".stripMargin
withJdbcStatement(tableName) { statement =>
statement.execute(ddl)
val metaData = statement.getConnection.getMetaData
Seq("%", null, ".*", "c.*") foreach { columnPattern =>
val rowSet = metaData.getColumns(catalog, dftSchema, tableName, columnPattern)
import java.sql.Types._
val expectedJavaTypes = Seq(BOOLEAN, INTEGER, BIGINT, FLOAT, DOUBLE,
DECIMAL, DECIMAL, VARCHAR, ARRAY, ARRAY, JAVA_OBJECT, DATE, TIMESTAMP, STRUCT, BINARY,
STRUCT)
var pos = 0
while (rowSet.next()) {
assert(rowSet.getString(TABLE_CAT) === catalog)
assert(rowSet.getString(TABLE_SCHEM) === dftSchema)
assert(rowSet.getString(TABLE_NAME) === tableName)
assert(rowSet.getString(COLUMN_NAME) === colNames(pos))
assert(rowSet.getInt(DATA_TYPE) === expectedJavaTypes(pos))
assert(rowSet.getString(TYPE_NAME) equalsIgnoreCase dataTypes(pos))
pos += 1
}
assert(pos === dataTypes.size, "all columns should have been verified")
}
val rowSet = metaData.getColumns(catalog, "*", "not_exist", "not_exist")
assert(!rowSet.next())
}
}
}