Extract Query class from Benchmark into its own top-level class and make SparkContext field transient

This patch extracts `Query` into its own top-level class and makes its `sparkContext` field transient in order to fix `NotSerializableException`s.

Author: Josh Rosen <rosenville@gmail.com>

Closes #53 from JoshRosen/make-query-into-top-level-class.
This commit is contained in:
Josh Rosen 2016-02-22 18:23:06 -08:00
parent 7e38b77c50
commit 42a415e8d4
4 changed files with 181 additions and 154 deletions

View File

@ -16,20 +16,17 @@
package com.databricks.spark.sql.perf
import java.util.UUID
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent._
import scala.concurrent.duration._
import scala.concurrent.ExecutionContext.Implicits.global
import scala.language.implicitConversions
import scala.util.Try
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Dataset, AnalysisException, DataFrame, SQLContext}
import org.apache.spark.sql.catalyst.analysis.{NoSuchTableException, UnresolvedRelation}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.{SparkContext, SparkEnv}
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.SparkContext
import com.databricks.spark.sql.perf.cpu._
@ -106,6 +103,7 @@ abstract class Benchmark(
/**
* Starts an experiment run with a given set of executions to run.
*
* @param executionsToRun a list of executions to run.
* @param includeBreakdown If it is true, breakdown results of an execution will be recorded.
* Setting it to true may significantly increase the time used to
@ -505,147 +503,4 @@ abstract class Benchmark(
}
}
/** Holds one benchmark query and its metadata. */
class Query(
override val name: String,
buildDataFrame: => DataFrame,
val description: String = "",
val sqlText: Option[String] = None,
override val executionMode: ExecutionMode = ExecutionMode.ForeachResults)
extends Benchmarkable with Serializable {
override def toString: String = {
try {
s"""
|== Query: $name ==
|${buildDataFrame.queryExecution.analyzed}
""".stripMargin
} catch {
case e: Exception =>
s"""
|== Query: $name ==
| Can't be analyzed: $e
|
| $description
""".stripMargin
}
}
lazy val tablesInvolved = buildDataFrame.queryExecution.logical collect {
case UnresolvedRelation(tableIdentifier, _) => {
// We are ignoring the database name.
tableIdentifier.table
}
}
def newDataFrame() = buildDataFrame
protected override def doBenchmark(
includeBreakdown: Boolean,
description: String = "",
messages: ArrayBuffer[String]): BenchmarkResult = {
try {
val dataFrame = buildDataFrame
val queryExecution = dataFrame.queryExecution
// We are not counting the time of ScalaReflection.convertRowToScala.
val parsingTime = measureTimeMs {
queryExecution.logical
}
val analysisTime = measureTimeMs {
queryExecution.analyzed
}
val optimizationTime = measureTimeMs {
queryExecution.optimizedPlan
}
val planningTime = measureTimeMs {
queryExecution.executedPlan
}
val breakdownResults = if (includeBreakdown) {
val depth = queryExecution.executedPlan.collect { case p: SparkPlan => p }.size
val physicalOperators = (0 until depth).map(i => (i, queryExecution.executedPlan(i)))
val indexMap = physicalOperators.map { case (index, op) => (op, index) }.toMap
val timeMap = new mutable.HashMap[Int, Double]
physicalOperators.reverse.map {
case (index, node) =>
messages += s"Breakdown: ${node.simpleString}"
val newNode = buildDataFrame.queryExecution.executedPlan(index)
val executionTime = measureTimeMs {
newNode.execute().foreach((row: Any) => Unit)
}
timeMap += ((index, executionTime))
val childIndexes = node.children.map(indexMap)
val childTime = childIndexes.map(timeMap).sum
messages += s"Breakdown time: $executionTime (+${executionTime - childTime})"
BreakdownResult(
node.nodeName,
node.simpleString.replaceAll("#\\d+", ""),
index,
childIndexes,
executionTime,
executionTime - childTime)
}
} else {
Seq.empty[BreakdownResult]
}
// The executionTime for the entire query includes the time of type conversion from catalyst
// to scala.
// The executionTime for the entire query includes the time of type conversion
// from catalyst to scala.
var result: Option[Long] = None
val executionTime = measureTimeMs {
executionMode match {
case ExecutionMode.CollectResults => dataFrame.rdd.collect()
case ExecutionMode.ForeachResults => dataFrame.rdd.foreach { row => Unit }
case ExecutionMode.WriteParquet(location) =>
dataFrame.write.parquet(s"$location/$name.parquet")
case ExecutionMode.HashResults =>
val columnStr = dataFrame.schema.map(_.name).mkString(",")
// SELECT SUM(HASH(col1, col2, ...)) FROM (benchmark query)
val row =
dataFrame
.selectExpr(s"hash($columnStr) as hashValue")
.groupBy()
.sum("hashValue")
.head()
result = if (row.isNullAt(0)) None else Some(row.getLong(0))
}
}
val joinTypes = dataFrame.queryExecution.executedPlan.collect {
case k if k.nodeName contains "Join" => k.nodeName
}
BenchmarkResult(
name = name,
mode = executionMode.toString,
joinTypes = joinTypes,
tables = tablesInvolved,
parsingTime = parsingTime,
analysisTime = analysisTime,
optimizationTime = optimizationTime,
planningTime = planningTime,
executionTime = executionTime,
result = result,
queryExecution = dataFrame.queryExecution.toString,
breakDown = breakdownResults)
} catch {
case e: Exception =>
BenchmarkResult(
name = name,
mode = executionMode.toString,
failure = Failure(e.getClass.getName, e.getMessage))
}
}
/** Change the ExecutionMode of this Query to HashResults, which is used to check the query result. */
def checkResult: Query = {
new Query(name, buildDataFrame, description, sqlText, ExecutionMode.HashResults)
}
}
}

View File

@ -25,8 +25,8 @@ import scala.collection.mutable.ArrayBuffer
/** A trait to describe things that can be benchmarked. */
trait Benchmarkable {
val sqlContext = SQLContext.getOrCreate(SparkContext.getOrCreate())
val sparkContext = sqlContext.sparkContext
@transient protected[this] val sqlContext = SQLContext.getOrCreate(SparkContext.getOrCreate())
@transient protected[this] val sparkContext = sqlContext.sparkContext
val name: String
protected val executionMode: ExecutionMode

View File

@ -0,0 +1,172 @@
/*
* Copyright 2015 Databricks Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.databricks.spark.sql.perf
import scala.language.implicitConversions
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.execution.SparkPlan
/** Holds one benchmark query and its metadata. */
class Query(
override val name: String,
buildDataFrame: => DataFrame,
val description: String = "",
val sqlText: Option[String] = None,
override val executionMode: ExecutionMode = ExecutionMode.ForeachResults)
extends Benchmarkable with Serializable {
private implicit def toOption[A](a: A): Option[A] = Option(a)
override def toString: String = {
try {
s"""
|== Query: $name ==
|${buildDataFrame.queryExecution.analyzed}
""".stripMargin
} catch {
case e: Exception =>
s"""
|== Query: $name ==
| Can't be analyzed: $e
|
| $description
""".stripMargin
}
}
lazy val tablesInvolved = buildDataFrame.queryExecution.logical collect {
case UnresolvedRelation(tableIdentifier, _) => {
// We are ignoring the database name.
tableIdentifier.table
}
}
def newDataFrame() = buildDataFrame
protected override def doBenchmark(
includeBreakdown: Boolean,
description: String = "",
messages: ArrayBuffer[String]): BenchmarkResult = {
try {
val dataFrame = buildDataFrame
val queryExecution = dataFrame.queryExecution
// We are not counting the time of ScalaReflection.convertRowToScala.
val parsingTime = measureTimeMs {
queryExecution.logical
}
val analysisTime = measureTimeMs {
queryExecution.analyzed
}
val optimizationTime = measureTimeMs {
queryExecution.optimizedPlan
}
val planningTime = measureTimeMs {
queryExecution.executedPlan
}
val breakdownResults = if (includeBreakdown) {
val depth = queryExecution.executedPlan.collect { case p: SparkPlan => p }.size
val physicalOperators = (0 until depth).map(i => (i, queryExecution.executedPlan(i)))
val indexMap = physicalOperators.map { case (index, op) => (op, index) }.toMap
val timeMap = new mutable.HashMap[Int, Double]
physicalOperators.reverse.map {
case (index, node) =>
messages += s"Breakdown: ${node.simpleString}"
val newNode = buildDataFrame.queryExecution.executedPlan(index)
val executionTime = measureTimeMs {
newNode.execute().foreach((row: Any) => Unit)
}
timeMap += ((index, executionTime))
val childIndexes = node.children.map(indexMap)
val childTime = childIndexes.map(timeMap).sum
messages += s"Breakdown time: $executionTime (+${executionTime - childTime})"
BreakdownResult(
node.nodeName,
node.simpleString.replaceAll("#\\d+", ""),
index,
childIndexes,
executionTime,
executionTime - childTime)
}
} else {
Seq.empty[BreakdownResult]
}
// The executionTime for the entire query includes the time of type conversion from catalyst
// to scala.
// The executionTime for the entire query includes the time of type conversion
// from catalyst to scala.
var result: Option[Long] = None
val executionTime = measureTimeMs {
executionMode match {
case ExecutionMode.CollectResults => dataFrame.rdd.collect()
case ExecutionMode.ForeachResults => dataFrame.rdd.foreach { row => Unit }
case ExecutionMode.WriteParquet(location) =>
dataFrame.write.parquet(s"$location/$name.parquet")
case ExecutionMode.HashResults =>
val columnStr = dataFrame.schema.map(_.name).mkString(",")
// SELECT SUM(HASH(col1, col2, ...)) FROM (benchmark query)
val row =
dataFrame
.selectExpr(s"hash($columnStr) as hashValue")
.groupBy()
.sum("hashValue")
.head()
result = if (row.isNullAt(0)) None else Some(row.getLong(0))
}
}
val joinTypes = dataFrame.queryExecution.executedPlan.collect {
case k if k.nodeName contains "Join" => k.nodeName
}
BenchmarkResult(
name = name,
mode = executionMode.toString,
joinTypes = joinTypes,
tables = tablesInvolved,
parsingTime = parsingTime,
analysisTime = analysisTime,
optimizationTime = optimizationTime,
planningTime = planningTime,
executionTime = executionTime,
result = result,
queryExecution = dataFrame.queryExecution.toString,
breakDown = breakdownResults)
} catch {
case e: Exception =>
BenchmarkResult(
name = name,
mode = executionMode.toString,
failure = Failure(e.getClass.getName, e.getMessage))
}
}
/** Change the ExecutionMode of this Query to HashResults, which is used to check the query result. */
def checkResult: Query = {
new Query(name, buildDataFrame, description, sqlText, ExecutionMode.HashResults)
}
}

View File

@ -16,7 +16,7 @@
package com.databricks.spark.sql.perf.tpcds
import com.databricks.spark.sql.perf.{ExecutionMode, Benchmark}
import com.databricks.spark.sql.perf.{Benchmark, ExecutionMode, Query}
/**
* This implements the official TPCDS v1.4 queries with only cosmetic modifications