set the solver
This commit is contained in:
parent
def20479a1
commit
ce7e20ae6d
@ -1,8 +1,8 @@
|
||||
package com.databricks.spark.sql.perf.mllib.regression
|
||||
|
||||
import org.apache.spark.ml
|
||||
import org.apache.spark.ml.evaluation.{Evaluator, RegressionEvaluator}
|
||||
import org.apache.spark.ml.linalg.Vectors
|
||||
import org.apache.spark.ml.regression.{LinearRegression, GeneralizedLinearRegression}
|
||||
import org.apache.spark.ml.{Estimator, ModelBuilder, Transformer}
|
||||
|
||||
import com.databricks.spark.sql.perf.mllib.OptionImplicits._
|
||||
@ -34,7 +34,8 @@ object LinearRegression extends BenchmarkAlgorithm with TestFromTraining with
|
||||
|
||||
override def getEstimator(ctx: MLBenchContext): Estimator[_] = {
|
||||
import ctx.params._
|
||||
new LinearRegression()
|
||||
new ml.regression.LinearRegression()
|
||||
.setSolver("l-bfgs")
|
||||
.setRegParam(regParam)
|
||||
.setMaxIter(maxIter)
|
||||
.setTol(tol)
|
||||
|
||||
@ -63,9 +63,9 @@ benchmarks:
|
||||
numClasses: 4
|
||||
numFeatures: 5
|
||||
maxIter: 3
|
||||
- name: classification.LogisticRegression
|
||||
- name: regression.LinearRegression
|
||||
params:
|
||||
numFeatures: 100
|
||||
regParam: 0.1
|
||||
tol: [0.2, 0.1]
|
||||
tol: [0.0]
|
||||
maxIter: 10
|
||||
|
||||
Loading…
Reference in New Issue
Block a user