From 5c1990e4ffc78e11c6082a68b2099e247d0a120a Mon Sep 17 00:00:00 2001 From: Timothy Hunter Date: Mon, 27 Jun 2016 13:32:38 -0700 Subject: [PATCH] no normalization --- .../databricks/spark/sql/perf/mllib/BenchmarkAlgorithm.scala | 4 ++++ .../spark/sql/perf/mllib/MLTransformerBenchmarkable.scala | 5 +++-- .../sql/perf/mllib/classification/LogisticRegression.scala | 1 + 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/BenchmarkAlgorithm.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/BenchmarkAlgorithm.scala index 3839ab5..ad714e5 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/mllib/BenchmarkAlgorithm.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/BenchmarkAlgorithm.scala @@ -30,6 +30,10 @@ trait BenchmarkAlgorithm extends Logging { ctx: MLBenchContext, trainingSet: DataFrame): Transformer + /** + * The unnormalized score of the training procedure on a dataset. The normalization is + * performed by the caller. + */ @throws[Exception]("if scoring fails") def score( ctx: MLBenchContext, diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/MLTransformerBenchmarkable.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/MLTransformerBenchmarkable.scala index 951cd26..8794c07 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/mllib/MLTransformerBenchmarkable.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/MLTransformerBenchmarkable.scala @@ -16,6 +16,7 @@ class MLTransformerBenchmarkable( private var testData: DataFrame = null private var trainingData: DataFrame = null + private var testDataCount: Option[Long] = None private val param = MLBenchContext(params, sqlContext) override val name = test.name @@ -27,7 +28,7 @@ class MLTransformerBenchmarkable( try { testData = test.testDataSet(param) testData.cache() - testData.count() + testDataCount = Some(testData.count()) trainingData = test.trainingDataSet(param) trainingData.cache() trainingData.count() @@ -57,7 +58,7 @@ class MLTransformerBenchmarkable( trainingTime = Some(trainingTime.toMillis), trainingMetric = Some(scoreTraining), testTime = Some(scoreTestTime.toMillis), - testMetric = Some(scoreTest)) + testMetric = Some(scoreTest / testDataCount.get)) BenchmarkResult( name = name, diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/LogisticRegression.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/LogisticRegression.scala index 5e81496..284905d 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/LogisticRegression.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/LogisticRegression.scala @@ -38,6 +38,7 @@ object LogisticRegression extends BenchmarkAlgorithm import ctx.params._ val lr = new ml.classification.LogisticRegression() .setTol(tol) + .setMaxIter(maxIter) .setRegParam(regParam) lr.fit(trainingSet) }