Extract Query class from Benchmark into its own top-level class and make SparkContext field transient
This patch extracts `Query` into its own top-level class and makes its `sparkContext` field transient in order to fix `NotSerializableException`s. Author: Josh Rosen <rosenville@gmail.com> Closes #53 from JoshRosen/make-query-into-top-level-class.
This commit is contained in:
parent
7e38b77c50
commit
42a415e8d4
@ -16,20 +16,17 @@
|
||||
|
||||
package com.databricks.spark.sql.perf
|
||||
|
||||
import java.util.UUID
|
||||
|
||||
import scala.collection.mutable
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
import scala.concurrent.ExecutionContext.Implicits.global
|
||||
import scala.concurrent._
|
||||
import scala.concurrent.duration._
|
||||
import scala.concurrent.ExecutionContext.Implicits.global
|
||||
import scala.language.implicitConversions
|
||||
import scala.util.Try
|
||||
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.{Dataset, AnalysisException, DataFrame, SQLContext}
|
||||
import org.apache.spark.sql.catalyst.analysis.{NoSuchTableException, UnresolvedRelation}
|
||||
import org.apache.spark.sql.execution.SparkPlan
|
||||
import org.apache.spark.{SparkContext, SparkEnv}
|
||||
import org.apache.spark.sql.{DataFrame, SQLContext}
|
||||
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
|
||||
import org.apache.spark.SparkContext
|
||||
|
||||
import com.databricks.spark.sql.perf.cpu._
|
||||
|
||||
@ -106,6 +103,7 @@ abstract class Benchmark(
|
||||
|
||||
/**
|
||||
* Starts an experiment run with a given set of executions to run.
|
||||
*
|
||||
* @param executionsToRun a list of executions to run.
|
||||
* @param includeBreakdown If it is true, breakdown results of an execution will be recorded.
|
||||
* Setting it to true may significantly increase the time used to
|
||||
@ -505,147 +503,4 @@ abstract class Benchmark(
|
||||
}
|
||||
}
|
||||
|
||||
/** Holds one benchmark query and its metadata. */
|
||||
class Query(
|
||||
override val name: String,
|
||||
buildDataFrame: => DataFrame,
|
||||
val description: String = "",
|
||||
val sqlText: Option[String] = None,
|
||||
override val executionMode: ExecutionMode = ExecutionMode.ForeachResults)
|
||||
extends Benchmarkable with Serializable {
|
||||
|
||||
override def toString: String = {
|
||||
try {
|
||||
s"""
|
||||
|== Query: $name ==
|
||||
|${buildDataFrame.queryExecution.analyzed}
|
||||
""".stripMargin
|
||||
} catch {
|
||||
case e: Exception =>
|
||||
s"""
|
||||
|== Query: $name ==
|
||||
| Can't be analyzed: $e
|
||||
|
|
||||
| $description
|
||||
""".stripMargin
|
||||
}
|
||||
}
|
||||
|
||||
lazy val tablesInvolved = buildDataFrame.queryExecution.logical collect {
|
||||
case UnresolvedRelation(tableIdentifier, _) => {
|
||||
// We are ignoring the database name.
|
||||
tableIdentifier.table
|
||||
}
|
||||
}
|
||||
|
||||
def newDataFrame() = buildDataFrame
|
||||
|
||||
protected override def doBenchmark(
|
||||
includeBreakdown: Boolean,
|
||||
description: String = "",
|
||||
messages: ArrayBuffer[String]): BenchmarkResult = {
|
||||
try {
|
||||
val dataFrame = buildDataFrame
|
||||
val queryExecution = dataFrame.queryExecution
|
||||
// We are not counting the time of ScalaReflection.convertRowToScala.
|
||||
val parsingTime = measureTimeMs {
|
||||
queryExecution.logical
|
||||
}
|
||||
val analysisTime = measureTimeMs {
|
||||
queryExecution.analyzed
|
||||
}
|
||||
val optimizationTime = measureTimeMs {
|
||||
queryExecution.optimizedPlan
|
||||
}
|
||||
val planningTime = measureTimeMs {
|
||||
queryExecution.executedPlan
|
||||
}
|
||||
|
||||
val breakdownResults = if (includeBreakdown) {
|
||||
val depth = queryExecution.executedPlan.collect { case p: SparkPlan => p }.size
|
||||
val physicalOperators = (0 until depth).map(i => (i, queryExecution.executedPlan(i)))
|
||||
val indexMap = physicalOperators.map { case (index, op) => (op, index) }.toMap
|
||||
val timeMap = new mutable.HashMap[Int, Double]
|
||||
|
||||
physicalOperators.reverse.map {
|
||||
case (index, node) =>
|
||||
messages += s"Breakdown: ${node.simpleString}"
|
||||
val newNode = buildDataFrame.queryExecution.executedPlan(index)
|
||||
val executionTime = measureTimeMs {
|
||||
newNode.execute().foreach((row: Any) => Unit)
|
||||
}
|
||||
timeMap += ((index, executionTime))
|
||||
|
||||
val childIndexes = node.children.map(indexMap)
|
||||
val childTime = childIndexes.map(timeMap).sum
|
||||
|
||||
messages += s"Breakdown time: $executionTime (+${executionTime - childTime})"
|
||||
|
||||
BreakdownResult(
|
||||
node.nodeName,
|
||||
node.simpleString.replaceAll("#\\d+", ""),
|
||||
index,
|
||||
childIndexes,
|
||||
executionTime,
|
||||
executionTime - childTime)
|
||||
}
|
||||
} else {
|
||||
Seq.empty[BreakdownResult]
|
||||
}
|
||||
|
||||
// The executionTime for the entire query includes the time of type conversion from catalyst
|
||||
// to scala.
|
||||
// The executionTime for the entire query includes the time of type conversion
|
||||
// from catalyst to scala.
|
||||
var result: Option[Long] = None
|
||||
val executionTime = measureTimeMs {
|
||||
executionMode match {
|
||||
case ExecutionMode.CollectResults => dataFrame.rdd.collect()
|
||||
case ExecutionMode.ForeachResults => dataFrame.rdd.foreach { row => Unit }
|
||||
case ExecutionMode.WriteParquet(location) =>
|
||||
dataFrame.write.parquet(s"$location/$name.parquet")
|
||||
case ExecutionMode.HashResults =>
|
||||
val columnStr = dataFrame.schema.map(_.name).mkString(",")
|
||||
// SELECT SUM(HASH(col1, col2, ...)) FROM (benchmark query)
|
||||
val row =
|
||||
dataFrame
|
||||
.selectExpr(s"hash($columnStr) as hashValue")
|
||||
.groupBy()
|
||||
.sum("hashValue")
|
||||
.head()
|
||||
result = if (row.isNullAt(0)) None else Some(row.getLong(0))
|
||||
}
|
||||
}
|
||||
|
||||
val joinTypes = dataFrame.queryExecution.executedPlan.collect {
|
||||
case k if k.nodeName contains "Join" => k.nodeName
|
||||
}
|
||||
|
||||
BenchmarkResult(
|
||||
name = name,
|
||||
mode = executionMode.toString,
|
||||
joinTypes = joinTypes,
|
||||
tables = tablesInvolved,
|
||||
parsingTime = parsingTime,
|
||||
analysisTime = analysisTime,
|
||||
optimizationTime = optimizationTime,
|
||||
planningTime = planningTime,
|
||||
executionTime = executionTime,
|
||||
result = result,
|
||||
queryExecution = dataFrame.queryExecution.toString,
|
||||
breakDown = breakdownResults)
|
||||
} catch {
|
||||
case e: Exception =>
|
||||
BenchmarkResult(
|
||||
name = name,
|
||||
mode = executionMode.toString,
|
||||
failure = Failure(e.getClass.getName, e.getMessage))
|
||||
}
|
||||
}
|
||||
|
||||
/** Change the ExecutionMode of this Query to HashResults, which is used to check the query result. */
|
||||
def checkResult: Query = {
|
||||
new Query(name, buildDataFrame, description, sqlText, ExecutionMode.HashResults)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -25,8 +25,8 @@ import scala.collection.mutable.ArrayBuffer
|
||||
|
||||
/** A trait to describe things that can be benchmarked. */
|
||||
trait Benchmarkable {
|
||||
val sqlContext = SQLContext.getOrCreate(SparkContext.getOrCreate())
|
||||
val sparkContext = sqlContext.sparkContext
|
||||
@transient protected[this] val sqlContext = SQLContext.getOrCreate(SparkContext.getOrCreate())
|
||||
@transient protected[this] val sparkContext = sqlContext.sparkContext
|
||||
|
||||
val name: String
|
||||
protected val executionMode: ExecutionMode
|
||||
|
||||
172
src/main/scala/com/databricks/spark/sql/perf/Query.scala
Normal file
172
src/main/scala/com/databricks/spark/sql/perf/Query.scala
Normal file
@ -0,0 +1,172 @@
|
||||
/*
|
||||
* Copyright 2015 Databricks Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.databricks.spark.sql.perf
|
||||
|
||||
import scala.language.implicitConversions
|
||||
import scala.collection.mutable
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
|
||||
import org.apache.spark.sql.DataFrame
|
||||
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
|
||||
import org.apache.spark.sql.execution.SparkPlan
|
||||
|
||||
|
||||
/** Holds one benchmark query and its metadata. */
|
||||
class Query(
|
||||
override val name: String,
|
||||
buildDataFrame: => DataFrame,
|
||||
val description: String = "",
|
||||
val sqlText: Option[String] = None,
|
||||
override val executionMode: ExecutionMode = ExecutionMode.ForeachResults)
|
||||
extends Benchmarkable with Serializable {
|
||||
|
||||
private implicit def toOption[A](a: A): Option[A] = Option(a)
|
||||
|
||||
override def toString: String = {
|
||||
try {
|
||||
s"""
|
||||
|== Query: $name ==
|
||||
|${buildDataFrame.queryExecution.analyzed}
|
||||
""".stripMargin
|
||||
} catch {
|
||||
case e: Exception =>
|
||||
s"""
|
||||
|== Query: $name ==
|
||||
| Can't be analyzed: $e
|
||||
|
|
||||
| $description
|
||||
""".stripMargin
|
||||
}
|
||||
}
|
||||
|
||||
lazy val tablesInvolved = buildDataFrame.queryExecution.logical collect {
|
||||
case UnresolvedRelation(tableIdentifier, _) => {
|
||||
// We are ignoring the database name.
|
||||
tableIdentifier.table
|
||||
}
|
||||
}
|
||||
|
||||
def newDataFrame() = buildDataFrame
|
||||
|
||||
protected override def doBenchmark(
|
||||
includeBreakdown: Boolean,
|
||||
description: String = "",
|
||||
messages: ArrayBuffer[String]): BenchmarkResult = {
|
||||
try {
|
||||
val dataFrame = buildDataFrame
|
||||
val queryExecution = dataFrame.queryExecution
|
||||
// We are not counting the time of ScalaReflection.convertRowToScala.
|
||||
val parsingTime = measureTimeMs {
|
||||
queryExecution.logical
|
||||
}
|
||||
val analysisTime = measureTimeMs {
|
||||
queryExecution.analyzed
|
||||
}
|
||||
val optimizationTime = measureTimeMs {
|
||||
queryExecution.optimizedPlan
|
||||
}
|
||||
val planningTime = measureTimeMs {
|
||||
queryExecution.executedPlan
|
||||
}
|
||||
|
||||
val breakdownResults = if (includeBreakdown) {
|
||||
val depth = queryExecution.executedPlan.collect { case p: SparkPlan => p }.size
|
||||
val physicalOperators = (0 until depth).map(i => (i, queryExecution.executedPlan(i)))
|
||||
val indexMap = physicalOperators.map { case (index, op) => (op, index) }.toMap
|
||||
val timeMap = new mutable.HashMap[Int, Double]
|
||||
|
||||
physicalOperators.reverse.map {
|
||||
case (index, node) =>
|
||||
messages += s"Breakdown: ${node.simpleString}"
|
||||
val newNode = buildDataFrame.queryExecution.executedPlan(index)
|
||||
val executionTime = measureTimeMs {
|
||||
newNode.execute().foreach((row: Any) => Unit)
|
||||
}
|
||||
timeMap += ((index, executionTime))
|
||||
|
||||
val childIndexes = node.children.map(indexMap)
|
||||
val childTime = childIndexes.map(timeMap).sum
|
||||
|
||||
messages += s"Breakdown time: $executionTime (+${executionTime - childTime})"
|
||||
|
||||
BreakdownResult(
|
||||
node.nodeName,
|
||||
node.simpleString.replaceAll("#\\d+", ""),
|
||||
index,
|
||||
childIndexes,
|
||||
executionTime,
|
||||
executionTime - childTime)
|
||||
}
|
||||
} else {
|
||||
Seq.empty[BreakdownResult]
|
||||
}
|
||||
|
||||
// The executionTime for the entire query includes the time of type conversion from catalyst
|
||||
// to scala.
|
||||
// The executionTime for the entire query includes the time of type conversion
|
||||
// from catalyst to scala.
|
||||
var result: Option[Long] = None
|
||||
val executionTime = measureTimeMs {
|
||||
executionMode match {
|
||||
case ExecutionMode.CollectResults => dataFrame.rdd.collect()
|
||||
case ExecutionMode.ForeachResults => dataFrame.rdd.foreach { row => Unit }
|
||||
case ExecutionMode.WriteParquet(location) =>
|
||||
dataFrame.write.parquet(s"$location/$name.parquet")
|
||||
case ExecutionMode.HashResults =>
|
||||
val columnStr = dataFrame.schema.map(_.name).mkString(",")
|
||||
// SELECT SUM(HASH(col1, col2, ...)) FROM (benchmark query)
|
||||
val row =
|
||||
dataFrame
|
||||
.selectExpr(s"hash($columnStr) as hashValue")
|
||||
.groupBy()
|
||||
.sum("hashValue")
|
||||
.head()
|
||||
result = if (row.isNullAt(0)) None else Some(row.getLong(0))
|
||||
}
|
||||
}
|
||||
|
||||
val joinTypes = dataFrame.queryExecution.executedPlan.collect {
|
||||
case k if k.nodeName contains "Join" => k.nodeName
|
||||
}
|
||||
|
||||
BenchmarkResult(
|
||||
name = name,
|
||||
mode = executionMode.toString,
|
||||
joinTypes = joinTypes,
|
||||
tables = tablesInvolved,
|
||||
parsingTime = parsingTime,
|
||||
analysisTime = analysisTime,
|
||||
optimizationTime = optimizationTime,
|
||||
planningTime = planningTime,
|
||||
executionTime = executionTime,
|
||||
result = result,
|
||||
queryExecution = dataFrame.queryExecution.toString,
|
||||
breakDown = breakdownResults)
|
||||
} catch {
|
||||
case e: Exception =>
|
||||
BenchmarkResult(
|
||||
name = name,
|
||||
mode = executionMode.toString,
|
||||
failure = Failure(e.getClass.getName, e.getMessage))
|
||||
}
|
||||
}
|
||||
|
||||
/** Change the ExecutionMode of this Query to HashResults, which is used to check the query result. */
|
||||
def checkResult: Query = {
|
||||
new Query(name, buildDataFrame, description, sqlText, ExecutionMode.HashResults)
|
||||
}
|
||||
}
|
||||
@ -16,7 +16,7 @@
|
||||
|
||||
package com.databricks.spark.sql.perf.tpcds
|
||||
|
||||
import com.databricks.spark.sql.perf.{ExecutionMode, Benchmark}
|
||||
import com.databricks.spark.sql.perf.{Benchmark, ExecutionMode, Query}
|
||||
|
||||
/**
|
||||
* This implements the official TPCDS v1.4 queries with only cosmetic modifications
|
||||
|
||||
Loading…
Reference in New Issue
Block a user