Tweaks and improvements (#106)

Data generation:
* Add an option to change Dates to Strings, and specify it in Tables object creator.
* Add discovering partitions to createExternalTables
* Add analyzeTables function that gathers statistics.

Benchmark execution:
* Perform collect() on Dataframe, so that it is recorded by SQL SparkUI.
This commit is contained in:
Juliusz Sompolski 2017-06-13 11:42:14 +02:00 committed by GitHub
parent 75f3876e59
commit bff6b34f62
2 changed files with 48 additions and 19 deletions

View File

@ -117,13 +117,13 @@ class Query(
// The executionTime for the entire query includes the time of type conversion from catalyst
// to scala.
// The executionTime for the entire query includes the time of type conversion
// from catalyst to scala.
// Note: queryExecution.{logical, analyzed, optimizedPlan, executedPlan} has been already
// lazily evaluated above, so below we will count only execution time.
var result: Option[Long] = None
val executionTime = measureTimeMs {
executionMode match {
case ExecutionMode.CollectResults => dataFrame.rdd.collect()
case ExecutionMode.ForeachResults => dataFrame.rdd.foreach { row => Unit }
case ExecutionMode.CollectResults => dataFrame.collect()
case ExecutionMode.ForeachResults => dataFrame.foreach { row => Unit }
case ExecutionMode.WriteParquet(location) =>
dataFrame.write.parquet(s"$location/$name.parquet")
case ExecutionMode.HashResults =>

View File

@ -24,7 +24,10 @@ import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Row, SQLContext, SaveMode}
class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) extends Serializable {
class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int,
useDoubleForDecimal: Boolean = false, useStringForDate: Boolean = false)
extends Serializable {
import sqlContext.implicits._
private val log = LoggerFactory.getLogger(getClass)
@ -104,10 +107,11 @@ class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) extend
}
}
def useDoubleForDecimal(): Table = {
def convertTypes(): Table = {
val newFields = fields.map { field =>
val newDataType = field.dataType match {
case decimal: DecimalType => DoubleType
case decimal: DecimalType if useDoubleForDecimal => DoubleType
case date: DateType if useStringForDate => StringType
case other => other
}
field.copy(dataType = newDataType)
@ -172,7 +176,9 @@ class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) extend
sqlContext.dropTempTable(tempTableName)
}
def createExternalTable(location: String, format: String, databaseName: String, overwrite: Boolean): Unit = {
def createExternalTable(location: String, format: String, databaseName: String,
overwrite: Boolean, discoverPartitions: Boolean = true): Unit = {
val qualifiedTableName = databaseName + "." + name
val tableExists = sqlContext.tableNames(databaseName).contains(name)
if (overwrite) {
@ -183,6 +189,11 @@ class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) extend
log.info(s"Creating external table $name in database $databaseName using data stored in $location.")
sqlContext.createExternalTable(qualifiedTableName, location, format)
}
if (partitionColumns.nonEmpty && discoverPartitions) {
println(s"Discovering partitions for table $name.")
log.info(s"Discovering partitions for table $name.")
sqlContext.sql(s"ALTER TABLE $databaseName.$name RECOVER PARTITIONS")
}
}
def createTemporaryTable(location: String, format: String): Unit = {
@ -190,6 +201,18 @@ class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) extend
log.info(s"Creating temporary table $name using data stored in $location.")
sqlContext.read.format(format).load(location).registerTempTable(name)
}
def analyzeTable(databaseName: String, analyzeColumns: Boolean = false): Unit = {
println(s"Analyzing table $name.")
log.info(s"Analyzing table $name.")
sqlContext.sql(s"ANALYZE TABLE $databaseName.$name COMPUTE STATISTICS")
if (analyzeColumns) {
val allColumns = fields.map(_.name).mkString(", ")
println(s"Analyzing table $name columns $allColumns.")
log.info(s"Analyzing table $name columns $allColumns.")
sqlContext.sql(s"ANALYZE TABLE $databaseName.$name COMPUTE STATISTICS FOR COLUMNS $allColumns")
}
}
}
def genData(
@ -197,7 +220,6 @@ class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) extend
format: String,
overwrite: Boolean,
partitionTables: Boolean,
useDoubleForDecimal: Boolean,
clusterByPartitionColumns: Boolean,
filterOutNullPartitionValues: Boolean,
tableFilter: String = "",
@ -215,20 +237,16 @@ class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) extend
}
}
val withSpecifiedDataType = if (useDoubleForDecimal) {
tablesToBeGenerated.map(_.useDoubleForDecimal())
} else {
tablesToBeGenerated
}
withSpecifiedDataType.foreach { table =>
tablesToBeGenerated.foreach { table =>
val tableLocation = s"$location/${table.name}"
table.genData(tableLocation, format, overwrite, clusterByPartitionColumns,
filterOutNullPartitionValues, numPartitions)
}
}
def createExternalTables(location: String, format: String, databaseName: String, overwrite: Boolean, tableFilter: String = ""): Unit = {
def createExternalTables(location: String, format: String, databaseName: String,
overwrite: Boolean, discoverPartitions: Boolean, tableFilter: String = ""): Unit = {
val filtered = if (tableFilter.isEmpty) {
tables
} else {
@ -238,7 +256,7 @@ class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) extend
sqlContext.sql(s"CREATE DATABASE IF NOT EXISTS $databaseName")
filtered.foreach { table =>
val tableLocation = s"$location/${table.name}"
table.createExternalTable(tableLocation, format, databaseName, overwrite)
table.createExternalTable(tableLocation, format, databaseName, overwrite, discoverPartitions)
}
sqlContext.sql(s"USE $databaseName")
println(s"The current database has been set to $databaseName.")
@ -257,6 +275,17 @@ class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) extend
}
}
def analyzeTables(databaseName: String, analyzeColumns: Boolean = false, tableFilter: String = ""): Unit = {
val filtered = if (tableFilter.isEmpty) {
tables
} else {
tables.filter(_.name == tableFilter)
}
filtered.foreach { table =>
table.analyzeTable(databaseName, analyzeColumns)
}
}
val tables = Seq(
Table("catalog_sales",
partitionColumns = "cs_sold_date_sk" :: Nil,
@ -731,5 +760,5 @@ class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) extend
'web_country .string,
'web_gmt_offset .decimal(5,2),
'web_tax_percentage .decimal(5,2))
)
).map(_.convertTypes())
}