Merge pull request #74 from jkbradley/dt-tests
Decision tree, random forest, GBT classification perf tests
This commit is contained in:
commit
3d3443791c
@ -2,8 +2,8 @@ package com.databricks.spark.sql.perf.mllib
|
||||
|
||||
import com.typesafe.scalalogging.slf4j.Logging
|
||||
|
||||
import org.apache.spark.ml.Transformer
|
||||
import org.apache.spark.ml.evaluation.{MulticlassClassificationEvaluator, Evaluator}
|
||||
import org.apache.spark.ml.{Estimator, Transformer}
|
||||
import org.apache.spark.ml.evaluation.Evaluator
|
||||
import org.apache.spark.sql._
|
||||
import org.apache.spark.sql.functions._
|
||||
|
||||
@ -25,10 +25,10 @@ trait BenchmarkAlgorithm extends Logging {
|
||||
|
||||
def testDataSet(ctx: MLBenchContext): DataFrame
|
||||
|
||||
@throws[Exception]("if training fails")
|
||||
def train(
|
||||
ctx: MLBenchContext,
|
||||
trainingSet: DataFrame): Transformer
|
||||
/**
|
||||
* Create an [[Estimator]] with params set from the given [[MLBenchContext]].
|
||||
*/
|
||||
def getEstimator(ctx: MLBenchContext): Estimator[_]
|
||||
|
||||
/**
|
||||
* The unnormalized score of the training procedure on a dataset. The normalization is
|
||||
|
||||
@ -2,10 +2,12 @@ package com.databricks.spark.sql.perf.mllib
|
||||
|
||||
import com.databricks.spark.sql.perf._
|
||||
import com.typesafe.scalalogging.slf4j.Logging
|
||||
import org.apache.spark.sql._
|
||||
|
||||
import org.apache.spark.sql._
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
|
||||
import org.apache.spark.ml.Transformer
|
||||
|
||||
class MLTransformerBenchmarkable(
|
||||
params: MLParams,
|
||||
test: BenchmarkAlgorithm,
|
||||
@ -44,7 +46,11 @@ class MLTransformerBenchmarkable(
|
||||
description: String,
|
||||
messages: ArrayBuffer[String]): BenchmarkResult = {
|
||||
try {
|
||||
val (trainingTime, model) = measureTime(test.train(param, trainingData))
|
||||
val (trainingTime, model: Transformer) = measureTime {
|
||||
logger.info(s"$this: train: trainingSet=${trainingData.schema}")
|
||||
val estimator = test.getEstimator(param)
|
||||
estimator.fit(trainingData)
|
||||
}
|
||||
logger.info(s"model: $model")
|
||||
val (_, scoreTraining) = measureTime {
|
||||
test.score(param, trainingData, model)
|
||||
|
||||
@ -0,0 +1,58 @@
|
||||
package com.databricks.spark.sql.perf.mllib.classification
|
||||
|
||||
import org.apache.spark.ml.{Estimator, ModelBuilder, Transformer}
|
||||
import org.apache.spark.ml.classification.DecisionTreeClassifier
|
||||
import org.apache.spark.ml.evaluation.{Evaluator, MulticlassClassificationEvaluator}
|
||||
|
||||
import com.databricks.spark.sql.perf.mllib.OptionImplicits._
|
||||
import com.databricks.spark.sql.perf.mllib._
|
||||
import com.databricks.spark.sql.perf.mllib.data.DataGenerator
|
||||
|
||||
|
||||
abstract class TreeOrForestClassification extends BenchmarkAlgorithm
|
||||
with TestFromTraining with TrainingSetFromTransformer with ScoringWithEvaluator {
|
||||
|
||||
override protected def initialData(ctx: MLBenchContext) = {
|
||||
import ctx.params._
|
||||
DataGenerator.generateMixedFeatures(ctx.sqlContext, numExamples, ctx.seed(), numPartitions,
|
||||
TreeOrForestClassification.getFeatureArity(ctx))
|
||||
}
|
||||
|
||||
override protected def trueModel(ctx: MLBenchContext): Transformer = {
|
||||
ModelBuilder.newDecisionTreeClassificationModel(ctx.params.depth, ctx.params.numClasses,
|
||||
TreeOrForestClassification.getFeatureArity(ctx), ctx.seed())
|
||||
}
|
||||
|
||||
override protected def evaluator(ctx: MLBenchContext): Evaluator =
|
||||
new MulticlassClassificationEvaluator()
|
||||
}
|
||||
|
||||
object DecisionTreeClassification extends TreeOrForestClassification {
|
||||
|
||||
override def getEstimator(ctx: MLBenchContext): Estimator[_] = {
|
||||
import ctx.params._
|
||||
new DecisionTreeClassifier()
|
||||
.setMaxDepth(depth)
|
||||
.setSeed(ctx.seed())
|
||||
}
|
||||
}
|
||||
|
||||
object TreeOrForestClassification {
|
||||
|
||||
/**
|
||||
* Get feature arity for tree and tree ensemble tests.
|
||||
* Currently, this is hard-coded as:
|
||||
* - 1/2 binary features
|
||||
* - 1/2 high-arity (20-category) features
|
||||
* - 1/2 continuous features
|
||||
* @return Array of length numFeatures, where 0 indicates continuous feature and
|
||||
* value > 0 indicates a categorical feature of that arity.
|
||||
*/
|
||||
def getFeatureArity(ctx: MLBenchContext): Array[Int] = {
|
||||
val numFeatures = ctx.params.numFeatures
|
||||
val fourthFeatures = numFeatures / 4
|
||||
Array.fill[Int](fourthFeatures)(2) ++ // low-arity categorical
|
||||
Array.fill[Int](fourthFeatures)(20) ++ // high-arity categorical
|
||||
Array.fill[Int](numFeatures - 2 * fourthFeatures)(0) // continuous
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,40 @@
|
||||
package com.databricks.spark.sql.perf.mllib.classification
|
||||
|
||||
import org.apache.spark.ml.{Estimator, ModelBuilder, Transformer}
|
||||
import org.apache.spark.ml.classification.GBTClassifier
|
||||
import org.apache.spark.ml.evaluation.{Evaluator, MulticlassClassificationEvaluator}
|
||||
|
||||
import com.databricks.spark.sql.perf.mllib._
|
||||
import com.databricks.spark.sql.perf.mllib.OptionImplicits._
|
||||
import com.databricks.spark.sql.perf.mllib.data.DataGenerator
|
||||
|
||||
|
||||
object GBTClassification extends BenchmarkAlgorithm
|
||||
with TestFromTraining with TrainingSetFromTransformer with ScoringWithEvaluator {
|
||||
|
||||
override protected def initialData(ctx: MLBenchContext) = {
|
||||
import ctx.params._
|
||||
DataGenerator.generateMixedFeatures(ctx.sqlContext, numExamples, ctx.seed(), numPartitions,
|
||||
TreeOrForestClassification.getFeatureArity(ctx))
|
||||
}
|
||||
|
||||
override protected def trueModel(ctx: MLBenchContext): Transformer = {
|
||||
// We add +1 to the depth to make it more likely that many iterations of boosting are needed
|
||||
// to model the true tree.
|
||||
ModelBuilder.newDecisionTreeClassificationModel(ctx.params.depth + 1, ctx.params.numClasses,
|
||||
TreeOrForestClassification.getFeatureArity(ctx), ctx.seed())
|
||||
}
|
||||
|
||||
override def getEstimator(ctx: MLBenchContext): Estimator[_] = {
|
||||
import ctx.params._
|
||||
// TODO: subsamplingRate, featureSubsetStrategy
|
||||
// TODO: cacheNodeIds, checkpoint?
|
||||
new GBTClassifier()
|
||||
.setMaxDepth(depth)
|
||||
.setMaxIter(maxIter)
|
||||
.setSeed(ctx.seed())
|
||||
}
|
||||
|
||||
override protected def evaluator(ctx: MLBenchContext): Evaluator =
|
||||
new MulticlassClassificationEvaluator()
|
||||
}
|
||||
@ -3,19 +3,19 @@ package com.databricks.spark.sql.perf.mllib.classification
|
||||
import com.databricks.spark.sql.perf.mllib.OptionImplicits._
|
||||
import com.databricks.spark.sql.perf.mllib._
|
||||
import com.databricks.spark.sql.perf.mllib.data.DataGenerator
|
||||
import org.apache.spark.ml.evaluation.{MulticlassClassificationEvaluator, Evaluator}
|
||||
|
||||
import org.apache.spark.ml.{Transformer, ModelBuilder}
|
||||
import org.apache.spark.ml.evaluation.{Evaluator, MulticlassClassificationEvaluator}
|
||||
import org.apache.spark.ml.{Estimator, ModelBuilder, Transformer}
|
||||
import org.apache.spark.ml
|
||||
import org.apache.spark.ml.linalg.Vectors
|
||||
import org.apache.spark.sql.DataFrame
|
||||
|
||||
|
||||
object LogisticRegression extends BenchmarkAlgorithm
|
||||
with TestFromTraining with TrainingSetFromTransformer with ScoringWithEvaluator {
|
||||
|
||||
override protected def initialData(ctx: MLBenchContext) = {
|
||||
import ctx.params._
|
||||
DataGenerator.generateFeatures(
|
||||
DataGenerator.generateContinuousFeatures(
|
||||
ctx.sqlContext,
|
||||
numExamples,
|
||||
ctx.seed(),
|
||||
@ -32,15 +32,12 @@ object LogisticRegression extends BenchmarkAlgorithm
|
||||
ModelBuilder.newLogisticRegressionModel(coefficients, intercept)
|
||||
}
|
||||
|
||||
override def train(ctx: MLBenchContext,
|
||||
trainingSet: DataFrame): Transformer = {
|
||||
logger.info(s"$this: train: trainingSet=${trainingSet.schema}")
|
||||
override def getEstimator(ctx: MLBenchContext): Estimator[_] = {
|
||||
import ctx.params._
|
||||
val lr = new ml.classification.LogisticRegression()
|
||||
new ml.classification.LogisticRegression()
|
||||
.setTol(tol)
|
||||
.setMaxIter(maxIter)
|
||||
.setRegParam(regParam)
|
||||
lr.fit(trainingSet)
|
||||
}
|
||||
|
||||
override protected def evaluator(ctx: MLBenchContext): Evaluator =
|
||||
|
||||
@ -0,0 +1,21 @@
|
||||
package com.databricks.spark.sql.perf.mllib.classification
|
||||
|
||||
import org.apache.spark.ml.Estimator
|
||||
import org.apache.spark.ml.classification.RandomForestClassifier
|
||||
|
||||
import com.databricks.spark.sql.perf.mllib._
|
||||
import com.databricks.spark.sql.perf.mllib.OptionImplicits._
|
||||
|
||||
|
||||
object RandomForestClassification extends TreeOrForestClassification {
|
||||
|
||||
override def getEstimator(ctx: MLBenchContext): Estimator[_] = {
|
||||
import ctx.params._
|
||||
// TODO: subsamplingRate, featureSubsetStrategy
|
||||
// TODO: cacheNodeIds, checkpoint?
|
||||
new RandomForestClassifier()
|
||||
.setMaxDepth(depth)
|
||||
.setNumTrees(maxIter)
|
||||
.setSeed(ctx.seed())
|
||||
}
|
||||
}
|
||||
@ -1,14 +1,18 @@
|
||||
package com.databricks.spark.sql.perf.mllib.clustering
|
||||
|
||||
import com.databricks.spark.sql.perf.mllib.{MLBenchContext, TestFromTraining, BenchmarkAlgorithm}
|
||||
import com.databricks.spark.sql.perf.mllib.OptionImplicits._
|
||||
import scala.collection.mutable.{HashMap => MHashMap}
|
||||
|
||||
import org.apache.commons.math3.random.Well19937c
|
||||
import org.apache.spark.ml.Transformer
|
||||
|
||||
import org.apache.spark.ml.Estimator
|
||||
import org.apache.spark.ml
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql._
|
||||
import org.apache.spark.ml.linalg.{Vectors, Vector}
|
||||
import scala.collection.mutable.{HashMap => MHashMap}
|
||||
import org.apache.spark.ml.linalg.{Vector, Vectors}
|
||||
|
||||
import com.databricks.spark.sql.perf.mllib.{BenchmarkAlgorithm, MLBenchContext, TestFromTraining}
|
||||
import com.databricks.spark.sql.perf.mllib.OptionImplicits._
|
||||
|
||||
|
||||
object LDA extends BenchmarkAlgorithm with TestFromTraining {
|
||||
// The LDA model is package private, no need to expose it.
|
||||
@ -40,16 +44,14 @@ object LDA extends BenchmarkAlgorithm with TestFromTraining {
|
||||
ctx.sqlContext.createDataFrame(data).toDF("docIndex", "features")
|
||||
}
|
||||
|
||||
override def train(ctx: MLBenchContext,
|
||||
trainingSet: DataFrame): Transformer = {
|
||||
override def getEstimator(ctx: MLBenchContext): Estimator[_] = {
|
||||
import ctx.params._
|
||||
new ml.clustering.LDA()
|
||||
.setK(k)
|
||||
.setSeed(randomSeed.toLong)
|
||||
.setMaxIter(maxIter)
|
||||
.setOptimizer(optimizer)
|
||||
.fit(trainingSet)
|
||||
.setK(k)
|
||||
.setSeed(randomSeed.toLong)
|
||||
.setMaxIter(maxIter)
|
||||
.setOptimizer(optimizer)
|
||||
}
|
||||
|
||||
// TODO(?) add a scoring method here.
|
||||
}
|
||||
}
|
||||
|
||||
@ -0,0 +1,80 @@
|
||||
package com.databricks.spark.sql.perf.mllib.data
|
||||
|
||||
import org.apache.spark.ml.linalg.Vectors
|
||||
import org.apache.spark.ml.linalg.Vector
|
||||
import org.apache.spark.mllib.random._
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.{SQLContext, DataFrame}
|
||||
|
||||
|
||||
object DataGenerator {
|
||||
|
||||
def generateContinuousFeatures(
|
||||
sql: SQLContext,
|
||||
numExamples: Long,
|
||||
seed: Long,
|
||||
numPartitions: Int,
|
||||
numFeatures: Int): DataFrame = {
|
||||
val featureArity = Array.fill[Int](numFeatures)(0)
|
||||
val rdd: RDD[Vector] = RandomRDDs.randomRDD(sql.sparkContext,
|
||||
new FeaturesGenerator(featureArity), numExamples, numPartitions, seed)
|
||||
sql.createDataFrame(rdd.map(Tuple1.apply)).toDF("features")
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a mix of continuous and categorical features.
|
||||
* @param featureArity Array of length numFeatures, where 0 indicates a continuous feature and
|
||||
* a value > 0 indicates a categorical feature with that arity.
|
||||
*/
|
||||
def generateMixedFeatures(
|
||||
sql: SQLContext,
|
||||
numExamples: Long,
|
||||
seed: Long,
|
||||
numPartitions: Int,
|
||||
featureArity: Array[Int]): DataFrame = {
|
||||
val rdd: RDD[Vector] = RandomRDDs.randomRDD(sql.sparkContext,
|
||||
new FeaturesGenerator(featureArity), numExamples, numPartitions, seed)
|
||||
sql.createDataFrame(rdd.map(Tuple1.apply)).toDF("features")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Generator for a feature vector which can include a mix of categorical and continuous features.
|
||||
* @param featureArity Length numFeatures, where 0 indicates continuous feature and > 0
|
||||
* indicates a categorical feature of that arity.
|
||||
*/
|
||||
class FeaturesGenerator(val featureArity: Array[Int])
|
||||
extends RandomDataGenerator[Vector] {
|
||||
|
||||
featureArity.foreach { arity =>
|
||||
require(arity >= 0, s"FeaturesGenerator given categorical arity = $arity, " +
|
||||
s"but arity should be >= 0.")
|
||||
}
|
||||
|
||||
val numFeatures = featureArity.length
|
||||
|
||||
private val rng = new java.util.Random()
|
||||
|
||||
/**
|
||||
* Generates vector with features in the order given by [[featureArity]]
|
||||
*/
|
||||
override def nextValue(): Vector = {
|
||||
val arr = new Array[Double](numFeatures)
|
||||
var j = 0
|
||||
while (j < featureArity.length) {
|
||||
if (featureArity(j) == 0)
|
||||
arr(j) = 2 * rng.nextDouble() - 1 // centered uniform data
|
||||
else
|
||||
arr(j) = rng.nextInt(featureArity(j))
|
||||
j += 1
|
||||
}
|
||||
Vectors.dense(arr)
|
||||
}
|
||||
|
||||
override def setSeed(seed: Long) {
|
||||
rng.setSeed(seed)
|
||||
}
|
||||
|
||||
override def copy(): FeaturesGenerator = new FeaturesGenerator(featureArity)
|
||||
}
|
||||
@ -1,102 +0,0 @@
|
||||
package com.databricks.spark.sql.perf.mllib.data
|
||||
|
||||
import org.apache.spark.ml.linalg.Vectors
|
||||
import org.apache.spark.ml.linalg.Vector
|
||||
import org.apache.spark.mllib.random._
|
||||
import org.apache.spark.mllib.regression.LabeledPoint
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.{SQLContext, DataFrame}
|
||||
|
||||
object DataGenerator {
|
||||
|
||||
def generateFeatures(
|
||||
sql: SQLContext,
|
||||
numExamples: Long,
|
||||
seed: Long,
|
||||
numPartitions: Int,
|
||||
numFeatures: Int): DataFrame = {
|
||||
val categoricalArities = Array.empty[Int]
|
||||
val rdd: RDD[Vector] = RandomRDDs.randomRDD(sql.sparkContext,
|
||||
new FeaturesGenerator(categoricalArities, numFeatures),
|
||||
numExamples, numPartitions, seed)
|
||||
sql.createDataFrame(rdd.map(Tuple1.apply)).toDF("features")
|
||||
}
|
||||
}
|
||||
|
||||
class BinaryLabeledDataGenerator(
|
||||
private val numFeatures: Int,
|
||||
private val threshold: Double) extends RandomDataGenerator[LabeledPoint] {
|
||||
|
||||
private val rng = new java.util.Random()
|
||||
|
||||
override def nextValue(): LabeledPoint = {
|
||||
val y = if (rng.nextDouble() < threshold) 0.0 else 1.0
|
||||
val x = Array.fill[Double](numFeatures) {
|
||||
if (rng.nextDouble() < threshold) 0.0 else 1.0
|
||||
}
|
||||
???
|
||||
// LabeledPoint(y, Vectors.dense(x))
|
||||
}
|
||||
|
||||
override def setSeed(seed: Long) {
|
||||
rng.setSeed(seed)
|
||||
}
|
||||
|
||||
override def copy(): BinaryLabeledDataGenerator =
|
||||
new BinaryLabeledDataGenerator(numFeatures, threshold)
|
||||
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Generator for a feature vector which can include a mix of categorical and continuous features.
|
||||
* @param categoricalArities Specifies the number of categories for each categorical feature.
|
||||
* @param numContinuous Number of continuous features. Feature values are in range [0,1].
|
||||
*/
|
||||
class FeaturesGenerator(val categoricalArities: Array[Int], val numContinuous: Int)
|
||||
extends RandomDataGenerator[Vector] {
|
||||
|
||||
categoricalArities.foreach { arity =>
|
||||
require(arity >= 2, s"FeaturesGenerator given categorical arity = $arity, " +
|
||||
s"but arity should be >= 2.")
|
||||
}
|
||||
|
||||
val numFeatures = categoricalArities.length + numContinuous
|
||||
|
||||
private val rng = new java.util.Random()
|
||||
|
||||
/**
|
||||
* Generates vector with categorical features first, and continuous features in [0,1] second.
|
||||
*/
|
||||
override def nextValue(): Vector = {
|
||||
// Feature ordering matches getCategoricalFeaturesInfo.
|
||||
val arr = new Array[Double](numFeatures)
|
||||
var j = 0
|
||||
while (j < categoricalArities.length) {
|
||||
arr(j) = rng.nextInt(categoricalArities(j))
|
||||
j += 1
|
||||
}
|
||||
while (j < numFeatures) {
|
||||
// Generating some centered data
|
||||
arr(j) = 2 * rng.nextDouble() - 1
|
||||
j += 1
|
||||
}
|
||||
Vectors.dense(arr)
|
||||
}
|
||||
|
||||
override def setSeed(seed: Long) {
|
||||
rng.setSeed(seed)
|
||||
}
|
||||
|
||||
override def copy(): FeaturesGenerator = new FeaturesGenerator(categoricalArities, numContinuous)
|
||||
|
||||
/**
|
||||
* @return categoricalFeaturesInfo Map storing arity of categorical features.
|
||||
* E.g., an entry (n -> k) indicates that feature n is categorical
|
||||
* with k categories indexed from 0: {0, 1, ..., k-1}.
|
||||
*/
|
||||
def getCategoricalFeaturesInfo: Map[Int, Int] = {
|
||||
// Categorical features are indexed from 0 because of the implementation of nextValue().
|
||||
categoricalArities.zipWithIndex.map(_.swap).toMap
|
||||
}
|
||||
}
|
||||
@ -1,13 +1,13 @@
|
||||
package com.databricks.spark.sql.perf.mllib.regression
|
||||
|
||||
import org.apache.spark.ml.evaluation.{Evaluator, RegressionEvaluator}
|
||||
import org.apache.spark.ml.linalg.Vectors
|
||||
import org.apache.spark.ml.regression.GeneralizedLinearRegression
|
||||
import org.apache.spark.ml.{Estimator, ModelBuilder, Transformer}
|
||||
|
||||
import com.databricks.spark.sql.perf.mllib.OptionImplicits._
|
||||
import com.databricks.spark.sql.perf.mllib._
|
||||
import com.databricks.spark.sql.perf.mllib.data.DataGenerator
|
||||
import org.apache.spark.ml.evaluation.{Evaluator, RegressionEvaluator}
|
||||
import org.apache.spark.ml.linalg.Vectors
|
||||
import org.apache.spark.ml.regression.GeneralizedLinearRegression
|
||||
import org.apache.spark.ml.{ModelBuilder, Transformer}
|
||||
import org.apache.spark.sql._
|
||||
|
||||
|
||||
object GLMRegression extends BenchmarkAlgorithm with TestFromTraining with
|
||||
@ -15,7 +15,7 @@ object GLMRegression extends BenchmarkAlgorithm with TestFromTraining with
|
||||
|
||||
override protected def initialData(ctx: MLBenchContext) = {
|
||||
import ctx.params._
|
||||
DataGenerator.generateFeatures(
|
||||
DataGenerator.generateContinuousFeatures(
|
||||
ctx.sqlContext,
|
||||
numExamples,
|
||||
ctx.seed(),
|
||||
@ -36,18 +36,14 @@ object GLMRegression extends BenchmarkAlgorithm with TestFromTraining with
|
||||
m
|
||||
}
|
||||
|
||||
override def train(
|
||||
ctx: MLBenchContext,
|
||||
trainingSet: DataFrame): Transformer = {
|
||||
logger.info(s"$this: train: trainingSet=${trainingSet.schema}")
|
||||
override def getEstimator(ctx: MLBenchContext): Estimator[_] = {
|
||||
import ctx.params._
|
||||
val glr = new GeneralizedLinearRegression()
|
||||
new GeneralizedLinearRegression()
|
||||
.setLink(link)
|
||||
.setFamily(family)
|
||||
.setRegParam(regParam)
|
||||
.setMaxIter(maxIter)
|
||||
.setTol(tol)
|
||||
glr.fit(trainingSet)
|
||||
}
|
||||
|
||||
override protected def evaluator(ctx: MLBenchContext): Evaluator =
|
||||
|
||||
@ -113,13 +113,15 @@ case class MLParams(
|
||||
numTestExamples: Option[Long] = None,
|
||||
numPartitions: Option[Int] = None,
|
||||
// *** Specialized and sorted by name ***
|
||||
depth: Option[Int] = None,
|
||||
elasticNetParam: Option[Double] = None,
|
||||
family: Option[String] = None,
|
||||
link: Option[String] = None,
|
||||
k: Option[Int] = None,
|
||||
ldaDocLength: Option[Int] = None,
|
||||
ldaNumVocabulary: Option[Int] = None,
|
||||
link: Option[String] = None,
|
||||
maxIter: Option[Int] = None,
|
||||
numClasses: Option[Int] = None,
|
||||
numFeatures: Option[Int] = None,
|
||||
optimizer: Option[String] = None,
|
||||
regParam: Option[Double] = None,
|
||||
|
||||
@ -33,3 +33,26 @@ benchmarks:
|
||||
tol: 0.0
|
||||
maxIter: 10
|
||||
regParam: 0.1
|
||||
- name: classification.DecisionTreeClassification
|
||||
params:
|
||||
numExamples: 100
|
||||
numTestExamples: 10
|
||||
depth: 3
|
||||
numClasses: 4
|
||||
numFeatures: 5
|
||||
- name: classification.RandomForestClassification
|
||||
params:
|
||||
numExamples: 100
|
||||
numTestExamples: 10
|
||||
depth: 3
|
||||
numClasses: 4
|
||||
numFeatures: 5
|
||||
maxIter: 3
|
||||
- name: classification.GBTClassification
|
||||
params:
|
||||
numExamples: 100
|
||||
numTestExamples: 10
|
||||
depth: 3
|
||||
numClasses: 4
|
||||
numFeatures: 5
|
||||
maxIter: 3
|
||||
|
||||
@ -1,8 +1,13 @@
|
||||
package org.apache.spark.ml
|
||||
|
||||
import org.apache.spark.ml.classification.LogisticRegressionModel
|
||||
import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, LogisticRegressionModel}
|
||||
import org.apache.spark.ml.linalg.Vector
|
||||
import org.apache.spark.ml.regression.GeneralizedLinearRegressionModel
|
||||
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
|
||||
import org.apache.spark.ml.tree._
|
||||
import org.apache.spark.mllib.random.RandomDataGenerator
|
||||
import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
|
||||
|
||||
|
||||
/**
|
||||
* Helper for creating MLlib models which have private constructors.
|
||||
@ -19,4 +24,194 @@ object ModelBuilder {
|
||||
coefficients: Vector,
|
||||
intercept: Double): GeneralizedLinearRegressionModel =
|
||||
new GeneralizedLinearRegressionModel("glr-uid", coefficients, intercept)
|
||||
}
|
||||
|
||||
def newDecisionTreeClassificationModel(
|
||||
depth: Int,
|
||||
numClasses: Int,
|
||||
featureArity: Array[Int],
|
||||
seed: Long): DecisionTreeClassificationModel = {
|
||||
require(numClasses >= 2, s"DecisionTreeClassificationModel requires numClasses >= 2," +
|
||||
s" but was given $numClasses")
|
||||
val rootNode = TreeBuilder.randomBalancedDecisionTree(depth = depth, labelType = numClasses,
|
||||
featureArity = featureArity, seed = seed)
|
||||
new DecisionTreeClassificationModel(rootNode, numFeatures = featureArity.length,
|
||||
numClasses = numClasses)
|
||||
}
|
||||
|
||||
def newDecisionTreeRegressionModel(
|
||||
depth: Int,
|
||||
featureArity: Array[Int],
|
||||
seed: Long): DecisionTreeRegressionModel = {
|
||||
val rootNode = TreeBuilder.randomBalancedDecisionTree(depth = depth, labelType = 0,
|
||||
featureArity = featureArity, seed = seed)
|
||||
new DecisionTreeRegressionModel(rootNode, numFeatures = featureArity.length)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Helpers for creating random decision trees.
|
||||
*/
|
||||
object TreeBuilder {
|
||||
|
||||
/**
|
||||
* Generator for a pair of distinct class labels from the set {0,...,numClasses-1}.
|
||||
* Pairs are useful for trees to make sure sibling leaf nodes make different predictions.
|
||||
* @param numClasses Number of classes.
|
||||
*/
|
||||
private class ClassLabelPairGenerator(val numClasses: Int)
|
||||
extends RandomDataGenerator[Pair[Double, Double]] {
|
||||
|
||||
require(numClasses >= 2,
|
||||
s"ClassLabelPairGenerator given label numClasses = $numClasses, but numClasses should be >= 2.")
|
||||
|
||||
private val rng = new java.util.Random()
|
||||
|
||||
override def nextValue(): Pair[Double, Double] = {
|
||||
val left = rng.nextInt(numClasses)
|
||||
var right = rng.nextInt(numClasses)
|
||||
while (right == left) {
|
||||
right = rng.nextInt(numClasses)
|
||||
}
|
||||
new Pair[Double, Double](left, right)
|
||||
}
|
||||
|
||||
override def setSeed(seed: Long): Unit = {
|
||||
rng.setSeed(seed)
|
||||
}
|
||||
|
||||
override def copy(): ClassLabelPairGenerator = new ClassLabelPairGenerator(numClasses)
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Generator for a pair of real-valued labels.
|
||||
* Pairs are useful for trees to make sure sibling leaf nodes make different predictions.
|
||||
*/
|
||||
private class RealLabelPairGenerator() extends RandomDataGenerator[Pair[Double, Double]] {
|
||||
|
||||
private val rng = new java.util.Random()
|
||||
|
||||
override def nextValue(): Pair[Double, Double] =
|
||||
new Pair[Double, Double](rng.nextDouble(), rng.nextDouble())
|
||||
|
||||
override def setSeed(seed: Long): Unit = {
|
||||
rng.setSeed(seed)
|
||||
}
|
||||
|
||||
override def copy(): RealLabelPairGenerator = new RealLabelPairGenerator()
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a random decision tree structure.
|
||||
* @param depth Depth of tree to build. Must be <= numFeatures.
|
||||
* @param labelType Value 0 indicates regression. Integers >= 2 indicate numClasses for
|
||||
* classification.
|
||||
* @param featureArity Array of length numFeatures indicating feature type.
|
||||
* Value 0 indicates continuous feature.
|
||||
* Other values >= 2 indicate a categorical feature,
|
||||
* where the value is the number of categories.
|
||||
* @return root node of tree
|
||||
*/
|
||||
def randomBalancedDecisionTree(
|
||||
depth: Int,
|
||||
labelType: Int,
|
||||
featureArity: Array[Int],
|
||||
seed: Long): Node = {
|
||||
require(depth >= 0, s"randomBalancedDecisionTree given depth < 0.")
|
||||
val numFeatures = featureArity.length
|
||||
require(depth <= numFeatures,
|
||||
s"randomBalancedDecisionTree requires depth <= featureArity.size," +
|
||||
s" but depth = $depth and featureArity.size = $numFeatures")
|
||||
val isRegression = labelType == 0
|
||||
if (!isRegression) {
|
||||
require(labelType >= 2, s"labelType must be >= 2 for classification. 0 indicates regression.")
|
||||
}
|
||||
|
||||
val rng = new scala.util.Random()
|
||||
rng.setSeed(seed)
|
||||
|
||||
val labelGenerator = if (isRegression) {
|
||||
new RealLabelPairGenerator()
|
||||
} else {
|
||||
new ClassLabelPairGenerator(labelType)
|
||||
}
|
||||
labelGenerator.setSeed(rng.nextLong)
|
||||
// We use a dummy impurityCalculator for all nodes.
|
||||
val impurityCalculator = if (isRegression) {
|
||||
ImpurityCalculator.getCalculator("variance", Array.fill[Double](3)(0.0))
|
||||
} else {
|
||||
ImpurityCalculator.getCalculator("gini", Array.fill[Double](labelType)(0.0))
|
||||
}
|
||||
|
||||
randomBalancedDecisionTreeHelper(depth, featureArity, impurityCalculator,
|
||||
labelGenerator, Set.empty, rng)
|
||||
}
|
||||
|
||||
/**
|
||||
* Create an internal node. Either create the leaf nodes beneath it, or recurse as needed.
|
||||
* @param subtreeDepth Depth of subtree to build. Depth 0 means this is a leaf node.
|
||||
* @param featureArity Indicates feature type. Value 0 indicates continuous feature.
|
||||
* Other values >= 2 indicate a categorical feature,
|
||||
* where the value is the number of categories.
|
||||
* @param impurityCalculator Dummy impurity calculator to use at all tree nodes
|
||||
* @param usedFeatures Features appearing in the path from the tree root to the node
|
||||
* being constructed.
|
||||
* @param labelGenerator Generates pairs of distinct labels.
|
||||
* @return
|
||||
*/
|
||||
private def randomBalancedDecisionTreeHelper(
|
||||
subtreeDepth: Int,
|
||||
featureArity: Array[Int],
|
||||
impurityCalculator: ImpurityCalculator,
|
||||
labelGenerator: RandomDataGenerator[Pair[Double, Double]],
|
||||
usedFeatures: Set[Int],
|
||||
rng: scala.util.Random): Node = {
|
||||
|
||||
if (subtreeDepth == 0) {
|
||||
// This case only happens for a depth 0 tree.
|
||||
return new LeafNode(prediction = 0.0, impurity = 0.0, impurityStats = impurityCalculator)
|
||||
}
|
||||
|
||||
val numFeatures = featureArity.length
|
||||
// Should not happen.
|
||||
assert(usedFeatures.size < numFeatures, s"randomBalancedDecisionTreeSplitNode ran out of " +
|
||||
s"features for splits.")
|
||||
|
||||
// Make node internal.
|
||||
var feature: Int = rng.nextInt(numFeatures)
|
||||
while (usedFeatures.contains(feature)) {
|
||||
feature = rng.nextInt(numFeatures)
|
||||
}
|
||||
val split: Split = if (featureArity(feature) == 0) {
|
||||
// continuous feature
|
||||
new ContinuousSplit(featureIndex = feature, threshold = rng.nextDouble())
|
||||
} else {
|
||||
// categorical feature
|
||||
// Put nCatsSplit categories on left, and the rest on the right.
|
||||
// nCatsSplit is in {1,...,arity-1}.
|
||||
val nCatsSplit = rng.nextInt(featureArity(feature) - 1) + 1
|
||||
val splitCategories: Array[Double] =
|
||||
rng.shuffle(Range(0,featureArity(feature)).toList).toArray.map(_.toDouble).take(nCatsSplit)
|
||||
new CategoricalSplit(featureIndex = feature,
|
||||
_leftCategories = splitCategories, numCategories = featureArity(feature))
|
||||
}
|
||||
|
||||
val (leftChild: Node, rightChild: Node) = if (subtreeDepth == 1) {
|
||||
// Add leaf nodes. Assign these jointly so they make different predictions.
|
||||
val predictions = labelGenerator.nextValue()
|
||||
val leftChild = new LeafNode(prediction = predictions._1, impurity = 0.0,
|
||||
impurityStats = impurityCalculator)
|
||||
val rightChild = new LeafNode(prediction = predictions._2, impurity = 0.0,
|
||||
impurityStats = impurityCalculator)
|
||||
(leftChild, rightChild)
|
||||
} else {
|
||||
val leftChild = randomBalancedDecisionTreeHelper(subtreeDepth - 1, featureArity,
|
||||
impurityCalculator, labelGenerator, usedFeatures + feature, rng)
|
||||
val rightChild = randomBalancedDecisionTreeHelper(subtreeDepth - 1, featureArity,
|
||||
impurityCalculator, labelGenerator, usedFeatures + feature, rng)
|
||||
(leftChild, rightChild)
|
||||
}
|
||||
new InternalNode(prediction = 0.0, impurity = 0.0, gain = 0.0, leftChild = leftChild,
|
||||
rightChild = rightChild, split = split, impurityStats = impurityCalculator)
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user