From 8d9e8ce9a35b9bf9d14945b6451970c320180882 Mon Sep 17 00:00:00 2001 From: Nong Li Date: Wed, 18 Nov 2015 11:12:01 -0800 Subject: [PATCH] Add another fact table and updates to load a single table at a time. Author: Nong Li Closes #31 from nongli/more_tables. --- .../spark/sql/perf/tpcds/SimpleQueries.scala | 16 ++++++ .../spark/sql/perf/tpcds/Tables.scala | 56 ++++++++++++++++--- 2 files changed, 65 insertions(+), 7 deletions(-) diff --git a/src/main/scala/com/databricks/spark/sql/perf/tpcds/SimpleQueries.scala b/src/main/scala/com/databricks/spark/sql/perf/tpcds/SimpleQueries.scala index 2f3e768..8c68c25 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/tpcds/SimpleQueries.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/tpcds/SimpleQueries.scala @@ -22,6 +22,22 @@ trait SimpleQueries extends Benchmark { import ExecutionMode._ + val targetedPerfQueries = Seq( + // Query to measure scan performance. + ("stores-sales-scan", + """ + |select * from store_sales where ss_item_sk = 1 + """.stripMargin), + ("fact-fact-join", + """ + | select count(*) from store_sales + | join store_returns + | on store_sales.ss_item_sk = store_returns.sr_item_sk + """.stripMargin) + ).map { case (name, sqlText) => + Query(name = name, sqlText = sqlText, description = "", executionMode = ForeachResults) + } + val q7Derived = Seq( ("q7-simpleScan", """ diff --git a/src/main/scala/com/databricks/spark/sql/perf/tpcds/Tables.scala b/src/main/scala/com/databricks/spark/sql/perf/tpcds/Tables.scala index cdbba05..2e52552 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/tpcds/Tables.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/tpcds/Tables.scala @@ -198,13 +198,21 @@ class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) extend partitionTables: Boolean, useDoubleForDecimal: Boolean, clusterByPartitionColumns: Boolean, - filterOutNullPartitionValues: Boolean): Unit = { - val tablesToBeGenerated = if (partitionTables) { + filterOutNullPartitionValues: Boolean, + tableFilter: String = ""): Unit = { + var tablesToBeGenerated = if (partitionTables) { tables } else { tables.map(_.nonPartitioned) } + if (!tableFilter.isEmpty) { + tablesToBeGenerated = tablesToBeGenerated.filter(_.name == tableFilter) + if (tablesToBeGenerated.isEmpty) { + throw new RuntimeException("Bad table name filter: " + tableFilter) + } + } + val withSpecifiedDataType = if (useDoubleForDecimal) { tablesToBeGenerated.map(_.useDoubleForDecimal()) } else { @@ -213,13 +221,20 @@ class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) extend withSpecifiedDataType.foreach { table => val tableLocation = s"$location/${table.name}" - table.genData(tableLocation, format, overwrite, clusterByPartitionColumns, filterOutNullPartitionValues) + table.genData(tableLocation, format, overwrite, clusterByPartitionColumns, + filterOutNullPartitionValues) } } - def createExternalTables(location: String, format: String, databaseName: String, overwrite: Boolean): Unit = { + def createExternalTables(location: String, format: String, databaseName: String, overwrite: Boolean, tableFilter: String = ""): Unit = { + val filtered = if (tableFilter.isEmpty) { + tables + } else { + tables.filter(_.name == tableFilter) + } + sqlContext.sql(s"CREATE DATABASE IF NOT EXISTS $databaseName") - tables.foreach { table => + filtered.foreach { table => val tableLocation = s"$location/${table.name}" table.createExternalTable(tableLocation, format, databaseName, overwrite) } @@ -228,8 +243,13 @@ class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) extend logInfo(s"The current database has been set to $databaseName.") } - def createTemporaryTables(location: String, format: String): Unit = { - tables.foreach { table => + def createTemporaryTables(location: String, format: String, tableFilter: String = ""): Unit = { + val filtered = if (tableFilter.isEmpty) { + tables + } else { + tables.filter(_.name == tableFilter) + } + filtered.foreach { table => val tableLocation = s"$location/${table.name}" table.createTemporaryTable(tableLocation, format) } @@ -268,6 +288,28 @@ class Tables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) extend 'ss_net_paid .decimal(7,2), 'ss_net_paid_inc_tax .decimal(7,2), 'ss_net_profit .decimal(7,2)), + Table("store_returns", + partitionColumns = "sr_returned_date_sk" ::Nil, + 'sr_returned_date_sk .long, + 'sr_return_time_sk .long, + 'sr_item_sk .long, + 'sr_customer_sk .long, + 'sr_cdemo_sk .long, + 'sr_hdemo_sk .long, + 'sr_addr_sk .long, + 'sr_store_sk .long, + 'sr_reason_sk .long, + 'sr_ticket_number .long, + 'sr_return_quantity .long, + 'sr_return_amt .decimal(7,2), + 'sr_return_tax .decimal(7,2), + 'sr_return_amt_inc_tax.decimal(7,2), + 'sr_fee .decimal(7,2), + 'sr_return_ship_cost .decimal(7,2), + 'sr_refunded_cash .decimal(7,2), + 'sr_reversed_charge .decimal(7,2), + 'sr_store_credit .decimal(7,2), + 'sr_net_loss .decimal(7,2)), Table("customer", partitionColumns = Nil, 'c_customer_sk .int,