diff --git a/README.md b/README.md index 94308f2..520049c 100644 --- a/README.md +++ b/README.md @@ -11,10 +11,19 @@ The rest of document will use TPC-DS benchmark as an example. We will add conten Before running any query, a dataset needs to be setup by creating a `Benchmark` object. ``` -import org.apache.spark.sql.parquet.Tables +import com.databricks.spark.sql.perf.tpcds.Tables // Tables in TPC-DS benchmark used by experiments. -val tables = Tables(sqlContext) +// dsdgenDir is the location of dsdgen tool installed in your machines. +val tables = new Tables(sqlContext, dsdgenDir, scaleFactor) +// Generate data. +tables.genData(location, format, overwrite, partitionTables, useDoubleForDecimal, clusterByPartitionColumns, filterOutNullPartitionValues) +// Create metastore tables in a specified database for your data. +// Once tables are created, the current database will be switched to the specified database. +tables.createExternalTables(location, format, databaseName, overwrite) +// Or, if you want to create temporary tables +tables.createTemporaryTables(location, format) // Setup TPC-DS experiment +import com.databricks.spark.sql.perf.tpcds.TPCDS val tpcds = new TPCDS (sqlContext = sqlContext) ``` @@ -31,11 +40,9 @@ For every experiment run (i.e.\ every call of `runExperiment`), Spark SQL Perf w While the experiment is running you can use `experiment.html` to list the status. Once the experiment is complete, the results will be saved to the table sqlPerformance in json. ``` -// Get experiments results. -import com.databricks.spark.sql.perf.Results -val results = Results(resultsLocation = , sqlContext = sqlContext) -// Get the DataFrame representing all results stored in the dir specified by resultsLocation. -val allResults = results.allResults -// Use DataFrame API to get results of a single run. -allResults.filter("timestamp = 1429132621024") +// Get all experiments results. +tpcds.createResultsTable() +sqlContext.sql("sqlPerformance") +// Get the result of a particular run by specifying the timestamp of that run. +sqlContext.sql("sqlPerformance").filter("timestamp = 1429132621024") ``` \ No newline at end of file diff --git a/src/main/scala/com/databricks/spark/sql/perf/tpcds/Tables.scala b/src/main/scala/com/databricks/spark/sql/perf/tpcds/Tables.scala index 011221e..ee470f8 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/tpcds/Tables.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/tpcds/Tables.scala @@ -16,28 +16,17 @@ package com.databricks.spark.sql.perf.tpcds -import java.text.SimpleDateFormat -import java.util.Date - -import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil +import java.io.File import scala.sys.process._ +import org.apache.spark.Logging +import org.apache.spark.sql.{SaveMode, SQLContext} +import org.apache.spark.sql.Row +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ -import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat => NewFileOutputFormat} - -import com.databricks.spark.sql.perf._ -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.hadoop.mapreduce.{OutputCommitter, TaskAttemptContext, RecordWriter, Job} -import org.apache.spark.SerializableWritable -import org.apache.spark.sql.{Column, ColumnName, SQLContext} -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.{StringType, StructField, StructType} -import parquet.hadoop.ParquetOutputFormat -import parquet.hadoop.util.ContextUtil - -class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) extends Serializable { +class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) extends Serializable with Logging { import sqlContext.implicits._ def sparkContext = sqlContext.sparkContext @@ -47,6 +36,10 @@ class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) extend val schema = StructType(fields) val partitions = if (partitionColumns.isEmpty) 1 else 100 + def nonPartitioned: Table = { + Table(name, Nil, fields : _*) + } + def df = { val generatedData = { sparkContext.parallelize(1 to partitions, partitions).flatMap { i => @@ -58,10 +51,11 @@ class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) extend sys.error(s"Could not find dsdgen at $dsdgen or /$dsdgen. Run install") } + // Note: RNGSEED is the RNG seed used by the data generator. Right now, it is fixed to 100. val parallel = if (partitions > 1) s"-parallel $partitions -child $i" else "" val commands = Seq( "bash", "-c", - s"cd $localToolsDir && ./dsdgen -table $name -filter Y -scale $scaleFactor $parallel") + s"cd $localToolsDir && ./dsdgen -table $name -filter Y -scale $scaleFactor -RNGSEED 100 $parallel") println(commands) commands.lines } @@ -70,11 +64,16 @@ class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) extend generatedData.setName(s"$name, sf=$scaleFactor, strings") val rows = generatedData.mapPartitions { iter => - val currentRow = new GenericMutableRow(schema.fields.size) iter.map { l => - (0 until schema.fields.length).foreach(currentRow.setNullAt) - l.split("\\|", -1).zipWithIndex.dropRight(1).foreach { case (f, i) => currentRow(i) = f} - currentRow: Row + val values = l.split("\\|", -1).dropRight(1).map { v => + if (v.equals("")) { + // If the string value is an empty string, we turn it to a null + null + } else { + v + } + } + Row.fromSeq(values) } } @@ -85,14 +84,143 @@ class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) extend val convertedData = { val columns = schema.fields.map { f => - val columnName = new ColumnName(f.name) - columnName.cast(f.dataType).as(f.name) + col(f.name).cast(f.dataType).as(f.name) } stringData.select(columns: _*) } convertedData } + + def useDoubleForDecimal(): Table = { + val newFields = fields.map { field => + val newDataType = field.dataType match { + case decimal: DecimalType => DoubleType + case other => other + } + field.copy(dataType = newDataType) + } + + Table(name, partitionColumns, newFields:_*) + } + + def genData( + location: String, + format: String, + overwrite: Boolean, + clusterByPartitionColumns: Boolean, + filterOutNullPartitionValues: Boolean): Unit = { + val mode = if (overwrite) SaveMode.Overwrite else SaveMode.Ignore + + val data = df + val tempTableName = s"${name}_text" + data.registerTempTable(tempTableName) + + val writer = if (partitionColumns.nonEmpty) { + if (clusterByPartitionColumns) { + val columnString = data.schema.fields.map { field => + field.name + }.mkString(",") + val partitionColumnString = partitionColumns.mkString(",") + val predicates = if (filterOutNullPartitionValues) { + partitionColumns.map(col => s"$col IS NOT NULL").mkString("WHERE ", " AND ", "") + } else { + "" + } + + val query = + s""" + |SELECT + | $columnString + |FROM + | $tempTableName + |$predicates + |DISTRIBUTE BY + | $partitionColumnString + """.stripMargin + val grouped = sqlContext.sql(query) + println(s"Pre-clustering with partitioning columns with query $query.") + logInfo(s"Pre-clustering with partitioning columns with query $query.") + grouped.write + } else { + data.write + } + } else { + // If the table is not partitioned, coalesce the data to a single file. + data.coalesce(1).write + } + writer.format(format).mode(mode) + if (partitionColumns.nonEmpty) { + writer.partitionBy(partitionColumns : _*) + } + println(s"Generating table $name in database to $location with save mode $mode.") + logInfo(s"Generating table $name in database to $location with save mode $mode.") + writer.save(location) + sqlContext.dropTempTable(tempTableName) + } + + def createExternalTable(location: String, format: String, databaseName: String, overwrite: Boolean): Unit = { + val qualifiedTableName = databaseName + "." + name + val tableExists = sqlContext.tableNames(databaseName).contains(name) + if (overwrite) { + sqlContext.sql(s"DROP TABLE IF EXISTS $databaseName.$name") + } + if (!tableExists || overwrite) { + println(s"Creating external table $name in database $databaseName using data stored in $location.") + logInfo(s"Creating external table $name in database $databaseName using data stored in $location.") + sqlContext.createExternalTable(qualifiedTableName, location, format) + } + } + + def createTemporaryTable(location: String, format: String): Unit = { + println(s"Creating temporary table $name using data stored in $location.") + logInfo(s"Creating temporary table $name using data stored in $location.") + sqlContext.read.format(format).load(location).registerTempTable(name) + } + } + + def genData( + location: String, + format: String, + overwrite: Boolean, + partitionTables: Boolean, + useDoubleForDecimal: Boolean, + clusterByPartitionColumns: Boolean, + filterOutNullPartitionValues: Boolean): Unit = { + val tablesToBeGenerated = if (partitionTables) { + tables + } else { + tables.map(_.nonPartitioned) + } + + val withSpecifiedDataType = if (useDoubleForDecimal) { + tablesToBeGenerated.map(_.useDoubleForDecimal()) + } else { + tablesToBeGenerated + } + + withSpecifiedDataType.foreach { table => + val tableLocation = s"$location/${table.name}" + table.genData(tableLocation, format, overwrite, clusterByPartitionColumns, filterOutNullPartitionValues) + } + } + + def createExternalTables(location: String, format: String, databaseName: String, overwrite: Boolean): Unit = { + sqlContext.sql(s"CREATE DATABASE IF NOT EXISTS $databaseName") + tables.foreach { table => + val tableLocation = s"$location/${table.name}" + table.createExternalTable(tableLocation, format, databaseName, overwrite) + } + sqlContext.sql(s"USE $databaseName") + println(s"The current database has been set to $databaseName.") + logInfo(s"The current database has been set to $databaseName.") + } + + def createTemporaryTables(location: String, format: String): Unit = { + tables.foreach { table => + val tableLocation = s"$location/${table.name}" + table.createTemporaryTable(tableLocation, format) + } } val tables = Seq(