diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/BenchmarkAlgorithm.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/BenchmarkAlgorithm.scala index f5eb079..986b1a8 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/mllib/BenchmarkAlgorithm.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/BenchmarkAlgorithm.scala @@ -35,12 +35,21 @@ trait BenchmarkAlgorithm extends Logging { /** * The unnormalized score of the training procedure on a dataset. The normalization is * performed by the caller. + * This calls `count()` on the transformed data to attempt to materialize the result for + * recording timing metrics. */ @throws[Exception]("if scoring fails") def score( ctx: MLBenchContext, testSet: DataFrame, - model: Transformer): MLMetric = MLMetric.Invalid + model: Transformer): MLMetric = { + val output = model.transform(testSet) + // We create a useless UDF to make sure the entire DataFrame is instantiated. + val fakeUDF = udf { (_: Any) => 0 } + val columns = testSet.columns + output.select(sum(fakeUDF(struct(columns.map(col) : _*)))).first() + MLMetric.Invalid + } def name: String = { this.getClass.getCanonicalName.replace("$", "")