commit
93c0407bbe
@ -0,0 +1,46 @@
|
||||
package com.databricks.spark.sql.perf.mllib.regression
|
||||
|
||||
import org.apache.spark.ml
|
||||
import org.apache.spark.ml.evaluation.{Evaluator, RegressionEvaluator}
|
||||
import org.apache.spark.ml.linalg.Vectors
|
||||
import org.apache.spark.ml.{Estimator, ModelBuilder, Transformer}
|
||||
|
||||
import com.databricks.spark.sql.perf.mllib.OptionImplicits._
|
||||
import com.databricks.spark.sql.perf.mllib._
|
||||
import com.databricks.spark.sql.perf.mllib.data.DataGenerator
|
||||
|
||||
|
||||
object LinearRegression extends BenchmarkAlgorithm with TestFromTraining with
|
||||
TrainingSetFromTransformer with ScoringWithEvaluator {
|
||||
|
||||
override protected def initialData(ctx: MLBenchContext) = {
|
||||
import ctx.params._
|
||||
DataGenerator.generateContinuousFeatures(
|
||||
ctx.sqlContext,
|
||||
numExamples,
|
||||
ctx.seed(),
|
||||
numPartitions,
|
||||
numFeatures)
|
||||
}
|
||||
|
||||
override protected def trueModel(ctx: MLBenchContext): Transformer = {
|
||||
val rng = ctx.newGenerator()
|
||||
val coefficients =
|
||||
Vectors.dense(Array.fill[Double](ctx.params.numFeatures)(2 * rng.nextDouble() - 1))
|
||||
// Small intercept to prevent some skew in the data.
|
||||
val intercept = 0.01 * (2 * rng.nextDouble - 1)
|
||||
ModelBuilder.newLinearRegressionModel(coefficients, intercept)
|
||||
}
|
||||
|
||||
override def getEstimator(ctx: MLBenchContext): Estimator[_] = {
|
||||
import ctx.params._
|
||||
new ml.regression.LinearRegression()
|
||||
.setSolver("l-bfgs")
|
||||
.setRegParam(regParam)
|
||||
.setMaxIter(maxIter)
|
||||
.setTol(tol)
|
||||
}
|
||||
|
||||
override protected def evaluator(ctx: MLBenchContext): Evaluator =
|
||||
new RegressionEvaluator()
|
||||
}
|
||||
@ -63,3 +63,11 @@ benchmarks:
|
||||
numClasses: 4
|
||||
numFeatures: 5
|
||||
maxIter: 3
|
||||
- name: regression.LinearRegression
|
||||
params:
|
||||
numExamples: 100
|
||||
numTestExamples: 100
|
||||
numFeatures: 100
|
||||
regParam: 0.1
|
||||
tol: [0.0]
|
||||
maxIter: 10
|
||||
|
||||
@ -2,8 +2,7 @@ package org.apache.spark.ml
|
||||
|
||||
import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, LogisticRegressionModel}
|
||||
import org.apache.spark.ml.linalg.Vector
|
||||
import org.apache.spark.ml.regression.GeneralizedLinearRegressionModel
|
||||
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
|
||||
import org.apache.spark.ml.regression.{LinearRegressionModel, GeneralizedLinearRegressionModel, DecisionTreeRegressionModel}
|
||||
import org.apache.spark.ml.tree._
|
||||
import org.apache.spark.mllib.random.RandomDataGenerator
|
||||
import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
|
||||
@ -20,6 +19,12 @@ object ModelBuilder {
|
||||
new LogisticRegressionModel("lr", coefficients, intercept)
|
||||
}
|
||||
|
||||
def newLinearRegressionModel(
|
||||
coefficients: Vector,
|
||||
intercept: Double): LinearRegressionModel = {
|
||||
new LinearRegressionModel("linr", coefficients, intercept)
|
||||
}
|
||||
|
||||
def newGLR(
|
||||
coefficients: Vector,
|
||||
intercept: Double): GeneralizedLinearRegressionModel =
|
||||
|
||||
Loading…
Reference in New Issue
Block a user