From 42a415e8d46f3354dc7b0b37c3a85061f62a777f Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 22 Feb 2016 18:23:06 -0800 Subject: [PATCH] Extract Query class from Benchmark into its own top-level class and make SparkContext field transient This patch extracts `Query` into its own top-level class and makes its `sparkContext` field transient in order to fix `NotSerializableException`s. Author: Josh Rosen Closes #53 from JoshRosen/make-query-into-top-level-class. --- .../databricks/spark/sql/perf/Benchmark.scala | 157 +--------------- .../spark/sql/perf/Benchmarkable.scala | 4 +- .../com/databricks/spark/sql/perf/Query.scala | 172 ++++++++++++++++++ .../sql/perf/tpcds/TPCDS_1_4_Queries.scala | 2 +- 4 files changed, 181 insertions(+), 154 deletions(-) create mode 100644 src/main/scala/com/databricks/spark/sql/perf/Query.scala diff --git a/src/main/scala/com/databricks/spark/sql/perf/Benchmark.scala b/src/main/scala/com/databricks/spark/sql/perf/Benchmark.scala index 1676f14..436b8bf 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/Benchmark.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/Benchmark.scala @@ -16,20 +16,17 @@ package com.databricks.spark.sql.perf -import java.util.UUID - -import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent._ import scala.concurrent.duration._ -import scala.concurrent.ExecutionContext.Implicits.global +import scala.language.implicitConversions import scala.util.Try import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Dataset, AnalysisException, DataFrame, SQLContext} -import org.apache.spark.sql.catalyst.analysis.{NoSuchTableException, UnresolvedRelation} -import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.{SparkContext, SparkEnv} +import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.SparkContext import com.databricks.spark.sql.perf.cpu._ @@ -106,6 +103,7 @@ abstract class Benchmark( /** * Starts an experiment run with a given set of executions to run. + * * @param executionsToRun a list of executions to run. * @param includeBreakdown If it is true, breakdown results of an execution will be recorded. * Setting it to true may significantly increase the time used to @@ -505,147 +503,4 @@ abstract class Benchmark( } } - /** Holds one benchmark query and its metadata. */ - class Query( - override val name: String, - buildDataFrame: => DataFrame, - val description: String = "", - val sqlText: Option[String] = None, - override val executionMode: ExecutionMode = ExecutionMode.ForeachResults) - extends Benchmarkable with Serializable { - - override def toString: String = { - try { - s""" - |== Query: $name == - |${buildDataFrame.queryExecution.analyzed} - """.stripMargin - } catch { - case e: Exception => - s""" - |== Query: $name == - | Can't be analyzed: $e - | - | $description - """.stripMargin - } - } - - lazy val tablesInvolved = buildDataFrame.queryExecution.logical collect { - case UnresolvedRelation(tableIdentifier, _) => { - // We are ignoring the database name. - tableIdentifier.table - } - } - - def newDataFrame() = buildDataFrame - - protected override def doBenchmark( - includeBreakdown: Boolean, - description: String = "", - messages: ArrayBuffer[String]): BenchmarkResult = { - try { - val dataFrame = buildDataFrame - val queryExecution = dataFrame.queryExecution - // We are not counting the time of ScalaReflection.convertRowToScala. - val parsingTime = measureTimeMs { - queryExecution.logical - } - val analysisTime = measureTimeMs { - queryExecution.analyzed - } - val optimizationTime = measureTimeMs { - queryExecution.optimizedPlan - } - val planningTime = measureTimeMs { - queryExecution.executedPlan - } - - val breakdownResults = if (includeBreakdown) { - val depth = queryExecution.executedPlan.collect { case p: SparkPlan => p }.size - val physicalOperators = (0 until depth).map(i => (i, queryExecution.executedPlan(i))) - val indexMap = physicalOperators.map { case (index, op) => (op, index) }.toMap - val timeMap = new mutable.HashMap[Int, Double] - - physicalOperators.reverse.map { - case (index, node) => - messages += s"Breakdown: ${node.simpleString}" - val newNode = buildDataFrame.queryExecution.executedPlan(index) - val executionTime = measureTimeMs { - newNode.execute().foreach((row: Any) => Unit) - } - timeMap += ((index, executionTime)) - - val childIndexes = node.children.map(indexMap) - val childTime = childIndexes.map(timeMap).sum - - messages += s"Breakdown time: $executionTime (+${executionTime - childTime})" - - BreakdownResult( - node.nodeName, - node.simpleString.replaceAll("#\\d+", ""), - index, - childIndexes, - executionTime, - executionTime - childTime) - } - } else { - Seq.empty[BreakdownResult] - } - - // The executionTime for the entire query includes the time of type conversion from catalyst - // to scala. - // The executionTime for the entire query includes the time of type conversion - // from catalyst to scala. - var result: Option[Long] = None - val executionTime = measureTimeMs { - executionMode match { - case ExecutionMode.CollectResults => dataFrame.rdd.collect() - case ExecutionMode.ForeachResults => dataFrame.rdd.foreach { row => Unit } - case ExecutionMode.WriteParquet(location) => - dataFrame.write.parquet(s"$location/$name.parquet") - case ExecutionMode.HashResults => - val columnStr = dataFrame.schema.map(_.name).mkString(",") - // SELECT SUM(HASH(col1, col2, ...)) FROM (benchmark query) - val row = - dataFrame - .selectExpr(s"hash($columnStr) as hashValue") - .groupBy() - .sum("hashValue") - .head() - result = if (row.isNullAt(0)) None else Some(row.getLong(0)) - } - } - - val joinTypes = dataFrame.queryExecution.executedPlan.collect { - case k if k.nodeName contains "Join" => k.nodeName - } - - BenchmarkResult( - name = name, - mode = executionMode.toString, - joinTypes = joinTypes, - tables = tablesInvolved, - parsingTime = parsingTime, - analysisTime = analysisTime, - optimizationTime = optimizationTime, - planningTime = planningTime, - executionTime = executionTime, - result = result, - queryExecution = dataFrame.queryExecution.toString, - breakDown = breakdownResults) - } catch { - case e: Exception => - BenchmarkResult( - name = name, - mode = executionMode.toString, - failure = Failure(e.getClass.getName, e.getMessage)) - } - } - - /** Change the ExecutionMode of this Query to HashResults, which is used to check the query result. */ - def checkResult: Query = { - new Query(name, buildDataFrame, description, sqlText, ExecutionMode.HashResults) - } - } } diff --git a/src/main/scala/com/databricks/spark/sql/perf/Benchmarkable.scala b/src/main/scala/com/databricks/spark/sql/perf/Benchmarkable.scala index 2271a97..142ff3f 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/Benchmarkable.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/Benchmarkable.scala @@ -25,8 +25,8 @@ import scala.collection.mutable.ArrayBuffer /** A trait to describe things that can be benchmarked. */ trait Benchmarkable { - val sqlContext = SQLContext.getOrCreate(SparkContext.getOrCreate()) - val sparkContext = sqlContext.sparkContext + @transient protected[this] val sqlContext = SQLContext.getOrCreate(SparkContext.getOrCreate()) + @transient protected[this] val sparkContext = sqlContext.sparkContext val name: String protected val executionMode: ExecutionMode diff --git a/src/main/scala/com/databricks/spark/sql/perf/Query.scala b/src/main/scala/com/databricks/spark/sql/perf/Query.scala new file mode 100644 index 0000000..498835f --- /dev/null +++ b/src/main/scala/com/databricks/spark/sql/perf/Query.scala @@ -0,0 +1,172 @@ +/* + * Copyright 2015 Databricks Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.databricks.spark.sql.perf + +import scala.language.implicitConversions +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.execution.SparkPlan + + +/** Holds one benchmark query and its metadata. */ +class Query( + override val name: String, + buildDataFrame: => DataFrame, + val description: String = "", + val sqlText: Option[String] = None, + override val executionMode: ExecutionMode = ExecutionMode.ForeachResults) + extends Benchmarkable with Serializable { + + private implicit def toOption[A](a: A): Option[A] = Option(a) + + override def toString: String = { + try { + s""" + |== Query: $name == + |${buildDataFrame.queryExecution.analyzed} + """.stripMargin + } catch { + case e: Exception => + s""" + |== Query: $name == + | Can't be analyzed: $e + | + | $description + """.stripMargin + } + } + + lazy val tablesInvolved = buildDataFrame.queryExecution.logical collect { + case UnresolvedRelation(tableIdentifier, _) => { + // We are ignoring the database name. + tableIdentifier.table + } + } + + def newDataFrame() = buildDataFrame + + protected override def doBenchmark( + includeBreakdown: Boolean, + description: String = "", + messages: ArrayBuffer[String]): BenchmarkResult = { + try { + val dataFrame = buildDataFrame + val queryExecution = dataFrame.queryExecution + // We are not counting the time of ScalaReflection.convertRowToScala. + val parsingTime = measureTimeMs { + queryExecution.logical + } + val analysisTime = measureTimeMs { + queryExecution.analyzed + } + val optimizationTime = measureTimeMs { + queryExecution.optimizedPlan + } + val planningTime = measureTimeMs { + queryExecution.executedPlan + } + + val breakdownResults = if (includeBreakdown) { + val depth = queryExecution.executedPlan.collect { case p: SparkPlan => p }.size + val physicalOperators = (0 until depth).map(i => (i, queryExecution.executedPlan(i))) + val indexMap = physicalOperators.map { case (index, op) => (op, index) }.toMap + val timeMap = new mutable.HashMap[Int, Double] + + physicalOperators.reverse.map { + case (index, node) => + messages += s"Breakdown: ${node.simpleString}" + val newNode = buildDataFrame.queryExecution.executedPlan(index) + val executionTime = measureTimeMs { + newNode.execute().foreach((row: Any) => Unit) + } + timeMap += ((index, executionTime)) + + val childIndexes = node.children.map(indexMap) + val childTime = childIndexes.map(timeMap).sum + + messages += s"Breakdown time: $executionTime (+${executionTime - childTime})" + + BreakdownResult( + node.nodeName, + node.simpleString.replaceAll("#\\d+", ""), + index, + childIndexes, + executionTime, + executionTime - childTime) + } + } else { + Seq.empty[BreakdownResult] + } + + // The executionTime for the entire query includes the time of type conversion from catalyst + // to scala. + // The executionTime for the entire query includes the time of type conversion + // from catalyst to scala. + var result: Option[Long] = None + val executionTime = measureTimeMs { + executionMode match { + case ExecutionMode.CollectResults => dataFrame.rdd.collect() + case ExecutionMode.ForeachResults => dataFrame.rdd.foreach { row => Unit } + case ExecutionMode.WriteParquet(location) => + dataFrame.write.parquet(s"$location/$name.parquet") + case ExecutionMode.HashResults => + val columnStr = dataFrame.schema.map(_.name).mkString(",") + // SELECT SUM(HASH(col1, col2, ...)) FROM (benchmark query) + val row = + dataFrame + .selectExpr(s"hash($columnStr) as hashValue") + .groupBy() + .sum("hashValue") + .head() + result = if (row.isNullAt(0)) None else Some(row.getLong(0)) + } + } + + val joinTypes = dataFrame.queryExecution.executedPlan.collect { + case k if k.nodeName contains "Join" => k.nodeName + } + + BenchmarkResult( + name = name, + mode = executionMode.toString, + joinTypes = joinTypes, + tables = tablesInvolved, + parsingTime = parsingTime, + analysisTime = analysisTime, + optimizationTime = optimizationTime, + planningTime = planningTime, + executionTime = executionTime, + result = result, + queryExecution = dataFrame.queryExecution.toString, + breakDown = breakdownResults) + } catch { + case e: Exception => + BenchmarkResult( + name = name, + mode = executionMode.toString, + failure = Failure(e.getClass.getName, e.getMessage)) + } + } + + /** Change the ExecutionMode of this Query to HashResults, which is used to check the query result. */ + def checkResult: Query = { + new Query(name, buildDataFrame, description, sqlText, ExecutionMode.HashResults) + } +} \ No newline at end of file diff --git a/src/main/scala/com/databricks/spark/sql/perf/tpcds/TPCDS_1_4_Queries.scala b/src/main/scala/com/databricks/spark/sql/perf/tpcds/TPCDS_1_4_Queries.scala index 9042eec..63481f7 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/tpcds/TPCDS_1_4_Queries.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/tpcds/TPCDS_1_4_Queries.scala @@ -16,7 +16,7 @@ package com.databricks.spark.sql.perf.tpcds -import com.databricks.spark.sql.perf.{ExecutionMode, Benchmark} +import com.databricks.spark.sql.perf.{Benchmark, ExecutionMode, Query} /** * This implements the official TPCDS v1.4 queries with only cosmetic modifications