diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/BenchmarkAlgorithm.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/BenchmarkAlgorithm.scala index 95b7478..f5eb079 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/mllib/BenchmarkAlgorithm.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/BenchmarkAlgorithm.scala @@ -7,6 +7,8 @@ import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.sql._ import org.apache.spark.sql.functions._ +import com.databricks.spark.sql.perf._ + /** * The description of a benchmark for an ML algorithm. It follows a simple, standard proceduce: * - generate some test and training data @@ -38,7 +40,7 @@ trait BenchmarkAlgorithm extends Logging { def score( ctx: MLBenchContext, testSet: DataFrame, - model: Transformer): Double = -1.0 // Not putting NaN because it is not valid JSON. + model: Transformer): MLMetric = MLMetric.Invalid def name: String = { this.getClass.getCanonicalName.replace("$", "") @@ -67,9 +69,17 @@ trait ScoringWithEvaluator { final override def score( ctx: MLBenchContext, testSet: DataFrame, - model: Transformer): Double = { - val eval = model.transform(testSet) - evaluator(ctx).evaluate(eval) + model: Transformer): MLMetric = { + val results = model.transform(testSet) + val eval = evaluator(ctx) + val metricName = if (eval.hasParam("metricName")) { + val param = eval.getParam("metricName") + eval.getOrDefault(param).toString + } else { + eval.getClass.getSimpleName + } + val metricValue = eval.evaluate(results) + MLMetric(metricName, metricValue, eval.isLargerBetter) } } diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/MLPipelineStageBenchmarkable.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/MLPipelineStageBenchmarkable.scala index d598186..25ddd42 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/mllib/MLPipelineStageBenchmarkable.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/MLPipelineStageBenchmarkable.scala @@ -72,9 +72,17 @@ class MLPipelineStageBenchmarkable( val (scoreTrainTime, scoreTraining) = measureTime { test.score(param, trainingData, model) } + val metricTrainingTime = MLMetric("training.time", scoreTrainTime.toMillis, false) + val metricTraining = MLMetric("training."+scoreTraining.metricName, + scoreTraining.metricValue, + scoreTraining.isLargerBetter) val (scoreTestTime, scoreTest) = measureTime { test.score(param, testData, model) } + val metricTestTime = MLMetric("test.time", scoreTestTime.toMillis, false) + val metricTest = MLMetric("test."+scoreTraining.metricName, + scoreTraining.metricValue, + scoreTraining.isLargerBetter) logger.info(s"$this doBenchmark: Trained model in ${trainingTime.toMillis / 1000.0}" + s" s, Scored training dataset in ${scoreTrainTime.toMillis / 1000.0} s," + @@ -86,19 +94,14 @@ class MLPipelineStageBenchmarkable( tuple._1 -> additionalMethodTime.toMillis.toDouble } - val ml = MLResult( - trainingTime = Some(trainingTime.toMillis), - trainingMetric = Some(scoreTraining), - testTime = Some(scoreTestTime.toMillis), - testMetric = Some(scoreTest / testDataCount.get), - additionalTests = additionalTests) + val mlMetrics = Array(metricTrainingTime, metricTraining, metricTestTime, metricTest) BenchmarkResult( name = name, mode = executionMode.toString, parameters = params.toMap, executionTime = Some(trainingTime.toMillis), - mlResult = Some(ml)) + mlResult = Some(mlMetrics)) } catch { case e: Exception => BenchmarkResult( diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/feature/Word2Vec.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/feature/Word2Vec.scala index 3e1b995..a59d29e 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/mllib/feature/Word2Vec.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/feature/Word2Vec.scala @@ -9,7 +9,6 @@ import org.apache.spark.ml.linalg.Vectors import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions.{col, split} -import com.databricks.spark.sql.perf.MLResult import com.databricks.spark.sql.perf.mllib.{BenchmarkAlgorithm, MLBenchContext, TestFromTraining} import com.databricks.spark.sql.perf.mllib.OptionImplicits._ import com.databricks.spark.sql.perf.mllib.data.DataGenerator diff --git a/src/main/scala/com/databricks/spark/sql/perf/results.scala b/src/main/scala/com/databricks/spark/sql/perf/results.scala index 4ac99d7..209a732 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/results.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/results.scala @@ -83,7 +83,7 @@ case class BenchmarkResult( breakDown: Seq[BreakdownResult] = Nil, queryExecution: Option[String] = None, failure: Option[Failure] = None, - mlResult: Option[MLResult] = None) + mlResult: Option[Array[MLMetric]] = None) /** * The execution time of a subtree of the query plan tree of a specific query. @@ -223,19 +223,17 @@ object MLParams { } /** - * Result information specific to MLlib. + * Metrics specific to MLlib benchmark. * - * @param trainingTime (MLlib) Training time. - * executionTime is set to the same value to match Spark Core tests. - * @param trainingMetric (MLlib) Training metric, such as accuracy - * @param testTime (MLlib) Test time (for prediction on test set, or on training set if there - * is no test set). - * @param testMetric (MLlib) Test metric, such as accuracy - * @param additionalTests (MLlib) Additional methods test results. + * @param metricName the name of the metric + * @param metricValue the value of the metric + * @param isLargerBetter the indicator showing whether larger metric value is better */ -case class MLResult( - trainingTime: Option[Double] = None, - trainingMetric: Option[Double] = None, - testTime: Option[Double] = None, - testMetric: Option[Double] = None, - additionalTests: Map[String, Double]) +case class MLMetric( + metricName: String, + metricValue: Double, + isLargerBetter: Boolean) + +object MLMetric { + val Invalid = MLMetric("Invalid", 0.0, false) +} \ No newline at end of file