fix GetColumns

This commit is contained in:
Kent Yao 2020-09-09 10:06:15 +08:00
parent 46d6dda3f3
commit 0a43b82400
2 changed files with 47 additions and 7 deletions

View File

@ -19,6 +19,8 @@ package org.apache.kyuubi.engine.spark.operation
import java.util.regex.Pattern
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, CalendarIntervalType, DataType, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, NullType, NumericType, ShortType, StringType, StructField, StructType, TimestampType}
@ -172,12 +174,8 @@ class GetColumns(
val columnPattern = Option(columnName)
.map(c => Pattern.compile(convertIdentifierPattern(c, datanucleusFormat = false)))
.orNull
var databases: Seq[String] = catalog.listDatabases(schemaPattern)
val globalTmpDb = catalog.globalTempViewManager.database
if (Pattern.compile(schemaPattern).matcher(globalTmpDb).matches()) {
databases = databases ++ Seq(globalTmpDb)
}
val tableAndGlobalViews: Seq[Row] = databases.flatMap { db =>
val databases: Seq[String] = catalog.listDatabases(schemaPattern)
val tables: Seq[Row] = catalog.listDatabases(schemaPattern).flatMap { db =>
val identifiers =
catalog.listTables(db, tablePattern, includeLocalTempViews = false)
catalog.getTablesByName(identifiers).flatMap { t =>
@ -188,6 +186,19 @@ class GetColumns(
}
}
val gviews = new ArrayBuffer[Row]()
val globalTmpDb = catalog.globalTempViewManager.database
if (Pattern.compile(schemaPattern).matcher(globalTmpDb).matches()) {
catalog.globalTempViewManager.listViewNames(tablePattern).foreach { v =>
catalog.globalTempViewManager.get(v).foreach { plan =>
plan.schema.zipWithIndex
.filter { f => columnPattern == null || columnPattern.matcher(f._1.name).matches() }
.foreach { case (f, i) => gviews += toRow(globalTmpDb, v, f, i) }
}
}
}
val views: Seq[Row] = catalog.listLocalTempViews(tablePattern)
.map(v => (v, catalog.getTempView(v.table).get))
.flatMap { case (v, plan) =>
@ -196,7 +207,7 @@ class GetColumns(
.map { case (f, i) => toRow(null, v.table, f, i) }
}
iter = (tableAndGlobalViews ++ views).toList.iterator
iter = (tables ++ gviews ++ views).toList.iterator
} catch onError()
}
}

View File

@ -368,6 +368,35 @@ class SparkOperationSuite extends WithSparkSQLEngine {
}
}
test("get columns operation should handle interval column properly") {
val viewName = "view_interval"
val ddl = s"CREATE GLOBAL TEMP VIEW $viewName as select interval 1 day as i"
withJdbcStatement(viewName) { statement =>
statement.execute(ddl)
val data = statement.getConnection.getMetaData
val rowSet = data.getColumns("", "global_temp", viewName, null)
while (rowSet.next()) {
assert(rowSet.getString(TABLE_CAT) === null)
assert(rowSet.getString(TABLE_SCHEM) === "global_temp")
assert(rowSet.getString(TABLE_NAME) === viewName)
assert(rowSet.getString(COLUMN_NAME) === "i")
assert(rowSet.getInt(DATA_TYPE) === java.sql.Types.OTHER)
assert(rowSet.getString(TYPE_NAME).equalsIgnoreCase(CalendarIntervalType.sql))
assert(rowSet.getInt(COLUMN_SIZE) === CalendarIntervalType.defaultSize)
assert(rowSet.getInt(DECIMAL_DIGITS) === 0)
assert(rowSet.getInt(NUM_PREC_RADIX) === 0)
assert(rowSet.getInt(NULLABLE) === 0)
assert(rowSet.getString(REMARKS) === "")
assert(rowSet.getInt(ORDINAL_POSITION) === 0)
assert(rowSet.getString(IS_NULLABLE) === "YES")
assert(rowSet.getString(IS_AUTO_INCREMENT) === "NO")
}
}
}
test("get functions") {
withJdbcStatement() { statement =>
val metaData = statement.getConnection.getMetaData