From 30c50dddbb44b1e2f63189a0393898a4175e780a Mon Sep 17 00:00:00 2001 From: Joseph Bradley Date: Sun, 8 Jul 2018 16:09:24 -0700 Subject: [PATCH] [ML-2918] Call count() in default score() to improve timing of transform() (#159) For Models and Transformers which are not tested with Evaluators, I think we are not timing transform() correctly here: spark-sql-perf/src/main/scala/com/databricks/spark/sql/perf/mllib/MLPipelineStageBenchmarkable.scala Line 65 in aa1587f transformer.transform(trainingData) Since transform() is lazy, we need to materialize it during timing. This PR currently just calls count() in the default implementation of score(). * call count() in score() * changed count to UDF --- .../spark/sql/perf/mllib/BenchmarkAlgorithm.scala | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/BenchmarkAlgorithm.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/BenchmarkAlgorithm.scala index f5eb079..986b1a8 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/mllib/BenchmarkAlgorithm.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/BenchmarkAlgorithm.scala @@ -35,12 +35,21 @@ trait BenchmarkAlgorithm extends Logging { /** * The unnormalized score of the training procedure on a dataset. The normalization is * performed by the caller. + * This calls `count()` on the transformed data to attempt to materialize the result for + * recording timing metrics. */ @throws[Exception]("if scoring fails") def score( ctx: MLBenchContext, testSet: DataFrame, - model: Transformer): MLMetric = MLMetric.Invalid + model: Transformer): MLMetric = { + val output = model.transform(testSet) + // We create a useless UDF to make sure the entire DataFrame is instantiated. + val fakeUDF = udf { (_: Any) => 0 } + val columns = testSet.columns + output.select(sum(fakeUDF(struct(columns.map(col) : _*)))).first() + MLMetric.Invalid + } def name: String = { this.getClass.getCanonicalName.replace("$", "")