diff --git a/src/main/scala/com/databricks/spark/sql/perf/Benchmark.scala b/src/main/scala/com/databricks/spark/sql/perf/Benchmark.scala index 1e8d297..639b0cd 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/Benchmark.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/Benchmark.scala @@ -25,6 +25,7 @@ import scala.concurrent.ExecutionContext.Implicits.global import org.apache.spark.sql.{AnalysisException, DataFrame, SQLContext} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.{SparkContext, SparkEnv} import com.databricks.spark.sql.perf.cpu._ @@ -92,20 +93,20 @@ abstract class Benchmark( } /** - * Starts an experiment run with a given set of queries. - * @param queriesToRun a list of queries to be executed. - * @param includeBreakdown If it is true, breakdown results of a query will be recorded. + * Starts an experiment run with a given set of executions to run. + * @param executionsToRun a list of executions to run. + * @param includeBreakdown If it is true, breakdown results of an execution will be recorded. * Setting it to true may significantly increase the time used to - * execute a query. - * @param iterations The number of iterations to run of each query. + * run an execution. + * @param iterations The number of iterations to run of each execution. * @param variations [[Variation]]s used in this run. The cross product of all variations will be - * run for each query * iteration. + * run for each execution * iteration. * @param tags Tags of this run. * @return It returns a ExperimentStatus object that can be used to * track the progress of this experiment run. */ def runExperiment( - queriesToRun: Seq[Query], + executionsToRun: Seq[Benchmarkable], includeBreakdown: Boolean = false, iterations: Int = 3, variations: Seq[Variation[_]] = Seq(Variation("StandardRun", Seq("true")) { _ => {} }), @@ -117,8 +118,8 @@ abstract class Benchmark( val currentMessages = new collection.mutable.ArrayBuffer[String]() // Stats for HTML status message. - @volatile var currentQuery = "" - @volatile var currentPlan = "" + @volatile var currentExecution = "" + @volatile var currentPlan = "" // for queries only @volatile var currentConfig = "" @volatile var failures = 0 @volatile var startTime = 0L @@ -135,29 +136,35 @@ abstract class Benchmark( val timestamp = System.currentTimeMillis() val combinations = cartesianProduct(variations.map(l => (0 until l.options.size).toList).toList) val resultsFuture = Future { - queriesToRun.flatMap { query => - query.newDataFrame().queryExecution.logical.collect { - case UnresolvedRelation(t, _) => t.table + + // If we're running queries, create tables for them + executionsToRun + .collect { case query: Query => query } + .flatMap { query => + query.newDataFrame().queryExecution.logical.collect { + case UnresolvedRelation(t, _) => t.table + } } - }.distinct.foreach { name => - try { - sqlContext.table(name) - currentMessages += s"Table $name exists." - } catch { - case ae: AnalysisException => - val table = allTables - .find(_.name == name) - .getOrElse(sys.error(s"Couldn't read table $name and its not defined as a Benchmark.Table.")) + .distinct + .foreach { name => + try { + sqlContext.table(name) + currentMessages += s"Table $name exists." + } catch { + case ae: AnalysisException => + val table = allTables + .find(_.name == name) + .getOrElse(sys.error(s"Couldn't read table $name and its not defined as a Benchmark.Table.")) - currentMessages += s"Creating table: $name" - table.data - .write - .mode("overwrite") - .saveAsTable(name) + currentMessages += s"Creating table: $name" + table.data + .write + .mode("overwrite") + .saveAsTable(name) + } } - } - + // Run the benchmarks! val results = (1 to iterations).flatMap { i => combinations.map { setup => val currentOptions = variations.asInstanceOf[Seq[Variation[Any]]].zip(setup).map { @@ -172,24 +179,30 @@ abstract class Benchmark( iteration = i, tags = currentOptions.toMap ++ tags, configuration = currentConfiguration, - queriesToRun.flatMap { q => - val setup = s"iteration: $i, ${currentOptions.map { case (k, v) => s"$k=$v"}.mkString(", ")}" - currentMessages += s"Running query ${q.name} $setup" - currentQuery = q.name - currentPlan = q.newDataFrame().queryExecution.executedPlan.toString + executionsToRun.flatMap { q => + val setup = s"iteration: $i, ${currentOptions.map { case (k, v) => s"$k=$v"}.mkString(", ")}" + currentMessages += s"Running execution ${q.name} $setup" + + currentExecution = q.name + currentPlan = q match { + case query: Query => query.newDataFrame().queryExecution.executedPlan.toString() + case _ => "" + } startTime = System.currentTimeMillis() val singleResult = q.benchmark(includeBreakdown, setup, currentMessages) singleResult.failure.foreach { f => failures += 1 - currentMessages += s"Query '${q.name}' failed: ${f.message}" + currentMessages += s"Execution '${q.name}' failed: ${f.message}" + } + singleResult.executionTime.foreach { time => + currentMessages += s"Execution time: ${time / 1000}s" } - singleResult.executionTime.foreach(time => - currentMessages += s"Exec time: ${time / 1000}s") currentResults += singleResult singleResult :: Nil }) + currentRuns += result result @@ -269,28 +282,37 @@ abstract class Benchmark( s"""Permalink: table("sqlPerformance").where('timestamp === ${timestamp}L)""" - def html = + def html: String = { + val maybeQueryPlan: String = + if (currentPlan.nonEmpty) { + s""" + |
+ |${currentPlan.replaceAll("\n", "
")}
+ |
+ """.stripMargin
+ } else {
+ ""
+ }
s"""
|
- |${currentPlan.replaceAll("\n", "
")}
- |
- |
+ |$maybeQueryPlan
|
|${tail()}
|
""".stripMargin
+ }
}
new ExperimentStatus
}
@@ -420,13 +442,88 @@ abstract class Benchmark(
}
}
+ /** A trait to describe things that can be benchmarked. */
+ trait Benchmarkable {
+ val name: String
+ protected val executionMode: ExecutionMode
+
+ final def benchmark(
+ includeBreakdown: Boolean,
+ description: String = "",
+ messages: ArrayBuffer[String]): BenchmarkResult = {
+ sparkContext.setJobDescription(s"Execution: $name, $description")
+ beforeBenchmark()
+ val result = doBenchmark(includeBreakdown, description, messages)
+ afterBenchmark(sqlContext.sparkContext)
+ result
+ }
+
+ protected def beforeBenchmark(): Unit = { }
+
+ private def afterBenchmark(sc: SparkContext): Unit = {
+ // Best-effort clean up of weakly referenced RDDs, shuffles, and broadcasts
+ System.gc()
+ // Remove any leftover blocks that still exist
+ sc.getExecutorStorageStatus
+ .flatMap { status => status.blocks.map { case (bid, _) => bid } }
+ .foreach { bid => SparkEnv.get.blockManager.master.removeBlock(bid) }
+ }
+
+ protected def doBenchmark(
+ includeBreakdown: Boolean,
+ description: String = "",
+ messages: ArrayBuffer[String]): BenchmarkResult
+
+ protected def measureTimeMs[A](f: => A): Double = {
+ val startTime = System.nanoTime()
+ f
+ val endTime = System.nanoTime()
+ (endTime - startTime).toDouble / 1000000
+ }
+ }
+
+ /** A class for benchmarking Spark perf results. */
+ class SparkPerfExecution(
+ override val name: String,
+ parameters: Map[String, String],
+ prepare: () => Unit,
+ run: () => Unit)
+ extends Benchmarkable {
+
+ protected override val executionMode: ExecutionMode = ExecutionMode.SparkPerfResults
+
+ protected override def beforeBenchmark(): Unit = { prepare() }
+
+ protected override def doBenchmark(
+ includeBreakdown: Boolean,
+ description: String = "",
+ messages: ArrayBuffer[String]): BenchmarkResult = {
+ try {
+ val timeMs = measureTimeMs(run())
+ BenchmarkResult(
+ name = name,
+ mode = executionMode.toString,
+ parameters = parameters,
+ executionTime = Some(timeMs))
+ } catch {
+ case e: Exception =>
+ BenchmarkResult(
+ name = name,
+ mode = executionMode.toString,
+ parameters = parameters,
+ failure = Some(Failure(e.getClass.getSimpleName, e.getMessage)))
+ }
+ }
+ }
+
/** Holds one benchmark query and its metadata. */
class Query(
- val name: String,
+ override val name: String,
buildDataFrame: => DataFrame,
val description: String = "",
val sqlText: Option[String] = None,
- val executionMode: ExecutionMode = ExecutionMode.ForeachResults) extends Serializable {
+ override val executionMode: ExecutionMode = ExecutionMode.ForeachResults)
+ extends Benchmarkable with Serializable {
override def toString =
s"""
@@ -443,32 +540,24 @@ abstract class Benchmark(
def newDataFrame() = buildDataFrame
- def benchmarkMs[A](f: => A): Double = {
- val startTime = System.nanoTime()
- val ret = f
- val endTime = System.nanoTime()
- (endTime - startTime).toDouble / 1000000
- }
-
- def benchmark(
+ protected override def doBenchmark(
includeBreakdown: Boolean,
description: String = "",
- messages: ArrayBuffer[String]) = {
+ messages: ArrayBuffer[String]): BenchmarkResult = {
try {
val dataFrame = buildDataFrame
- sparkContext.setJobDescription(s"Query: $name, $description")
val queryExecution = dataFrame.queryExecution
// We are not counting the time of ScalaReflection.convertRowToScala.
- val parsingTime = benchmarkMs {
+ val parsingTime = measureTimeMs {
queryExecution.logical
}
- val analysisTime = benchmarkMs {
+ val analysisTime = measureTimeMs {
queryExecution.analyzed
}
- val optimizationTime = benchmarkMs {
+ val optimizationTime = measureTimeMs {
queryExecution.optimizedPlan
}
- val planningTime = benchmarkMs {
+ val planningTime = measureTimeMs {
queryExecution.executedPlan
}
@@ -482,7 +571,7 @@ abstract class Benchmark(
case (index, node) =>
messages += s"Breakdown: ${node.simpleString}"
val newNode = buildDataFrame.queryExecution.executedPlan(index)
- val executionTime = benchmarkMs {
+ val executionTime = measureTimeMs {
newNode.execute().foreach((row: Any) => Unit)
}
timeMap += ((index, executionTime))
@@ -509,7 +598,7 @@ abstract class Benchmark(
// The executionTime for the entire query includes the time of type conversion
// from catalyst to scala.
var result: Option[Long] = None
- val executionTime = benchmarkMs {
+ val executionTime = measureTimeMs {
executionMode match {
case ExecutionMode.CollectResults => dataFrame.rdd.collect()
case ExecutionMode.ForeachResults => dataFrame.rdd.foreach { row => Unit }
diff --git a/src/main/scala/com/databricks/spark/sql/perf/results.scala b/src/main/scala/com/databricks/spark/sql/perf/results.scala
index e902109..80e2e80 100644
--- a/src/main/scala/com/databricks/spark/sql/perf/results.scala
+++ b/src/main/scala/com/databricks/spark/sql/perf/results.scala
@@ -49,6 +49,7 @@ case class BenchmarkConfiguration(
* The result of a query.
* @param name The name of the query.
* @param mode The ExecutionMode of this run.
+ * @param parameters Additional parameters that describe this query.
* @param joinTypes The type of join operations in the query.
* @param tables The tables involved in the query.
* @param parsingTime The time used to parse the query.
@@ -64,6 +65,7 @@ case class BenchmarkConfiguration(
case class BenchmarkResult(
name: String,
mode: String,
+ parameters: Map[String, String] = Map.empty[String, String],
joinTypes: Seq[String] = Nil,
tables: Seq[String] = Nil,
parsingTime: Option[Double] = None,