From ce7e20ae6d5a85856f1b82c55581bbf586aca9ae Mon Sep 17 00:00:00 2001 From: Timothy Hunter Date: Tue, 5 Jul 2016 13:46:19 -0700 Subject: [PATCH] set the solver --- .../spark/sql/perf/mllib/regression/LinearRegression.scala | 5 +++-- src/main/scala/configs/mllib-small.yaml | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/regression/LinearRegression.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/regression/LinearRegression.scala index 4f81f4c..8acbb51 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/mllib/regression/LinearRegression.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/regression/LinearRegression.scala @@ -1,8 +1,8 @@ package com.databricks.spark.sql.perf.mllib.regression +import org.apache.spark.ml import org.apache.spark.ml.evaluation.{Evaluator, RegressionEvaluator} import org.apache.spark.ml.linalg.Vectors -import org.apache.spark.ml.regression.{LinearRegression, GeneralizedLinearRegression} import org.apache.spark.ml.{Estimator, ModelBuilder, Transformer} import com.databricks.spark.sql.perf.mllib.OptionImplicits._ @@ -34,7 +34,8 @@ object LinearRegression extends BenchmarkAlgorithm with TestFromTraining with override def getEstimator(ctx: MLBenchContext): Estimator[_] = { import ctx.params._ - new LinearRegression() + new ml.regression.LinearRegression() + .setSolver("l-bfgs") .setRegParam(regParam) .setMaxIter(maxIter) .setTol(tol) diff --git a/src/main/scala/configs/mllib-small.yaml b/src/main/scala/configs/mllib-small.yaml index 8851bec..f9896f0 100644 --- a/src/main/scala/configs/mllib-small.yaml +++ b/src/main/scala/configs/mllib-small.yaml @@ -63,9 +63,9 @@ benchmarks: numClasses: 4 numFeatures: 5 maxIter: 3 - - name: classification.LogisticRegression + - name: regression.LinearRegression params: numFeatures: 100 regParam: 0.1 - tol: [0.2, 0.1] + tol: [0.0] maxIter: 10