From bff6b34f62bf59946b26c478bb9871550566ccda Mon Sep 17 00:00:00 2001 From: Juliusz Sompolski Date: Tue, 13 Jun 2017 11:42:14 +0200 Subject: [PATCH] Tweaks and improvements (#106) Data generation: * Add an option to change Dates to Strings, and specify it in Tables object creator. * Add discovering partitions to createExternalTables * Add analyzeTables function that gathers statistics. Benchmark execution: * Perform collect() on Dataframe, so that it is recorded by SQL SparkUI. --- .../com/databricks/spark/sql/perf/Query.scala | 8 +-- .../spark/sql/perf/tpcds/Tables.scala | 59 ++++++++++++++----- 2 files changed, 48 insertions(+), 19 deletions(-) diff --git a/src/main/scala/com/databricks/spark/sql/perf/Query.scala b/src/main/scala/com/databricks/spark/sql/perf/Query.scala index 3868efd..7a8f75e 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/Query.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/Query.scala @@ -117,13 +117,13 @@ class Query( // The executionTime for the entire query includes the time of type conversion from catalyst // to scala. - // The executionTime for the entire query includes the time of type conversion - // from catalyst to scala. + // Note: queryExecution.{logical, analyzed, optimizedPlan, executedPlan} has been already + // lazily evaluated above, so below we will count only execution time. var result: Option[Long] = None val executionTime = measureTimeMs { executionMode match { - case ExecutionMode.CollectResults => dataFrame.rdd.collect() - case ExecutionMode.ForeachResults => dataFrame.rdd.foreach { row => Unit } + case ExecutionMode.CollectResults => dataFrame.collect() + case ExecutionMode.ForeachResults => dataFrame.foreach { row => Unit } case ExecutionMode.WriteParquet(location) => dataFrame.write.parquet(s"$location/$name.parquet") case ExecutionMode.HashResults => diff --git a/src/main/scala/com/databricks/spark/sql/perf/tpcds/Tables.scala b/src/main/scala/com/databricks/spark/sql/perf/tpcds/Tables.scala index d6b8065..4d44f99 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/tpcds/Tables.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/tpcds/Tables.scala @@ -24,7 +24,10 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.sql.{Row, SQLContext, SaveMode} -class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) extends Serializable { +class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int, + useDoubleForDecimal: Boolean = false, useStringForDate: Boolean = false) + extends Serializable { + import sqlContext.implicits._ private val log = LoggerFactory.getLogger(getClass) @@ -104,10 +107,11 @@ class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) extend } } - def useDoubleForDecimal(): Table = { + def convertTypes(): Table = { val newFields = fields.map { field => val newDataType = field.dataType match { - case decimal: DecimalType => DoubleType + case decimal: DecimalType if useDoubleForDecimal => DoubleType + case date: DateType if useStringForDate => StringType case other => other } field.copy(dataType = newDataType) @@ -172,7 +176,9 @@ class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) extend sqlContext.dropTempTable(tempTableName) } - def createExternalTable(location: String, format: String, databaseName: String, overwrite: Boolean): Unit = { + def createExternalTable(location: String, format: String, databaseName: String, + overwrite: Boolean, discoverPartitions: Boolean = true): Unit = { + val qualifiedTableName = databaseName + "." + name val tableExists = sqlContext.tableNames(databaseName).contains(name) if (overwrite) { @@ -183,6 +189,11 @@ class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) extend log.info(s"Creating external table $name in database $databaseName using data stored in $location.") sqlContext.createExternalTable(qualifiedTableName, location, format) } + if (partitionColumns.nonEmpty && discoverPartitions) { + println(s"Discovering partitions for table $name.") + log.info(s"Discovering partitions for table $name.") + sqlContext.sql(s"ALTER TABLE $databaseName.$name RECOVER PARTITIONS") + } } def createTemporaryTable(location: String, format: String): Unit = { @@ -190,6 +201,18 @@ class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) extend log.info(s"Creating temporary table $name using data stored in $location.") sqlContext.read.format(format).load(location).registerTempTable(name) } + + def analyzeTable(databaseName: String, analyzeColumns: Boolean = false): Unit = { + println(s"Analyzing table $name.") + log.info(s"Analyzing table $name.") + sqlContext.sql(s"ANALYZE TABLE $databaseName.$name COMPUTE STATISTICS") + if (analyzeColumns) { + val allColumns = fields.map(_.name).mkString(", ") + println(s"Analyzing table $name columns $allColumns.") + log.info(s"Analyzing table $name columns $allColumns.") + sqlContext.sql(s"ANALYZE TABLE $databaseName.$name COMPUTE STATISTICS FOR COLUMNS $allColumns") + } + } } def genData( @@ -197,7 +220,6 @@ class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) extend format: String, overwrite: Boolean, partitionTables: Boolean, - useDoubleForDecimal: Boolean, clusterByPartitionColumns: Boolean, filterOutNullPartitionValues: Boolean, tableFilter: String = "", @@ -215,20 +237,16 @@ class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) extend } } - val withSpecifiedDataType = if (useDoubleForDecimal) { - tablesToBeGenerated.map(_.useDoubleForDecimal()) - } else { - tablesToBeGenerated - } - - withSpecifiedDataType.foreach { table => + tablesToBeGenerated.foreach { table => val tableLocation = s"$location/${table.name}" table.genData(tableLocation, format, overwrite, clusterByPartitionColumns, filterOutNullPartitionValues, numPartitions) } } - def createExternalTables(location: String, format: String, databaseName: String, overwrite: Boolean, tableFilter: String = ""): Unit = { + def createExternalTables(location: String, format: String, databaseName: String, + overwrite: Boolean, discoverPartitions: Boolean, tableFilter: String = ""): Unit = { + val filtered = if (tableFilter.isEmpty) { tables } else { @@ -238,7 +256,7 @@ class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) extend sqlContext.sql(s"CREATE DATABASE IF NOT EXISTS $databaseName") filtered.foreach { table => val tableLocation = s"$location/${table.name}" - table.createExternalTable(tableLocation, format, databaseName, overwrite) + table.createExternalTable(tableLocation, format, databaseName, overwrite, discoverPartitions) } sqlContext.sql(s"USE $databaseName") println(s"The current database has been set to $databaseName.") @@ -257,6 +275,17 @@ class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) extend } } + def analyzeTables(databaseName: String, analyzeColumns: Boolean = false, tableFilter: String = ""): Unit = { + val filtered = if (tableFilter.isEmpty) { + tables + } else { + tables.filter(_.name == tableFilter) + } + filtered.foreach { table => + table.analyzeTable(databaseName, analyzeColumns) + } + } + val tables = Seq( Table("catalog_sales", partitionColumns = "cs_sold_date_sk" :: Nil, @@ -731,5 +760,5 @@ class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) extend 'web_country .string, 'web_gmt_offset .decimal(5,2), 'web_tax_percentage .decimal(5,2)) - ) + ).map(_.convertTypes()) }