diff --git a/src/main/scala/com/databricks/spark/sql/perf/tpcds/Tables.scala b/src/main/scala/com/databricks/spark/sql/perf/tpcds/Tables.scala index ee470f8..cdbba05 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/tpcds/Tables.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/tpcds/Tables.scala @@ -40,7 +40,11 @@ class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) extend Table(name, Nil, fields : _*) } - def df = { + /** + * If convertToSchema is true, the data from generator will be parsed into columns and + * converted to `schema`. Otherwise, it just outputs the raw data (as a single STRING column). + */ + def df(convertToSchema: Boolean) = { val generatedData = { sparkContext.parallelize(1 to partitions, partitions).flatMap { i => val localToolsDir = if (new java.io.File(dsdgen).exists) { @@ -65,31 +69,39 @@ class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) extend val rows = generatedData.mapPartitions { iter => iter.map { l => - val values = l.split("\\|", -1).dropRight(1).map { v => - if (v.equals("")) { - // If the string value is an empty string, we turn it to a null - null - } else { - v + if (convertToSchema) { + val values = l.split("\\|", -1).dropRight(1).map { v => + if (v.equals("")) { + // If the string value is an empty string, we turn it to a null + null + } else { + v + } } + Row.fromSeq(values) + } else { + Row.fromSeq(Seq(l)) } - Row.fromSeq(values) } } - val stringData = - sqlContext.createDataFrame( - rows, - StructType(schema.fields.map(f => StructField(f.name, StringType)))) + if (convertToSchema) { + val stringData = + sqlContext.createDataFrame( + rows, + StructType(schema.fields.map(f => StructField(f.name, StringType)))) - val convertedData = { - val columns = schema.fields.map { f => - col(f.name).cast(f.dataType).as(f.name) + val convertedData = { + val columns = schema.fields.map { f => + col(f.name).cast(f.dataType).as(f.name) + } + stringData.select(columns: _*) } - stringData.select(columns: _*) - } - convertedData + convertedData + } else { + sqlContext.createDataFrame(rows, StructType(Seq(StructField("value", StringType)))) + } } def useDoubleForDecimal(): Table = { @@ -112,7 +124,7 @@ class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) extend filterOutNullPartitionValues: Boolean): Unit = { val mode = if (overwrite) SaveMode.Overwrite else SaveMode.Ignore - val data = df + val data = df(format != "text") val tempTableName = s"${name}_text" data.registerTempTable(tempTableName)