add more tests

This commit is contained in:
Kent Yao 2020-09-10 12:01:58 +08:00
parent 83d1ab7450
commit 59b7c3a09e
7 changed files with 82 additions and 76 deletions

View File

@ -172,28 +172,26 @@ class GetColumns(
val catalog = spark.sessionState.catalog
val schemaPattern = convertSchemaPattern(schemaName)
val tablePattern = convertIdentifierPattern(tableName, datanucleusFormat = true)
val columnPattern = Option(columnName)
.map(c => Pattern.compile(convertIdentifierPattern(c, datanucleusFormat = false)))
.orNull
val columnPattern =
Pattern.compile(convertIdentifierPattern(columnName, datanucleusFormat = false))
val tables: Seq[Row] = catalog.listDatabases(schemaPattern).flatMap { db =>
val identifiers =
catalog.listTables(db, tablePattern, includeLocalTempViews = false)
catalog.getTablesByName(identifiers).flatMap { t =>
t.schema.zipWithIndex
.filter { f => columnPattern == null || columnPattern.matcher(f._1.name).matches() }
.filter { f => columnPattern.matcher(f._1.name).matches() }
.map { case (f, i) => toRow(t.database, t.identifier.table, f, i)
}
}
}
val gviews = new ArrayBuffer[Row]()
val globalTmpDb = catalog.globalTempViewManager.database
if (Pattern.compile(schemaPattern).matcher(globalTmpDb).matches()) {
catalog.globalTempViewManager.listViewNames(tablePattern).foreach { v =>
catalog.globalTempViewManager.get(v).foreach { plan =>
plan.schema.zipWithIndex
.filter { f => columnPattern == null || columnPattern.matcher(f._1.name).matches() }
.filter { f => columnPattern.matcher(f._1.name).matches() }
.foreach { case (f, i) => gviews += toRow(globalTmpDb, v, f, i) }
}
}
@ -203,11 +201,13 @@ class GetColumns(
.map(v => (v, catalog.getTempView(v.table).get))
.flatMap { case (v, plan) =>
plan.schema.zipWithIndex
.filter(f => columnPattern == null || columnPattern.matcher(f._1.name).matches())
.filter(f => columnPattern.matcher(f._1.name).matches())
.map { case (f, i) => toRow(null, v.table, f, i) }
}
iter = (tables ++ gviews ++ views).toList.iterator
} catch onError()
} catch {
onError()
}
}
}

View File

@ -63,6 +63,8 @@ class GetFunctions(
}
}
iter = a.toList.iterator
} catch onError()
} catch {
onError()
}
}
}

View File

@ -44,8 +44,7 @@ class GetSchemas(spark: SparkSession, session: Session, catalogName: String, sch
val schemaPattern = convertSchemaPattern(schema)
val databases = spark.sessionState.catalog.listDatabases(schemaPattern)
val globalTmpViewDb = spark.sessionState.catalog.globalTempViewManager.database
if (schema == null ||
Pattern.compile(convertSchemaPattern(schema, false))
if (Pattern.compile(convertSchemaPattern(schema, false))
.matcher(globalTmpViewDb).matches()) {
iter = (databases :+ globalTmpViewDb).map(Row(_, "")).toList.iterator
} else {

View File

@ -100,7 +100,9 @@ class GetTables(
Seq.empty[Row]
}
iter = (tables ++ views).toList.iterator
} catch onError()
} catch {
onError()
}
}
}

View File

@ -50,8 +50,8 @@ class GetTypeInfo(spark: SparkSession, session: Session)
" null)")
.add("MINIMUM_SCALE", "smallint", nullable = false, "Minimum scale supported")
.add("MAXIMUM_SCALE", "smallint", nullable = false, "Maximum scale supported")
.add("SQL_DATA_TYPE", "int", nullable = true, "Unused")
.add("SQL_DATETIME_SUB", "int", nullable = true, "Unused")
.add(SQL_DATA_TYPE, "int", nullable = true, "Unused")
.add(SQL_DATETIME_SUB, "int", nullable = true, "Unused")
.add(NUM_PREC_RADIX, "int", nullable = false, "Usually 2 or 10")
}
@ -61,24 +61,24 @@ class GetTypeInfo(spark: SparkSession, session: Session)
}
private def toRow(name: String, javaType: Int, precision: Integer = null): Row = {
Row(name, // TYPE_NAME
javaType, // DATA_TYPE
precision, // PRECISION
null, // LITERAL_PREFIX
null, // LITERAL_SUFFIX
null, // CREATE_PARAMS
1.toShort, // NULLABLE
javaType == VARCHAR, // CASE_SENSITIVE
Row(name, // TYPE_NAME
javaType, // DATA_TYPE
precision, // PRECISION
null, // LITERAL_PREFIX
null, // LITERAL_SUFFIX
null, // CREATE_PARAMS
1.toShort, // NULLABLE
javaType == VARCHAR, // CASE_SENSITIVE
if (javaType < 1111) 3.toShort else 0.toShort, // SEARCHABLE
!isNumericType(javaType), // UNSIGNED_ATTRIBUTE
false, // FIXED_PREC_SCALE
false, // AUTO_INCREMENT
null, // LOCAL_TYPE_NAME
0.toShort, // MINIMUM_SCALE
0.toShort, // MAXIMUM_SCALE
null, // SQL_DATA_TYPE
null, // SQL_DATETIME_SUB
if (isNumericType(javaType)) 10 else null // NUM_PREC_RADIX
!isNumericType(javaType), // UNSIGNED_ATTRIBUTE
false, // FIXED_PREC_SCALE
false, // AUTO_INCREMENT
null, // LOCAL_TYPE_NAME
0.toShort, // MINIMUM_SCALE
0.toShort, // MAXIMUM_SCALE
null, // SQL_DATA_TYPE
null, // SQL_DATETIME_SUB
if (isNumericType(javaType)) 10 else null // NUM_PREC_RADIX
)
}
@ -92,7 +92,7 @@ class GetTypeInfo(spark: SparkSession, session: Session)
toRow("BIGINT", BIGINT, 19),
toRow("FLOAT", FLOAT, 7),
toRow("DOUBLE", DOUBLE, 15),
toRow("VARCHAR", VARCHAR),
toRow("STRING", VARCHAR),
toRow("BINARY", BINARY),
toRow("DECIMAL", DECIMAL, 38),
toRow("DATE", DATE),

View File

@ -78,7 +78,7 @@ abstract class SparkOperation(spark: SparkSession, opType: OperationType, sessio
*/
protected def convertIdentifierPattern(pattern: String, datanucleusFormat: Boolean): String = {
if (pattern == null) {
convertPattern("%", datanucleusFormat = true)
convertPattern("%", datanucleusFormat)
} else {
convertPattern(pattern, datanucleusFormat)
}

View File

@ -330,57 +330,60 @@ class SparkOperationSuite extends WithSparkSQLEngine {
statement.execute(ddl)
val metaData = statement.getConnection.getMetaData
val rowSet = metaData.getColumns("", dftSchema, tableName, null)
import java.sql.Types._
val expectedJavaTypes = Seq(BOOLEAN, TINYINT, SMALLINT, INTEGER, BIGINT, FLOAT, DOUBLE,
DECIMAL, DECIMAL, VARCHAR, ARRAY, ARRAY, JAVA_OBJECT, DATE, TIMESTAMP, STRUCT, BINARY,
STRUCT)
Seq("%", null, ".*", "c.*") foreach { pattern =>
val rowSet = metaData.getColumns("", dftSchema, tableName, pattern)
var pos = 0
import java.sql.Types._
val expectedJavaTypes = Seq(BOOLEAN, TINYINT, SMALLINT, INTEGER, BIGINT, FLOAT, DOUBLE,
DECIMAL, DECIMAL, VARCHAR, ARRAY, ARRAY, JAVA_OBJECT, DATE, TIMESTAMP, STRUCT, BINARY,
STRUCT)
while (rowSet.next()) {
assert(rowSet.getString(TABLE_CAT) === null)
assert(rowSet.getString(TABLE_SCHEM) === dftSchema)
assert(rowSet.getString(TABLE_NAME) === tableName)
assert(rowSet.getString(COLUMN_NAME) === schema(pos).name)
assert(rowSet.getInt(DATA_TYPE) === expectedJavaTypes(pos))
assert(rowSet.getString(TYPE_NAME) === schema(pos).dataType.sql)
var pos = 0
val colSize = rowSet.getInt(COLUMN_SIZE)
schema(pos).dataType match {
case StringType | BinaryType | _: ArrayType | _: MapType => assert(colSize === 0)
case StructType(fields) if fields.size == 1 => assert(colSize === 0)
case o => assert(colSize === o.defaultSize)
while (rowSet.next()) {
assert(rowSet.getString(TABLE_CAT) === null)
assert(rowSet.getString(TABLE_SCHEM) === dftSchema)
assert(rowSet.getString(TABLE_NAME) === tableName)
assert(rowSet.getString(COLUMN_NAME) === schema(pos).name)
assert(rowSet.getInt(DATA_TYPE) === expectedJavaTypes(pos))
assert(rowSet.getString(TYPE_NAME) === schema(pos).dataType.sql)
val colSize = rowSet.getInt(COLUMN_SIZE)
schema(pos).dataType match {
case StringType | BinaryType | _: ArrayType | _: MapType => assert(colSize === 0)
case StructType(fields) if fields.length == 1 => assert(colSize === 0)
case o => assert(colSize === o.defaultSize)
}
assert(rowSet.getInt(BUFFER_LENGTH) === 0) // not used
val decimalDigits = rowSet.getInt(DECIMAL_DIGITS)
schema(pos).dataType match {
case BooleanType | _: IntegerType => assert(decimalDigits === 0)
case d: DecimalType => assert(decimalDigits === d.scale)
case FloatType => assert(decimalDigits === 7)
case DoubleType => assert(decimalDigits === 15)
case TimestampType => assert(decimalDigits === 6)
case _ => assert(decimalDigits === 0) // nulls
}
val radix = rowSet.getInt(NUM_PREC_RADIX)
schema(pos).dataType match {
case _: NumericType => assert(radix === 10)
case _ => assert(radix === 0) // nulls
}
assert(rowSet.getInt(NULLABLE) === 1)
assert(rowSet.getString(REMARKS) === pos.toString)
assert(rowSet.getInt(ORDINAL_POSITION) === pos)
assert(rowSet.getString(IS_NULLABLE) === "YES")
assert(rowSet.getString(IS_AUTO_INCREMENT) === "NO")
pos += 1
}
assert(rowSet.getInt(BUFFER_LENGTH) === 0) // not used
val decimalDigits = rowSet.getInt(DECIMAL_DIGITS)
schema(pos).dataType match {
case BooleanType | _: IntegerType => assert(decimalDigits === 0)
case d: DecimalType => assert(decimalDigits === d.scale)
case FloatType => assert(decimalDigits === 7)
case DoubleType => assert(decimalDigits === 15)
case TimestampType => assert(decimalDigits === 6)
case _ => assert(decimalDigits === 0) // nulls
}
val radix = rowSet.getInt(NUM_PREC_RADIX)
schema(pos).dataType match {
case _: NumericType => assert(radix === 10)
case _ => assert(radix === 0) // nulls
}
assert(rowSet.getInt(NULLABLE) === 1)
assert(rowSet.getString(REMARKS) === pos.toString)
assert(rowSet.getInt(ORDINAL_POSITION) === pos)
assert(rowSet.getString(IS_NULLABLE) === "YES")
assert(rowSet.getString(IS_AUTO_INCREMENT) === "NO")
pos += 1
assert(pos === 18, "all columns should have been verified")
}
assert(pos === 18, "all columns should have been verified")
val e = intercept[HiveSQLException](metaData.getColumns(null, "*", null, null))
assert(e.getCause.getMessage === "org.apache.kyuubi.KyuubiSQLException:" +
"Error operating GET_COLUMNS: Dangling meta character '*' near index 0\n*\n^")