diff --git a/src/main/scala/com/databricks/spark/sql/perf/tpcds/Tables.scala b/src/main/scala/com/databricks/spark/sql/perf/tpcds/Tables.scala index aee846d..a2646f1 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/tpcds/Tables.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/tpcds/Tables.scala @@ -17,28 +17,16 @@ package com.databricks.spark.sql.perf.tpcds import java.io.File -import java.text.SimpleDateFormat -import java.util.Date - -import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import scala.sys.process._ - -import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat => NewFileOutputFormat} - -import com.databricks.spark.sql.perf._ -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.hadoop.mapreduce.{OutputCommitter, TaskAttemptContext, RecordWriter, Job} -import org.apache.spark.SerializableWritable -import org.apache.spark.sql.{SaveMode, Column, ColumnName, SQLContext} -import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.Logging +import org.apache.spark.sql.{SaveMode, SQLContext} +import org.apache.spark.sql.Row +import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{StringType, StructField, StructType} -import parquet.hadoop.ParquetOutputFormat -import parquet.hadoop.util.ContextUtil -class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) extends Serializable { +class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) extends Serializable with Logging { import sqlContext.implicits._ def sparkContext = sqlContext.sparkContext @@ -71,11 +59,16 @@ class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) extend generatedData.setName(s"$name, sf=$scaleFactor, strings") val rows = generatedData.mapPartitions { iter => - val currentRow = new GenericMutableRow(schema.fields.size) iter.map { l => - (0 until schema.fields.length).foreach(currentRow.setNullAt) - l.split("\\|", -1).zipWithIndex.dropRight(1).foreach { case (f, i) => currentRow(i) = f} - currentRow: Row + val values = l.split("\\|", -1).dropRight(1).map { v => + if (v.equals("")) { + // If the string value is an empty string, we turn it to a null + null + } else { + v + } + } + Row.fromSeq(values) } } @@ -86,8 +79,7 @@ class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) extend val convertedData = { val columns = schema.fields.map { f => - val columnName = new ColumnName(f.name) - columnName.cast(f.dataType).as(f.name) + col(f.name).cast(f.dataType).as(f.name) } stringData.select(columns: _*) } @@ -104,29 +96,36 @@ class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) extend } else { withFormatAndMode.partitionBy(partitionColumns : _*) } - + logInfo(s"Generating table $name in database to $location with save mode $mode.") withPartitionColumns.save(location) } def createExternalTables(location: String, format: String, databaseName: String, overwrite: Boolean): Unit = { val qualifiedTableName = databaseName + "." + name + val tableExists = sqlContext.tableNames(databaseName).contains(name) if (overwrite) { sqlContext.sql(s"DROP TABLE IF EXISTS $databaseName.$name") } - sqlContext.createExternalTable(qualifiedTableName, location, format) + if (!tableExists || overwrite) { + logInfo(s"Creating external table $name in database $databaseName.") + sqlContext.createExternalTable(qualifiedTableName, location, format) + } } } def genData(location: String, format: String, overwrite: Boolean): Unit = { tables.foreach { table => - val tableLocation = location + File.separator + format + File.separator + table.name + val tableLocation = + location + File.separator + format + File.separator + "sf" + scaleFactor + File.separator + table.name table.genData(tableLocation, format, overwrite) } } def createExternalTables(location: String, format: String, databaseName: String, overwrite: Boolean): Unit = { + sqlContext.sql(s"CREATE DATABASE IF NOT EXISTS $databaseName") tables.foreach { table => - val tableLocation = location + File.separator + format + File.separator + table.name + val tableLocation = + location + File.separator + format + File.separator + "sf" + scaleFactor + File.separator + table.name table.createExternalTables(tableLocation, format, databaseName, overwrite) } }