spark-sql-perf/src/main/scala/com/databricks/spark/sql/perf/Tables.scala
Bogdan Raducanu 4e7a2363b9 Support for TPC-H benchmark
Refactored TPC-DS code to be able to reuse it for TPC-H.
Added TPC-H queries texts adapted for Spark.
2017-08-09 12:26:32 +02:00

279 lines
9.3 KiB
Scala

/*
* Copyright 2015 Databricks Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.databricks.spark.sql.perf
import scala.sys.process._
import org.slf4j.LoggerFactory
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Row, SQLContext, SaveMode}
trait DataGenerator extends Serializable {
def generate(
sparkContext: SparkContext,
name: String,
partitions: Int,
scaleFactor: String): RDD[String]
}
abstract class Tables(sqlContext: SQLContext, scaleFactor: String,
useDoubleForDecimal: Boolean = false, useStringForDate: Boolean = false)
extends Serializable {
def dataGenerator: DataGenerator
def tables: Seq[Table]
private val log = LoggerFactory.getLogger(getClass)
def sparkContext = sqlContext.sparkContext
case class Table(name: String, partitionColumns: Seq[String], fields: StructField*) {
val schema = StructType(fields)
def nonPartitioned: Table = {
Table(name, Nil, fields : _*)
}
/**
* If convertToSchema is true, the data from generator will be parsed into columns and
* converted to `schema`. Otherwise, it just outputs the raw data (as a single STRING column).
*/
def df(convertToSchema: Boolean, numPartition: Int) = {
val generatedData = dataGenerator.generate(sparkContext, name, numPartition, scaleFactor)
val rows = generatedData.mapPartitions { iter =>
iter.map { l =>
if (convertToSchema) {
val values = l.split("\\|", -1).dropRight(1).map { v =>
if (v.equals("")) {
// If the string value is an empty string, we turn it to a null
null
} else {
v
}
}
Row.fromSeq(values)
} else {
Row.fromSeq(Seq(l))
}
}
}
if (convertToSchema) {
val stringData =
sqlContext.createDataFrame(
rows,
StructType(schema.fields.map(f => StructField(f.name, StringType))))
val convertedData = {
val columns = schema.fields.map { f =>
col(f.name).cast(f.dataType).as(f.name)
}
stringData.select(columns: _*)
}
convertedData
} else {
sqlContext.createDataFrame(rows, StructType(Seq(StructField("value", StringType))))
}
}
def convertTypes(): Table = {
val newFields = fields.map { field =>
val newDataType = field.dataType match {
case decimal: DecimalType if useDoubleForDecimal => DoubleType
case date: DateType if useStringForDate => StringType
case other => other
}
field.copy(dataType = newDataType)
}
Table(name, partitionColumns, newFields:_*)
}
def genData(
location: String,
format: String,
overwrite: Boolean,
clusterByPartitionColumns: Boolean,
filterOutNullPartitionValues: Boolean,
numPartitions: Int): Unit = {
val mode = if (overwrite) SaveMode.Overwrite else SaveMode.Ignore
val data = df(format != "text", numPartitions)
val tempTableName = s"${name}_text"
data.registerTempTable(tempTableName)
val writer = if (partitionColumns.nonEmpty) {
if (clusterByPartitionColumns) {
val columnString = data.schema.fields.map { field =>
field.name
}.mkString(",")
val partitionColumnString = partitionColumns.mkString(",")
val predicates = if (filterOutNullPartitionValues) {
partitionColumns.map(col => s"$col IS NOT NULL").mkString("WHERE ", " AND ", "")
} else {
""
}
val query =
s"""
|SELECT
| $columnString
|FROM
| $tempTableName
|$predicates
|DISTRIBUTE BY
| $partitionColumnString
""".stripMargin
val grouped = sqlContext.sql(query)
println(s"Pre-clustering with partitioning columns with query $query.")
log.info(s"Pre-clustering with partitioning columns with query $query.")
grouped.write
} else {
data.write
}
} else {
data.write
}
writer.format(format).mode(mode)
if (partitionColumns.nonEmpty) {
writer.partitionBy(partitionColumns : _*)
}
println(s"Generating table $name in database to $location with save mode $mode.")
log.info(s"Generating table $name in database to $location with save mode $mode.")
writer.save(location)
sqlContext.dropTempTable(tempTableName)
}
def createExternalTable(location: String, format: String, databaseName: String,
overwrite: Boolean, discoverPartitions: Boolean = true): Unit = {
val qualifiedTableName = databaseName + "." + name
val tableExists = sqlContext.tableNames(databaseName).contains(name)
if (overwrite) {
sqlContext.sql(s"DROP TABLE IF EXISTS $databaseName.$name")
}
if (!tableExists || overwrite) {
println(s"Creating external table $name in database $databaseName using data stored in $location.")
log.info(s"Creating external table $name in database $databaseName using data stored in $location.")
sqlContext.createExternalTable(qualifiedTableName, location, format)
}
if (partitionColumns.nonEmpty && discoverPartitions) {
println(s"Discovering partitions for table $name.")
log.info(s"Discovering partitions for table $name.")
sqlContext.sql(s"ALTER TABLE $databaseName.$name RECOVER PARTITIONS")
}
}
def createTemporaryTable(location: String, format: String): Unit = {
println(s"Creating temporary table $name using data stored in $location.")
log.info(s"Creating temporary table $name using data stored in $location.")
sqlContext.read.format(format).load(location).registerTempTable(name)
}
def analyzeTable(databaseName: String, analyzeColumns: Boolean = false): Unit = {
println(s"Analyzing table $name.")
log.info(s"Analyzing table $name.")
sqlContext.sql(s"ANALYZE TABLE $databaseName.$name COMPUTE STATISTICS")
if (analyzeColumns) {
val allColumns = fields.map(_.name).mkString(", ")
println(s"Analyzing table $name columns $allColumns.")
log.info(s"Analyzing table $name columns $allColumns.")
sqlContext.sql(s"ANALYZE TABLE $databaseName.$name COMPUTE STATISTICS FOR COLUMNS $allColumns")
}
}
}
def genData(
location: String,
format: String,
overwrite: Boolean,
partitionTables: Boolean,
clusterByPartitionColumns: Boolean,
filterOutNullPartitionValues: Boolean,
tableFilter: String = "",
numPartitions: Int = 100): Unit = {
var tablesToBeGenerated = if (partitionTables) {
tables
} else {
tables.map(_.nonPartitioned)
}
if (!tableFilter.isEmpty) {
tablesToBeGenerated = tablesToBeGenerated.filter(_.name == tableFilter)
if (tablesToBeGenerated.isEmpty) {
throw new RuntimeException("Bad table name filter: " + tableFilter)
}
}
tablesToBeGenerated.foreach { table =>
val tableLocation = s"$location/${table.name}"
table.genData(tableLocation, format, overwrite, clusterByPartitionColumns,
filterOutNullPartitionValues, numPartitions)
}
}
def createExternalTables(location: String, format: String, databaseName: String,
overwrite: Boolean, discoverPartitions: Boolean, tableFilter: String = ""): Unit = {
val filtered = if (tableFilter.isEmpty) {
tables
} else {
tables.filter(_.name == tableFilter)
}
sqlContext.sql(s"CREATE DATABASE IF NOT EXISTS $databaseName")
filtered.foreach { table =>
val tableLocation = s"$location/${table.name}"
table.createExternalTable(tableLocation, format, databaseName, overwrite, discoverPartitions)
}
sqlContext.sql(s"USE $databaseName")
println(s"The current database has been set to $databaseName.")
log.info(s"The current database has been set to $databaseName.")
}
def createTemporaryTables(location: String, format: String, tableFilter: String = ""): Unit = {
val filtered = if (tableFilter.isEmpty) {
tables
} else {
tables.filter(_.name == tableFilter)
}
filtered.foreach { table =>
val tableLocation = s"$location/${table.name}"
table.createTemporaryTable(tableLocation, format)
}
}
def analyzeTables(databaseName: String, analyzeColumns: Boolean = false, tableFilter: String = ""): Unit = {
val filtered = if (tableFilter.isEmpty) {
tables
} else {
tables.filter(_.name == tableFilter)
}
filtered.foreach { table =>
table.analyzeTable(databaseName, analyzeColumns)
}
}
}