[ML-4154] Added testing for before/after of ml benchmarks. (#162)

This PR adds a unit tests which runs the beforeBenchmark & afterBenchmark methods for the benchmarks included in mllib-small.yaml.
This commit is contained in:
Bago Amirbekian 2018-07-12 16:43:54 -07:00 committed by Joseph Bradley
parent 107495afe2
commit 8e8c08d75b
3 changed files with 50 additions and 15 deletions

View File

@ -54,6 +54,21 @@ object MLLib extends Logging {
run(yamlFile = configFile)
}
private[mllib] def getConf(yamlFile: String = null, yamlConfig: String = null): YamlConfig = {
Option(yamlFile).map(YamlConfig.readFile).getOrElse {
require(yamlConfig != null)
YamlConfig.readString(yamlConfig)
}
}
private[mllib] def getBenchmarks(conf: YamlConfig): Seq[MLPipelineStageBenchmarkable] = {
val sqlContext = com.databricks.spark.sql.perf.mllib.MLBenchmarks.sqlContext
val benchmarksDescriptions = conf.runnableBenchmarks
benchmarksDescriptions.map { mlb =>
new MLPipelineStageBenchmarkable(mlb.params, mlb.benchmark, sqlContext)
}
}
/**
* Runs all the experiments and blocks on completion
*
@ -62,20 +77,12 @@ object MLLib extends Logging {
*/
def run(yamlFile: String = null, yamlConfig: String = null): DataFrame = {
logger.info("Starting run")
val conf: YamlConfig = Option(yamlFile).map(YamlConfig.readFile).getOrElse {
require(yamlConfig != null)
YamlConfig.readString(yamlConfig)
}
val conf = getConf(yamlFile, yamlConfig)
val sparkConf = new SparkConf().setAppName("MLlib QA").setMaster("local[2]")
val sc = SparkContext.getOrCreate(sparkConf)
sc.setLogLevel("INFO")
val b = new com.databricks.spark.sql.perf.mllib.MLLib()
val sqlContext = com.databricks.spark.sql.perf.mllib.MLBenchmarks.sqlContext
val benchmarksDescriptions = conf.runnableBenchmarks
val benchmarks = benchmarksDescriptions.map { mlb =>
new MLPipelineStageBenchmarkable(mlb.params, mlb.benchmark, sqlContext)
}
val benchmarks = getBenchmarks(conf)
println(s"${benchmarks.size} benchmarks identified:")
val str = benchmarks.map(_.prettyPrint).mkString("\n")
println(str)

View File

@ -27,7 +27,7 @@ class MLPipelineStageBenchmarkable(
override protected val executionMode: ExecutionMode = ExecutionMode.SparkPerfResults
override protected def beforeBenchmark(): Unit = {
override protected[mllib] def beforeBenchmark(): Unit = {
logger.info(s"$this beforeBenchmark")
try {
testData = test.testDataSet(param)
@ -43,7 +43,7 @@ class MLPipelineStageBenchmarkable(
}
}
override protected def afterBenchmark(sc: SparkContext): Unit = {
override protected[mllib] def afterBenchmark(sc: SparkContext): Unit = {
// Best-effort clean up of weakly referenced RDDs, shuffles, and broadcasts
// Remove any leftover blocks that still exist
sc.getExecutorStorageStatus

View File

@ -1,10 +1,30 @@
package com.databricks.spark.sql.perf.mllib
import org.scalatest.FunSuite
import org.scalatest.{BeforeAndAfterAll, FunSuite}
import org.apache.spark.sql.Row
import org.apache.spark.sql.{Row, SparkSession}
class MLLibTest extends FunSuite {
class MLLibSuite extends FunSuite with BeforeAndAfterAll {
private var sparkSession: SparkSession = _
override def beforeAll(): Unit = {
super.beforeAll()
sparkSession = SparkSession.builder.master("local[2]").appName("MLlib QA").getOrCreate()
}
override def afterAll(): Unit = {
try {
if (sparkSession != null) {
sparkSession.stop()
}
// To avoid RPC rebinding to the same port, since it doesn't unbind immediately on shutdown
System.clearProperty("spark.driver.port")
sparkSession = null
} finally {
super.afterAll()
}
}
test("test MlLib benchmarks with mllib-small.yaml.") {
val results = MLLib.run(yamlConfig = MLLib.smallConfig)
@ -20,4 +40,12 @@ class MLLibTest extends FunSuite {
fail("Unable to run all benchmarks successfully, see console output for more info.")
}
}
test("test before & after benchmark methods for pipeline benchmarks.") {
val benchmarks = MLLib.getBenchmarks(MLLib.getConf(yamlConfig = MLLib.smallConfig))
benchmarks.foreach { b =>
b.beforeBenchmark()
b.afterBenchmark(sparkSession.sparkContext)
}
}
}