set the solver

This commit is contained in:
Timothy Hunter 2016-07-05 13:46:19 -07:00
parent def20479a1
commit ce7e20ae6d
2 changed files with 5 additions and 4 deletions

View File

@ -1,8 +1,8 @@
package com.databricks.spark.sql.perf.mllib.regression
import org.apache.spark.ml
import org.apache.spark.ml.evaluation.{Evaluator, RegressionEvaluator}
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.regression.{LinearRegression, GeneralizedLinearRegression}
import org.apache.spark.ml.{Estimator, ModelBuilder, Transformer}
import com.databricks.spark.sql.perf.mllib.OptionImplicits._
@ -34,7 +34,8 @@ object LinearRegression extends BenchmarkAlgorithm with TestFromTraining with
override def getEstimator(ctx: MLBenchContext): Estimator[_] = {
import ctx.params._
new LinearRegression()
new ml.regression.LinearRegression()
.setSolver("l-bfgs")
.setRegParam(regParam)
.setMaxIter(maxIter)
.setTol(tol)

View File

@ -63,9 +63,9 @@ benchmarks:
numClasses: 4
numFeatures: 5
maxIter: 3
- name: classification.LogisticRegression
- name: regression.LinearRegression
params:
numFeatures: 100
regParam: 0.1
tol: [0.2, 0.1]
tol: [0.0]
maxIter: 10