Support Spark StructType to TTableSchema
This commit is contained in:
parent
934a9d3158
commit
eacaa608e1
@ -0,0 +1,87 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.kyuubi.schema
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import org.apache.hive.service.rpc.thrift._
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
object SchemaHelper {
|
||||
|
||||
def toTTypeId(typ: DataType): TTypeId = typ match {
|
||||
case NullType => TTypeId.NULL_TYPE
|
||||
case BooleanType => TTypeId.BOOLEAN_TYPE
|
||||
case ByteType => TTypeId.TINYINT_TYPE
|
||||
case ShortType => TTypeId.SMALLINT_TYPE
|
||||
case IntegerType => TTypeId.INT_TYPE
|
||||
case LongType => TTypeId.BIGINT_TYPE
|
||||
case FloatType => TTypeId.FLOAT_TYPE
|
||||
case DoubleType => TTypeId.DOUBLE_TYPE
|
||||
case StringType => TTypeId.STRING_TYPE
|
||||
case DecimalType() => TTypeId.DECIMAL_TYPE
|
||||
case DateType => TTypeId.DATE_TYPE
|
||||
case TimestampType => TTypeId.TIMESTAMP_TYPE
|
||||
case BinaryType => TTypeId.BINARY_TYPE
|
||||
case CalendarIntervalType => TTypeId.STRING_TYPE
|
||||
case _: ArrayType => TTypeId.ARRAY_TYPE
|
||||
case _: MapType => TTypeId.MAP_TYPE
|
||||
case _: StructType => TTypeId.STRUCT_TYPE
|
||||
// TODO: it is private now, case udt: UserDefinedType => TTypeId.USER_DEFINED_TYPE
|
||||
case other =>
|
||||
throw new IllegalArgumentException(s"Unrecognized type name: ${other.catalogString}")
|
||||
}
|
||||
|
||||
def toTTypeQualifiers(typ: DataType): TTypeQualifiers = {
|
||||
val ret = new TTypeQualifiers()
|
||||
typ match {
|
||||
case d: DecimalType =>
|
||||
val qualifiers =
|
||||
Map(TCLIServiceConstants.PRECISION -> TTypeQualifierValue.i32Value(d.precision),
|
||||
TCLIServiceConstants.SCALE -> TTypeQualifierValue.i32Value(d.scale))
|
||||
ret.setQualifiers(qualifiers.asJava)
|
||||
case _ =>
|
||||
}
|
||||
ret
|
||||
}
|
||||
|
||||
def toTTypeDesc(typ: DataType): TTypeDesc = {
|
||||
val typeEntry = new TPrimitiveTypeEntry(toTTypeId(typ))
|
||||
typeEntry.setTypeQualifiers(toTTypeQualifiers(typ))
|
||||
val tTypeDesc = new TTypeDesc()
|
||||
tTypeDesc.addToTypes(TTypeEntry.primitiveEntry(typeEntry))
|
||||
tTypeDesc
|
||||
}
|
||||
|
||||
def toTColumnDesc(field: StructField, pos: Int): TColumnDesc = {
|
||||
val tColumnDesc = new TColumnDesc()
|
||||
tColumnDesc.setColumnName(field.name)
|
||||
tColumnDesc.setTypeDesc(toTTypeDesc(field.dataType))
|
||||
tColumnDesc.setComment(field.getComment().getOrElse(""))
|
||||
tColumnDesc.setPosition(pos)
|
||||
tColumnDesc
|
||||
}
|
||||
|
||||
def toTTableSchema(schema: StructType): TTableSchema = {
|
||||
val tTableSchema = new TTableSchema()
|
||||
schema.zipWithIndex.foreach { case (f, i) =>
|
||||
tTableSchema.addToColumns(toTColumnDesc(f, i))
|
||||
}
|
||||
tTableSchema
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,107 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.kyuubi.schema
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import org.apache.hive.service.rpc.thrift.{TCLIServiceConstants, TTypeId}
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
import org.apache.kyuubi.KyuubiFunSuite
|
||||
|
||||
class SchemaHelperSuite extends KyuubiFunSuite {
|
||||
|
||||
import SchemaHelper._
|
||||
|
||||
val innerSchema: StructType = new StructType()
|
||||
.add("a", StringType, nullable = true, "")
|
||||
.add("b", IntegerType, nullable = true, "")
|
||||
|
||||
val outerSchema: StructType = new StructType()
|
||||
.add("c0", NullType, true, "this is comment")
|
||||
.add("c1", BooleanType, true, "this is comment too")
|
||||
.add("c2", ByteType)
|
||||
.add("c3", ShortType)
|
||||
.add("c4", IntegerType)
|
||||
.add("c5", LongType)
|
||||
.add("c6", FloatType)
|
||||
.add("c7", DoubleType)
|
||||
.add("c8", StringType)
|
||||
.add("c9", DecimalType(10, 8))
|
||||
.add("c10", DateType)
|
||||
.add("c11", TimestampType)
|
||||
.add("c12", BinaryType)
|
||||
.add("c13", ArrayType(LongType))
|
||||
.add("c14", MapType(IntegerType, FloatType))
|
||||
.add("c15", innerSchema)
|
||||
|
||||
|
||||
test("toTTypeId") {
|
||||
assert(toTTypeId(outerSchema.head.dataType) === TTypeId.NULL_TYPE)
|
||||
assert(toTTypeId(outerSchema(1).dataType) === TTypeId.BOOLEAN_TYPE)
|
||||
assert(toTTypeId(outerSchema(2).dataType) === TTypeId.TINYINT_TYPE)
|
||||
assert(toTTypeId(outerSchema(3).dataType) === TTypeId.SMALLINT_TYPE)
|
||||
assert(toTTypeId(outerSchema(4).dataType) === TTypeId.INT_TYPE)
|
||||
assert(toTTypeId(outerSchema(5).dataType) === TTypeId.BIGINT_TYPE)
|
||||
assert(toTTypeId(outerSchema(6).dataType) === TTypeId.FLOAT_TYPE)
|
||||
assert(toTTypeId(outerSchema(7).dataType) === TTypeId.DOUBLE_TYPE)
|
||||
assert(toTTypeId(outerSchema(8).dataType) === TTypeId.STRING_TYPE)
|
||||
assert(toTTypeId(outerSchema(9).dataType) === TTypeId.DECIMAL_TYPE)
|
||||
assert(toTTypeId(outerSchema(10).dataType) === TTypeId.DATE_TYPE)
|
||||
assert(toTTypeId(outerSchema(11).dataType) === TTypeId.TIMESTAMP_TYPE)
|
||||
assert(toTTypeId(outerSchema(12).dataType) === TTypeId.BINARY_TYPE)
|
||||
assert(toTTypeId(outerSchema(13).dataType) === TTypeId.ARRAY_TYPE)
|
||||
assert(toTTypeId(outerSchema(14).dataType) === TTypeId.MAP_TYPE)
|
||||
assert(toTTypeId(outerSchema(15).dataType) === TTypeId.STRUCT_TYPE)
|
||||
}
|
||||
|
||||
test("toTTypeQualifiers") {
|
||||
val qualifiers = toTTypeQualifiers(outerSchema(9).dataType)
|
||||
val q = qualifiers.getQualifiers
|
||||
assert(q.size() === 2)
|
||||
assert(q.get(TCLIServiceConstants.PRECISION).getI32Value === 10)
|
||||
assert(q.get(TCLIServiceConstants.SCALE).getI32Value === 8)
|
||||
|
||||
outerSchema.foreach {
|
||||
case f if f.dataType == DecimalType(10, 8) =>
|
||||
case f => assert(toTTypeQualifiers(f.dataType).getQualifiers === null)
|
||||
}
|
||||
}
|
||||
|
||||
test("toTTableSchema") {
|
||||
val tTableSchema = toTTableSchema(outerSchema)
|
||||
assert(tTableSchema.getColumnsSize === outerSchema.size)
|
||||
val iter = tTableSchema.getColumns
|
||||
|
||||
iter.asScala.zipWithIndex.foreach { case (col, pos) =>
|
||||
val field = outerSchema(pos)
|
||||
assert(col.getColumnName === field.name)
|
||||
assert(col.getComment === field.getComment().getOrElse(""))
|
||||
assert(col.getPosition === pos)
|
||||
val qualifiers =
|
||||
col.getTypeDesc.getTypes.get(0).getPrimitiveEntry.getTypeQualifiers.getQualifiers
|
||||
if (pos == 9) {
|
||||
assert(qualifiers.get(TCLIServiceConstants.PRECISION).getI32Value === 10)
|
||||
assert(qualifiers.get(TCLIServiceConstants.SCALE).getI32Value === 8)
|
||||
|
||||
} else {
|
||||
assert(qualifiers == null)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user