diff --git a/README.md b/README.md index 4f7d887..a2f41e2 100644 --- a/README.md +++ b/README.md @@ -14,9 +14,9 @@ Before running any query, a dataset needs to be setup by creating a `Benchmark` import com.databricks.spark.sql.perf.tpcds.Tables // Tables in TPC-DS benchmark used by experiments. // dsdgenDir is the location of dsdgen tool installed in your machines. -val tables = Tables(sqlContext, dsdgenDir, scaleFactor) +val tables = new Tables(sqlContext, dsdgenDir, scaleFactor) // Generate data. -tables.genData(location, format, overwrite, partitionTables) +tables.genData(location, format, overwrite, partitionTables, useDoubleForDecimal) // Create metastore tables in a specified database for your data. // Once tables are created, the current database will be switched to the specified database. tables.createExternalTables(location, format, databaseName, overwrite) diff --git a/src/main/scala/com/databricks/spark/sql/perf/tpcds/Tables.scala b/src/main/scala/com/databricks/spark/sql/perf/tpcds/Tables.scala index 7bdb978..51b687b 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/tpcds/Tables.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/tpcds/Tables.scala @@ -24,7 +24,7 @@ import org.apache.spark.Logging import org.apache.spark.sql.{SaveMode, SQLContext} import org.apache.spark.sql.Row import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.{StringType, StructField, StructType} +import org.apache.spark.sql.types._ class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) extends Serializable with Logging { import sqlContext.implicits._ @@ -91,6 +91,18 @@ class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) extend convertedData } + def useDoubleForDecimal(): Table = { + val newFields = fields.map { field => + val newDataType = field.dataType match { + case decimal: DecimalType => DoubleType + case other => other + } + field.copy(dataType = newDataType) + } + + Table(name, partitionColumns, newFields:_*) + } + def genData(location: String, format: String, overwrite: Boolean): Unit = { val mode = if (overwrite) SaveMode.Overwrite else SaveMode.Ignore @@ -129,13 +141,25 @@ class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) extend } } - def genData(location: String, format: String, overwrite: Boolean, partitionTables: Boolean): Unit = { + def genData( + location: String, + format: String, + overwrite: Boolean, + partitionTables: Boolean, + useDoubleForDecimal: Boolean): Unit = { val tablesToBeGenerated = if (partitionTables) { tables } else { tables.map(_.nonPartitioned) } - tablesToBeGenerated.foreach { table => + + val withSpecifiedDataType = if (useDoubleForDecimal) { + tablesToBeGenerated.map(_.useDoubleForDecimal()) + } else { + tablesToBeGenerated + } + + withSpecifiedDataType.foreach { table => val tableLocation = s"$location/${table.name}" table.genData(tableLocation, format, overwrite) }