[KYUUBI #310] GetColumns supports DSv2 and keeps its backward compatibility
 [](https://github.com/yaooqinn/kyuubi/pull/358)    [❨?❩](https://pullrequestbadge.com/?utm_medium=github&utm_source=yaooqinn&utm_campaign=badge_info)<!-- PR-BADGE: PLEASE DO NOT REMOVE THIS COMMENT --> <!-- Thanks for sending a pull request! Here are some tips for you: 1. If this is your first time, please read our contributor guidelines: https://kyuubi.readthedocs.io/en/latest/community/contributions.html 2. If the PR is related to an issue in https://github.com/yaooqinn/kyuubi/issues, add '[KYUUBI #XXXX]' in your PR title, e.g., '[KYUUBI #XXXX] Your PR title ...'. 3. If the PR is unfinished, add '[WIP]' in your PR title, e.g., '[WIP][KYUUBI #XXXX] Your PR title ...'. --> close #310 ### _Why are the changes needed?_ <!-- Please clarify why the changes are needed. For instance, 1. If you add a feature, you can talk about the use case of it. 2. If you fix a bug, you can clarify why it is a bug. --> ### _How was this patch tested?_ - [x] Add some test cases that check the changes thoroughly including negative and positive cases if possible - [ ] Add screenshots for manual tests if appropriate - [x] [Run test](https://kyuubi.readthedocs.io/en/latest/tools/testing.html#running-tests) locally before make a pull request Closes #358 from yaooqinn/310. 8cb30a4 [Kent Yao] sql wildcards to java regex 34d2c3a [Kent Yao] Merge branch 'master' into 310 d332be5 [Kent Yao] [KYUUBI #310] GetColumns supports DSv2 and keeps its backward compatibility Authored-by: Kent Yao <yao@apache.org> Signed-off-by: Kent Yao <yao@apache.org>
This commit is contained in:
parent
54c306b967
commit
ee50890b00
@ -20,7 +20,7 @@ package org.apache.kyuubi.engine.spark.operation
|
||||
import org.apache.spark.sql.SparkSession
|
||||
import org.apache.spark.sql.types.StructType
|
||||
|
||||
import org.apache.kyuubi.engine.spark.shim.SparkShim
|
||||
import org.apache.kyuubi.engine.spark.shim.SparkCatalogShim
|
||||
import org.apache.kyuubi.operation.OperationType
|
||||
import org.apache.kyuubi.operation.meta.ResultSetSchemaConstant.TABLE_CAT
|
||||
import org.apache.kyuubi.session.Session
|
||||
@ -35,7 +35,7 @@ class GetCatalogs(spark: SparkSession, session: Session)
|
||||
|
||||
override protected def runInternal(): Unit = {
|
||||
try {
|
||||
iter = SparkShim().getCatalogs(spark).toIterator
|
||||
iter = SparkCatalogShim().getCatalogs(spark).toIterator
|
||||
} catch onError()
|
||||
}
|
||||
}
|
||||
|
||||
@ -17,14 +17,10 @@
|
||||
|
||||
package org.apache.kyuubi.engine.spark.operation
|
||||
|
||||
import java.util.regex.Pattern
|
||||
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
|
||||
import org.apache.commons.lang3.StringUtils
|
||||
import org.apache.spark.sql.{Row, SparkSession}
|
||||
import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, CalendarIntervalType, DataType, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, NullType, NumericType, ShortType, StringType, StructField, StructType, TimestampType}
|
||||
import org.apache.spark.sql.SparkSession
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
import org.apache.kyuubi.engine.spark.shim.SparkCatalogShim
|
||||
import org.apache.kyuubi.operation.OperationType
|
||||
import org.apache.kyuubi.operation.meta.ResultSetSchemaConstant._
|
||||
import org.apache.kyuubi.session.Session
|
||||
@ -46,95 +42,6 @@ class GetColumns(
|
||||
s" columnPattern: $columnName]"
|
||||
}
|
||||
|
||||
private def toJavaSQLType(typ: DataType): Int = typ match {
|
||||
case NullType => java.sql.Types.NULL
|
||||
case BooleanType => java.sql.Types.BOOLEAN
|
||||
case ByteType => java.sql.Types.TINYINT
|
||||
case ShortType => java.sql.Types.SMALLINT
|
||||
case IntegerType => java.sql.Types.INTEGER
|
||||
case LongType => java.sql.Types.BIGINT
|
||||
case FloatType => java.sql.Types.FLOAT
|
||||
case DoubleType => java.sql.Types.DOUBLE
|
||||
case StringType => java.sql.Types.VARCHAR
|
||||
case _: DecimalType => java.sql.Types.DECIMAL
|
||||
case DateType => java.sql.Types.DATE
|
||||
case TimestampType => java.sql.Types.TIMESTAMP
|
||||
case BinaryType => java.sql.Types.BINARY
|
||||
case _: ArrayType => java.sql.Types.ARRAY
|
||||
case _: MapType => java.sql.Types.JAVA_OBJECT
|
||||
case _: StructType => java.sql.Types.STRUCT
|
||||
case _ => java.sql.Types.OTHER
|
||||
}
|
||||
|
||||
/**
|
||||
* For boolean, numeric and datetime types, it returns the default size of its catalyst type
|
||||
* For struct type, when its elements are fixed-size, the summation of all element sizes will be
|
||||
* returned.
|
||||
* For array, map, string, and binaries, the column size is variable, return null as unknown.
|
||||
*/
|
||||
private def getColumnSize(typ: DataType): Option[Int] = typ match {
|
||||
case dt @ (BooleanType | _: NumericType | DateType | TimestampType |
|
||||
CalendarIntervalType | NullType) =>
|
||||
Some(dt.defaultSize)
|
||||
case StructType(fields) =>
|
||||
val sizeArr = fields.map(f => getColumnSize(f.dataType))
|
||||
if (sizeArr.contains(None)) {
|
||||
None
|
||||
} else {
|
||||
Some(sizeArr.map(_.get).sum)
|
||||
}
|
||||
case _ => None
|
||||
}
|
||||
|
||||
/**
|
||||
* The number of fractional digits for this type.
|
||||
* Null is returned for data types where this is not applicable.
|
||||
* For boolean and integrals, the decimal digits is 0
|
||||
* For floating types, we follow the IEEE Standard for Floating-Point Arithmetic (IEEE 754)
|
||||
* For timestamp values, we support microseconds
|
||||
* For decimals, it returns the scale
|
||||
*/
|
||||
private def getDecimalDigits(typ: DataType): Option[Int] = typ match {
|
||||
case BooleanType | _: IntegerType => Some(0)
|
||||
case FloatType => Some(7)
|
||||
case DoubleType => Some(15)
|
||||
case d: DecimalType => Some(d.scale)
|
||||
case TimestampType => Some(6)
|
||||
case _ => None
|
||||
}
|
||||
|
||||
private def getNumPrecRadix(typ: DataType): Option[Int] = typ match {
|
||||
case _: NumericType => Some(10)
|
||||
case _ => None
|
||||
}
|
||||
|
||||
private def toRow(db: String, table: String, col: StructField, pos: Int): Row = {
|
||||
Row(
|
||||
null, // TABLE_CAT
|
||||
db, // TABLE_SCHEM
|
||||
table, // TABLE_NAME
|
||||
col.name, // COLUMN_NAME
|
||||
toJavaSQLType(col.dataType), // DATA_TYPE
|
||||
col.dataType.sql, // TYPE_NAME
|
||||
getColumnSize(col.dataType).orNull, // COLUMN_SIZE
|
||||
null, // BUFFER_LENGTH
|
||||
getDecimalDigits(col.dataType).orNull, // DECIMAL_DIGITS
|
||||
getNumPrecRadix(col.dataType).orNull, // NUM_PREC_RADIX
|
||||
if (col.nullable) 1 else 0, // NULLABLE
|
||||
col.getComment().getOrElse(""), // REMARKS
|
||||
null, // COLUMN_DEF
|
||||
null, // SQL_DATA_TYPE
|
||||
null, // SQL_DATETIME_SUB
|
||||
null, // CHAR_OCTET_LENGTH
|
||||
pos, // ORDINAL_POSITION
|
||||
"YES", // IS_NULLABLE
|
||||
null, // SCOPE_CATALOG
|
||||
null, // SCOPE_SCHEMA
|
||||
null, // SCOPE_TABLE
|
||||
null, // SOURCE_DATA_TYPE
|
||||
"NO" // IS_AUTO_INCREMENT
|
||||
)
|
||||
}
|
||||
override protected def resultSchema: StructType = {
|
||||
new StructType()
|
||||
.add(TABLE_CAT, "string", nullable = true, "Catalog name. NULL if not applicable")
|
||||
@ -178,45 +85,12 @@ class GetColumns(
|
||||
|
||||
override protected def runInternal(): Unit = {
|
||||
try {
|
||||
val catalog = spark.sessionState.catalog
|
||||
val schemaPattern = convertSchemaPattern(schemaName)
|
||||
val tablePattern = convertIdentifierPattern(tableName, datanucleusFormat = true)
|
||||
val columnPattern =
|
||||
Pattern.compile(convertIdentifierPattern(columnName, datanucleusFormat = false))
|
||||
val tables: Seq[Row] = catalog.listDatabases(schemaPattern).flatMap { db =>
|
||||
val identifiers =
|
||||
catalog.listTables(db, tablePattern, includeLocalTempViews = false)
|
||||
catalog.getTablesByName(identifiers).flatMap { t =>
|
||||
t.schema.zipWithIndex
|
||||
.filter { f => columnPattern.matcher(f._1.name).matches() }
|
||||
.map { case (f, i) => toRow(t.database, t.identifier.table, f, i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
val gviews = new ArrayBuffer[Row]()
|
||||
val globalTmpDb = catalog.globalTempViewManager.database
|
||||
if (StringUtils.isEmpty(schemaName) || schemaName == "*"
|
||||
|| Pattern.compile(convertSchemaPattern(schemaName, false))
|
||||
.matcher(globalTmpDb).matches()) {
|
||||
catalog.globalTempViewManager.listViewNames(tablePattern).foreach { v =>
|
||||
catalog.globalTempViewManager.get(v).foreach { plan =>
|
||||
plan.schema.zipWithIndex
|
||||
.filter { f => columnPattern.matcher(f._1.name).matches() }
|
||||
.foreach { case (f, i) => gviews += toRow(globalTmpDb, v, f, i) }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
val views: Seq[Row] = catalog.listLocalTempViews(tablePattern)
|
||||
.map(v => (v, catalog.getTempView(v.table).get))
|
||||
.flatMap { case (v, plan) =>
|
||||
plan.schema.zipWithIndex
|
||||
.filter(f => columnPattern.matcher(f._1.name).matches())
|
||||
.map { case (f, i) => toRow(null, v.table, f, i) }
|
||||
}
|
||||
|
||||
iter = (tables ++ gviews ++ views).toList.iterator
|
||||
val schemaPattern = toJavaRegex(schemaName)
|
||||
val tablePattern = toJavaRegex(tableName)
|
||||
val columnPattern = toJavaRegex(columnName)
|
||||
iter = SparkCatalogShim()
|
||||
.getColumns(spark, catalogName, schemaPattern, tablePattern, columnPattern)
|
||||
.toList.iterator
|
||||
} catch {
|
||||
onError()
|
||||
}
|
||||
|
||||
@ -55,8 +55,8 @@ class GetFunctions(
|
||||
|
||||
override protected def runInternal(): Unit = {
|
||||
try {
|
||||
val schemaPattern = convertSchemaPattern(schemaName)
|
||||
val functionPattern = convertIdentifierPattern(functionName, datanucleusFormat = false)
|
||||
val schemaPattern = toJavaRegex(schemaName)
|
||||
val functionPattern = toJavaRegex(functionName)
|
||||
val catalog = spark.sessionState.catalog
|
||||
val a: Seq[Row] = catalog.listDatabases(schemaPattern).flatMap { db =>
|
||||
catalog.listFunctions(db, functionPattern).map { case (f, _) =>
|
||||
|
||||
@ -20,7 +20,7 @@ package org.apache.kyuubi.engine.spark.operation
|
||||
import org.apache.spark.sql.SparkSession
|
||||
import org.apache.spark.sql.types.StructType
|
||||
|
||||
import org.apache.kyuubi.engine.spark.shim.SparkShim
|
||||
import org.apache.kyuubi.engine.spark.shim.SparkCatalogShim
|
||||
import org.apache.kyuubi.operation.OperationType
|
||||
import org.apache.kyuubi.operation.meta.ResultSetSchemaConstant._
|
||||
import org.apache.kyuubi.session.Session
|
||||
@ -40,8 +40,8 @@ class GetSchemas(spark: SparkSession, session: Session, catalogName: String, sch
|
||||
|
||||
override protected def runInternal(): Unit = {
|
||||
try {
|
||||
val schemaPattern = convertSchemaPattern(schema, datanucleusFormat = false)
|
||||
val rows = SparkShim().getSchemas(spark, catalogName, schemaPattern)
|
||||
val schemaPattern = toJavaRegex(schema)
|
||||
val rows = SparkCatalogShim().getSchemas(spark, catalogName, schemaPattern)
|
||||
iter = rows.toList.toIterator
|
||||
} catch onError()
|
||||
}
|
||||
|
||||
@ -20,7 +20,7 @@ package org.apache.kyuubi.engine.spark.operation
|
||||
import org.apache.spark.sql.{Row, SparkSession}
|
||||
import org.apache.spark.sql.types.StructType
|
||||
|
||||
import org.apache.kyuubi.engine.spark.shim.SparkShim
|
||||
import org.apache.kyuubi.engine.spark.shim.SparkCatalogShim
|
||||
import org.apache.kyuubi.operation.OperationType
|
||||
import org.apache.kyuubi.operation.meta.ResultSetSchemaConstant._
|
||||
import org.apache.kyuubi.session.Session
|
||||
@ -33,6 +33,6 @@ class GetTableTypes(spark: SparkSession, session: Session)
|
||||
}
|
||||
|
||||
override protected def runInternal(): Unit = {
|
||||
iter = SparkShim.sparkTableTypes.map(Row(_)).toList.iterator
|
||||
iter = SparkCatalogShim.sparkTableTypes.map(Row(_)).toList.iterator
|
||||
}
|
||||
}
|
||||
|
||||
@ -18,10 +18,9 @@
|
||||
package org.apache.kyuubi.engine.spark.operation
|
||||
|
||||
import org.apache.spark.sql.SparkSession
|
||||
import org.apache.spark.sql.catalyst.catalog.CatalogTableType
|
||||
import org.apache.spark.sql.types.StructType
|
||||
|
||||
import org.apache.kyuubi.engine.spark.shim.SparkShim
|
||||
import org.apache.kyuubi.engine.spark.shim.SparkCatalogShim
|
||||
import org.apache.kyuubi.operation.OperationType
|
||||
import org.apache.kyuubi.operation.meta.ResultSetSchemaConstant._
|
||||
import org.apache.kyuubi.session.Session
|
||||
@ -61,9 +60,9 @@ class GetTables(
|
||||
|
||||
override protected def runInternal(): Unit = {
|
||||
try {
|
||||
val schemaPattern = convertSchemaPattern(schema, datanucleusFormat = false)
|
||||
val tablePattern = convertIdentifierPattern(tableName, datanucleusFormat = true)
|
||||
val sparkShim = SparkShim()
|
||||
val schemaPattern = toJavaRegex(schema)
|
||||
val tablePattern = toJavaRegex(tableName)
|
||||
val sparkShim = SparkCatalogShim()
|
||||
val catalogTablesAndViews =
|
||||
sparkShim.getCatalogTablesOrViews(spark, catalog, schemaPattern, tablePattern, tableTypes)
|
||||
|
||||
|
||||
@ -17,6 +17,8 @@
|
||||
|
||||
package org.apache.kyuubi.engine.spark.operation
|
||||
|
||||
import java.util.regex.Pattern
|
||||
|
||||
import org.apache.commons.lang3.StringUtils
|
||||
import org.apache.hive.service.rpc.thrift.{TRowSet, TTableSchema}
|
||||
import org.apache.spark.sql.{Row, SparkSession}
|
||||
@ -49,38 +51,33 @@ abstract class SparkOperation(spark: SparkSession, opType: OperationType, sessio
|
||||
}
|
||||
}
|
||||
|
||||
private def convertPattern(pattern: String, datanucleusFormat: Boolean): String = {
|
||||
val wStr = if (datanucleusFormat) "*" else ".*"
|
||||
pattern
|
||||
.replaceAll("([^\\\\])%", "$1" + wStr)
|
||||
.replaceAll("\\\\%", "%")
|
||||
.replaceAll("^%", wStr)
|
||||
.replaceAll("([^\\\\])_", "$1.")
|
||||
.replaceAll("\\\\_", "_")
|
||||
.replaceAll("^_", ".")
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert wildcards and escape sequence of schema pattern from JDBC format to datanucleous/regex
|
||||
* The schema pattern treats empty string also as wildcard
|
||||
* convert SQL 'like' pattern to a Java regular expression.
|
||||
*
|
||||
* Underscores (_) are converted to '.' and percent signs (%) are converted to '.*'.
|
||||
*
|
||||
* @param input the SQL pattern to convert
|
||||
* @return the equivalent Java regular expression of the pattern
|
||||
*/
|
||||
protected def convertSchemaPattern(pattern: String, datanucleusFormat: Boolean = true): String = {
|
||||
if (StringUtils.isEmpty(pattern) || pattern == "*") {
|
||||
convertPattern("%", datanucleusFormat)
|
||||
def toJavaRegex(input: String): String = {
|
||||
val res = if (StringUtils.isEmpty(input) || input == "*") {
|
||||
"%"
|
||||
} else {
|
||||
convertPattern(pattern, datanucleusFormat)
|
||||
input
|
||||
}
|
||||
}
|
||||
val in = res.toIterator
|
||||
val out = new StringBuilder()
|
||||
|
||||
/**
|
||||
* Convert wildcards and escape sequence from JDBC format to datanucleous/regex
|
||||
*/
|
||||
protected def convertIdentifierPattern(pattern: String, datanucleusFormat: Boolean): String = {
|
||||
if (pattern == null) {
|
||||
convertPattern("%", datanucleusFormat)
|
||||
} else {
|
||||
convertPattern(pattern, datanucleusFormat)
|
||||
while (in.hasNext) {
|
||||
in.next match {
|
||||
case c if c == '\\' && in.hasNext => Pattern.quote(Character.toString(in.next()))
|
||||
case c if c == '\\' && !in.hasNext => Pattern.quote(Character.toString(c))
|
||||
case '_' => out ++= "."
|
||||
case '%' => out ++= ".*"
|
||||
case c => out ++= Character.toString(c)
|
||||
}
|
||||
}
|
||||
out.result()
|
||||
}
|
||||
|
||||
protected def onError(cancel: Boolean = false): PartialFunction[Throwable, Unit] = {
|
||||
|
||||
@ -24,7 +24,7 @@ import scala.collection.JavaConverters._
|
||||
import org.apache.spark.sql.SparkSession
|
||||
|
||||
import org.apache.kyuubi.KyuubiSQLException
|
||||
import org.apache.kyuubi.engine.spark.shim.SparkShim
|
||||
import org.apache.kyuubi.engine.spark.shim.SparkCatalogShim
|
||||
import org.apache.kyuubi.operation.{Operation, OperationManager}
|
||||
import org.apache.kyuubi.session.{Session, SessionHandle}
|
||||
|
||||
@ -91,7 +91,7 @@ class SparkSQLOperationManager private (name: String) extends OperationManager(n
|
||||
tableTypes: java.util.List[String]): Operation = {
|
||||
val spark = getSparkSession(session.handle)
|
||||
val tTypes = if (tableTypes == null || tableTypes.isEmpty) {
|
||||
SparkShim.sparkTableTypes
|
||||
SparkCatalogShim.sparkTableTypes
|
||||
} else {
|
||||
tableTypes.asScala.toSet
|
||||
}
|
||||
|
||||
@ -17,15 +17,16 @@
|
||||
|
||||
package org.apache.kyuubi.engine.spark.shim
|
||||
|
||||
import java.util.regex.Pattern
|
||||
|
||||
import org.apache.spark.sql.{Row, SparkSession}
|
||||
import org.apache.spark.sql.catalyst.TableIdentifier
|
||||
import org.apache.spark.sql.connector.catalog.CatalogPlugin
|
||||
|
||||
class Shim_v2_4 extends SparkShim {
|
||||
class CatalogShim_v2_4 extends SparkCatalogShim {
|
||||
|
||||
override def getCatalogs(spark: SparkSession): Seq[Row] = Seq(Row(""))
|
||||
|
||||
override protected def getCatalog(spark: SparkSession, catalog: String): CatalogPlugin = null
|
||||
override def getCatalogs(spark: SparkSession): Seq[Row] = {
|
||||
Seq(Row(SparkCatalogShim.SESSION_CATALOG))
|
||||
}
|
||||
|
||||
override protected def catalogExists(spark: SparkSession, catalog: String): Boolean = false
|
||||
|
||||
@ -86,4 +87,74 @@ class Shim_v2_4 extends SparkShim {
|
||||
spark.sessionState.catalog.listLocalTempViews(tablePattern)
|
||||
}
|
||||
}
|
||||
|
||||
override def getColumns(
|
||||
spark: SparkSession,
|
||||
catalogName: String,
|
||||
schemaPattern: String,
|
||||
tablePattern: String,
|
||||
columnPattern: String): Seq[Row] = {
|
||||
|
||||
val cp = columnPattern.r.pattern
|
||||
val byCatalog = getColumnsByCatalog(spark, catalogName, schemaPattern, tablePattern, cp)
|
||||
val byGlobalTmpDB = getColumnsByGlobalTempViewManager(spark, schemaPattern, tablePattern, cp)
|
||||
val byLocalTmp = getColumnsByLocalTempViews(spark, tablePattern, cp)
|
||||
|
||||
byCatalog ++ byGlobalTmpDB ++ byLocalTmp
|
||||
}
|
||||
|
||||
protected def getColumnsByCatalog(
|
||||
spark: SparkSession,
|
||||
catalogName: String,
|
||||
schemaPattern: String,
|
||||
tablePattern: String,
|
||||
columnPattern: Pattern): Seq[Row] = {
|
||||
val catalog = spark.sessionState.catalog
|
||||
|
||||
val databases = catalog.listDatabases(schemaPattern)
|
||||
|
||||
databases.flatMap { db =>
|
||||
val identifiers = catalog.listTables(db, tablePattern, includeLocalTempViews = true)
|
||||
catalog.getTablesByName(identifiers).flatMap { t =>
|
||||
t.schema.zipWithIndex.filter(f => columnPattern.matcher(f._1.name).matches())
|
||||
.map { case (f, i) => toColumnResult(catalogName, t.database, t.identifier.table, f, i) }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
protected def getColumnsByGlobalTempViewManager(
|
||||
spark: SparkSession,
|
||||
schemaPattern: String,
|
||||
tablePattern: String,
|
||||
columnPattern: Pattern): Seq[Row] = {
|
||||
val catalog = spark.sessionState.catalog
|
||||
|
||||
getGlobalTempViewManager(spark, schemaPattern).flatMap { globalTmpDb =>
|
||||
catalog.globalTempViewManager.listViewNames(tablePattern).flatMap { v =>
|
||||
catalog.globalTempViewManager.get(v).map { plan =>
|
||||
plan.schema.zipWithIndex.filter(f => columnPattern.matcher(f._1.name).matches())
|
||||
.map { case (f, i) =>
|
||||
toColumnResult(SparkCatalogShim.SESSION_CATALOG, globalTmpDb, v, f, i)
|
||||
}
|
||||
}
|
||||
}.flatten
|
||||
}
|
||||
}
|
||||
|
||||
protected def getColumnsByLocalTempViews(
|
||||
spark: SparkSession,
|
||||
tablePattern: String,
|
||||
columnPattern: Pattern): Seq[Row] = {
|
||||
val catalog = spark.sessionState.catalog
|
||||
|
||||
catalog.listLocalTempViews(tablePattern)
|
||||
.map(v => (v, catalog.getTempView(v.table).get))
|
||||
.flatMap { case (v, plan) =>
|
||||
plan.schema.zipWithIndex
|
||||
.filter(f => columnPattern.matcher(f._1.name).matches())
|
||||
.map { case (f, i) =>
|
||||
toColumnResult(SparkCatalogShim.SESSION_CATALOG, null, v.table, f, i)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -17,10 +17,14 @@
|
||||
|
||||
package org.apache.kyuubi.engine.spark.shim
|
||||
|
||||
import java.util.regex.Pattern
|
||||
|
||||
import org.apache.spark.sql.{Row, SparkSession}
|
||||
import org.apache.spark.sql.connector.catalog.{CatalogExtension, CatalogPlugin, SupportsNamespaces, TableCatalog}
|
||||
|
||||
class Shim_v3_0 extends Shim_v2_4 {
|
||||
import org.apache.kyuubi.engine.spark.shim.SparkCatalogShim.SESSION_CATALOG
|
||||
|
||||
class CatalogShim_v3_0 extends CatalogShim_v2_4 {
|
||||
|
||||
override def getCatalogs(spark: SparkSession): Seq[Row] = {
|
||||
|
||||
@ -37,9 +41,9 @@ class Shim_v3_0 extends Shim_v2_4 {
|
||||
(catalogs.keys ++: defaults).distinct.map(Row(_))
|
||||
}
|
||||
|
||||
override def getCatalog(spark: SparkSession, catalogName: String): CatalogPlugin = {
|
||||
private def getCatalog(spark: SparkSession, catalogName: String): CatalogPlugin = {
|
||||
val catalogManager = spark.sessionState.catalogManager
|
||||
if (catalogName == null) {
|
||||
if (catalogName == null || catalogName.isEmpty) {
|
||||
catalogManager.currentCatalog
|
||||
} else {
|
||||
catalogManager.catalog(catalogName)
|
||||
@ -124,17 +128,17 @@ class Shim_v3_0 extends Shim_v2_4 {
|
||||
tablePattern: String,
|
||||
tableTypes: Set[String]): Seq[Row] = {
|
||||
val catalog = getCatalog(spark, catalogName)
|
||||
val schemas = listNamespacesWithPattern(catalog, schemaPattern)
|
||||
val namespaces = listNamespacesWithPattern(catalog, schemaPattern)
|
||||
catalog match {
|
||||
case catalog if catalog.name() == SESSION_CATALOG =>
|
||||
case builtin if builtin.name() == SESSION_CATALOG =>
|
||||
super.getCatalogTablesOrViews(
|
||||
spark, SESSION_CATALOG, schemaPattern, tablePattern, tableTypes)
|
||||
case ce: CatalogExtension =>
|
||||
super.getCatalogTablesOrViews(spark, ce.name(), schemaPattern, tablePattern, tableTypes)
|
||||
case tc: TableCatalog =>
|
||||
schemas.flatMap { ns =>
|
||||
tc.listTables(ns)
|
||||
}.map { ident =>
|
||||
val tp = tablePattern.r.pattern
|
||||
val identifiers = namespaces.flatMap { ns =>
|
||||
tc.listTables(ns).filter(i => tp.matcher(quoteIfNeeded(i.name())).matches())
|
||||
}
|
||||
identifiers.map { ident =>
|
||||
val table = tc.loadTable(ident)
|
||||
// TODO: restore view type for session catalog
|
||||
val comment = table.properties().getOrDefault(TableCatalog.PROP_COMMENT, "")
|
||||
@ -145,4 +149,34 @@ class Shim_v3_0 extends Shim_v2_4 {
|
||||
case _ => Seq.empty[Row]
|
||||
}
|
||||
}
|
||||
|
||||
override protected def getColumnsByCatalog(
|
||||
spark: SparkSession,
|
||||
catalogName: String,
|
||||
schemaPattern: String,
|
||||
tablePattern: String,
|
||||
columnPattern: Pattern): Seq[Row] = {
|
||||
val catalog = getCatalog(spark, catalogName)
|
||||
|
||||
catalog match {
|
||||
case builtin if builtin.name() == SESSION_CATALOG =>
|
||||
super.getColumnsByCatalog(
|
||||
spark, SESSION_CATALOG, schemaPattern, tablePattern, columnPattern)
|
||||
|
||||
case tc: TableCatalog =>
|
||||
val namespaces = listNamespacesWithPattern(catalog, schemaPattern)
|
||||
val tp = tablePattern.r.pattern
|
||||
val identifiers = namespaces.flatMap { ns =>
|
||||
tc.listTables(ns).filter(i => tp.matcher(quoteIfNeeded(i.name())).matches())
|
||||
}
|
||||
identifiers.flatMap { ident =>
|
||||
val table = tc.loadTable(ident)
|
||||
val namespace = ident.namespace().map(quoteIfNeeded).mkString(".")
|
||||
val tableName = quoteIfNeeded(ident.name())
|
||||
|
||||
table.schema.zipWithIndex.filter(f => columnPattern.matcher(f._1.name).matches())
|
||||
.map { case (f, i) => toColumnResult(tc.name(), namespace, tableName, f, i) }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -19,14 +19,15 @@ package org.apache.kyuubi.engine.spark.shim
|
||||
|
||||
import org.apache.spark.sql.{Row, SparkSession}
|
||||
import org.apache.spark.sql.catalyst.TableIdentifier
|
||||
import org.apache.spark.sql.connector.catalog.CatalogPlugin
|
||||
import org.apache.spark.sql.types.StructField
|
||||
|
||||
import org.apache.kyuubi.{Logging, Utils}
|
||||
import org.apache.kyuubi.schema.SchemaHelper
|
||||
|
||||
/**
|
||||
* A shim that defines the interface interact with Spark's catalogs
|
||||
*/
|
||||
trait SparkShim extends Logging {
|
||||
trait SparkCatalogShim extends Logging {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// Catalog //
|
||||
@ -37,8 +38,6 @@ trait SparkShim extends Logging {
|
||||
*/
|
||||
def getCatalogs(spark: SparkSession): Seq[Row]
|
||||
|
||||
protected def getCatalog(spark: SparkSession, catalog: String): CatalogPlugin
|
||||
|
||||
protected def catalogExists(spark: SparkSession, catalog: String): Boolean
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -74,6 +73,47 @@ trait SparkShim extends Logging {
|
||||
schemaPattern: String,
|
||||
tablePattern: String): Seq[TableIdentifier]
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// Columns //
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
def getColumns(
|
||||
spark: SparkSession,
|
||||
catalogName: String,
|
||||
schemaPattern: String,
|
||||
tablePattern: String,
|
||||
columnPattern: String): Seq[Row]
|
||||
|
||||
protected def toColumnResult(
|
||||
catalog: String, db: String, table: String, col: StructField, pos: Int): Row = {
|
||||
Row(
|
||||
catalog, // TABLE_CAT
|
||||
db, // TABLE_SCHEM
|
||||
table, // TABLE_NAME
|
||||
col.name, // COLUMN_NAME
|
||||
SchemaHelper.toJavaSQLType(col.dataType), // DATA_TYPE
|
||||
col.dataType.sql, // TYPE_NAME
|
||||
SchemaHelper.getColumnSize(col.dataType).orNull, // COLUMN_SIZE
|
||||
null, // BUFFER_LENGTH
|
||||
SchemaHelper.getDecimalDigits(col.dataType).orNull, // DECIMAL_DIGITS
|
||||
SchemaHelper.getNumPrecRadix(col.dataType).orNull, // NUM_PREC_RADIX
|
||||
if (col.nullable) 1 else 0, // NULLABLE
|
||||
col.getComment().getOrElse(""), // REMARKS
|
||||
null, // COLUMN_DEF
|
||||
null, // SQL_DATA_TYPE
|
||||
null, // SQL_DATETIME_SUB
|
||||
null, // CHAR_OCTET_LENGTH
|
||||
pos, // ORDINAL_POSITION
|
||||
"YES", // IS_NULLABLE
|
||||
null, // SCOPE_CATALOG
|
||||
null, // SCOPE_SCHEMA
|
||||
null, // SCOPE_TABLE
|
||||
null, // SOURCE_DATA_TYPE
|
||||
"NO" // IS_AUTO_INCREMENT
|
||||
)
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// Miscellaneous //
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -110,19 +150,20 @@ trait SparkShim extends Logging {
|
||||
tableTypes.exists(typ.equalsIgnoreCase)
|
||||
}
|
||||
|
||||
protected val SESSION_CATALOG: String = "spark_catalog"
|
||||
}
|
||||
|
||||
object SparkShim {
|
||||
def apply(): SparkShim = {
|
||||
object SparkCatalogShim {
|
||||
def apply(): SparkCatalogShim = {
|
||||
val runtimeSparkVer = org.apache.spark.SPARK_VERSION
|
||||
val (major, minor) = Utils.majorMinorVersion(runtimeSparkVer)
|
||||
(major, minor) match {
|
||||
case (3, _) => new Shim_v3_0
|
||||
case (2, _) => new Shim_v2_4
|
||||
case (3, _) => new CatalogShim_v3_0
|
||||
case (2, _) => new CatalogShim_v2_4
|
||||
case _ => throw new IllegalArgumentException(s"Not Support spark version $runtimeSparkVer")
|
||||
}
|
||||
}
|
||||
|
||||
val SESSION_CATALOG: String = "spark_catalog"
|
||||
|
||||
val sparkTableTypes = Set("VIEW", "TABLE")
|
||||
}
|
||||
@ -85,4 +85,68 @@ object SchemaHelper {
|
||||
}
|
||||
tTableSchema
|
||||
}
|
||||
|
||||
def toJavaSQLType(sparkType: DataType): Int = sparkType match {
|
||||
case NullType => java.sql.Types.NULL
|
||||
case BooleanType => java.sql.Types.BOOLEAN
|
||||
case ByteType => java.sql.Types.TINYINT
|
||||
case ShortType => java.sql.Types.SMALLINT
|
||||
case IntegerType => java.sql.Types.INTEGER
|
||||
case LongType => java.sql.Types.BIGINT
|
||||
case FloatType => java.sql.Types.FLOAT
|
||||
case DoubleType => java.sql.Types.DOUBLE
|
||||
case StringType => java.sql.Types.VARCHAR
|
||||
case _: DecimalType => java.sql.Types.DECIMAL
|
||||
case DateType => java.sql.Types.DATE
|
||||
case TimestampType => java.sql.Types.TIMESTAMP
|
||||
case BinaryType => java.sql.Types.BINARY
|
||||
case _: ArrayType => java.sql.Types.ARRAY
|
||||
case _: MapType => java.sql.Types.JAVA_OBJECT
|
||||
case _: StructType => java.sql.Types.STRUCT
|
||||
case _ => java.sql.Types.OTHER
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* For boolean, numeric and datetime types, it returns the default size of its catalyst type
|
||||
* For struct type, when its elements are fixed-size, the summation of all element sizes will be
|
||||
* returned.
|
||||
* For array, map, string, and binaries, the column size is variable, return null as unknown.
|
||||
*/
|
||||
def getColumnSize(sparkType: DataType): Option[Int] = sparkType match {
|
||||
case dt @ (BooleanType | _: NumericType | DateType | TimestampType |
|
||||
CalendarIntervalType | NullType) =>
|
||||
Some(dt.defaultSize)
|
||||
case StructType(fields) =>
|
||||
val sizeArr = fields.map(f => getColumnSize(f.dataType))
|
||||
if (sizeArr.contains(None)) {
|
||||
None
|
||||
} else {
|
||||
Some(sizeArr.map(_.get).sum)
|
||||
}
|
||||
case _ => None
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* The number of fractional digits for this type.
|
||||
* Null is returned for data types where this is not applicable.
|
||||
* For boolean and integrals, the decimal digits is 0
|
||||
* For floating types, we follow the IEEE Standard for Floating-Point Arithmetic (IEEE 754)
|
||||
* For timestamp values, we support microseconds
|
||||
* For decimals, it returns the scale
|
||||
*/
|
||||
def getDecimalDigits(sparkType: DataType): Option[Int] = sparkType match {
|
||||
case BooleanType | _: IntegerType => Some(0)
|
||||
case FloatType => Some(7)
|
||||
case DoubleType => Some(15)
|
||||
case d: DecimalType => Some(d.scale)
|
||||
case TimestampType => Some(6)
|
||||
case _ => None
|
||||
}
|
||||
|
||||
def getNumPrecRadix(typ: DataType): Option[Int] = typ match {
|
||||
case _: NumericType => Some(10)
|
||||
case _ => None
|
||||
}
|
||||
}
|
||||
|
||||
@ -32,7 +32,7 @@ import org.apache.spark.sql.types._
|
||||
|
||||
import org.apache.kyuubi.Utils
|
||||
import org.apache.kyuubi.engine.spark.WithSparkSQLEngine
|
||||
import org.apache.kyuubi.engine.spark.shim.SparkShim
|
||||
import org.apache.kyuubi.engine.spark.shim.SparkCatalogShim
|
||||
import org.apache.kyuubi.operation.JDBCTests
|
||||
import org.apache.kyuubi.operation.meta.ResultSetSchemaConstant._
|
||||
|
||||
@ -44,7 +44,7 @@ class SparkOperationSuite extends WithSparkSQLEngine with JDBCTests {
|
||||
withJdbcStatement() { statement =>
|
||||
val meta = statement.getConnection.getMetaData
|
||||
val types = meta.getTableTypes
|
||||
val expected = SparkShim.sparkTableTypes.toIterator
|
||||
val expected = SparkCatalogShim.sparkTableTypes.toIterator
|
||||
while (types.next()) {
|
||||
assert(types.getString(TABLE_TYPE) === expected.next())
|
||||
}
|
||||
@ -98,7 +98,7 @@ class SparkOperationSuite extends WithSparkSQLEngine with JDBCTests {
|
||||
var pos = 0
|
||||
|
||||
while (rowSet.next()) {
|
||||
assert(rowSet.getString(TABLE_CAT) === null)
|
||||
assert(rowSet.getString(TABLE_CAT) === SparkCatalogShim.SESSION_CATALOG)
|
||||
assert(rowSet.getString(TABLE_SCHEM) === dftSchema)
|
||||
assert(rowSet.getString(TABLE_NAME) === tableName)
|
||||
assert(rowSet.getString(COLUMN_NAME) === schema(pos).name)
|
||||
@ -142,9 +142,6 @@ class SparkOperationSuite extends WithSparkSQLEngine with JDBCTests {
|
||||
|
||||
val rowSet = metaData.getColumns(null, "*", "not_exist", "not_exist")
|
||||
assert(!rowSet.next())
|
||||
|
||||
val e1 = intercept[HiveSQLException](metaData.getColumns(null, null, null, "*"))
|
||||
assert(e1.getCause.getMessage contains "Dangling meta character '*' near index 0\n*\n^")
|
||||
}
|
||||
}
|
||||
|
||||
@ -157,7 +154,7 @@ class SparkOperationSuite extends WithSparkSQLEngine with JDBCTests {
|
||||
val data = statement.getConnection.getMetaData
|
||||
val rowSet = data.getColumns("", "global_temp", viewName, null)
|
||||
while (rowSet.next()) {
|
||||
assert(rowSet.getString(TABLE_CAT) === null)
|
||||
assert(rowSet.getString(TABLE_CAT) === SparkCatalogShim.SESSION_CATALOG)
|
||||
assert(rowSet.getString(TABLE_SCHEM) === "global_temp")
|
||||
assert(rowSet.getString(TABLE_NAME) === viewName)
|
||||
assert(rowSet.getString(COLUMN_NAME) === "i")
|
||||
@ -184,20 +181,20 @@ class SparkOperationSuite extends WithSparkSQLEngine with JDBCTests {
|
||||
val data = statement.getConnection.getMetaData
|
||||
val rowSet = data.getColumns("", "global_temp", viewName, "n")
|
||||
while (rowSet.next()) {
|
||||
assert(rowSet.getString("TABLE_CAT") === null)
|
||||
assert(rowSet.getString("TABLE_SCHEM") === "global_temp")
|
||||
assert(rowSet.getString("TABLE_NAME") === viewName)
|
||||
assert(rowSet.getString("COLUMN_NAME") === "n")
|
||||
assert(rowSet.getInt("DATA_TYPE") === java.sql.Types.NULL)
|
||||
assert(rowSet.getString("TYPE_NAME").equalsIgnoreCase(NullType.sql))
|
||||
assert(rowSet.getInt("COLUMN_SIZE") === 1)
|
||||
assert(rowSet.getInt("DECIMAL_DIGITS") === 0)
|
||||
assert(rowSet.getInt("NUM_PREC_RADIX") === 0)
|
||||
assert(rowSet.getInt("NULLABLE") === 1)
|
||||
assert(rowSet.getString("REMARKS") === "")
|
||||
assert(rowSet.getInt("ORDINAL_POSITION") === 0)
|
||||
assert(rowSet.getString("IS_NULLABLE") === "YES")
|
||||
assert(rowSet.getString("IS_AUTO_INCREMENT") === "NO")
|
||||
assert(rowSet.getString(TABLE_CAT) === SparkCatalogShim.SESSION_CATALOG)
|
||||
assert(rowSet.getString(TABLE_SCHEM) === "global_temp")
|
||||
assert(rowSet.getString(TABLE_NAME) === viewName)
|
||||
assert(rowSet.getString(COLUMN_NAME) === "n")
|
||||
assert(rowSet.getInt(DATA_TYPE) === java.sql.Types.NULL)
|
||||
assert(rowSet.getString(TYPE_NAME).equalsIgnoreCase(NullType.sql))
|
||||
assert(rowSet.getInt(COLUMN_SIZE) === 1)
|
||||
assert(rowSet.getInt(DECIMAL_DIGITS) === 0)
|
||||
assert(rowSet.getInt(NUM_PREC_RADIX) === 0)
|
||||
assert(rowSet.getInt(NULLABLE) === 1)
|
||||
assert(rowSet.getString(REMARKS) === "")
|
||||
assert(rowSet.getInt(ORDINAL_POSITION) === 0)
|
||||
assert(rowSet.getString(IS_NULLABLE) === "YES")
|
||||
assert(rowSet.getString(IS_AUTO_INCREMENT) === "NO")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -152,4 +152,55 @@ trait BasicIcebergJDBCTests extends JDBCTestUtils {
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
test("get columns operation") {
|
||||
val dataTypes = Seq("boolean", "int", "bigint",
|
||||
"float", "double", "decimal(38,20)", "decimal(10,2)",
|
||||
"string", "array<bigint>", "array<string>", "map<int, bigint>",
|
||||
"date", "timestamp", "struct<`X`: bigint, `Y`: double>", "binary", "struct<`X`: string>")
|
||||
val cols = dataTypes.zipWithIndex.map { case (dt, idx) => s"c$idx" -> dt }
|
||||
val (colNames, _) = cols.unzip
|
||||
|
||||
val tableName = "iceberg_get_col_operation"
|
||||
|
||||
val ddl =
|
||||
s"""
|
||||
|CREATE TABLE IF NOT EXISTS $catalog.$dftSchema.$tableName (
|
||||
| ${cols.map { case (cn, dt) => cn + " " + dt }.mkString(",\n")}
|
||||
|)
|
||||
|USING iceberg""".stripMargin
|
||||
|
||||
|
||||
withJdbcStatement(tableName) { statement =>
|
||||
statement.execute(ddl)
|
||||
|
||||
val metaData = statement.getConnection.getMetaData
|
||||
|
||||
Seq("%", null, ".*", "c.*") foreach { columnPattern =>
|
||||
val rowSet = metaData.getColumns(catalog, dftSchema, tableName, columnPattern)
|
||||
|
||||
import java.sql.Types._
|
||||
val expectedJavaTypes = Seq(BOOLEAN, INTEGER, BIGINT, FLOAT, DOUBLE,
|
||||
DECIMAL, DECIMAL, VARCHAR, ARRAY, ARRAY, JAVA_OBJECT, DATE, TIMESTAMP, STRUCT, BINARY,
|
||||
STRUCT)
|
||||
|
||||
var pos = 0
|
||||
|
||||
while (rowSet.next()) {
|
||||
assert(rowSet.getString(TABLE_CAT) === catalog)
|
||||
assert(rowSet.getString(TABLE_SCHEM) === dftSchema)
|
||||
assert(rowSet.getString(TABLE_NAME) === tableName)
|
||||
assert(rowSet.getString(COLUMN_NAME) === colNames(pos))
|
||||
assert(rowSet.getInt(DATA_TYPE) === expectedJavaTypes(pos))
|
||||
assert(rowSet.getString(TYPE_NAME) equalsIgnoreCase dataTypes(pos))
|
||||
pos += 1
|
||||
}
|
||||
|
||||
assert(pos === dataTypes.size, "all columns should have been verified")
|
||||
}
|
||||
|
||||
val rowSet = metaData.getColumns(catalog, "*", "not_exist", "not_exist")
|
||||
assert(!rowSet.next())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user