diff --git a/src/main/scala/com/databricks/spark/sql/perf/bigdata/Queries.scala b/src/main/scala/com/databricks/spark/sql/perf/bigdata/Queries.scala index 14b5720..12103b5 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/bigdata/Queries.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/bigdata/Queries.scala @@ -16,9 +16,9 @@ package com.databricks.spark.sql.perf.bigdata -import com.databricks.spark.sql.perf.Query +import com.databricks.spark.sql.perf.QuerySet -object Queries { +trait Queries extends QuerySet { val queries1to3 = Seq( Query( name = "q1A", diff --git a/src/main/scala/com/databricks/spark/sql/perf/query.scala b/src/main/scala/com/databricks/spark/sql/perf/query.scala index c7191a3..8f2fd1e 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/query.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/query.scala @@ -16,81 +16,120 @@ package com.databricks.spark.sql.perf -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation -case class Query(name: String, sqlText: String, description: String, collectResults: Boolean) +object x { -case class QueryForTest( - query: Query, - includeBreakdown: Boolean, - @transient sqlContext: SQLContext) { - @transient val sparkContext = sqlContext.sparkContext - val name = query.name - def benchmarkMs[A](f: => A): Double = { - val startTime = System.nanoTime() - val ret = f - val endTime = System.nanoTime() - (endTime - startTime).toDouble / 1000000 - } +trait QuerySet { + val sqlContext: SQLContext + def sparkContext = sqlContext.sparkContext - def benchmark(description: String = "") = { - try { - sparkContext.setJobDescription(s"Query: ${query.name}, $description") - val dataFrame = sqlContext.sql(query.sqlText) - val queryExecution = dataFrame.queryExecution - // We are not counting the time of ScalaReflection.convertRowToScala. - val parsingTime = benchmarkMs { queryExecution.logical } - val analysisTime = benchmarkMs { queryExecution.analyzed } - val optimizationTime = benchmarkMs { queryExecution.optimizedPlan } - val planningTime = benchmarkMs { queryExecution.executedPlan } - val breakdownResults = if (includeBreakdown) { - val depth = queryExecution.executedPlan.treeString.split("\n").size - val physicalOperators = (0 until depth).map(i => (i, queryExecution.executedPlan(i))) - physicalOperators.map { - case (index, node) => - val executionTime = benchmarkMs { node.execute().map(_.copy()).foreach(row => Unit) } - BreakdownResult(node.nodeName, node.simpleString, index, executionTime) - } - } else { - Seq.empty[BreakdownResult] - } + object Query { + def apply( + name: String, + sqlText: String, + description: String, + collectResults: Boolean = true): Query = { + new Query(name, sqlContext.sql(sqlText), description, collectResults, Some(sqlText)) + } - // The executionTime for the entire query includes the time of type conversion from catalyst to scala. - val executionTime = if (query.collectResults) { - benchmarkMs { dataFrame.rdd.collect() } - } else { - benchmarkMs { dataFrame.rdd.foreach {row => Unit } } - } - - val joinTypes = dataFrame.queryExecution.executedPlan.collect { - case k if k.nodeName contains "Join" => k.nodeName - } - - val tablesInvolved = dataFrame.queryExecution.logical collect { - case UnresolvedRelation(tableIdentifier, _) => { - // We are ignoring the database name. - tableIdentifier.last - } - } - - BenchmarkResult( - name = query.name, - joinTypes = joinTypes, - tables = tablesInvolved, - parsingTime = parsingTime, - analysisTime = analysisTime, - optimizationTime = optimizationTime, - planningTime = planningTime, - executionTime = executionTime, - breakdownResults) - } catch { - case e: Exception => - throw new RuntimeException( - s"Failed to benchmark query ${query.name}", e) + def apply( + name: String, + dataFrameBuilder: => DataFrame, + description: String): Query = { + new Query(name, dataFrameBuilder, description, true, None) } } -} + + class Query( + val name: String, + dataFrameBuilder: => DataFrame, + val description: String, + val collectResults: Boolean, + val sqlText: Option[String]) { + + val tablesInvolved = dataFrameBuilder.queryExecution.logical collect { + case UnresolvedRelation(tableIdentifier, _) => { + // We are ignoring the database name. + tableIdentifier.last + } + } + + def benchmarkMs[A](f: => A): Double = { + val startTime = System.nanoTime() + val ret = f + val endTime = System.nanoTime() + (endTime - startTime).toDouble / 1000000 + } + + def benchmark(includeBreakdown: Boolean, description: String = "") = { + try { + val dataFrame = dataFrameBuilder + sparkContext.setJobDescription(s"Query: $name, $description") + val queryExecution = dataFrame.queryExecution + // We are not counting the time of ScalaReflection.convertRowToScala. + val parsingTime = benchmarkMs { + queryExecution.logical + } + val analysisTime = benchmarkMs { + queryExecution.analyzed + } + val optimizationTime = benchmarkMs { + queryExecution.optimizedPlan + } + val planningTime = benchmarkMs { + queryExecution.executedPlan + } + + val breakdownResults = if (includeBreakdown) { + val depth = queryExecution.executedPlan.treeString.split("\n").size + val physicalOperators = (0 until depth).map(i => (i, queryExecution.executedPlan(i))) + physicalOperators.map { + case (index, node) => + val executionTime = benchmarkMs { + node.execute().map(_.copy()).foreach(row => Unit) + } + BreakdownResult(node.nodeName, node.simpleString, index, executionTime) + } + } else { + Seq.empty[BreakdownResult] + } + + // The executionTime for the entire query includes the time of type conversion from catalyst to scala. + val executionTime = if (collectResults) { + benchmarkMs { + dataFrame.rdd.collect() + } + } else { + benchmarkMs { + dataFrame.rdd.foreach { row => Unit } + } + } + + val joinTypes = dataFrame.queryExecution.executedPlan.collect { + case k if k.nodeName contains "Join" => k.nodeName + } + + BenchmarkResult( + name = name, + joinTypes = joinTypes, + tables = tablesInvolved, + parsingTime = parsingTime, + analysisTime = analysisTime, + optimizationTime = optimizationTime, + planningTime = planningTime, + executionTime = executionTime, + breakdownResults) + } catch { + case e: Exception => + throw new RuntimeException( + s"Failed to benchmark query $name", e) + } + } + } + +} \ No newline at end of file diff --git a/src/main/scala/com/databricks/spark/sql/perf/runBenchmarks.scala b/src/main/scala/com/databricks/spark/sql/perf/runBenchmarks.scala index fb4d69a..3920c6f 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/runBenchmarks.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/runBenchmarks.scala @@ -127,16 +127,14 @@ case class ExperimentRun( * is a short string describing the scale of the dataset. */ abstract class Dataset( - @transient sqlContext: SQLContext, + @transient val sqlContext: SQLContext, sparkVersion: String, dataLocation: String, tables: Seq[Table], - scaleFactor: String) extends Serializable { + scaleFactor: String) extends Serializable with QuerySet { val datasetName: String - @transient val sparkContext = sqlContext.sparkContext - def createTablesForTest(tables: Seq[Table]): Seq[TableForTest] val tablesForTest: Seq[TableForTest] = createTablesForTest(tables) @@ -181,7 +179,7 @@ abstract class Dataset( /** * Starts an experiment run with a given set of queries. - * @param queries Queries to be executed. + * @param queriesToRun Queries to be executed. * @param resultsLocation The location of performance results. * @param includeBreakdown If it is true, breakdown results of a query will be recorded. * Setting it to true may significantly increase the time used to @@ -193,15 +191,13 @@ abstract class Dataset( * track the progress of this experiment run. */ def runExperiment( - queries: Seq[Query], + queriesToRun: Seq[Query], resultsLocation: String, includeBreakdown: Boolean = false, iterations: Int = 3, variations: Seq[Variation[_]] = Seq(Variation("StandardRun", Seq("")) { _ => {} }), tags: Map[String, String] = Map.empty) = { - val queriesToRun = queries.map(query => QueryForTest(query, includeBreakdown, sqlContext)) - class ExperimentStatus { val currentResults = new collection.mutable.ArrayBuffer[BenchmarkResult]() val currentRuns = new collection.mutable.ArrayBuffer[ExperimentRun]() @@ -237,7 +233,7 @@ abstract class Dataset( currentMessages += s"Running query ${q.name} $setup" currentQuery = q.name - val singleResult = try q.benchmark(setup) :: Nil catch { + val singleResult = try q.benchmark(includeBreakdown, setup) :: Nil catch { case e: Exception => currentMessages += s"Failed to run query ${q.name}: $e" Nil @@ -287,6 +283,7 @@ abstract class Dataset( "Running" } + override def toString = s""" |=== $status Experiment === diff --git a/src/main/scala/com/databricks/spark/sql/perf/tpcds/queries/ImpalaKitQueries.scala b/src/main/scala/com/databricks/spark/sql/perf/tpcds/queries/ImpalaKitQueries.scala index c530232..df621bf 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/tpcds/queries/ImpalaKitQueries.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/tpcds/queries/ImpalaKitQueries.scala @@ -16,9 +16,9 @@ package com.databricks.spark.sql.perf.tpcds.queries -import com.databricks.spark.sql.perf.Query +import com.databricks.spark.sql.perf.QuerySet -object ImpalaKitQueries { +trait ImpalaKitQueries extends QuerySet { // Queries are from // https://github.com/cloudera/impala-tpcds-kit/tree/master/queries-sql92-modified/queries val queries = Seq( diff --git a/src/main/scala/com/databricks/spark/sql/perf/tpcds/queries/SimpleQueries.scala b/src/main/scala/com/databricks/spark/sql/perf/tpcds/queries/SimpleQueries.scala index fd4f4b6..98aec64 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/tpcds/queries/SimpleQueries.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/tpcds/queries/SimpleQueries.scala @@ -16,9 +16,9 @@ package com.databricks.spark.sql.perf.tpcds.queries -import com.databricks.spark.sql.perf.Query +import com.databricks.spark.sql.perf.QuerySet -object SimpleQueries { +trait SimpleQueries extends QuerySet{ val q7Derived = Seq( ("q7-simpleScan", """ diff --git a/src/main/scala/com/databricks/spark/sql/perf/tpcds/queries/package.scala b/src/main/scala/com/databricks/spark/sql/perf/tpcds/queries/package.scala index 4b154ac..65f72fe 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/tpcds/queries/package.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/tpcds/queries/package.scala @@ -15,8 +15,3 @@ */ package com.databricks.spark.sql.perf.tpcds - -package object queries { - val impalaKitQueries = ImpalaKitQueries.impalaKitQueries - val q7DerivedQueries = SimpleQueries.q7Derived -}