Allow users to use double instead of decimal for generated tables.

This commit is contained in:
Yin Huai 2015-08-24 12:17:25 -07:00
parent 88aadb45a4
commit 58188c6711
2 changed files with 29 additions and 5 deletions

View File

@ -14,9 +14,9 @@ Before running any query, a dataset needs to be setup by creating a `Benchmark`
import com.databricks.spark.sql.perf.tpcds.Tables
// Tables in TPC-DS benchmark used by experiments.
// dsdgenDir is the location of dsdgen tool installed in your machines.
val tables = Tables(sqlContext, dsdgenDir, scaleFactor)
val tables = new Tables(sqlContext, dsdgenDir, scaleFactor)
// Generate data.
tables.genData(location, format, overwrite, partitionTables)
tables.genData(location, format, overwrite, partitionTables, useDoubleForDecimal)
// Create metastore tables in a specified database for your data.
// Once tables are created, the current database will be switched to the specified database.
tables.createExternalTables(location, format, databaseName, overwrite)

View File

@ -24,7 +24,7 @@ import org.apache.spark.Logging
import org.apache.spark.sql.{SaveMode, SQLContext}
import org.apache.spark.sql.Row
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.apache.spark.sql.types._
class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) extends Serializable with Logging {
import sqlContext.implicits._
@ -91,6 +91,18 @@ class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) extend
convertedData
}
def useDoubleForDecimal(): Table = {
val newFields = fields.map { field =>
val newDataType = field.dataType match {
case decimal: DecimalType => DoubleType
case other => other
}
field.copy(dataType = newDataType)
}
Table(name, partitionColumns, newFields:_*)
}
def genData(location: String, format: String, overwrite: Boolean): Unit = {
val mode = if (overwrite) SaveMode.Overwrite else SaveMode.Ignore
@ -129,13 +141,25 @@ class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) extend
}
}
def genData(location: String, format: String, overwrite: Boolean, partitionTables: Boolean): Unit = {
def genData(
location: String,
format: String,
overwrite: Boolean,
partitionTables: Boolean,
useDoubleForDecimal: Boolean): Unit = {
val tablesToBeGenerated = if (partitionTables) {
tables
} else {
tables.map(_.nonPartitioned)
}
tablesToBeGenerated.foreach { table =>
val withSpecifiedDataType = if (useDoubleForDecimal) {
tablesToBeGenerated.map(_.useDoubleForDecimal())
} else {
tablesToBeGenerated
}
withSpecifiedDataType.foreach { table =>
val tableLocation = s"$location/${table.name}"
table.genData(tableLocation, format, overwrite)
}