From eacaa608e1a9c661ef20aad376f5fd0163ab4d8d Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Wed, 8 Jul 2020 21:15:12 +0800 Subject: [PATCH] Support Spark StructType to TTableSchema --- .../apache/kyuubi/schema/SchemaHelper.scala | 87 ++++++++++++++ .../kyuubi/schema/SchemaHelperSuite.scala | 107 ++++++++++++++++++ 2 files changed, 194 insertions(+) create mode 100644 kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/schema/SchemaHelper.scala create mode 100644 kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/schema/SchemaHelperSuite.scala diff --git a/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/schema/SchemaHelper.scala b/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/schema/SchemaHelper.scala new file mode 100644 index 000000000..ab868e736 --- /dev/null +++ b/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/schema/SchemaHelper.scala @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.schema + +import scala.collection.JavaConverters._ + +import org.apache.hive.service.rpc.thrift._ +import org.apache.spark.sql.types._ + +object SchemaHelper { + + def toTTypeId(typ: DataType): TTypeId = typ match { + case NullType => TTypeId.NULL_TYPE + case BooleanType => TTypeId.BOOLEAN_TYPE + case ByteType => TTypeId.TINYINT_TYPE + case ShortType => TTypeId.SMALLINT_TYPE + case IntegerType => TTypeId.INT_TYPE + case LongType => TTypeId.BIGINT_TYPE + case FloatType => TTypeId.FLOAT_TYPE + case DoubleType => TTypeId.DOUBLE_TYPE + case StringType => TTypeId.STRING_TYPE + case DecimalType() => TTypeId.DECIMAL_TYPE + case DateType => TTypeId.DATE_TYPE + case TimestampType => TTypeId.TIMESTAMP_TYPE + case BinaryType => TTypeId.BINARY_TYPE + case CalendarIntervalType => TTypeId.STRING_TYPE + case _: ArrayType => TTypeId.ARRAY_TYPE + case _: MapType => TTypeId.MAP_TYPE + case _: StructType => TTypeId.STRUCT_TYPE + // TODO: it is private now, case udt: UserDefinedType => TTypeId.USER_DEFINED_TYPE + case other => + throw new IllegalArgumentException(s"Unrecognized type name: ${other.catalogString}") + } + + def toTTypeQualifiers(typ: DataType): TTypeQualifiers = { + val ret = new TTypeQualifiers() + typ match { + case d: DecimalType => + val qualifiers = + Map(TCLIServiceConstants.PRECISION -> TTypeQualifierValue.i32Value(d.precision), + TCLIServiceConstants.SCALE -> TTypeQualifierValue.i32Value(d.scale)) + ret.setQualifiers(qualifiers.asJava) + case _ => + } + ret + } + + def toTTypeDesc(typ: DataType): TTypeDesc = { + val typeEntry = new TPrimitiveTypeEntry(toTTypeId(typ)) + typeEntry.setTypeQualifiers(toTTypeQualifiers(typ)) + val tTypeDesc = new TTypeDesc() + tTypeDesc.addToTypes(TTypeEntry.primitiveEntry(typeEntry)) + tTypeDesc + } + + def toTColumnDesc(field: StructField, pos: Int): TColumnDesc = { + val tColumnDesc = new TColumnDesc() + tColumnDesc.setColumnName(field.name) + tColumnDesc.setTypeDesc(toTTypeDesc(field.dataType)) + tColumnDesc.setComment(field.getComment().getOrElse("")) + tColumnDesc.setPosition(pos) + tColumnDesc + } + + def toTTableSchema(schema: StructType): TTableSchema = { + val tTableSchema = new TTableSchema() + schema.zipWithIndex.foreach { case (f, i) => + tTableSchema.addToColumns(toTColumnDesc(f, i)) + } + tTableSchema + } +} diff --git a/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/schema/SchemaHelperSuite.scala b/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/schema/SchemaHelperSuite.scala new file mode 100644 index 000000000..a255e3a2a --- /dev/null +++ b/kyuubi-spark-sql-engine/src/test/scala/org/apache/kyuubi/schema/SchemaHelperSuite.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kyuubi.schema + +import scala.collection.JavaConverters._ + +import org.apache.hive.service.rpc.thrift.{TCLIServiceConstants, TTypeId} +import org.apache.spark.sql.types._ + +import org.apache.kyuubi.KyuubiFunSuite + +class SchemaHelperSuite extends KyuubiFunSuite { + + import SchemaHelper._ + + val innerSchema: StructType = new StructType() + .add("a", StringType, nullable = true, "") + .add("b", IntegerType, nullable = true, "") + + val outerSchema: StructType = new StructType() + .add("c0", NullType, true, "this is comment") + .add("c1", BooleanType, true, "this is comment too") + .add("c2", ByteType) + .add("c3", ShortType) + .add("c4", IntegerType) + .add("c5", LongType) + .add("c6", FloatType) + .add("c7", DoubleType) + .add("c8", StringType) + .add("c9", DecimalType(10, 8)) + .add("c10", DateType) + .add("c11", TimestampType) + .add("c12", BinaryType) + .add("c13", ArrayType(LongType)) + .add("c14", MapType(IntegerType, FloatType)) + .add("c15", innerSchema) + + + test("toTTypeId") { + assert(toTTypeId(outerSchema.head.dataType) === TTypeId.NULL_TYPE) + assert(toTTypeId(outerSchema(1).dataType) === TTypeId.BOOLEAN_TYPE) + assert(toTTypeId(outerSchema(2).dataType) === TTypeId.TINYINT_TYPE) + assert(toTTypeId(outerSchema(3).dataType) === TTypeId.SMALLINT_TYPE) + assert(toTTypeId(outerSchema(4).dataType) === TTypeId.INT_TYPE) + assert(toTTypeId(outerSchema(5).dataType) === TTypeId.BIGINT_TYPE) + assert(toTTypeId(outerSchema(6).dataType) === TTypeId.FLOAT_TYPE) + assert(toTTypeId(outerSchema(7).dataType) === TTypeId.DOUBLE_TYPE) + assert(toTTypeId(outerSchema(8).dataType) === TTypeId.STRING_TYPE) + assert(toTTypeId(outerSchema(9).dataType) === TTypeId.DECIMAL_TYPE) + assert(toTTypeId(outerSchema(10).dataType) === TTypeId.DATE_TYPE) + assert(toTTypeId(outerSchema(11).dataType) === TTypeId.TIMESTAMP_TYPE) + assert(toTTypeId(outerSchema(12).dataType) === TTypeId.BINARY_TYPE) + assert(toTTypeId(outerSchema(13).dataType) === TTypeId.ARRAY_TYPE) + assert(toTTypeId(outerSchema(14).dataType) === TTypeId.MAP_TYPE) + assert(toTTypeId(outerSchema(15).dataType) === TTypeId.STRUCT_TYPE) + } + + test("toTTypeQualifiers") { + val qualifiers = toTTypeQualifiers(outerSchema(9).dataType) + val q = qualifiers.getQualifiers + assert(q.size() === 2) + assert(q.get(TCLIServiceConstants.PRECISION).getI32Value === 10) + assert(q.get(TCLIServiceConstants.SCALE).getI32Value === 8) + + outerSchema.foreach { + case f if f.dataType == DecimalType(10, 8) => + case f => assert(toTTypeQualifiers(f.dataType).getQualifiers === null) + } + } + + test("toTTableSchema") { + val tTableSchema = toTTableSchema(outerSchema) + assert(tTableSchema.getColumnsSize === outerSchema.size) + val iter = tTableSchema.getColumns + + iter.asScala.zipWithIndex.foreach { case (col, pos) => + val field = outerSchema(pos) + assert(col.getColumnName === field.name) + assert(col.getComment === field.getComment().getOrElse("")) + assert(col.getPosition === pos) + val qualifiers = + col.getTypeDesc.getTypes.get(0).getPrimitiveEntry.getTypeQualifiers.getQualifiers + if (pos == 9) { + assert(qualifiers.get(TCLIServiceConstants.PRECISION).getI32Value === 10) + assert(qualifiers.get(TCLIServiceConstants.SCALE).getI32Value === 8) + + } else { + assert(qualifiers == null) + } + } + } +}