Refactor to work in notebooks

This commit is contained in:
Michael Armbrust 2015-07-03 11:26:06 -07:00
parent 3eca8d2947
commit eb3dd30c35
6 changed files with 119 additions and 88 deletions

View File

@ -16,9 +16,9 @@
package com.databricks.spark.sql.perf.bigdata
import com.databricks.spark.sql.perf.Query
import com.databricks.spark.sql.perf.QuerySet
object Queries {
trait Queries extends QuerySet {
val queries1to3 = Seq(
Query(
name = "q1A",

View File

@ -16,81 +16,120 @@
package com.databricks.spark.sql.perf
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
case class Query(name: String, sqlText: String, description: String, collectResults: Boolean)
object x {
case class QueryForTest(
query: Query,
includeBreakdown: Boolean,
@transient sqlContext: SQLContext) {
@transient val sparkContext = sqlContext.sparkContext
val name = query.name
def benchmarkMs[A](f: => A): Double = {
val startTime = System.nanoTime()
val ret = f
val endTime = System.nanoTime()
(endTime - startTime).toDouble / 1000000
}
trait QuerySet {
val sqlContext: SQLContext
def sparkContext = sqlContext.sparkContext
def benchmark(description: String = "") = {
try {
sparkContext.setJobDescription(s"Query: ${query.name}, $description")
val dataFrame = sqlContext.sql(query.sqlText)
val queryExecution = dataFrame.queryExecution
// We are not counting the time of ScalaReflection.convertRowToScala.
val parsingTime = benchmarkMs { queryExecution.logical }
val analysisTime = benchmarkMs { queryExecution.analyzed }
val optimizationTime = benchmarkMs { queryExecution.optimizedPlan }
val planningTime = benchmarkMs { queryExecution.executedPlan }
val breakdownResults = if (includeBreakdown) {
val depth = queryExecution.executedPlan.treeString.split("\n").size
val physicalOperators = (0 until depth).map(i => (i, queryExecution.executedPlan(i)))
physicalOperators.map {
case (index, node) =>
val executionTime = benchmarkMs { node.execute().map(_.copy()).foreach(row => Unit) }
BreakdownResult(node.nodeName, node.simpleString, index, executionTime)
}
} else {
Seq.empty[BreakdownResult]
}
object Query {
def apply(
name: String,
sqlText: String,
description: String,
collectResults: Boolean = true): Query = {
new Query(name, sqlContext.sql(sqlText), description, collectResults, Some(sqlText))
}
// The executionTime for the entire query includes the time of type conversion from catalyst to scala.
val executionTime = if (query.collectResults) {
benchmarkMs { dataFrame.rdd.collect() }
} else {
benchmarkMs { dataFrame.rdd.foreach {row => Unit } }
}
val joinTypes = dataFrame.queryExecution.executedPlan.collect {
case k if k.nodeName contains "Join" => k.nodeName
}
val tablesInvolved = dataFrame.queryExecution.logical collect {
case UnresolvedRelation(tableIdentifier, _) => {
// We are ignoring the database name.
tableIdentifier.last
}
}
BenchmarkResult(
name = query.name,
joinTypes = joinTypes,
tables = tablesInvolved,
parsingTime = parsingTime,
analysisTime = analysisTime,
optimizationTime = optimizationTime,
planningTime = planningTime,
executionTime = executionTime,
breakdownResults)
} catch {
case e: Exception =>
throw new RuntimeException(
s"Failed to benchmark query ${query.name}", e)
def apply(
name: String,
dataFrameBuilder: => DataFrame,
description: String): Query = {
new Query(name, dataFrameBuilder, description, true, None)
}
}
}
class Query(
val name: String,
dataFrameBuilder: => DataFrame,
val description: String,
val collectResults: Boolean,
val sqlText: Option[String]) {
val tablesInvolved = dataFrameBuilder.queryExecution.logical collect {
case UnresolvedRelation(tableIdentifier, _) => {
// We are ignoring the database name.
tableIdentifier.last
}
}
def benchmarkMs[A](f: => A): Double = {
val startTime = System.nanoTime()
val ret = f
val endTime = System.nanoTime()
(endTime - startTime).toDouble / 1000000
}
def benchmark(includeBreakdown: Boolean, description: String = "") = {
try {
val dataFrame = dataFrameBuilder
sparkContext.setJobDescription(s"Query: $name, $description")
val queryExecution = dataFrame.queryExecution
// We are not counting the time of ScalaReflection.convertRowToScala.
val parsingTime = benchmarkMs {
queryExecution.logical
}
val analysisTime = benchmarkMs {
queryExecution.analyzed
}
val optimizationTime = benchmarkMs {
queryExecution.optimizedPlan
}
val planningTime = benchmarkMs {
queryExecution.executedPlan
}
val breakdownResults = if (includeBreakdown) {
val depth = queryExecution.executedPlan.treeString.split("\n").size
val physicalOperators = (0 until depth).map(i => (i, queryExecution.executedPlan(i)))
physicalOperators.map {
case (index, node) =>
val executionTime = benchmarkMs {
node.execute().map(_.copy()).foreach(row => Unit)
}
BreakdownResult(node.nodeName, node.simpleString, index, executionTime)
}
} else {
Seq.empty[BreakdownResult]
}
// The executionTime for the entire query includes the time of type conversion from catalyst to scala.
val executionTime = if (collectResults) {
benchmarkMs {
dataFrame.rdd.collect()
}
} else {
benchmarkMs {
dataFrame.rdd.foreach { row => Unit }
}
}
val joinTypes = dataFrame.queryExecution.executedPlan.collect {
case k if k.nodeName contains "Join" => k.nodeName
}
BenchmarkResult(
name = name,
joinTypes = joinTypes,
tables = tablesInvolved,
parsingTime = parsingTime,
analysisTime = analysisTime,
optimizationTime = optimizationTime,
planningTime = planningTime,
executionTime = executionTime,
breakdownResults)
} catch {
case e: Exception =>
throw new RuntimeException(
s"Failed to benchmark query $name", e)
}
}
}
}

View File

@ -127,16 +127,14 @@ case class ExperimentRun(
* is a short string describing the scale of the dataset.
*/
abstract class Dataset(
@transient sqlContext: SQLContext,
@transient val sqlContext: SQLContext,
sparkVersion: String,
dataLocation: String,
tables: Seq[Table],
scaleFactor: String) extends Serializable {
scaleFactor: String) extends Serializable with QuerySet {
val datasetName: String
@transient val sparkContext = sqlContext.sparkContext
def createTablesForTest(tables: Seq[Table]): Seq[TableForTest]
val tablesForTest: Seq[TableForTest] = createTablesForTest(tables)
@ -181,7 +179,7 @@ abstract class Dataset(
/**
* Starts an experiment run with a given set of queries.
* @param queries Queries to be executed.
* @param queriesToRun Queries to be executed.
* @param resultsLocation The location of performance results.
* @param includeBreakdown If it is true, breakdown results of a query will be recorded.
* Setting it to true may significantly increase the time used to
@ -193,15 +191,13 @@ abstract class Dataset(
* track the progress of this experiment run.
*/
def runExperiment(
queries: Seq[Query],
queriesToRun: Seq[Query],
resultsLocation: String,
includeBreakdown: Boolean = false,
iterations: Int = 3,
variations: Seq[Variation[_]] = Seq(Variation("StandardRun", Seq("")) { _ => {} }),
tags: Map[String, String] = Map.empty) = {
val queriesToRun = queries.map(query => QueryForTest(query, includeBreakdown, sqlContext))
class ExperimentStatus {
val currentResults = new collection.mutable.ArrayBuffer[BenchmarkResult]()
val currentRuns = new collection.mutable.ArrayBuffer[ExperimentRun]()
@ -237,7 +233,7 @@ abstract class Dataset(
currentMessages += s"Running query ${q.name} $setup"
currentQuery = q.name
val singleResult = try q.benchmark(setup) :: Nil catch {
val singleResult = try q.benchmark(includeBreakdown, setup) :: Nil catch {
case e: Exception =>
currentMessages += s"Failed to run query ${q.name}: $e"
Nil
@ -287,6 +283,7 @@ abstract class Dataset(
"Running"
}
override def toString =
s"""
|=== $status Experiment ===

View File

@ -16,9 +16,9 @@
package com.databricks.spark.sql.perf.tpcds.queries
import com.databricks.spark.sql.perf.Query
import com.databricks.spark.sql.perf.QuerySet
object ImpalaKitQueries {
trait ImpalaKitQueries extends QuerySet {
// Queries are from
// https://github.com/cloudera/impala-tpcds-kit/tree/master/queries-sql92-modified/queries
val queries = Seq(

View File

@ -16,9 +16,9 @@
package com.databricks.spark.sql.perf.tpcds.queries
import com.databricks.spark.sql.perf.Query
import com.databricks.spark.sql.perf.QuerySet
object SimpleQueries {
trait SimpleQueries extends QuerySet{
val q7Derived = Seq(
("q7-simpleScan",
"""

View File

@ -15,8 +15,3 @@
*/
package com.databricks.spark.sql.perf.tpcds
package object queries {
val impalaKitQueries = ImpalaKitQueries.impalaKitQueries
val q7DerivedQueries = SimpleQueries.q7Derived
}