add ut RowSetSuite

This commit is contained in:
Kent Yao 2018-03-23 15:06:22 +08:00
parent e4a74b7989
commit 9c8dfda8e2
3 changed files with 129 additions and 3 deletions

View File

@ -220,9 +220,11 @@ class KyuubiOperation(session: KyuubiSession, statement: String) extends Logging
validateDefaultFetchOrientation(order)
assertState(FINISHED)
setHasResultSet(true)
val taken = iter.take(maxRowsL.toInt)
val remained = iter.drop(maxRowsL.toInt)
iter = if (order == FetchOrientation.FETCH_FIRST) taken else remained
val taken = if (order == FetchOrientation.FETCH_FIRST) {
result.toLocalIterator().asScala.take(maxRowsL.toInt)
} else {
iter.take(maxRowsL.toInt)
}
RowSet(getResultSetSchema, taken)
}

View File

@ -23,6 +23,10 @@ import org.apache.hive.service.cli.thrift._
import org.apache.spark.sql.{Row, SparkSQLUtils}
import org.apache.spark.sql.types._
/**
* A result set of Spark's [[Row]]s with its [[StructType]] as its schema, with the ability of
* transform to [[TRowSet]].
*/
case class RowSet(types: StructType, rows: Iterator[Row]) {
def toTRowSet: TRowSet = new TRowSet(0, toTRows(rows).asJava)

View File

@ -0,0 +1,120 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package yaooqinn.kyuubi.schema
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.StructType
class RowSetSuite extends SparkFunSuite {
test("1") {
val maxRows: Int = 5
val schema = new StructType().add("a", "int").add("b", "string")
val rows = Seq(
Row(1, "11"),
Row(2, "22"),
Row(3, "33"),
Row(4, "44"),
Row(5, "55"),
Row(6, "66"),
Row(7, "77"),
Row(8, "88"),
Row(9, "99"),
Row(10, "000"),
Row(11, "111"),
Row(12, "222"),
Row(13, "333"),
Row(14, "444"),
Row(15, "555"),
Row(16, "666"))
// fetch next
val rowIterator = rows.iterator
var taken = rowIterator.take(maxRows)
var tRowSet = RowSet(schema, taken).toTRowSet
assert(tRowSet.getRowsSize === 5)
assert(tRowSet.getRows.get(0).getColVals.get(1).getStringVal.getValue === "11")
assert(tRowSet.getRows.get(1).getColVals.get(1).getStringVal.getValue === "22")
assert(tRowSet.getRows.get(2).getColVals.get(1).getStringVal.getValue === "33")
assert(tRowSet.getRows.get(3).getColVals.get(1).getStringVal.getValue === "44")
assert(tRowSet.getRows.get(4).getColVals.get(1).getStringVal.getValue === "55")
taken = rowIterator.take(maxRows)
tRowSet = RowSet(schema, taken).toTRowSet
assert(tRowSet.getRowsSize === 5)
assert(tRowSet.getRows.get(0).getColVals.get(1).getStringVal.getValue === "66")
assert(tRowSet.getRows.get(1).getColVals.get(1).getStringVal.getValue === "77")
assert(tRowSet.getRows.get(2).getColVals.get(1).getStringVal.getValue === "88")
assert(tRowSet.getRows.get(3).getColVals.get(1).getStringVal.getValue === "99")
assert(tRowSet.getRows.get(4).getColVals.get(1).getStringVal.getValue === "000")
taken = rowIterator.take(maxRows)
tRowSet = RowSet(schema, taken).toTRowSet
assert(tRowSet.getRowsSize === 5)
assert(tRowSet.getRows.get(0).getColVals.get(1).getStringVal.getValue === "111")
assert(tRowSet.getRows.get(1).getColVals.get(1).getStringVal.getValue === "222")
assert(tRowSet.getRows.get(2).getColVals.get(1).getStringVal.getValue === "333")
assert(tRowSet.getRows.get(3).getColVals.get(1).getStringVal.getValue === "444")
assert(tRowSet.getRows.get(4).getColVals.get(1).getStringVal.getValue === "555")
taken = rowIterator.take(maxRows)
tRowSet = RowSet(schema, taken).toTRowSet
assert(tRowSet.getRowsSize === 1)
assert(tRowSet.getRows.get(0).getColVals.get(1).getStringVal.getValue === "666")
intercept[IndexOutOfBoundsException](tRowSet.getRows.get(1))
assert(rowIterator.isEmpty)
// fetch first
val rowIterator2 = rows.iterator
val (itr1, itr2) = rowIterator2.take(maxRows).duplicate
val resultList = itr2.toList
taken = itr1
tRowSet = RowSet(schema, taken).toTRowSet
assert(tRowSet.getRowsSize === 5)
assert(tRowSet.getRows.get(0).getColVals.get(1).getStringVal.getValue === "11")
assert(tRowSet.getRows.get(1).getColVals.get(1).getStringVal.getValue === "22")
assert(tRowSet.getRows.get(2).getColVals.get(1).getStringVal.getValue === "33")
assert(tRowSet.getRows.get(3).getColVals.get(1).getStringVal.getValue === "44")
assert(tRowSet.getRows.get(4).getColVals.get(1).getStringVal.getValue === "55")
taken = resultList.iterator
tRowSet = RowSet(schema, taken).toTRowSet
assert(tRowSet.getRowsSize === 5)
assert(tRowSet.getRows.get(0).getColVals.get(1).getStringVal.getValue === "11")
assert(tRowSet.getRows.get(1).getColVals.get(1).getStringVal.getValue === "22")
assert(tRowSet.getRows.get(2).getColVals.get(1).getStringVal.getValue === "33")
assert(tRowSet.getRows.get(3).getColVals.get(1).getStringVal.getValue === "44")
assert(tRowSet.getRows.get(4).getColVals.get(1).getStringVal.getValue === "55")
taken = resultList.iterator
tRowSet = RowSet(schema, taken).toTRowSet
assert(tRowSet.getRowsSize === 5)
assert(tRowSet.getRows.get(0).getColVals.get(1).getStringVal.getValue === "11")
assert(tRowSet.getRows.get(1).getColVals.get(1).getStringVal.getValue === "22")
assert(tRowSet.getRows.get(2).getColVals.get(1).getStringVal.getValue === "33")
assert(tRowSet.getRows.get(3).getColVals.get(1).getStringVal.getValue === "44")
assert(tRowSet.getRows.get(4).getColVals.get(1).getStringVal.getValue === "55")
}
}