Allow users to use double instead of decimal for generated tables.
This commit is contained in:
parent
88aadb45a4
commit
58188c6711
@ -14,9 +14,9 @@ Before running any query, a dataset needs to be setup by creating a `Benchmark`
|
||||
import com.databricks.spark.sql.perf.tpcds.Tables
|
||||
// Tables in TPC-DS benchmark used by experiments.
|
||||
// dsdgenDir is the location of dsdgen tool installed in your machines.
|
||||
val tables = Tables(sqlContext, dsdgenDir, scaleFactor)
|
||||
val tables = new Tables(sqlContext, dsdgenDir, scaleFactor)
|
||||
// Generate data.
|
||||
tables.genData(location, format, overwrite, partitionTables)
|
||||
tables.genData(location, format, overwrite, partitionTables, useDoubleForDecimal)
|
||||
// Create metastore tables in a specified database for your data.
|
||||
// Once tables are created, the current database will be switched to the specified database.
|
||||
tables.createExternalTables(location, format, databaseName, overwrite)
|
||||
|
||||
@ -24,7 +24,7 @@ import org.apache.spark.Logging
|
||||
import org.apache.spark.sql.{SaveMode, SQLContext}
|
||||
import org.apache.spark.sql.Row
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.types.{StringType, StructField, StructType}
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) extends Serializable with Logging {
|
||||
import sqlContext.implicits._
|
||||
@ -91,6 +91,18 @@ class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) extend
|
||||
convertedData
|
||||
}
|
||||
|
||||
def useDoubleForDecimal(): Table = {
|
||||
val newFields = fields.map { field =>
|
||||
val newDataType = field.dataType match {
|
||||
case decimal: DecimalType => DoubleType
|
||||
case other => other
|
||||
}
|
||||
field.copy(dataType = newDataType)
|
||||
}
|
||||
|
||||
Table(name, partitionColumns, newFields:_*)
|
||||
}
|
||||
|
||||
def genData(location: String, format: String, overwrite: Boolean): Unit = {
|
||||
val mode = if (overwrite) SaveMode.Overwrite else SaveMode.Ignore
|
||||
|
||||
@ -129,13 +141,25 @@ class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) extend
|
||||
}
|
||||
}
|
||||
|
||||
def genData(location: String, format: String, overwrite: Boolean, partitionTables: Boolean): Unit = {
|
||||
def genData(
|
||||
location: String,
|
||||
format: String,
|
||||
overwrite: Boolean,
|
||||
partitionTables: Boolean,
|
||||
useDoubleForDecimal: Boolean): Unit = {
|
||||
val tablesToBeGenerated = if (partitionTables) {
|
||||
tables
|
||||
} else {
|
||||
tables.map(_.nonPartitioned)
|
||||
}
|
||||
tablesToBeGenerated.foreach { table =>
|
||||
|
||||
val withSpecifiedDataType = if (useDoubleForDecimal) {
|
||||
tablesToBeGenerated.map(_.useDoubleForDecimal())
|
||||
} else {
|
||||
tablesToBeGenerated
|
||||
}
|
||||
|
||||
withSpecifiedDataType.foreach { table =>
|
||||
val tableLocation = s"$location/${table.name}"
|
||||
table.genData(tableLocation, format, overwrite)
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user