diff --git a/externals/kyuubi-flink-sql-engine/src/main/scala/org/apache/kyuubi/engine/flink/schema/RowSet.scala b/externals/kyuubi-flink-sql-engine/src/main/scala/org/apache/kyuubi/engine/flink/schema/RowSet.scala index c2eace289..5f2baf1a6 100644 --- a/externals/kyuubi-flink-sql-engine/src/main/scala/org/apache/kyuubi/engine/flink/schema/RowSet.scala +++ b/externals/kyuubi-flink-sql-engine/src/main/scala/org/apache/kyuubi/engine/flink/schema/RowSet.scala @@ -19,17 +19,22 @@ package org.apache.kyuubi.engine.flink.schema import java.{lang, util} import java.nio.ByteBuffer +import java.nio.charset.StandardCharsets +import java.sql.{Date, Timestamp} +import java.time.{LocalDate, LocalDateTime} import java.util.Collections import scala.collection.JavaConverters._ +import scala.collection.mutable.ListBuffer import scala.language.implicitConversions import org.apache.flink.table.catalog.Column -import org.apache.flink.table.types.logical.{DecimalType, _} +import org.apache.flink.table.types.logical._ import org.apache.flink.types.Row import org.apache.hive.service.rpc.thrift._ import org.apache.kyuubi.engine.flink.result.ResultSet +import org.apache.kyuubi.util.RowSetUtils.{dateFormatter, timestampFormatter} object RowSet { @@ -70,7 +75,8 @@ object RowSet { row: Row, resultSet: ResultSet): TColumnValue = { - val logicalType = resultSet.getColumns.get(ordinal).getDataType.getLogicalType + val column = resultSet.getColumns.get(ordinal) + val logicalType = column.getDataType.getLogicalType logicalType match { case _: BooleanType => @@ -80,6 +86,12 @@ object RowSet { } TColumnValue.boolVal(boolValue) case _: TinyIntType => + val tByteValue = new TByteValue + if (row.getField(ordinal) != null) { + tByteValue.setValue(row.getField(ordinal).asInstanceOf[Byte]) + } + TColumnValue.byteVal(tByteValue) + case _: SmallIntType => val tI16Value = new TI16Value if (row.getField(ordinal) != null) { tI16Value.setValue(row.getField(ordinal).asInstanceOf[Short]) @@ -88,7 +100,7 @@ object RowSet { case _: IntType => val tI32Value = new TI32Value if (row.getField(ordinal) != null) { - tI32Value.setValue(row.getField(ordinal).asInstanceOf[Short]) + tI32Value.setValue(row.getField(ordinal).asInstanceOf[Int]) } TColumnValue.i32Val(tI32Value) case _: BigIntType => @@ -112,21 +124,23 @@ object RowSet { case _: VarCharType => val tStringValue = new TStringValue if (row.getField(ordinal) != null) { - tStringValue.setValue(row.getField(ordinal).asInstanceOf[String]) + val stringValue = row.getField(ordinal).asInstanceOf[String] + tStringValue.setValue(stringValue) } TColumnValue.stringVal(tStringValue) case _: CharType => val tStringValue = new TStringValue if (row.getField(ordinal) != null) { - tStringValue.setValue(row.getField(ordinal).asInstanceOf[String]) + val stringValue = row.getField(ordinal).asInstanceOf[String] + tStringValue.setValue(stringValue) } TColumnValue.stringVal(tStringValue) - case _ => - val tStrValue = new TStringValue + case t => + val tStringValue = new TStringValue if (row.getField(ordinal) != null) { - // TODO to be done + tStringValue.setValue(toHiveString((row.getField(ordinal), t))) } - TColumnValue.stringVal(tStrValue) + TColumnValue.stringVal(tStringValue) } } @@ -141,15 +155,29 @@ object RowSet { val values = getOrSetAsNull[lang.Boolean](rows, ordinal, nulls, true) TColumn.boolVal(new TBoolColumn(values, nulls)) case _: TinyIntType => + val values = getOrSetAsNull[lang.Byte](rows, ordinal, nulls, 0.toByte) + TColumn.byteVal(new TByteColumn(values, nulls)) + case _: SmallIntType => val values = getOrSetAsNull[lang.Short](rows, ordinal, nulls, 0.toShort) TColumn.i16Val(new TI16Column(values, nulls)) + case _: IntType => + val values = getOrSetAsNull[lang.Integer](rows, ordinal, nulls, 0) + TColumn.i32Val(new TI32Column(values, nulls)) + case _: BigIntType => + val values = getOrSetAsNull[lang.Long](rows, ordinal, nulls, 0L) + TColumn.i64Val(new TI64Column(values, nulls)) + case _: FloatType => + val values = getOrSetAsNull[lang.Double](rows, ordinal, nulls, 0.0) + TColumn.doubleVal(new TDoubleColumn(values, nulls)) + case _: DoubleType => + val values = getOrSetAsNull[lang.Double](rows, ordinal, nulls, 0.0) + TColumn.doubleVal(new TDoubleColumn(values, nulls)) case _: VarCharType => val values = getOrSetAsNull[String](rows, ordinal, nulls, "") TColumn.stringVal(new TStringColumn(values, nulls)) case _: CharType => val values = getOrSetAsNull[String](rows, ordinal, nulls, "") TColumn.stringVal(new TStringColumn(values, nulls)) - case _ => val values = rows.zipWithIndex.toList.map { case (row, i) => nulls.set(i, row.getField(ordinal) == null) @@ -209,6 +237,12 @@ object RowSet { Map( TCLIServiceConstants.PRECISION -> TTypeQualifierValue.i32Value(d.getPrecision), TCLIServiceConstants.SCALE -> TTypeQualifierValue.i32Value(d.getScale)).asJava + case v: VarCharType => + Map(TCLIServiceConstants.CHARACTER_MAXIMUM_LENGTH -> + TTypeQualifierValue.i32Value(v.getLength)).asJava + case ch: CharType => + Map(TCLIServiceConstants.CHARACTER_MAXIMUM_LENGTH -> + TTypeQualifierValue.i32Value(ch.getLength)).asJava case _ => Collections.emptyMap[String, TTypeQualifierValue]() } ret.setQualifiers(qualifiers) @@ -220,15 +254,31 @@ object RowSet { case _: BooleanType => TTypeId.BOOLEAN_TYPE case _: FloatType => TTypeId.FLOAT_TYPE case _: DoubleType => TTypeId.DOUBLE_TYPE - case _: VarCharType => TTypeId.STRING_TYPE - case _: CharType => TTypeId.STRING_TYPE + case _: VarCharType => TTypeId.VARCHAR_TYPE + case _: CharType => TTypeId.CHAR_TYPE + case _: TinyIntType => TTypeId.TINYINT_TYPE + case _: SmallIntType => TTypeId.SMALLINT_TYPE + case _: IntType => TTypeId.INT_TYPE + case _: BigIntType => TTypeId.BIGINT_TYPE case _: DecimalType => TTypeId.DECIMAL_TYPE + case _: DateType => TTypeId.DATE_TYPE + case _: TimestampType => TTypeId.TIMESTAMP_TYPE + case _: ArrayType => TTypeId.ARRAY_TYPE + case _: MapType => TTypeId.MAP_TYPE + case _: RowType => TTypeId.STRUCT_TYPE + case _: BinaryType => TTypeId.BINARY_TYPE + case t @ (_: ZonedTimestampType | _: LocalZonedTimestampType | _: MultisetType | + _: YearMonthIntervalType | _: DayTimeIntervalType) => + throw new IllegalArgumentException( + "Flink data type `%s` is not supported currently".format(t.asSummaryString()), + null) case other => throw new IllegalArgumentException(s"Unrecognized type name: ${other.asSummaryString()}") } /** * A simpler impl of Flink's toHiveString + * TODO: support Flink's new data type system */ def toHiveString(dataWithType: (Any, LogicalType)): String = { dataWithType match { @@ -236,9 +286,51 @@ object RowSet { // Only match nulls in nested type values "null" + case (d: Int, _: DateType) => + dateFormatter.format(LocalDate.ofEpochDay(d)) + + case (ld: LocalDate, _: DateType) => + dateFormatter.format(ld) + + case (d: Date, _: DateType) => + dateFormatter.format(d.toInstant) + + case (ldt: LocalDateTime, _: TimestampType) => + timestampFormatter.format(ldt) + + case (ts: Timestamp, _: TimestampType) => + timestampFormatter.format(ts.toInstant) + case (decimal: java.math.BigDecimal, _: DecimalType) => decimal.toPlainString + case (a: Array[_], t: ArrayType) => + a.map(v => toHiveString((v, t.getElementType))).toSeq.mkString( + "[", + ",", + "]") + + case (m: Map[_, _], t: MapType) => + m.map { + case (k, v) => + toHiveString((k, t.getKeyType)) + ":" + toHiveString((v, t.getValueType)) + } + .toSeq.mkString("{", ",", "}") + + case (r: Row, t: RowType) => + val lb = ListBuffer[String]() + for (i <- 0 until r.getArity) { + lb += s"""${t.getTypeAt(i).toString}:${toHiveString((r.getField(i), t.getTypeAt(i)))}""" + } + lb.toList.mkString("{", ",", "}") + + case (s: String, _ @(_: VarCharType | _: CharType)) => + // Only match string in nested type values + "\"" + s + "\"" + + case (bin: Array[Byte], _: BinaryType) => + new String(bin, StandardCharsets.UTF_8) + case (other, _) => other.toString } diff --git a/externals/kyuubi-flink-sql-engine/src/test/scala/org/apache/kyuubi/engine/flink/operation/FlinkOperationSuite.scala b/externals/kyuubi-flink-sql-engine/src/test/scala/org/apache/kyuubi/engine/flink/operation/FlinkOperationSuite.scala index 8801a974e..049970038 100644 --- a/externals/kyuubi-flink-sql-engine/src/test/scala/org/apache/kyuubi/engine/flink/operation/FlinkOperationSuite.scala +++ b/externals/kyuubi-flink-sql-engine/src/test/scala/org/apache/kyuubi/engine/flink/operation/FlinkOperationSuite.scala @@ -198,6 +198,134 @@ class FlinkOperationSuite extends WithFlinkSQLEngine with HiveJDBCTestHelper { } } + test("execute statement - select varchar/char") { + withJdbcStatement() { statement => + val resultSet = + statement.executeQuery("select cast('varchar10' as varchar(10)), " + + "cast('char16' as char(16))") + val metaData = resultSet.getMetaData + assert(metaData.getColumnType(1) === java.sql.Types.VARCHAR) + assert(metaData.getPrecision(1) === 10) + assert(metaData.getColumnType(2) === java.sql.Types.CHAR) + assert(metaData.getPrecision(2) === 16) + assert(resultSet.next()) + assert(resultSet.getString(1) === "varchar10") + assert(resultSet.getString(2) === "char16 ") + } + } + + test("execute statement - select tinyint") { + withJdbcStatement() { statement => + val resultSet = statement.executeQuery("select cast(1 as tinyint)") + assert(resultSet.next()) + assert(resultSet.getByte(1) === 1) + val metaData = resultSet.getMetaData + assert(metaData.getColumnType(1) === java.sql.Types.TINYINT) + } + } + + test("execute statement - select smallint") { + withJdbcStatement() { statement => + val resultSet = statement.executeQuery("select cast(1 as smallint)") + assert(resultSet.next()) + assert(resultSet.getShort(1) === 1) + val metaData = resultSet.getMetaData + assert(metaData.getColumnType(1) === java.sql.Types.SMALLINT) + } + } + + test("execute statement - select int") { + withJdbcStatement() { statement => + val resultSet = statement.executeQuery("select 1") + assert(resultSet.next()) + assert(resultSet.getInt(1) === 1) + val metaData = resultSet.getMetaData + assert(metaData.getColumnType(1) === java.sql.Types.INTEGER) + } + } + + test("execute statement - select bigint") { + withJdbcStatement() { statement => + val resultSet = statement.executeQuery("select cast(1 as bigint)") + assert(resultSet.next()) + assert(resultSet.getLong(1) === 1) + val metaData = resultSet.getMetaData + assert(metaData.getColumnType(1) === java.sql.Types.BIGINT) + } + } + + test("execute statement - select date") { + withJdbcStatement() { statement => + val resultSet = statement.executeQuery("select date '2022-01-01'") + assert(resultSet.next()) + assert(resultSet.getDate(1).toLocalDate.toString == "2022-01-01") + val metaData = resultSet.getMetaData + assert(metaData.getColumnType(1) === java.sql.Types.DATE) + } + } + + test("execute statement - select timestamp") { + withJdbcStatement() { statement => + val resultSet = + statement.executeQuery( + "select timestamp '2022-01-01 00:00:00', timestamp '2022-01-01 00:00:00.123456789'") + val metaData = resultSet.getMetaData + assert(metaData.getColumnType(1) === java.sql.Types.TIMESTAMP) + assert(metaData.getColumnType(2) === java.sql.Types.TIMESTAMP) + // 9 digits for fraction of seconds + assert(metaData.getPrecision(1) == 29) + assert(metaData.getPrecision(2) == 29) + assert(resultSet.next()) + assert(resultSet.getTimestamp(1).toString == "2022-01-01 00:00:00.0") + assert(resultSet.getTimestamp(2).toString == "2022-01-01 00:00:00.123456789") + } + } + + test("execute statement - select array") { + withJdbcStatement() { statement => + val resultSet = + statement.executeQuery("select array ['v1', 'v2', 'v3']") + val metaData = resultSet.getMetaData + assert(metaData.getColumnType(1) === java.sql.Types.ARRAY) + assert(resultSet.next()) + assert(resultSet.getObject(1).toString == "[\"v1\",\"v2\",\"v3\"]") + } + } + + test("execute statement - select map") { + withJdbcStatement() { statement => + val resultSet = + statement.executeQuery("select map ['k1', 'v1', 'k2', 'v2']") + assert(resultSet.next()) + assert(resultSet.getString(1) == "{k1=v1, k2=v2}") + val metaData = resultSet.getMetaData + assert(metaData.getColumnType(1) === java.sql.Types.JAVA_OBJECT) + } + } + + test("execute statement - select row") { + withJdbcStatement() { statement => + val resultSet = + statement.executeQuery("select (1, '2', true)") + assert(resultSet.next()) + assert( + resultSet.getString(1) == "{INT NOT NULL:1,CHAR(1) NOT NULL:\"2\",BOOLEAN NOT NULL:true}") + val metaData = resultSet.getMetaData + assert(metaData.getColumnType(1) === java.sql.Types.STRUCT) + } + } + + test("execute statement - select binary") { + withJdbcStatement() { statement => + val resultSet = statement.executeQuery("select encode('kyuubi', 'UTF-8')") + assert(resultSet.next()) + assert( + resultSet.getString(1) == "kyuubi") + val metaData = resultSet.getMetaData + assert(metaData.getColumnType(1) === java.sql.Types.BINARY) + } + } + test("execute statement - show functions") { withJdbcStatement() { statement => val resultSet = statement.executeQuery("show functions") @@ -270,11 +398,15 @@ class FlinkOperationSuite extends WithFlinkSQLEngine with HiveJDBCTestHelper { }) } - ignore("execute statement - insert into") { - // TODO: ignore temporally due to KYUUBI #1704 + test("execute statement - insert into") { withMultipleConnectionJdbcStatement()({ statement => statement.executeQuery("create table tbl_a (a int) with ('connector' = 'blackhole')") - statement.executeUpdate("insert into tbl_a select 1") + val resultSet = statement.executeQuery("insert into tbl_a select 1") + val metadata = resultSet.getMetaData + assert(metadata.getColumnName(1) == "default_catalog.default_database.tbl_a") + assert(metadata.getColumnType(1) == java.sql.Types.BIGINT) + assert(resultSet.next()) + assert(resultSet.getLong(1) == -1L) }) } }