Merge pull request #77 from thunterdb/1607-linear

Linear regression
This commit is contained in:
Timothy Hunter 2016-07-05 15:41:35 -07:00 committed by GitHub
commit 93c0407bbe
3 changed files with 61 additions and 2 deletions

View File

@ -0,0 +1,46 @@
package com.databricks.spark.sql.perf.mllib.regression
import org.apache.spark.ml
import org.apache.spark.ml.evaluation.{Evaluator, RegressionEvaluator}
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.{Estimator, ModelBuilder, Transformer}
import com.databricks.spark.sql.perf.mllib.OptionImplicits._
import com.databricks.spark.sql.perf.mllib._
import com.databricks.spark.sql.perf.mllib.data.DataGenerator
object LinearRegression extends BenchmarkAlgorithm with TestFromTraining with
TrainingSetFromTransformer with ScoringWithEvaluator {
override protected def initialData(ctx: MLBenchContext) = {
import ctx.params._
DataGenerator.generateContinuousFeatures(
ctx.sqlContext,
numExamples,
ctx.seed(),
numPartitions,
numFeatures)
}
override protected def trueModel(ctx: MLBenchContext): Transformer = {
val rng = ctx.newGenerator()
val coefficients =
Vectors.dense(Array.fill[Double](ctx.params.numFeatures)(2 * rng.nextDouble() - 1))
// Small intercept to prevent some skew in the data.
val intercept = 0.01 * (2 * rng.nextDouble - 1)
ModelBuilder.newLinearRegressionModel(coefficients, intercept)
}
override def getEstimator(ctx: MLBenchContext): Estimator[_] = {
import ctx.params._
new ml.regression.LinearRegression()
.setSolver("l-bfgs")
.setRegParam(regParam)
.setMaxIter(maxIter)
.setTol(tol)
}
override protected def evaluator(ctx: MLBenchContext): Evaluator =
new RegressionEvaluator()
}

View File

@ -63,3 +63,11 @@ benchmarks:
numClasses: 4
numFeatures: 5
maxIter: 3
- name: regression.LinearRegression
params:
numExamples: 100
numTestExamples: 100
numFeatures: 100
regParam: 0.1
tol: [0.0]
maxIter: 10

View File

@ -2,8 +2,7 @@ package org.apache.spark.ml
import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, LogisticRegressionModel}
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.regression.GeneralizedLinearRegressionModel
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
import org.apache.spark.ml.regression.{LinearRegressionModel, GeneralizedLinearRegressionModel, DecisionTreeRegressionModel}
import org.apache.spark.ml.tree._
import org.apache.spark.mllib.random.RandomDataGenerator
import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
@ -20,6 +19,12 @@ object ModelBuilder {
new LogisticRegressionModel("lr", coefficients, intercept)
}
def newLinearRegressionModel(
coefficients: Vector,
intercept: Double): LinearRegressionModel = {
new LinearRegressionModel("linr", coefficients, intercept)
}
def newGLR(
coefficients: Vector,
intercept: Double): GeneralizedLinearRegressionModel =