Merge pull request #74 from jkbradley/dt-tests

Decision tree, random forest, GBT classification perf tests
This commit is contained in:
jkbradley 2016-07-01 17:40:16 -07:00 committed by GitHub
commit 3d3443791c
13 changed files with 465 additions and 147 deletions

View File

@ -2,8 +2,8 @@ package com.databricks.spark.sql.perf.mllib
import com.typesafe.scalalogging.slf4j.Logging
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.evaluation.{MulticlassClassificationEvaluator, Evaluator}
import org.apache.spark.ml.{Estimator, Transformer}
import org.apache.spark.ml.evaluation.Evaluator
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
@ -25,10 +25,10 @@ trait BenchmarkAlgorithm extends Logging {
def testDataSet(ctx: MLBenchContext): DataFrame
@throws[Exception]("if training fails")
def train(
ctx: MLBenchContext,
trainingSet: DataFrame): Transformer
/**
* Create an [[Estimator]] with params set from the given [[MLBenchContext]].
*/
def getEstimator(ctx: MLBenchContext): Estimator[_]
/**
* The unnormalized score of the training procedure on a dataset. The normalization is

View File

@ -2,10 +2,12 @@ package com.databricks.spark.sql.perf.mllib
import com.databricks.spark.sql.perf._
import com.typesafe.scalalogging.slf4j.Logging
import org.apache.spark.sql._
import org.apache.spark.sql._
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.ml.Transformer
class MLTransformerBenchmarkable(
params: MLParams,
test: BenchmarkAlgorithm,
@ -44,7 +46,11 @@ class MLTransformerBenchmarkable(
description: String,
messages: ArrayBuffer[String]): BenchmarkResult = {
try {
val (trainingTime, model) = measureTime(test.train(param, trainingData))
val (trainingTime, model: Transformer) = measureTime {
logger.info(s"$this: train: trainingSet=${trainingData.schema}")
val estimator = test.getEstimator(param)
estimator.fit(trainingData)
}
logger.info(s"model: $model")
val (_, scoreTraining) = measureTime {
test.score(param, trainingData, model)

View File

@ -0,0 +1,58 @@
package com.databricks.spark.sql.perf.mllib.classification
import org.apache.spark.ml.{Estimator, ModelBuilder, Transformer}
import org.apache.spark.ml.classification.DecisionTreeClassifier
import org.apache.spark.ml.evaluation.{Evaluator, MulticlassClassificationEvaluator}
import com.databricks.spark.sql.perf.mllib.OptionImplicits._
import com.databricks.spark.sql.perf.mllib._
import com.databricks.spark.sql.perf.mllib.data.DataGenerator
abstract class TreeOrForestClassification extends BenchmarkAlgorithm
with TestFromTraining with TrainingSetFromTransformer with ScoringWithEvaluator {
override protected def initialData(ctx: MLBenchContext) = {
import ctx.params._
DataGenerator.generateMixedFeatures(ctx.sqlContext, numExamples, ctx.seed(), numPartitions,
TreeOrForestClassification.getFeatureArity(ctx))
}
override protected def trueModel(ctx: MLBenchContext): Transformer = {
ModelBuilder.newDecisionTreeClassificationModel(ctx.params.depth, ctx.params.numClasses,
TreeOrForestClassification.getFeatureArity(ctx), ctx.seed())
}
override protected def evaluator(ctx: MLBenchContext): Evaluator =
new MulticlassClassificationEvaluator()
}
object DecisionTreeClassification extends TreeOrForestClassification {
override def getEstimator(ctx: MLBenchContext): Estimator[_] = {
import ctx.params._
new DecisionTreeClassifier()
.setMaxDepth(depth)
.setSeed(ctx.seed())
}
}
object TreeOrForestClassification {
/**
* Get feature arity for tree and tree ensemble tests.
* Currently, this is hard-coded as:
* - 1/2 binary features
* - 1/2 high-arity (20-category) features
* - 1/2 continuous features
* @return Array of length numFeatures, where 0 indicates continuous feature and
* value > 0 indicates a categorical feature of that arity.
*/
def getFeatureArity(ctx: MLBenchContext): Array[Int] = {
val numFeatures = ctx.params.numFeatures
val fourthFeatures = numFeatures / 4
Array.fill[Int](fourthFeatures)(2) ++ // low-arity categorical
Array.fill[Int](fourthFeatures)(20) ++ // high-arity categorical
Array.fill[Int](numFeatures - 2 * fourthFeatures)(0) // continuous
}
}

View File

@ -0,0 +1,40 @@
package com.databricks.spark.sql.perf.mllib.classification
import org.apache.spark.ml.{Estimator, ModelBuilder, Transformer}
import org.apache.spark.ml.classification.GBTClassifier
import org.apache.spark.ml.evaluation.{Evaluator, MulticlassClassificationEvaluator}
import com.databricks.spark.sql.perf.mllib._
import com.databricks.spark.sql.perf.mllib.OptionImplicits._
import com.databricks.spark.sql.perf.mllib.data.DataGenerator
object GBTClassification extends BenchmarkAlgorithm
with TestFromTraining with TrainingSetFromTransformer with ScoringWithEvaluator {
override protected def initialData(ctx: MLBenchContext) = {
import ctx.params._
DataGenerator.generateMixedFeatures(ctx.sqlContext, numExamples, ctx.seed(), numPartitions,
TreeOrForestClassification.getFeatureArity(ctx))
}
override protected def trueModel(ctx: MLBenchContext): Transformer = {
// We add +1 to the depth to make it more likely that many iterations of boosting are needed
// to model the true tree.
ModelBuilder.newDecisionTreeClassificationModel(ctx.params.depth + 1, ctx.params.numClasses,
TreeOrForestClassification.getFeatureArity(ctx), ctx.seed())
}
override def getEstimator(ctx: MLBenchContext): Estimator[_] = {
import ctx.params._
// TODO: subsamplingRate, featureSubsetStrategy
// TODO: cacheNodeIds, checkpoint?
new GBTClassifier()
.setMaxDepth(depth)
.setMaxIter(maxIter)
.setSeed(ctx.seed())
}
override protected def evaluator(ctx: MLBenchContext): Evaluator =
new MulticlassClassificationEvaluator()
}

View File

@ -3,19 +3,19 @@ package com.databricks.spark.sql.perf.mllib.classification
import com.databricks.spark.sql.perf.mllib.OptionImplicits._
import com.databricks.spark.sql.perf.mllib._
import com.databricks.spark.sql.perf.mllib.data.DataGenerator
import org.apache.spark.ml.evaluation.{MulticlassClassificationEvaluator, Evaluator}
import org.apache.spark.ml.{Transformer, ModelBuilder}
import org.apache.spark.ml.evaluation.{Evaluator, MulticlassClassificationEvaluator}
import org.apache.spark.ml.{Estimator, ModelBuilder, Transformer}
import org.apache.spark.ml
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.sql.DataFrame
object LogisticRegression extends BenchmarkAlgorithm
with TestFromTraining with TrainingSetFromTransformer with ScoringWithEvaluator {
override protected def initialData(ctx: MLBenchContext) = {
import ctx.params._
DataGenerator.generateFeatures(
DataGenerator.generateContinuousFeatures(
ctx.sqlContext,
numExamples,
ctx.seed(),
@ -32,15 +32,12 @@ object LogisticRegression extends BenchmarkAlgorithm
ModelBuilder.newLogisticRegressionModel(coefficients, intercept)
}
override def train(ctx: MLBenchContext,
trainingSet: DataFrame): Transformer = {
logger.info(s"$this: train: trainingSet=${trainingSet.schema}")
override def getEstimator(ctx: MLBenchContext): Estimator[_] = {
import ctx.params._
val lr = new ml.classification.LogisticRegression()
new ml.classification.LogisticRegression()
.setTol(tol)
.setMaxIter(maxIter)
.setRegParam(regParam)
lr.fit(trainingSet)
}
override protected def evaluator(ctx: MLBenchContext): Evaluator =

View File

@ -0,0 +1,21 @@
package com.databricks.spark.sql.perf.mllib.classification
import org.apache.spark.ml.Estimator
import org.apache.spark.ml.classification.RandomForestClassifier
import com.databricks.spark.sql.perf.mllib._
import com.databricks.spark.sql.perf.mllib.OptionImplicits._
object RandomForestClassification extends TreeOrForestClassification {
override def getEstimator(ctx: MLBenchContext): Estimator[_] = {
import ctx.params._
// TODO: subsamplingRate, featureSubsetStrategy
// TODO: cacheNodeIds, checkpoint?
new RandomForestClassifier()
.setMaxDepth(depth)
.setNumTrees(maxIter)
.setSeed(ctx.seed())
}
}

View File

@ -1,14 +1,18 @@
package com.databricks.spark.sql.perf.mllib.clustering
import com.databricks.spark.sql.perf.mllib.{MLBenchContext, TestFromTraining, BenchmarkAlgorithm}
import com.databricks.spark.sql.perf.mllib.OptionImplicits._
import scala.collection.mutable.{HashMap => MHashMap}
import org.apache.commons.math3.random.Well19937c
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.Estimator
import org.apache.spark.ml
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.ml.linalg.{Vectors, Vector}
import scala.collection.mutable.{HashMap => MHashMap}
import org.apache.spark.ml.linalg.{Vector, Vectors}
import com.databricks.spark.sql.perf.mllib.{BenchmarkAlgorithm, MLBenchContext, TestFromTraining}
import com.databricks.spark.sql.perf.mllib.OptionImplicits._
object LDA extends BenchmarkAlgorithm with TestFromTraining {
// The LDA model is package private, no need to expose it.
@ -40,16 +44,14 @@ object LDA extends BenchmarkAlgorithm with TestFromTraining {
ctx.sqlContext.createDataFrame(data).toDF("docIndex", "features")
}
override def train(ctx: MLBenchContext,
trainingSet: DataFrame): Transformer = {
override def getEstimator(ctx: MLBenchContext): Estimator[_] = {
import ctx.params._
new ml.clustering.LDA()
.setK(k)
.setSeed(randomSeed.toLong)
.setMaxIter(maxIter)
.setOptimizer(optimizer)
.fit(trainingSet)
.setK(k)
.setSeed(randomSeed.toLong)
.setMaxIter(maxIter)
.setOptimizer(optimizer)
}
// TODO(?) add a scoring method here.
}
}

View File

@ -0,0 +1,80 @@
package com.databricks.spark.sql.perf.mllib.data
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.mllib.random._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{SQLContext, DataFrame}
object DataGenerator {
def generateContinuousFeatures(
sql: SQLContext,
numExamples: Long,
seed: Long,
numPartitions: Int,
numFeatures: Int): DataFrame = {
val featureArity = Array.fill[Int](numFeatures)(0)
val rdd: RDD[Vector] = RandomRDDs.randomRDD(sql.sparkContext,
new FeaturesGenerator(featureArity), numExamples, numPartitions, seed)
sql.createDataFrame(rdd.map(Tuple1.apply)).toDF("features")
}
/**
* Generate a mix of continuous and categorical features.
* @param featureArity Array of length numFeatures, where 0 indicates a continuous feature and
* a value > 0 indicates a categorical feature with that arity.
*/
def generateMixedFeatures(
sql: SQLContext,
numExamples: Long,
seed: Long,
numPartitions: Int,
featureArity: Array[Int]): DataFrame = {
val rdd: RDD[Vector] = RandomRDDs.randomRDD(sql.sparkContext,
new FeaturesGenerator(featureArity), numExamples, numPartitions, seed)
sql.createDataFrame(rdd.map(Tuple1.apply)).toDF("features")
}
}
/**
* Generator for a feature vector which can include a mix of categorical and continuous features.
* @param featureArity Length numFeatures, where 0 indicates continuous feature and > 0
* indicates a categorical feature of that arity.
*/
class FeaturesGenerator(val featureArity: Array[Int])
extends RandomDataGenerator[Vector] {
featureArity.foreach { arity =>
require(arity >= 0, s"FeaturesGenerator given categorical arity = $arity, " +
s"but arity should be >= 0.")
}
val numFeatures = featureArity.length
private val rng = new java.util.Random()
/**
* Generates vector with features in the order given by [[featureArity]]
*/
override def nextValue(): Vector = {
val arr = new Array[Double](numFeatures)
var j = 0
while (j < featureArity.length) {
if (featureArity(j) == 0)
arr(j) = 2 * rng.nextDouble() - 1 // centered uniform data
else
arr(j) = rng.nextInt(featureArity(j))
j += 1
}
Vectors.dense(arr)
}
override def setSeed(seed: Long) {
rng.setSeed(seed)
}
override def copy(): FeaturesGenerator = new FeaturesGenerator(featureArity)
}

View File

@ -1,102 +0,0 @@
package com.databricks.spark.sql.perf.mllib.data
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.mllib.random._
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{SQLContext, DataFrame}
object DataGenerator {
def generateFeatures(
sql: SQLContext,
numExamples: Long,
seed: Long,
numPartitions: Int,
numFeatures: Int): DataFrame = {
val categoricalArities = Array.empty[Int]
val rdd: RDD[Vector] = RandomRDDs.randomRDD(sql.sparkContext,
new FeaturesGenerator(categoricalArities, numFeatures),
numExamples, numPartitions, seed)
sql.createDataFrame(rdd.map(Tuple1.apply)).toDF("features")
}
}
class BinaryLabeledDataGenerator(
private val numFeatures: Int,
private val threshold: Double) extends RandomDataGenerator[LabeledPoint] {
private val rng = new java.util.Random()
override def nextValue(): LabeledPoint = {
val y = if (rng.nextDouble() < threshold) 0.0 else 1.0
val x = Array.fill[Double](numFeatures) {
if (rng.nextDouble() < threshold) 0.0 else 1.0
}
???
// LabeledPoint(y, Vectors.dense(x))
}
override def setSeed(seed: Long) {
rng.setSeed(seed)
}
override def copy(): BinaryLabeledDataGenerator =
new BinaryLabeledDataGenerator(numFeatures, threshold)
}
/**
* Generator for a feature vector which can include a mix of categorical and continuous features.
* @param categoricalArities Specifies the number of categories for each categorical feature.
* @param numContinuous Number of continuous features. Feature values are in range [0,1].
*/
class FeaturesGenerator(val categoricalArities: Array[Int], val numContinuous: Int)
extends RandomDataGenerator[Vector] {
categoricalArities.foreach { arity =>
require(arity >= 2, s"FeaturesGenerator given categorical arity = $arity, " +
s"but arity should be >= 2.")
}
val numFeatures = categoricalArities.length + numContinuous
private val rng = new java.util.Random()
/**
* Generates vector with categorical features first, and continuous features in [0,1] second.
*/
override def nextValue(): Vector = {
// Feature ordering matches getCategoricalFeaturesInfo.
val arr = new Array[Double](numFeatures)
var j = 0
while (j < categoricalArities.length) {
arr(j) = rng.nextInt(categoricalArities(j))
j += 1
}
while (j < numFeatures) {
// Generating some centered data
arr(j) = 2 * rng.nextDouble() - 1
j += 1
}
Vectors.dense(arr)
}
override def setSeed(seed: Long) {
rng.setSeed(seed)
}
override def copy(): FeaturesGenerator = new FeaturesGenerator(categoricalArities, numContinuous)
/**
* @return categoricalFeaturesInfo Map storing arity of categorical features.
* E.g., an entry (n -> k) indicates that feature n is categorical
* with k categories indexed from 0: {0, 1, ..., k-1}.
*/
def getCategoricalFeaturesInfo: Map[Int, Int] = {
// Categorical features are indexed from 0 because of the implementation of nextValue().
categoricalArities.zipWithIndex.map(_.swap).toMap
}
}

View File

@ -1,13 +1,13 @@
package com.databricks.spark.sql.perf.mllib.regression
import org.apache.spark.ml.evaluation.{Evaluator, RegressionEvaluator}
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.regression.GeneralizedLinearRegression
import org.apache.spark.ml.{Estimator, ModelBuilder, Transformer}
import com.databricks.spark.sql.perf.mllib.OptionImplicits._
import com.databricks.spark.sql.perf.mllib._
import com.databricks.spark.sql.perf.mllib.data.DataGenerator
import org.apache.spark.ml.evaluation.{Evaluator, RegressionEvaluator}
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.regression.GeneralizedLinearRegression
import org.apache.spark.ml.{ModelBuilder, Transformer}
import org.apache.spark.sql._
object GLMRegression extends BenchmarkAlgorithm with TestFromTraining with
@ -15,7 +15,7 @@ object GLMRegression extends BenchmarkAlgorithm with TestFromTraining with
override protected def initialData(ctx: MLBenchContext) = {
import ctx.params._
DataGenerator.generateFeatures(
DataGenerator.generateContinuousFeatures(
ctx.sqlContext,
numExamples,
ctx.seed(),
@ -36,18 +36,14 @@ object GLMRegression extends BenchmarkAlgorithm with TestFromTraining with
m
}
override def train(
ctx: MLBenchContext,
trainingSet: DataFrame): Transformer = {
logger.info(s"$this: train: trainingSet=${trainingSet.schema}")
override def getEstimator(ctx: MLBenchContext): Estimator[_] = {
import ctx.params._
val glr = new GeneralizedLinearRegression()
new GeneralizedLinearRegression()
.setLink(link)
.setFamily(family)
.setRegParam(regParam)
.setMaxIter(maxIter)
.setTol(tol)
glr.fit(trainingSet)
}
override protected def evaluator(ctx: MLBenchContext): Evaluator =

View File

@ -113,13 +113,15 @@ case class MLParams(
numTestExamples: Option[Long] = None,
numPartitions: Option[Int] = None,
// *** Specialized and sorted by name ***
depth: Option[Int] = None,
elasticNetParam: Option[Double] = None,
family: Option[String] = None,
link: Option[String] = None,
k: Option[Int] = None,
ldaDocLength: Option[Int] = None,
ldaNumVocabulary: Option[Int] = None,
link: Option[String] = None,
maxIter: Option[Int] = None,
numClasses: Option[Int] = None,
numFeatures: Option[Int] = None,
optimizer: Option[String] = None,
regParam: Option[Double] = None,

View File

@ -33,3 +33,26 @@ benchmarks:
tol: 0.0
maxIter: 10
regParam: 0.1
- name: classification.DecisionTreeClassification
params:
numExamples: 100
numTestExamples: 10
depth: 3
numClasses: 4
numFeatures: 5
- name: classification.RandomForestClassification
params:
numExamples: 100
numTestExamples: 10
depth: 3
numClasses: 4
numFeatures: 5
maxIter: 3
- name: classification.GBTClassification
params:
numExamples: 100
numTestExamples: 10
depth: 3
numClasses: 4
numFeatures: 5
maxIter: 3

View File

@ -1,8 +1,13 @@
package org.apache.spark.ml
import org.apache.spark.ml.classification.LogisticRegressionModel
import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, LogisticRegressionModel}
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.regression.GeneralizedLinearRegressionModel
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
import org.apache.spark.ml.tree._
import org.apache.spark.mllib.random.RandomDataGenerator
import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
/**
* Helper for creating MLlib models which have private constructors.
@ -19,4 +24,194 @@ object ModelBuilder {
coefficients: Vector,
intercept: Double): GeneralizedLinearRegressionModel =
new GeneralizedLinearRegressionModel("glr-uid", coefficients, intercept)
}
def newDecisionTreeClassificationModel(
depth: Int,
numClasses: Int,
featureArity: Array[Int],
seed: Long): DecisionTreeClassificationModel = {
require(numClasses >= 2, s"DecisionTreeClassificationModel requires numClasses >= 2," +
s" but was given $numClasses")
val rootNode = TreeBuilder.randomBalancedDecisionTree(depth = depth, labelType = numClasses,
featureArity = featureArity, seed = seed)
new DecisionTreeClassificationModel(rootNode, numFeatures = featureArity.length,
numClasses = numClasses)
}
def newDecisionTreeRegressionModel(
depth: Int,
featureArity: Array[Int],
seed: Long): DecisionTreeRegressionModel = {
val rootNode = TreeBuilder.randomBalancedDecisionTree(depth = depth, labelType = 0,
featureArity = featureArity, seed = seed)
new DecisionTreeRegressionModel(rootNode, numFeatures = featureArity.length)
}
}
/**
* Helpers for creating random decision trees.
*/
object TreeBuilder {
/**
* Generator for a pair of distinct class labels from the set {0,...,numClasses-1}.
* Pairs are useful for trees to make sure sibling leaf nodes make different predictions.
* @param numClasses Number of classes.
*/
private class ClassLabelPairGenerator(val numClasses: Int)
extends RandomDataGenerator[Pair[Double, Double]] {
require(numClasses >= 2,
s"ClassLabelPairGenerator given label numClasses = $numClasses, but numClasses should be >= 2.")
private val rng = new java.util.Random()
override def nextValue(): Pair[Double, Double] = {
val left = rng.nextInt(numClasses)
var right = rng.nextInt(numClasses)
while (right == left) {
right = rng.nextInt(numClasses)
}
new Pair[Double, Double](left, right)
}
override def setSeed(seed: Long): Unit = {
rng.setSeed(seed)
}
override def copy(): ClassLabelPairGenerator = new ClassLabelPairGenerator(numClasses)
}
/**
* Generator for a pair of real-valued labels.
* Pairs are useful for trees to make sure sibling leaf nodes make different predictions.
*/
private class RealLabelPairGenerator() extends RandomDataGenerator[Pair[Double, Double]] {
private val rng = new java.util.Random()
override def nextValue(): Pair[Double, Double] =
new Pair[Double, Double](rng.nextDouble(), rng.nextDouble())
override def setSeed(seed: Long): Unit = {
rng.setSeed(seed)
}
override def copy(): RealLabelPairGenerator = new RealLabelPairGenerator()
}
/**
* Creates a random decision tree structure.
* @param depth Depth of tree to build. Must be <= numFeatures.
* @param labelType Value 0 indicates regression. Integers >= 2 indicate numClasses for
* classification.
* @param featureArity Array of length numFeatures indicating feature type.
* Value 0 indicates continuous feature.
* Other values >= 2 indicate a categorical feature,
* where the value is the number of categories.
* @return root node of tree
*/
def randomBalancedDecisionTree(
depth: Int,
labelType: Int,
featureArity: Array[Int],
seed: Long): Node = {
require(depth >= 0, s"randomBalancedDecisionTree given depth < 0.")
val numFeatures = featureArity.length
require(depth <= numFeatures,
s"randomBalancedDecisionTree requires depth <= featureArity.size," +
s" but depth = $depth and featureArity.size = $numFeatures")
val isRegression = labelType == 0
if (!isRegression) {
require(labelType >= 2, s"labelType must be >= 2 for classification. 0 indicates regression.")
}
val rng = new scala.util.Random()
rng.setSeed(seed)
val labelGenerator = if (isRegression) {
new RealLabelPairGenerator()
} else {
new ClassLabelPairGenerator(labelType)
}
labelGenerator.setSeed(rng.nextLong)
// We use a dummy impurityCalculator for all nodes.
val impurityCalculator = if (isRegression) {
ImpurityCalculator.getCalculator("variance", Array.fill[Double](3)(0.0))
} else {
ImpurityCalculator.getCalculator("gini", Array.fill[Double](labelType)(0.0))
}
randomBalancedDecisionTreeHelper(depth, featureArity, impurityCalculator,
labelGenerator, Set.empty, rng)
}
/**
* Create an internal node. Either create the leaf nodes beneath it, or recurse as needed.
* @param subtreeDepth Depth of subtree to build. Depth 0 means this is a leaf node.
* @param featureArity Indicates feature type. Value 0 indicates continuous feature.
* Other values >= 2 indicate a categorical feature,
* where the value is the number of categories.
* @param impurityCalculator Dummy impurity calculator to use at all tree nodes
* @param usedFeatures Features appearing in the path from the tree root to the node
* being constructed.
* @param labelGenerator Generates pairs of distinct labels.
* @return
*/
private def randomBalancedDecisionTreeHelper(
subtreeDepth: Int,
featureArity: Array[Int],
impurityCalculator: ImpurityCalculator,
labelGenerator: RandomDataGenerator[Pair[Double, Double]],
usedFeatures: Set[Int],
rng: scala.util.Random): Node = {
if (subtreeDepth == 0) {
// This case only happens for a depth 0 tree.
return new LeafNode(prediction = 0.0, impurity = 0.0, impurityStats = impurityCalculator)
}
val numFeatures = featureArity.length
// Should not happen.
assert(usedFeatures.size < numFeatures, s"randomBalancedDecisionTreeSplitNode ran out of " +
s"features for splits.")
// Make node internal.
var feature: Int = rng.nextInt(numFeatures)
while (usedFeatures.contains(feature)) {
feature = rng.nextInt(numFeatures)
}
val split: Split = if (featureArity(feature) == 0) {
// continuous feature
new ContinuousSplit(featureIndex = feature, threshold = rng.nextDouble())
} else {
// categorical feature
// Put nCatsSplit categories on left, and the rest on the right.
// nCatsSplit is in {1,...,arity-1}.
val nCatsSplit = rng.nextInt(featureArity(feature) - 1) + 1
val splitCategories: Array[Double] =
rng.shuffle(Range(0,featureArity(feature)).toList).toArray.map(_.toDouble).take(nCatsSplit)
new CategoricalSplit(featureIndex = feature,
_leftCategories = splitCategories, numCategories = featureArity(feature))
}
val (leftChild: Node, rightChild: Node) = if (subtreeDepth == 1) {
// Add leaf nodes. Assign these jointly so they make different predictions.
val predictions = labelGenerator.nextValue()
val leftChild = new LeafNode(prediction = predictions._1, impurity = 0.0,
impurityStats = impurityCalculator)
val rightChild = new LeafNode(prediction = predictions._2, impurity = 0.0,
impurityStats = impurityCalculator)
(leftChild, rightChild)
} else {
val leftChild = randomBalancedDecisionTreeHelper(subtreeDepth - 1, featureArity,
impurityCalculator, labelGenerator, usedFeatures + feature, rng)
val rightChild = randomBalancedDecisionTreeHelper(subtreeDepth - 1, featureArity,
impurityCalculator, labelGenerator, usedFeatures + feature, rng)
(leftChild, rightChild)
}
new InternalNode(prediction = 0.0, impurity = 0.0, gain = 0.0, leftChild = leftChild,
rightChild = rightChild, split = split, impurityStats = impurityCalculator)
}
}