Merge pull request #9 from yhuai/genData

Add data generation support for TPC-DS
This commit is contained in:
Yin Huai 2015-09-04 15:48:10 -07:00
commit 88fa2f5af2
2 changed files with 169 additions and 34 deletions

View File

@ -11,10 +11,19 @@ The rest of document will use TPC-DS benchmark as an example. We will add conten
Before running any query, a dataset needs to be setup by creating a `Benchmark` object.
```
import org.apache.spark.sql.parquet.Tables
import com.databricks.spark.sql.perf.tpcds.Tables
// Tables in TPC-DS benchmark used by experiments.
val tables = Tables(sqlContext)
// dsdgenDir is the location of dsdgen tool installed in your machines.
val tables = new Tables(sqlContext, dsdgenDir, scaleFactor)
// Generate data.
tables.genData(location, format, overwrite, partitionTables, useDoubleForDecimal, clusterByPartitionColumns, filterOutNullPartitionValues)
// Create metastore tables in a specified database for your data.
// Once tables are created, the current database will be switched to the specified database.
tables.createExternalTables(location, format, databaseName, overwrite)
// Or, if you want to create temporary tables
tables.createTemporaryTables(location, format)
// Setup TPC-DS experiment
import com.databricks.spark.sql.perf.tpcds.TPCDS
val tpcds = new TPCDS (sqlContext = sqlContext)
```
@ -31,11 +40,9 @@ For every experiment run (i.e.\ every call of `runExperiment`), Spark SQL Perf w
While the experiment is running you can use `experiment.html` to list the status. Once the experiment is complete, the results will be saved to the table sqlPerformance in json.
```
// Get experiments results.
import com.databricks.spark.sql.perf.Results
val results = Results(resultsLocation = <the root location of performance results>, sqlContext = sqlContext)
// Get the DataFrame representing all results stored in the dir specified by resultsLocation.
val allResults = results.allResults
// Use DataFrame API to get results of a single run.
allResults.filter("timestamp = 1429132621024")
// Get all experiments results.
tpcds.createResultsTable()
sqlContext.sql("sqlPerformance")
// Get the result of a particular run by specifying the timestamp of that run.
sqlContext.sql("sqlPerformance").filter("timestamp = 1429132621024")
```

View File

@ -16,28 +16,17 @@
package com.databricks.spark.sql.perf.tpcds
import java.text.SimpleDateFormat
import java.util.Date
import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil
import java.io.File
import scala.sys.process._
import org.apache.spark.Logging
import org.apache.spark.sql.{SaveMode, SQLContext}
import org.apache.spark.sql.Row
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat => NewFileOutputFormat}
import com.databricks.spark.sql.perf._
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.mapreduce.{OutputCommitter, TaskAttemptContext, RecordWriter, Job}
import org.apache.spark.SerializableWritable
import org.apache.spark.sql.{Column, ColumnName, SQLContext}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import parquet.hadoop.ParquetOutputFormat
import parquet.hadoop.util.ContextUtil
class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) extends Serializable {
class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) extends Serializable with Logging {
import sqlContext.implicits._
def sparkContext = sqlContext.sparkContext
@ -47,6 +36,10 @@ class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) extend
val schema = StructType(fields)
val partitions = if (partitionColumns.isEmpty) 1 else 100
def nonPartitioned: Table = {
Table(name, Nil, fields : _*)
}
def df = {
val generatedData = {
sparkContext.parallelize(1 to partitions, partitions).flatMap { i =>
@ -58,10 +51,11 @@ class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) extend
sys.error(s"Could not find dsdgen at $dsdgen or /$dsdgen. Run install")
}
// Note: RNGSEED is the RNG seed used by the data generator. Right now, it is fixed to 100.
val parallel = if (partitions > 1) s"-parallel $partitions -child $i" else ""
val commands = Seq(
"bash", "-c",
s"cd $localToolsDir && ./dsdgen -table $name -filter Y -scale $scaleFactor $parallel")
s"cd $localToolsDir && ./dsdgen -table $name -filter Y -scale $scaleFactor -RNGSEED 100 $parallel")
println(commands)
commands.lines
}
@ -70,11 +64,16 @@ class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) extend
generatedData.setName(s"$name, sf=$scaleFactor, strings")
val rows = generatedData.mapPartitions { iter =>
val currentRow = new GenericMutableRow(schema.fields.size)
iter.map { l =>
(0 until schema.fields.length).foreach(currentRow.setNullAt)
l.split("\\|", -1).zipWithIndex.dropRight(1).foreach { case (f, i) => currentRow(i) = f}
currentRow: Row
val values = l.split("\\|", -1).dropRight(1).map { v =>
if (v.equals("")) {
// If the string value is an empty string, we turn it to a null
null
} else {
v
}
}
Row.fromSeq(values)
}
}
@ -85,14 +84,143 @@ class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) extend
val convertedData = {
val columns = schema.fields.map { f =>
val columnName = new ColumnName(f.name)
columnName.cast(f.dataType).as(f.name)
col(f.name).cast(f.dataType).as(f.name)
}
stringData.select(columns: _*)
}
convertedData
}
def useDoubleForDecimal(): Table = {
val newFields = fields.map { field =>
val newDataType = field.dataType match {
case decimal: DecimalType => DoubleType
case other => other
}
field.copy(dataType = newDataType)
}
Table(name, partitionColumns, newFields:_*)
}
def genData(
location: String,
format: String,
overwrite: Boolean,
clusterByPartitionColumns: Boolean,
filterOutNullPartitionValues: Boolean): Unit = {
val mode = if (overwrite) SaveMode.Overwrite else SaveMode.Ignore
val data = df
val tempTableName = s"${name}_text"
data.registerTempTable(tempTableName)
val writer = if (partitionColumns.nonEmpty) {
if (clusterByPartitionColumns) {
val columnString = data.schema.fields.map { field =>
field.name
}.mkString(",")
val partitionColumnString = partitionColumns.mkString(",")
val predicates = if (filterOutNullPartitionValues) {
partitionColumns.map(col => s"$col IS NOT NULL").mkString("WHERE ", " AND ", "")
} else {
""
}
val query =
s"""
|SELECT
| $columnString
|FROM
| $tempTableName
|$predicates
|DISTRIBUTE BY
| $partitionColumnString
""".stripMargin
val grouped = sqlContext.sql(query)
println(s"Pre-clustering with partitioning columns with query $query.")
logInfo(s"Pre-clustering with partitioning columns with query $query.")
grouped.write
} else {
data.write
}
} else {
// If the table is not partitioned, coalesce the data to a single file.
data.coalesce(1).write
}
writer.format(format).mode(mode)
if (partitionColumns.nonEmpty) {
writer.partitionBy(partitionColumns : _*)
}
println(s"Generating table $name in database to $location with save mode $mode.")
logInfo(s"Generating table $name in database to $location with save mode $mode.")
writer.save(location)
sqlContext.dropTempTable(tempTableName)
}
def createExternalTable(location: String, format: String, databaseName: String, overwrite: Boolean): Unit = {
val qualifiedTableName = databaseName + "." + name
val tableExists = sqlContext.tableNames(databaseName).contains(name)
if (overwrite) {
sqlContext.sql(s"DROP TABLE IF EXISTS $databaseName.$name")
}
if (!tableExists || overwrite) {
println(s"Creating external table $name in database $databaseName using data stored in $location.")
logInfo(s"Creating external table $name in database $databaseName using data stored in $location.")
sqlContext.createExternalTable(qualifiedTableName, location, format)
}
}
def createTemporaryTable(location: String, format: String): Unit = {
println(s"Creating temporary table $name using data stored in $location.")
logInfo(s"Creating temporary table $name using data stored in $location.")
sqlContext.read.format(format).load(location).registerTempTable(name)
}
}
def genData(
location: String,
format: String,
overwrite: Boolean,
partitionTables: Boolean,
useDoubleForDecimal: Boolean,
clusterByPartitionColumns: Boolean,
filterOutNullPartitionValues: Boolean): Unit = {
val tablesToBeGenerated = if (partitionTables) {
tables
} else {
tables.map(_.nonPartitioned)
}
val withSpecifiedDataType = if (useDoubleForDecimal) {
tablesToBeGenerated.map(_.useDoubleForDecimal())
} else {
tablesToBeGenerated
}
withSpecifiedDataType.foreach { table =>
val tableLocation = s"$location/${table.name}"
table.genData(tableLocation, format, overwrite, clusterByPartitionColumns, filterOutNullPartitionValues)
}
}
def createExternalTables(location: String, format: String, databaseName: String, overwrite: Boolean): Unit = {
sqlContext.sql(s"CREATE DATABASE IF NOT EXISTS $databaseName")
tables.foreach { table =>
val tableLocation = s"$location/${table.name}"
table.createExternalTable(tableLocation, format, databaseName, overwrite)
}
sqlContext.sql(s"USE $databaseName")
println(s"The current database has been set to $databaseName.")
logInfo(s"The current database has been set to $databaseName.")
}
def createTemporaryTables(location: String, format: String): Unit = {
tables.foreach { table =>
val tableLocation = s"$location/${table.name}"
table.createTemporaryTable(tableLocation, format)
}
}
val tables = Seq(