From f8aa93d968887b5d89c254c93c9fe5963d60a221 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 8 Dec 2015 16:04:42 -0800 Subject: [PATCH] Initial set of tests for Datasets Author: Michael Armbrust Closes #42 from marmbrus/dataset-tests. --- build.sbt | 2 + .../databricks/spark/sql/perf/Benchmark.scala | 12 +- .../spark/sql/perf/DatasetPerformance.scala | 122 ++++++++++++++++++ src/test/scala/DatasetPerformanceSuite.scala | 17 +++ 4 files changed, 152 insertions(+), 1 deletion(-) create mode 100644 src/main/scala/com/databricks/spark/sql/perf/DatasetPerformance.scala create mode 100644 src/test/scala/DatasetPerformanceSuite.scala diff --git a/build.sbt b/build.sbt index bd58136..0a04482 100644 --- a/build.sbt +++ b/build.sbt @@ -25,6 +25,8 @@ libraryDependencies += "com.twitter" %% "util-jvm" % "6.23.0" % "provided" libraryDependencies += "org.scalatest" %% "scalatest" % "2.2.1" % "test" +fork := true + // Your username to login to Databricks Cloud dbcUsername := sys.env.getOrElse("DBC_USERNAME", sys.error("Please set DBC_USERNAME")) diff --git a/src/main/scala/com/databricks/spark/sql/perf/Benchmark.scala b/src/main/scala/com/databricks/spark/sql/perf/Benchmark.scala index b0b9d55..67008a1 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/Benchmark.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/Benchmark.scala @@ -16,13 +16,15 @@ package com.databricks.spark.sql.perf +import org.apache.spark.rdd.RDD + import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.concurrent._ import scala.concurrent.duration._ import scala.concurrent.ExecutionContext.Implicits.global -import org.apache.spark.sql.{AnalysisException, DataFrame, SQLContext} +import org.apache.spark.sql.{Dataset, AnalysisException, DataFrame, SQLContext} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.{SparkContext, SparkEnv} @@ -484,6 +486,14 @@ abstract class Benchmark( } } + object RDDCount { + def apply( + name: String, + rdd: RDD[_]) = { + new SparkPerfExecution(name, Map.empty, () => Unit, () => rdd.count()) + } + } + /** A class for benchmarking Spark perf results. */ class SparkPerfExecution( override val name: String, diff --git a/src/main/scala/com/databricks/spark/sql/perf/DatasetPerformance.scala b/src/main/scala/com/databricks/spark/sql/perf/DatasetPerformance.scala new file mode 100644 index 0000000..d58ecfc --- /dev/null +++ b/src/main/scala/com/databricks/spark/sql/perf/DatasetPerformance.scala @@ -0,0 +1,122 @@ +package com.databricks.spark.sql.perf + +import org.apache.spark.sql.expressions.Aggregator + +case class Data(id: Long) + +case class SumAndCount(var sum: Long, var count: Int) + +trait DatasetPerformance extends Benchmark { + + import sqlContext.implicits._ + + val numLongs = 100000000 + val ds = sqlContext.range(1, numLongs) + val rdd = sparkContext.range(1, numLongs) + + val smallNumLongs = 1000000 + val smallds = sqlContext.range(1, smallNumLongs) + val smallrdd = sparkContext.range(1, smallNumLongs) + + def allBenchmarks = range ++ backToBackFilters ++ backToBackMaps ++ computeAverage + + val range = Seq( + new Query( + "DS: range", + ds.as[Data].toDF(), + executionMode = ExecutionMode.ForeachResults), + new Query( + "DF: range", + ds.toDF(), + executionMode = ExecutionMode.ForeachResults), + RDDCount( + "RDD: range", + rdd.map(Data)) + ) + + val backToBackFilters = Seq( + new Query( + "DS: back-to-back filters", + ds.as[Data] + .filter(_.id % 100 != 0) + .filter(_.id % 101 != 0) + .filter(_.id % 102 != 0) + .filter(_.id % 103 != 0).toDF()), + new Query( + "DF: back-to-back filters", + ds.toDF() + .filter("id % 100 != 0") + .filter("id % 101 != 0") + .filter("id % 102 != 0") + .filter("id % 103 != 0")), + RDDCount( + "RDD: back-to-back filters", + rdd.map(Data) + .filter(_.id % 100 != 0) + .filter(_.id % 101 != 0) + .filter(_.id % 102 != 0) + .filter(_.id % 103 != 0)) + ) + + val backToBackMaps = Seq( + new Query( + "DS: back-to-back maps", + ds.as[Data] + .map(d => Data(d.id + 1L)) + .map(d => Data(d.id + 1L)) + .map(d => Data(d.id + 1L)) + .map(d => Data(d.id + 1L)).toDF()), + new Query( + "DF: back-to-back maps", + ds.toDF() + .select($"id" + 1 as 'id) + .select($"id" + 1 as 'id) + .select($"id" + 1 as 'id) + .select($"id" + 1 as 'id)), + RDDCount( + "RDD: back-to-back maps", + rdd.map(Data) + .map(d => Data(d.id + 1L)) + .map(d => Data(d.id + 1L)) + .map(d => Data(d.id + 1L)) + .map(d => Data(d.id + 1L))) + ) + + val average = new Aggregator[Long, SumAndCount, Double] { + override def zero: SumAndCount = SumAndCount(0, 0) + + override def reduce(b: SumAndCount, a: Long): SumAndCount = { + b.count += 1 + b.sum += a + b + } + + override def finish(reduction: SumAndCount): Double = reduction.sum.toDouble / reduction.count + + override def merge(b1: SumAndCount, b2: SumAndCount): SumAndCount = { + b1.count += b2.count + b1.sum += b2.sum + b1 + } + }.toColumn + + val computeAverage = Seq( + new Query( + "DS: average", + smallds.as[Long].select(average).toDF(), + executionMode = ExecutionMode.CollectResults), + new Query( + "DF: average", + smallds.toDF().selectExpr("avg(id)"), + executionMode = ExecutionMode.CollectResults), + new SparkPerfExecution( + "RDD: average", + Map.empty, + prepare = () => Unit, + run = () => { + val sumAndCount = + smallrdd.map(i => (i, 1)).reduce((a, b) => (a._1 + b._1, a._2 + b._2)) + sumAndCount._1.toDouble / sumAndCount._2 + }) + ) +} diff --git a/src/test/scala/DatasetPerformanceSuite.scala b/src/test/scala/DatasetPerformanceSuite.scala new file mode 100644 index 0000000..b5179c2 --- /dev/null +++ b/src/test/scala/DatasetPerformanceSuite.scala @@ -0,0 +1,17 @@ +package com.databricks.spark.sql.perf + +import com.databricks.spark.sql.perf.{Benchmark, DatasetPerformance} +import org.apache.spark.sql.hive.test.TestHive +import org.scalatest.FunSuite + +class DatasetPerformanceSuite extends FunSuite { + test("run benchmark") { + val benchmark = new Benchmark(TestHive) with DatasetPerformance { + override val numLongs = 100 + } + import benchmark._ + + val exp = runExperiment(allBenchmarks) + exp.waitForFinish(10000) + } +}