add more tests
This commit is contained in:
parent
83d1ab7450
commit
59b7c3a09e
@ -172,28 +172,26 @@ class GetColumns(
|
||||
val catalog = spark.sessionState.catalog
|
||||
val schemaPattern = convertSchemaPattern(schemaName)
|
||||
val tablePattern = convertIdentifierPattern(tableName, datanucleusFormat = true)
|
||||
val columnPattern = Option(columnName)
|
||||
.map(c => Pattern.compile(convertIdentifierPattern(c, datanucleusFormat = false)))
|
||||
.orNull
|
||||
val columnPattern =
|
||||
Pattern.compile(convertIdentifierPattern(columnName, datanucleusFormat = false))
|
||||
val tables: Seq[Row] = catalog.listDatabases(schemaPattern).flatMap { db =>
|
||||
val identifiers =
|
||||
catalog.listTables(db, tablePattern, includeLocalTempViews = false)
|
||||
catalog.getTablesByName(identifiers).flatMap { t =>
|
||||
t.schema.zipWithIndex
|
||||
.filter { f => columnPattern == null || columnPattern.matcher(f._1.name).matches() }
|
||||
.filter { f => columnPattern.matcher(f._1.name).matches() }
|
||||
.map { case (f, i) => toRow(t.database, t.identifier.table, f, i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
val gviews = new ArrayBuffer[Row]()
|
||||
val globalTmpDb = catalog.globalTempViewManager.database
|
||||
if (Pattern.compile(schemaPattern).matcher(globalTmpDb).matches()) {
|
||||
catalog.globalTempViewManager.listViewNames(tablePattern).foreach { v =>
|
||||
catalog.globalTempViewManager.get(v).foreach { plan =>
|
||||
plan.schema.zipWithIndex
|
||||
.filter { f => columnPattern == null || columnPattern.matcher(f._1.name).matches() }
|
||||
.filter { f => columnPattern.matcher(f._1.name).matches() }
|
||||
.foreach { case (f, i) => gviews += toRow(globalTmpDb, v, f, i) }
|
||||
}
|
||||
}
|
||||
@ -203,11 +201,13 @@ class GetColumns(
|
||||
.map(v => (v, catalog.getTempView(v.table).get))
|
||||
.flatMap { case (v, plan) =>
|
||||
plan.schema.zipWithIndex
|
||||
.filter(f => columnPattern == null || columnPattern.matcher(f._1.name).matches())
|
||||
.filter(f => columnPattern.matcher(f._1.name).matches())
|
||||
.map { case (f, i) => toRow(null, v.table, f, i) }
|
||||
}
|
||||
|
||||
iter = (tables ++ gviews ++ views).toList.iterator
|
||||
} catch onError()
|
||||
} catch {
|
||||
onError()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -63,6 +63,8 @@ class GetFunctions(
|
||||
}
|
||||
}
|
||||
iter = a.toList.iterator
|
||||
} catch onError()
|
||||
} catch {
|
||||
onError()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -44,8 +44,7 @@ class GetSchemas(spark: SparkSession, session: Session, catalogName: String, sch
|
||||
val schemaPattern = convertSchemaPattern(schema)
|
||||
val databases = spark.sessionState.catalog.listDatabases(schemaPattern)
|
||||
val globalTmpViewDb = spark.sessionState.catalog.globalTempViewManager.database
|
||||
if (schema == null ||
|
||||
Pattern.compile(convertSchemaPattern(schema, false))
|
||||
if (Pattern.compile(convertSchemaPattern(schema, false))
|
||||
.matcher(globalTmpViewDb).matches()) {
|
||||
iter = (databases :+ globalTmpViewDb).map(Row(_, "")).toList.iterator
|
||||
} else {
|
||||
|
||||
@ -100,7 +100,9 @@ class GetTables(
|
||||
Seq.empty[Row]
|
||||
}
|
||||
iter = (tables ++ views).toList.iterator
|
||||
} catch onError()
|
||||
} catch {
|
||||
onError()
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@ -50,8 +50,8 @@ class GetTypeInfo(spark: SparkSession, session: Session)
|
||||
" null)")
|
||||
.add("MINIMUM_SCALE", "smallint", nullable = false, "Minimum scale supported")
|
||||
.add("MAXIMUM_SCALE", "smallint", nullable = false, "Maximum scale supported")
|
||||
.add("SQL_DATA_TYPE", "int", nullable = true, "Unused")
|
||||
.add("SQL_DATETIME_SUB", "int", nullable = true, "Unused")
|
||||
.add(SQL_DATA_TYPE, "int", nullable = true, "Unused")
|
||||
.add(SQL_DATETIME_SUB, "int", nullable = true, "Unused")
|
||||
.add(NUM_PREC_RADIX, "int", nullable = false, "Usually 2 or 10")
|
||||
}
|
||||
|
||||
@ -61,24 +61,24 @@ class GetTypeInfo(spark: SparkSession, session: Session)
|
||||
}
|
||||
|
||||
private def toRow(name: String, javaType: Int, precision: Integer = null): Row = {
|
||||
Row(name, // TYPE_NAME
|
||||
javaType, // DATA_TYPE
|
||||
precision, // PRECISION
|
||||
null, // LITERAL_PREFIX
|
||||
null, // LITERAL_SUFFIX
|
||||
null, // CREATE_PARAMS
|
||||
1.toShort, // NULLABLE
|
||||
javaType == VARCHAR, // CASE_SENSITIVE
|
||||
Row(name, // TYPE_NAME
|
||||
javaType, // DATA_TYPE
|
||||
precision, // PRECISION
|
||||
null, // LITERAL_PREFIX
|
||||
null, // LITERAL_SUFFIX
|
||||
null, // CREATE_PARAMS
|
||||
1.toShort, // NULLABLE
|
||||
javaType == VARCHAR, // CASE_SENSITIVE
|
||||
if (javaType < 1111) 3.toShort else 0.toShort, // SEARCHABLE
|
||||
!isNumericType(javaType), // UNSIGNED_ATTRIBUTE
|
||||
false, // FIXED_PREC_SCALE
|
||||
false, // AUTO_INCREMENT
|
||||
null, // LOCAL_TYPE_NAME
|
||||
0.toShort, // MINIMUM_SCALE
|
||||
0.toShort, // MAXIMUM_SCALE
|
||||
null, // SQL_DATA_TYPE
|
||||
null, // SQL_DATETIME_SUB
|
||||
if (isNumericType(javaType)) 10 else null // NUM_PREC_RADIX
|
||||
!isNumericType(javaType), // UNSIGNED_ATTRIBUTE
|
||||
false, // FIXED_PREC_SCALE
|
||||
false, // AUTO_INCREMENT
|
||||
null, // LOCAL_TYPE_NAME
|
||||
0.toShort, // MINIMUM_SCALE
|
||||
0.toShort, // MAXIMUM_SCALE
|
||||
null, // SQL_DATA_TYPE
|
||||
null, // SQL_DATETIME_SUB
|
||||
if (isNumericType(javaType)) 10 else null // NUM_PREC_RADIX
|
||||
)
|
||||
}
|
||||
|
||||
@ -92,7 +92,7 @@ class GetTypeInfo(spark: SparkSession, session: Session)
|
||||
toRow("BIGINT", BIGINT, 19),
|
||||
toRow("FLOAT", FLOAT, 7),
|
||||
toRow("DOUBLE", DOUBLE, 15),
|
||||
toRow("VARCHAR", VARCHAR),
|
||||
toRow("STRING", VARCHAR),
|
||||
toRow("BINARY", BINARY),
|
||||
toRow("DECIMAL", DECIMAL, 38),
|
||||
toRow("DATE", DATE),
|
||||
|
||||
@ -78,7 +78,7 @@ abstract class SparkOperation(spark: SparkSession, opType: OperationType, sessio
|
||||
*/
|
||||
protected def convertIdentifierPattern(pattern: String, datanucleusFormat: Boolean): String = {
|
||||
if (pattern == null) {
|
||||
convertPattern("%", datanucleusFormat = true)
|
||||
convertPattern("%", datanucleusFormat)
|
||||
} else {
|
||||
convertPattern(pattern, datanucleusFormat)
|
||||
}
|
||||
|
||||
@ -330,57 +330,60 @@ class SparkOperationSuite extends WithSparkSQLEngine {
|
||||
statement.execute(ddl)
|
||||
|
||||
val metaData = statement.getConnection.getMetaData
|
||||
val rowSet = metaData.getColumns("", dftSchema, tableName, null)
|
||||
|
||||
import java.sql.Types._
|
||||
val expectedJavaTypes = Seq(BOOLEAN, TINYINT, SMALLINT, INTEGER, BIGINT, FLOAT, DOUBLE,
|
||||
DECIMAL, DECIMAL, VARCHAR, ARRAY, ARRAY, JAVA_OBJECT, DATE, TIMESTAMP, STRUCT, BINARY,
|
||||
STRUCT)
|
||||
Seq("%", null, ".*", "c.*") foreach { pattern =>
|
||||
val rowSet = metaData.getColumns("", dftSchema, tableName, pattern)
|
||||
|
||||
var pos = 0
|
||||
import java.sql.Types._
|
||||
val expectedJavaTypes = Seq(BOOLEAN, TINYINT, SMALLINT, INTEGER, BIGINT, FLOAT, DOUBLE,
|
||||
DECIMAL, DECIMAL, VARCHAR, ARRAY, ARRAY, JAVA_OBJECT, DATE, TIMESTAMP, STRUCT, BINARY,
|
||||
STRUCT)
|
||||
|
||||
while (rowSet.next()) {
|
||||
assert(rowSet.getString(TABLE_CAT) === null)
|
||||
assert(rowSet.getString(TABLE_SCHEM) === dftSchema)
|
||||
assert(rowSet.getString(TABLE_NAME) === tableName)
|
||||
assert(rowSet.getString(COLUMN_NAME) === schema(pos).name)
|
||||
assert(rowSet.getInt(DATA_TYPE) === expectedJavaTypes(pos))
|
||||
assert(rowSet.getString(TYPE_NAME) === schema(pos).dataType.sql)
|
||||
var pos = 0
|
||||
|
||||
val colSize = rowSet.getInt(COLUMN_SIZE)
|
||||
schema(pos).dataType match {
|
||||
case StringType | BinaryType | _: ArrayType | _: MapType => assert(colSize === 0)
|
||||
case StructType(fields) if fields.size == 1 => assert(colSize === 0)
|
||||
case o => assert(colSize === o.defaultSize)
|
||||
while (rowSet.next()) {
|
||||
assert(rowSet.getString(TABLE_CAT) === null)
|
||||
assert(rowSet.getString(TABLE_SCHEM) === dftSchema)
|
||||
assert(rowSet.getString(TABLE_NAME) === tableName)
|
||||
assert(rowSet.getString(COLUMN_NAME) === schema(pos).name)
|
||||
assert(rowSet.getInt(DATA_TYPE) === expectedJavaTypes(pos))
|
||||
assert(rowSet.getString(TYPE_NAME) === schema(pos).dataType.sql)
|
||||
|
||||
val colSize = rowSet.getInt(COLUMN_SIZE)
|
||||
schema(pos).dataType match {
|
||||
case StringType | BinaryType | _: ArrayType | _: MapType => assert(colSize === 0)
|
||||
case StructType(fields) if fields.length == 1 => assert(colSize === 0)
|
||||
case o => assert(colSize === o.defaultSize)
|
||||
}
|
||||
|
||||
assert(rowSet.getInt(BUFFER_LENGTH) === 0) // not used
|
||||
val decimalDigits = rowSet.getInt(DECIMAL_DIGITS)
|
||||
schema(pos).dataType match {
|
||||
case BooleanType | _: IntegerType => assert(decimalDigits === 0)
|
||||
case d: DecimalType => assert(decimalDigits === d.scale)
|
||||
case FloatType => assert(decimalDigits === 7)
|
||||
case DoubleType => assert(decimalDigits === 15)
|
||||
case TimestampType => assert(decimalDigits === 6)
|
||||
case _ => assert(decimalDigits === 0) // nulls
|
||||
}
|
||||
|
||||
val radix = rowSet.getInt(NUM_PREC_RADIX)
|
||||
schema(pos).dataType match {
|
||||
case _: NumericType => assert(radix === 10)
|
||||
case _ => assert(radix === 0) // nulls
|
||||
}
|
||||
|
||||
assert(rowSet.getInt(NULLABLE) === 1)
|
||||
assert(rowSet.getString(REMARKS) === pos.toString)
|
||||
assert(rowSet.getInt(ORDINAL_POSITION) === pos)
|
||||
assert(rowSet.getString(IS_NULLABLE) === "YES")
|
||||
assert(rowSet.getString(IS_AUTO_INCREMENT) === "NO")
|
||||
pos += 1
|
||||
}
|
||||
|
||||
assert(rowSet.getInt(BUFFER_LENGTH) === 0) // not used
|
||||
val decimalDigits = rowSet.getInt(DECIMAL_DIGITS)
|
||||
schema(pos).dataType match {
|
||||
case BooleanType | _: IntegerType => assert(decimalDigits === 0)
|
||||
case d: DecimalType => assert(decimalDigits === d.scale)
|
||||
case FloatType => assert(decimalDigits === 7)
|
||||
case DoubleType => assert(decimalDigits === 15)
|
||||
case TimestampType => assert(decimalDigits === 6)
|
||||
case _ => assert(decimalDigits === 0) // nulls
|
||||
}
|
||||
|
||||
val radix = rowSet.getInt(NUM_PREC_RADIX)
|
||||
schema(pos).dataType match {
|
||||
case _: NumericType => assert(radix === 10)
|
||||
case _ => assert(radix === 0) // nulls
|
||||
}
|
||||
|
||||
assert(rowSet.getInt(NULLABLE) === 1)
|
||||
assert(rowSet.getString(REMARKS) === pos.toString)
|
||||
assert(rowSet.getInt(ORDINAL_POSITION) === pos)
|
||||
assert(rowSet.getString(IS_NULLABLE) === "YES")
|
||||
assert(rowSet.getString(IS_AUTO_INCREMENT) === "NO")
|
||||
pos += 1
|
||||
assert(pos === 18, "all columns should have been verified")
|
||||
}
|
||||
|
||||
assert(pos === 18, "all columns should have been verified")
|
||||
|
||||
val e = intercept[HiveSQLException](metaData.getColumns(null, "*", null, null))
|
||||
assert(e.getCause.getMessage === "org.apache.kyuubi.KyuubiSQLException:" +
|
||||
"Error operating GET_COLUMNS: Dangling meta character '*' near index 0\n*\n^")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user