diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/MLPipelineStageBenchmarkable.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/MLPipelineStageBenchmarkable.scala index 32605bd..e9ff623 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/mllib/MLPipelineStageBenchmarkable.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/MLPipelineStageBenchmarkable.scala @@ -91,10 +91,11 @@ class MLPipelineStageBenchmarkable( val additionalTests = test.testAdditionalMethods(param, model).map { tuple => val (additionalMethodTime, _) = measureTime { tuple._2() } - tuple._1 -> additionalMethodTime.toMillis.toDouble - } + MLMetric(tuple._1, additionalMethodTime.toMillis, false) + }.toArray - val mlMetrics = Array(metricTrainingTime, metricTraining, metricTestTime, metricTest) + val mlMetrics = Array(metricTrainingTime, metricTraining, metricTestTime, metricTest) ++ + additionalTests val paramsMap = params.toMap val benchmarkId = name.split('.').last + "_" + paramsMap.hashCode.abs