From 08cb68ca2034cb732e3ad6ace52d04f1f40e148e Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Wed, 9 Sep 2015 21:49:50 -0700 Subject: [PATCH] Make it easier to write benchmarks in notebooks Author: Michael Armbrust Closes #19 from marmbrus/notebookTests. --- .../databricks/spark/sql/perf/Benchmark.scala | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/src/main/scala/com/databricks/spark/sql/perf/Benchmark.scala b/src/main/scala/com/databricks/spark/sql/perf/Benchmark.scala index 2808ece..3ca6736 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/Benchmark.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/Benchmark.scala @@ -303,7 +303,10 @@ abstract class Benchmark( import reflect.runtime._ import universe._ + @transient private val runtimeMirror = universe.runtimeMirror(getClass.getClassLoader) + + @transient val myType = runtimeMirror.classSymbol(getClass).toType def singleTables = @@ -320,6 +323,7 @@ abstract class Benchmark( .filter(_.asMethod.returnType =:= typeOf[Seq[Table]]) .flatMap(method => runtimeMirror.reflect(this).reflectMethod(method).apply().asInstanceOf[Seq[Table]]) + @transient lazy val allTables: Seq[Table] = (singleTables ++ groupedTables).toSeq def singleQueries = @@ -336,6 +340,7 @@ abstract class Benchmark( .filter(_.asMethod.returnType =:= typeOf[Seq[Query]]) .flatMap(method => runtimeMirror.reflect(this).reflectMethod(method).apply().asInstanceOf[Seq[Query]]) + @transient lazy val allQueries = (singleQueries ++ groupedQueries).toSeq def html: String = { @@ -368,8 +373,8 @@ abstract class Benchmark( """.stripMargin } - trait ExecutionMode - object ExecutionMode { + trait ExecutionMode extends Serializable + case object ExecutionMode { // Benchmark run by collecting queries results (e.g. rdd.collect()) case object CollectResults extends ExecutionMode { override def toString: String = "collect" @@ -393,7 +398,7 @@ abstract class Benchmark( } /** Factory object for benchmark queries. */ - object Query { + case object Query { def apply( name: String, sqlText: String, @@ -414,9 +419,9 @@ abstract class Benchmark( class Query( val name: String, buildDataFrame: => DataFrame, - val description: String, - val sqlText: Option[String], - val executionMode: ExecutionMode) { + val description: String = "", + val sqlText: Option[String] = None, + val executionMode: ExecutionMode = ExecutionMode.ForeachResults) extends Serializable { override def toString = s"""