From 9c8dfda8e2ff761987c6d592c4e931ff6bea379e Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Fri, 23 Mar 2018 15:06:22 +0800 Subject: [PATCH] add ut RowSetSuite --- .../kyuubi/operation/KyuubiOperation.scala | 8 +- .../scala/yaooqinn/kyuubi/schema/RowSet.scala | 4 + .../yaooqinn/kyuubi/schema/RowSetSuite.scala | 120 ++++++++++++++++++ 3 files changed, 129 insertions(+), 3 deletions(-) create mode 100644 src/test/scala/yaooqinn/kyuubi/schema/RowSetSuite.scala diff --git a/src/main/scala/yaooqinn/kyuubi/operation/KyuubiOperation.scala b/src/main/scala/yaooqinn/kyuubi/operation/KyuubiOperation.scala index 51867eb3d..f19319000 100644 --- a/src/main/scala/yaooqinn/kyuubi/operation/KyuubiOperation.scala +++ b/src/main/scala/yaooqinn/kyuubi/operation/KyuubiOperation.scala @@ -220,9 +220,11 @@ class KyuubiOperation(session: KyuubiSession, statement: String) extends Logging validateDefaultFetchOrientation(order) assertState(FINISHED) setHasResultSet(true) - val taken = iter.take(maxRowsL.toInt) - val remained = iter.drop(maxRowsL.toInt) - iter = if (order == FetchOrientation.FETCH_FIRST) taken else remained + val taken = if (order == FetchOrientation.FETCH_FIRST) { + result.toLocalIterator().asScala.take(maxRowsL.toInt) + } else { + iter.take(maxRowsL.toInt) + } RowSet(getResultSetSchema, taken) } diff --git a/src/main/scala/yaooqinn/kyuubi/schema/RowSet.scala b/src/main/scala/yaooqinn/kyuubi/schema/RowSet.scala index b9e9a588f..88b4174a8 100644 --- a/src/main/scala/yaooqinn/kyuubi/schema/RowSet.scala +++ b/src/main/scala/yaooqinn/kyuubi/schema/RowSet.scala @@ -23,6 +23,10 @@ import org.apache.hive.service.cli.thrift._ import org.apache.spark.sql.{Row, SparkSQLUtils} import org.apache.spark.sql.types._ +/** + * A result set of Spark's [[Row]]s with its [[StructType]] as its schema, with the ability of + * transform to [[TRowSet]]. + */ case class RowSet(types: StructType, rows: Iterator[Row]) { def toTRowSet: TRowSet = new TRowSet(0, toTRows(rows).asJava) diff --git a/src/test/scala/yaooqinn/kyuubi/schema/RowSetSuite.scala b/src/test/scala/yaooqinn/kyuubi/schema/RowSetSuite.scala new file mode 100644 index 000000000..eff148511 --- /dev/null +++ b/src/test/scala/yaooqinn/kyuubi/schema/RowSetSuite.scala @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package yaooqinn.kyuubi.schema + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.Row +import org.apache.spark.sql.types.StructType + +class RowSetSuite extends SparkFunSuite { + + test("1") { + val maxRows: Int = 5 + + val schema = new StructType().add("a", "int").add("b", "string") + + val rows = Seq( + Row(1, "11"), + Row(2, "22"), + Row(3, "33"), + Row(4, "44"), + Row(5, "55"), + Row(6, "66"), + Row(7, "77"), + Row(8, "88"), + Row(9, "99"), + Row(10, "000"), + Row(11, "111"), + Row(12, "222"), + Row(13, "333"), + Row(14, "444"), + Row(15, "555"), + Row(16, "666")) + + // fetch next + val rowIterator = rows.iterator + var taken = rowIterator.take(maxRows) + var tRowSet = RowSet(schema, taken).toTRowSet + assert(tRowSet.getRowsSize === 5) + assert(tRowSet.getRows.get(0).getColVals.get(1).getStringVal.getValue === "11") + assert(tRowSet.getRows.get(1).getColVals.get(1).getStringVal.getValue === "22") + assert(tRowSet.getRows.get(2).getColVals.get(1).getStringVal.getValue === "33") + assert(tRowSet.getRows.get(3).getColVals.get(1).getStringVal.getValue === "44") + assert(tRowSet.getRows.get(4).getColVals.get(1).getStringVal.getValue === "55") + + taken = rowIterator.take(maxRows) + tRowSet = RowSet(schema, taken).toTRowSet + assert(tRowSet.getRowsSize === 5) + assert(tRowSet.getRows.get(0).getColVals.get(1).getStringVal.getValue === "66") + assert(tRowSet.getRows.get(1).getColVals.get(1).getStringVal.getValue === "77") + assert(tRowSet.getRows.get(2).getColVals.get(1).getStringVal.getValue === "88") + assert(tRowSet.getRows.get(3).getColVals.get(1).getStringVal.getValue === "99") + assert(tRowSet.getRows.get(4).getColVals.get(1).getStringVal.getValue === "000") + + taken = rowIterator.take(maxRows) + tRowSet = RowSet(schema, taken).toTRowSet + assert(tRowSet.getRowsSize === 5) + assert(tRowSet.getRows.get(0).getColVals.get(1).getStringVal.getValue === "111") + assert(tRowSet.getRows.get(1).getColVals.get(1).getStringVal.getValue === "222") + assert(tRowSet.getRows.get(2).getColVals.get(1).getStringVal.getValue === "333") + assert(tRowSet.getRows.get(3).getColVals.get(1).getStringVal.getValue === "444") + assert(tRowSet.getRows.get(4).getColVals.get(1).getStringVal.getValue === "555") + + taken = rowIterator.take(maxRows) + tRowSet = RowSet(schema, taken).toTRowSet + assert(tRowSet.getRowsSize === 1) + assert(tRowSet.getRows.get(0).getColVals.get(1).getStringVal.getValue === "666") + intercept[IndexOutOfBoundsException](tRowSet.getRows.get(1)) + + assert(rowIterator.isEmpty) + + // fetch first + val rowIterator2 = rows.iterator + + val (itr1, itr2) = rowIterator2.take(maxRows).duplicate + val resultList = itr2.toList + + taken = itr1 + tRowSet = RowSet(schema, taken).toTRowSet + assert(tRowSet.getRowsSize === 5) + assert(tRowSet.getRows.get(0).getColVals.get(1).getStringVal.getValue === "11") + assert(tRowSet.getRows.get(1).getColVals.get(1).getStringVal.getValue === "22") + assert(tRowSet.getRows.get(2).getColVals.get(1).getStringVal.getValue === "33") + assert(tRowSet.getRows.get(3).getColVals.get(1).getStringVal.getValue === "44") + assert(tRowSet.getRows.get(4).getColVals.get(1).getStringVal.getValue === "55") + + taken = resultList.iterator + tRowSet = RowSet(schema, taken).toTRowSet + assert(tRowSet.getRowsSize === 5) + assert(tRowSet.getRows.get(0).getColVals.get(1).getStringVal.getValue === "11") + assert(tRowSet.getRows.get(1).getColVals.get(1).getStringVal.getValue === "22") + assert(tRowSet.getRows.get(2).getColVals.get(1).getStringVal.getValue === "33") + assert(tRowSet.getRows.get(3).getColVals.get(1).getStringVal.getValue === "44") + assert(tRowSet.getRows.get(4).getColVals.get(1).getStringVal.getValue === "55") + + taken = resultList.iterator + tRowSet = RowSet(schema, taken).toTRowSet + assert(tRowSet.getRowsSize === 5) + assert(tRowSet.getRows.get(0).getColVals.get(1).getStringVal.getValue === "11") + assert(tRowSet.getRows.get(1).getColVals.get(1).getStringVal.getValue === "22") + assert(tRowSet.getRows.get(2).getColVals.get(1).getStringVal.getValue === "33") + assert(tRowSet.getRows.get(3).getColVals.get(1).getStringVal.getValue === "44") + assert(tRowSet.getRows.get(4).getColVals.get(1).getStringVal.getValue === "55") + } + +}