From 8e8c08d75bce54d0908d92eb944ec111561c93f1 Mon Sep 17 00:00:00 2001 From: Bago Amirbekian Date: Thu, 12 Jul 2018 16:43:54 -0700 Subject: [PATCH] [ML-4154] Added testing for before/after of ml benchmarks. (#162) This PR adds a unit tests which runs the beforeBenchmark & afterBenchmark methods for the benchmarks included in mllib-small.yaml. --- .../spark/sql/perf/mllib/MLLib.scala | 27 +++++++++------ .../mllib/MLPipelineStageBenchmarkable.scala | 4 +-- .../spark/sql/perf/mllib/MLLibSuite.scala | 34 +++++++++++++++++-- 3 files changed, 50 insertions(+), 15 deletions(-) diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/MLLib.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/MLLib.scala index cfb7709..d325327 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/mllib/MLLib.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/MLLib.scala @@ -54,6 +54,21 @@ object MLLib extends Logging { run(yamlFile = configFile) } + private[mllib] def getConf(yamlFile: String = null, yamlConfig: String = null): YamlConfig = { + Option(yamlFile).map(YamlConfig.readFile).getOrElse { + require(yamlConfig != null) + YamlConfig.readString(yamlConfig) + } + } + + private[mllib] def getBenchmarks(conf: YamlConfig): Seq[MLPipelineStageBenchmarkable] = { + val sqlContext = com.databricks.spark.sql.perf.mllib.MLBenchmarks.sqlContext + val benchmarksDescriptions = conf.runnableBenchmarks + benchmarksDescriptions.map { mlb => + new MLPipelineStageBenchmarkable(mlb.params, mlb.benchmark, sqlContext) + } + } + /** * Runs all the experiments and blocks on completion * @@ -62,20 +77,12 @@ object MLLib extends Logging { */ def run(yamlFile: String = null, yamlConfig: String = null): DataFrame = { logger.info("Starting run") - val conf: YamlConfig = Option(yamlFile).map(YamlConfig.readFile).getOrElse { - require(yamlConfig != null) - YamlConfig.readString(yamlConfig) - } - + val conf = getConf(yamlFile, yamlConfig) val sparkConf = new SparkConf().setAppName("MLlib QA").setMaster("local[2]") val sc = SparkContext.getOrCreate(sparkConf) sc.setLogLevel("INFO") val b = new com.databricks.spark.sql.perf.mllib.MLLib() - val sqlContext = com.databricks.spark.sql.perf.mllib.MLBenchmarks.sqlContext - val benchmarksDescriptions = conf.runnableBenchmarks - val benchmarks = benchmarksDescriptions.map { mlb => - new MLPipelineStageBenchmarkable(mlb.params, mlb.benchmark, sqlContext) - } + val benchmarks = getBenchmarks(conf) println(s"${benchmarks.size} benchmarks identified:") val str = benchmarks.map(_.prettyPrint).mkString("\n") println(str) diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/MLPipelineStageBenchmarkable.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/MLPipelineStageBenchmarkable.scala index e9ff623..2807aaf 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/mllib/MLPipelineStageBenchmarkable.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/MLPipelineStageBenchmarkable.scala @@ -27,7 +27,7 @@ class MLPipelineStageBenchmarkable( override protected val executionMode: ExecutionMode = ExecutionMode.SparkPerfResults - override protected def beforeBenchmark(): Unit = { + override protected[mllib] def beforeBenchmark(): Unit = { logger.info(s"$this beforeBenchmark") try { testData = test.testDataSet(param) @@ -43,7 +43,7 @@ class MLPipelineStageBenchmarkable( } } - override protected def afterBenchmark(sc: SparkContext): Unit = { + override protected[mllib] def afterBenchmark(sc: SparkContext): Unit = { // Best-effort clean up of weakly referenced RDDs, shuffles, and broadcasts // Remove any leftover blocks that still exist sc.getExecutorStorageStatus diff --git a/src/test/scala/com/databricks/spark/sql/perf/mllib/MLLibSuite.scala b/src/test/scala/com/databricks/spark/sql/perf/mllib/MLLibSuite.scala index f335bb9..c622155 100644 --- a/src/test/scala/com/databricks/spark/sql/perf/mllib/MLLibSuite.scala +++ b/src/test/scala/com/databricks/spark/sql/perf/mllib/MLLibSuite.scala @@ -1,10 +1,30 @@ package com.databricks.spark.sql.perf.mllib -import org.scalatest.FunSuite +import org.scalatest.{BeforeAndAfterAll, FunSuite} -import org.apache.spark.sql.Row +import org.apache.spark.sql.{Row, SparkSession} -class MLLibTest extends FunSuite { +class MLLibSuite extends FunSuite with BeforeAndAfterAll { + + private var sparkSession: SparkSession = _ + + override def beforeAll(): Unit = { + super.beforeAll() + sparkSession = SparkSession.builder.master("local[2]").appName("MLlib QA").getOrCreate() + } + + override def afterAll(): Unit = { + try { + if (sparkSession != null) { + sparkSession.stop() + } + // To avoid RPC rebinding to the same port, since it doesn't unbind immediately on shutdown + System.clearProperty("spark.driver.port") + sparkSession = null + } finally { + super.afterAll() + } + } test("test MlLib benchmarks with mllib-small.yaml.") { val results = MLLib.run(yamlConfig = MLLib.smallConfig) @@ -20,4 +40,12 @@ class MLLibTest extends FunSuite { fail("Unable to run all benchmarks successfully, see console output for more info.") } } + + test("test before & after benchmark methods for pipeline benchmarks.") { + val benchmarks = MLLib.getBenchmarks(MLLib.getConf(yamlConfig = MLLib.smallConfig)) + benchmarks.foreach { b => + b.beforeBenchmark() + b.afterBenchmark(sparkSession.sparkContext) + } + } }