diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/BenchmarkAlgorithm.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/BenchmarkAlgorithm.scala index ad714e5..52415b4 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/mllib/BenchmarkAlgorithm.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/BenchmarkAlgorithm.scala @@ -2,8 +2,8 @@ package com.databricks.spark.sql.perf.mllib import com.typesafe.scalalogging.slf4j.Logging -import org.apache.spark.ml.Transformer -import org.apache.spark.ml.evaluation.{MulticlassClassificationEvaluator, Evaluator} +import org.apache.spark.ml.{Estimator, Transformer} +import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.sql._ import org.apache.spark.sql.functions._ @@ -25,10 +25,10 @@ trait BenchmarkAlgorithm extends Logging { def testDataSet(ctx: MLBenchContext): DataFrame - @throws[Exception]("if training fails") - def train( - ctx: MLBenchContext, - trainingSet: DataFrame): Transformer + /** + * Create an [[Estimator]] with params set from the given [[MLBenchContext]]. + */ + def getEstimator(ctx: MLBenchContext): Estimator[_] /** * The unnormalized score of the training procedure on a dataset. The normalization is diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/MLTransformerBenchmarkable.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/MLTransformerBenchmarkable.scala index 8794c07..57c051b 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/mllib/MLTransformerBenchmarkable.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/MLTransformerBenchmarkable.scala @@ -2,10 +2,12 @@ package com.databricks.spark.sql.perf.mllib import com.databricks.spark.sql.perf._ import com.typesafe.scalalogging.slf4j.Logging -import org.apache.spark.sql._ +import org.apache.spark.sql._ import scala.collection.mutable.ArrayBuffer +import org.apache.spark.ml.Transformer + class MLTransformerBenchmarkable( params: MLParams, test: BenchmarkAlgorithm, @@ -44,7 +46,11 @@ class MLTransformerBenchmarkable( description: String, messages: ArrayBuffer[String]): BenchmarkResult = { try { - val (trainingTime, model) = measureTime(test.train(param, trainingData)) + val (trainingTime, model: Transformer) = measureTime { + logger.info(s"$this: train: trainingSet=${trainingData.schema}") + val estimator = test.getEstimator(param) + estimator.fit(trainingData) + } logger.info(s"model: $model") val (_, scoreTraining) = measureTime { test.score(param, trainingData, model) diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/DecisionTreeClassification.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/DecisionTreeClassification.scala new file mode 100644 index 0000000..7ba333a --- /dev/null +++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/DecisionTreeClassification.scala @@ -0,0 +1,58 @@ +package com.databricks.spark.sql.perf.mllib.classification + +import org.apache.spark.ml.{Estimator, ModelBuilder, Transformer} +import org.apache.spark.ml.classification.DecisionTreeClassifier +import org.apache.spark.ml.evaluation.{Evaluator, MulticlassClassificationEvaluator} + +import com.databricks.spark.sql.perf.mllib.OptionImplicits._ +import com.databricks.spark.sql.perf.mllib._ +import com.databricks.spark.sql.perf.mllib.data.DataGenerator + + +abstract class TreeOrForestClassification extends BenchmarkAlgorithm + with TestFromTraining with TrainingSetFromTransformer with ScoringWithEvaluator { + + override protected def initialData(ctx: MLBenchContext) = { + import ctx.params._ + DataGenerator.generateMixedFeatures(ctx.sqlContext, numExamples, ctx.seed(), numPartitions, + TreeOrForestClassification.getFeatureArity(ctx)) + } + + override protected def trueModel(ctx: MLBenchContext): Transformer = { + ModelBuilder.newDecisionTreeClassificationModel(ctx.params.depth, ctx.params.numClasses, + TreeOrForestClassification.getFeatureArity(ctx), ctx.seed()) + } + + override protected def evaluator(ctx: MLBenchContext): Evaluator = + new MulticlassClassificationEvaluator() +} + +object DecisionTreeClassification extends TreeOrForestClassification { + + override def getEstimator(ctx: MLBenchContext): Estimator[_] = { + import ctx.params._ + new DecisionTreeClassifier() + .setMaxDepth(depth) + .setSeed(ctx.seed()) + } +} + +object TreeOrForestClassification { + + /** + * Get feature arity for tree and tree ensemble tests. + * Currently, this is hard-coded as: + * - 1/2 binary features + * - 1/2 high-arity (20-category) features + * - 1/2 continuous features + * @return Array of length numFeatures, where 0 indicates continuous feature and + * value > 0 indicates a categorical feature of that arity. + */ + def getFeatureArity(ctx: MLBenchContext): Array[Int] = { + val numFeatures = ctx.params.numFeatures + val fourthFeatures = numFeatures / 4 + Array.fill[Int](fourthFeatures)(2) ++ // low-arity categorical + Array.fill[Int](fourthFeatures)(20) ++ // high-arity categorical + Array.fill[Int](numFeatures - 2 * fourthFeatures)(0) // continuous + } +} diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/GBTClassification.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/GBTClassification.scala new file mode 100644 index 0000000..acdb105 --- /dev/null +++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/GBTClassification.scala @@ -0,0 +1,40 @@ +package com.databricks.spark.sql.perf.mllib.classification + +import org.apache.spark.ml.{Estimator, ModelBuilder, Transformer} +import org.apache.spark.ml.classification.GBTClassifier +import org.apache.spark.ml.evaluation.{Evaluator, MulticlassClassificationEvaluator} + +import com.databricks.spark.sql.perf.mllib._ +import com.databricks.spark.sql.perf.mllib.OptionImplicits._ +import com.databricks.spark.sql.perf.mllib.data.DataGenerator + + +object GBTClassification extends BenchmarkAlgorithm + with TestFromTraining with TrainingSetFromTransformer with ScoringWithEvaluator { + + override protected def initialData(ctx: MLBenchContext) = { + import ctx.params._ + DataGenerator.generateMixedFeatures(ctx.sqlContext, numExamples, ctx.seed(), numPartitions, + TreeOrForestClassification.getFeatureArity(ctx)) + } + + override protected def trueModel(ctx: MLBenchContext): Transformer = { + // We add +1 to the depth to make it more likely that many iterations of boosting are needed + // to model the true tree. + ModelBuilder.newDecisionTreeClassificationModel(ctx.params.depth + 1, ctx.params.numClasses, + TreeOrForestClassification.getFeatureArity(ctx), ctx.seed()) + } + + override def getEstimator(ctx: MLBenchContext): Estimator[_] = { + import ctx.params._ + // TODO: subsamplingRate, featureSubsetStrategy + // TODO: cacheNodeIds, checkpoint? + new GBTClassifier() + .setMaxDepth(depth) + .setMaxIter(maxIter) + .setSeed(ctx.seed()) + } + + override protected def evaluator(ctx: MLBenchContext): Evaluator = + new MulticlassClassificationEvaluator() +} diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/LogisticRegression.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/LogisticRegression.scala index 284905d..edb3a68 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/LogisticRegression.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/LogisticRegression.scala @@ -3,19 +3,19 @@ package com.databricks.spark.sql.perf.mllib.classification import com.databricks.spark.sql.perf.mllib.OptionImplicits._ import com.databricks.spark.sql.perf.mllib._ import com.databricks.spark.sql.perf.mllib.data.DataGenerator -import org.apache.spark.ml.evaluation.{MulticlassClassificationEvaluator, Evaluator} -import org.apache.spark.ml.{Transformer, ModelBuilder} +import org.apache.spark.ml.evaluation.{Evaluator, MulticlassClassificationEvaluator} +import org.apache.spark.ml.{Estimator, ModelBuilder, Transformer} import org.apache.spark.ml import org.apache.spark.ml.linalg.Vectors -import org.apache.spark.sql.DataFrame + object LogisticRegression extends BenchmarkAlgorithm with TestFromTraining with TrainingSetFromTransformer with ScoringWithEvaluator { override protected def initialData(ctx: MLBenchContext) = { import ctx.params._ - DataGenerator.generateFeatures( + DataGenerator.generateContinuousFeatures( ctx.sqlContext, numExamples, ctx.seed(), @@ -32,15 +32,12 @@ object LogisticRegression extends BenchmarkAlgorithm ModelBuilder.newLogisticRegressionModel(coefficients, intercept) } - override def train(ctx: MLBenchContext, - trainingSet: DataFrame): Transformer = { - logger.info(s"$this: train: trainingSet=${trainingSet.schema}") + override def getEstimator(ctx: MLBenchContext): Estimator[_] = { import ctx.params._ - val lr = new ml.classification.LogisticRegression() + new ml.classification.LogisticRegression() .setTol(tol) .setMaxIter(maxIter) .setRegParam(regParam) - lr.fit(trainingSet) } override protected def evaluator(ctx: MLBenchContext): Evaluator = diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/RandomForestClassification.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/RandomForestClassification.scala new file mode 100644 index 0000000..3aff023 --- /dev/null +++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/RandomForestClassification.scala @@ -0,0 +1,21 @@ +package com.databricks.spark.sql.perf.mllib.classification + +import org.apache.spark.ml.Estimator +import org.apache.spark.ml.classification.RandomForestClassifier + +import com.databricks.spark.sql.perf.mllib._ +import com.databricks.spark.sql.perf.mllib.OptionImplicits._ + + +object RandomForestClassification extends TreeOrForestClassification { + + override def getEstimator(ctx: MLBenchContext): Estimator[_] = { + import ctx.params._ + // TODO: subsamplingRate, featureSubsetStrategy + // TODO: cacheNodeIds, checkpoint? + new RandomForestClassifier() + .setMaxDepth(depth) + .setNumTrees(maxIter) + .setSeed(ctx.seed()) + } +} diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/clustering/LDA.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/clustering/LDA.scala index 3fe5688..a6daf4b 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/mllib/clustering/LDA.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/clustering/LDA.scala @@ -1,14 +1,18 @@ package com.databricks.spark.sql.perf.mllib.clustering -import com.databricks.spark.sql.perf.mllib.{MLBenchContext, TestFromTraining, BenchmarkAlgorithm} -import com.databricks.spark.sql.perf.mllib.OptionImplicits._ +import scala.collection.mutable.{HashMap => MHashMap} + import org.apache.commons.math3.random.Well19937c -import org.apache.spark.ml.Transformer + +import org.apache.spark.ml.Estimator import org.apache.spark.ml import org.apache.spark.rdd.RDD import org.apache.spark.sql._ -import org.apache.spark.ml.linalg.{Vectors, Vector} -import scala.collection.mutable.{HashMap => MHashMap} +import org.apache.spark.ml.linalg.{Vector, Vectors} + +import com.databricks.spark.sql.perf.mllib.{BenchmarkAlgorithm, MLBenchContext, TestFromTraining} +import com.databricks.spark.sql.perf.mllib.OptionImplicits._ + object LDA extends BenchmarkAlgorithm with TestFromTraining { // The LDA model is package private, no need to expose it. @@ -40,16 +44,14 @@ object LDA extends BenchmarkAlgorithm with TestFromTraining { ctx.sqlContext.createDataFrame(data).toDF("docIndex", "features") } - override def train(ctx: MLBenchContext, - trainingSet: DataFrame): Transformer = { + override def getEstimator(ctx: MLBenchContext): Estimator[_] = { import ctx.params._ new ml.clustering.LDA() - .setK(k) - .setSeed(randomSeed.toLong) - .setMaxIter(maxIter) - .setOptimizer(optimizer) - .fit(trainingSet) + .setK(k) + .setSeed(randomSeed.toLong) + .setMaxIter(maxIter) + .setOptimizer(optimizer) } // TODO(?) add a scoring method here. -} \ No newline at end of file +} diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/data/dataGeneration.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/data/dataGeneration.scala new file mode 100644 index 0000000..2461fbc --- /dev/null +++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/data/dataGeneration.scala @@ -0,0 +1,80 @@ +package com.databricks.spark.sql.perf.mllib.data + +import org.apache.spark.ml.linalg.Vectors +import org.apache.spark.ml.linalg.Vector +import org.apache.spark.mllib.random._ +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{SQLContext, DataFrame} + + +object DataGenerator { + + def generateContinuousFeatures( + sql: SQLContext, + numExamples: Long, + seed: Long, + numPartitions: Int, + numFeatures: Int): DataFrame = { + val featureArity = Array.fill[Int](numFeatures)(0) + val rdd: RDD[Vector] = RandomRDDs.randomRDD(sql.sparkContext, + new FeaturesGenerator(featureArity), numExamples, numPartitions, seed) + sql.createDataFrame(rdd.map(Tuple1.apply)).toDF("features") + } + + /** + * Generate a mix of continuous and categorical features. + * @param featureArity Array of length numFeatures, where 0 indicates a continuous feature and + * a value > 0 indicates a categorical feature with that arity. + */ + def generateMixedFeatures( + sql: SQLContext, + numExamples: Long, + seed: Long, + numPartitions: Int, + featureArity: Array[Int]): DataFrame = { + val rdd: RDD[Vector] = RandomRDDs.randomRDD(sql.sparkContext, + new FeaturesGenerator(featureArity), numExamples, numPartitions, seed) + sql.createDataFrame(rdd.map(Tuple1.apply)).toDF("features") + } +} + + +/** + * Generator for a feature vector which can include a mix of categorical and continuous features. + * @param featureArity Length numFeatures, where 0 indicates continuous feature and > 0 + * indicates a categorical feature of that arity. + */ +class FeaturesGenerator(val featureArity: Array[Int]) + extends RandomDataGenerator[Vector] { + + featureArity.foreach { arity => + require(arity >= 0, s"FeaturesGenerator given categorical arity = $arity, " + + s"but arity should be >= 0.") + } + + val numFeatures = featureArity.length + + private val rng = new java.util.Random() + + /** + * Generates vector with features in the order given by [[featureArity]] + */ + override def nextValue(): Vector = { + val arr = new Array[Double](numFeatures) + var j = 0 + while (j < featureArity.length) { + if (featureArity(j) == 0) + arr(j) = 2 * rng.nextDouble() - 1 // centered uniform data + else + arr(j) = rng.nextInt(featureArity(j)) + j += 1 + } + Vectors.dense(arr) + } + + override def setSeed(seed: Long) { + rng.setSeed(seed) + } + + override def copy(): FeaturesGenerator = new FeaturesGenerator(featureArity) +} diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/data/data_generation.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/data/data_generation.scala deleted file mode 100644 index ca36704..0000000 --- a/src/main/scala/com/databricks/spark/sql/perf/mllib/data/data_generation.scala +++ /dev/null @@ -1,102 +0,0 @@ -package com.databricks.spark.sql.perf.mllib.data - -import org.apache.spark.ml.linalg.Vectors -import org.apache.spark.ml.linalg.Vector -import org.apache.spark.mllib.random._ -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{SQLContext, DataFrame} - -object DataGenerator { - - def generateFeatures( - sql: SQLContext, - numExamples: Long, - seed: Long, - numPartitions: Int, - numFeatures: Int): DataFrame = { - val categoricalArities = Array.empty[Int] - val rdd: RDD[Vector] = RandomRDDs.randomRDD(sql.sparkContext, - new FeaturesGenerator(categoricalArities, numFeatures), - numExamples, numPartitions, seed) - sql.createDataFrame(rdd.map(Tuple1.apply)).toDF("features") - } -} - -class BinaryLabeledDataGenerator( - private val numFeatures: Int, - private val threshold: Double) extends RandomDataGenerator[LabeledPoint] { - - private val rng = new java.util.Random() - - override def nextValue(): LabeledPoint = { - val y = if (rng.nextDouble() < threshold) 0.0 else 1.0 - val x = Array.fill[Double](numFeatures) { - if (rng.nextDouble() < threshold) 0.0 else 1.0 - } - ??? -// LabeledPoint(y, Vectors.dense(x)) - } - - override def setSeed(seed: Long) { - rng.setSeed(seed) - } - - override def copy(): BinaryLabeledDataGenerator = - new BinaryLabeledDataGenerator(numFeatures, threshold) - -} - - -/** - * Generator for a feature vector which can include a mix of categorical and continuous features. - * @param categoricalArities Specifies the number of categories for each categorical feature. - * @param numContinuous Number of continuous features. Feature values are in range [0,1]. - */ -class FeaturesGenerator(val categoricalArities: Array[Int], val numContinuous: Int) - extends RandomDataGenerator[Vector] { - - categoricalArities.foreach { arity => - require(arity >= 2, s"FeaturesGenerator given categorical arity = $arity, " + - s"but arity should be >= 2.") - } - - val numFeatures = categoricalArities.length + numContinuous - - private val rng = new java.util.Random() - - /** - * Generates vector with categorical features first, and continuous features in [0,1] second. - */ - override def nextValue(): Vector = { - // Feature ordering matches getCategoricalFeaturesInfo. - val arr = new Array[Double](numFeatures) - var j = 0 - while (j < categoricalArities.length) { - arr(j) = rng.nextInt(categoricalArities(j)) - j += 1 - } - while (j < numFeatures) { - // Generating some centered data - arr(j) = 2 * rng.nextDouble() - 1 - j += 1 - } - Vectors.dense(arr) - } - - override def setSeed(seed: Long) { - rng.setSeed(seed) - } - - override def copy(): FeaturesGenerator = new FeaturesGenerator(categoricalArities, numContinuous) - - /** - * @return categoricalFeaturesInfo Map storing arity of categorical features. - * E.g., an entry (n -> k) indicates that feature n is categorical - * with k categories indexed from 0: {0, 1, ..., k-1}. - */ - def getCategoricalFeaturesInfo: Map[Int, Int] = { - // Categorical features are indexed from 0 because of the implementation of nextValue(). - categoricalArities.zipWithIndex.map(_.swap).toMap - } -} \ No newline at end of file diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/regression/GLMRegression.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/regression/GLMRegression.scala index 6797e50..aea75fd 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/mllib/regression/GLMRegression.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/regression/GLMRegression.scala @@ -1,13 +1,13 @@ package com.databricks.spark.sql.perf.mllib.regression +import org.apache.spark.ml.evaluation.{Evaluator, RegressionEvaluator} +import org.apache.spark.ml.linalg.Vectors +import org.apache.spark.ml.regression.GeneralizedLinearRegression +import org.apache.spark.ml.{Estimator, ModelBuilder, Transformer} + import com.databricks.spark.sql.perf.mllib.OptionImplicits._ import com.databricks.spark.sql.perf.mllib._ import com.databricks.spark.sql.perf.mllib.data.DataGenerator -import org.apache.spark.ml.evaluation.{Evaluator, RegressionEvaluator} -import org.apache.spark.ml.linalg.Vectors -import org.apache.spark.ml.regression.GeneralizedLinearRegression -import org.apache.spark.ml.{ModelBuilder, Transformer} -import org.apache.spark.sql._ object GLMRegression extends BenchmarkAlgorithm with TestFromTraining with @@ -15,7 +15,7 @@ object GLMRegression extends BenchmarkAlgorithm with TestFromTraining with override protected def initialData(ctx: MLBenchContext) = { import ctx.params._ - DataGenerator.generateFeatures( + DataGenerator.generateContinuousFeatures( ctx.sqlContext, numExamples, ctx.seed(), @@ -36,18 +36,14 @@ object GLMRegression extends BenchmarkAlgorithm with TestFromTraining with m } - override def train( - ctx: MLBenchContext, - trainingSet: DataFrame): Transformer = { - logger.info(s"$this: train: trainingSet=${trainingSet.schema}") + override def getEstimator(ctx: MLBenchContext): Estimator[_] = { import ctx.params._ - val glr = new GeneralizedLinearRegression() + new GeneralizedLinearRegression() .setLink(link) .setFamily(family) .setRegParam(regParam) .setMaxIter(maxIter) .setTol(tol) - glr.fit(trainingSet) } override protected def evaluator(ctx: MLBenchContext): Evaluator = diff --git a/src/main/scala/com/databricks/spark/sql/perf/results.scala b/src/main/scala/com/databricks/spark/sql/perf/results.scala index 1effa9c..62a2435 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/results.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/results.scala @@ -113,13 +113,15 @@ case class MLParams( numTestExamples: Option[Long] = None, numPartitions: Option[Int] = None, // *** Specialized and sorted by name *** + depth: Option[Int] = None, elasticNetParam: Option[Double] = None, family: Option[String] = None, - link: Option[String] = None, k: Option[Int] = None, ldaDocLength: Option[Int] = None, ldaNumVocabulary: Option[Int] = None, + link: Option[String] = None, maxIter: Option[Int] = None, + numClasses: Option[Int] = None, numFeatures: Option[Int] = None, optimizer: Option[String] = None, regParam: Option[Double] = None, diff --git a/src/main/scala/configs/mllib-small.yaml b/src/main/scala/configs/mllib-small.yaml index 5d4fac7..2fb1e6a 100644 --- a/src/main/scala/configs/mllib-small.yaml +++ b/src/main/scala/configs/mllib-small.yaml @@ -33,3 +33,26 @@ benchmarks: tol: 0.0 maxIter: 10 regParam: 0.1 + - name: classification.DecisionTreeClassification + params: + numExamples: 100 + numTestExamples: 10 + depth: 3 + numClasses: 4 + numFeatures: 5 + - name: classification.RandomForestClassification + params: + numExamples: 100 + numTestExamples: 10 + depth: 3 + numClasses: 4 + numFeatures: 5 + maxIter: 3 + - name: classification.GBTClassification + params: + numExamples: 100 + numTestExamples: 10 + depth: 3 + numClasses: 4 + numFeatures: 5 + maxIter: 3 diff --git a/src/main/scala/org/apache/spark/ml/ModelBuilder.scala b/src/main/scala/org/apache/spark/ml/ModelBuilder.scala index 089376a..7d0143c 100644 --- a/src/main/scala/org/apache/spark/ml/ModelBuilder.scala +++ b/src/main/scala/org/apache/spark/ml/ModelBuilder.scala @@ -1,8 +1,13 @@ package org.apache.spark.ml -import org.apache.spark.ml.classification.LogisticRegressionModel +import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, LogisticRegressionModel} import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.regression.GeneralizedLinearRegressionModel +import org.apache.spark.ml.regression.DecisionTreeRegressionModel +import org.apache.spark.ml.tree._ +import org.apache.spark.mllib.random.RandomDataGenerator +import org.apache.spark.mllib.tree.impurity.ImpurityCalculator + /** * Helper for creating MLlib models which have private constructors. @@ -19,4 +24,194 @@ object ModelBuilder { coefficients: Vector, intercept: Double): GeneralizedLinearRegressionModel = new GeneralizedLinearRegressionModel("glr-uid", coefficients, intercept) -} \ No newline at end of file + + def newDecisionTreeClassificationModel( + depth: Int, + numClasses: Int, + featureArity: Array[Int], + seed: Long): DecisionTreeClassificationModel = { + require(numClasses >= 2, s"DecisionTreeClassificationModel requires numClasses >= 2," + + s" but was given $numClasses") + val rootNode = TreeBuilder.randomBalancedDecisionTree(depth = depth, labelType = numClasses, + featureArity = featureArity, seed = seed) + new DecisionTreeClassificationModel(rootNode, numFeatures = featureArity.length, + numClasses = numClasses) + } + + def newDecisionTreeRegressionModel( + depth: Int, + featureArity: Array[Int], + seed: Long): DecisionTreeRegressionModel = { + val rootNode = TreeBuilder.randomBalancedDecisionTree(depth = depth, labelType = 0, + featureArity = featureArity, seed = seed) + new DecisionTreeRegressionModel(rootNode, numFeatures = featureArity.length) + } +} + +/** + * Helpers for creating random decision trees. + */ +object TreeBuilder { + + /** + * Generator for a pair of distinct class labels from the set {0,...,numClasses-1}. + * Pairs are useful for trees to make sure sibling leaf nodes make different predictions. + * @param numClasses Number of classes. + */ + private class ClassLabelPairGenerator(val numClasses: Int) + extends RandomDataGenerator[Pair[Double, Double]] { + + require(numClasses >= 2, + s"ClassLabelPairGenerator given label numClasses = $numClasses, but numClasses should be >= 2.") + + private val rng = new java.util.Random() + + override def nextValue(): Pair[Double, Double] = { + val left = rng.nextInt(numClasses) + var right = rng.nextInt(numClasses) + while (right == left) { + right = rng.nextInt(numClasses) + } + new Pair[Double, Double](left, right) + } + + override def setSeed(seed: Long): Unit = { + rng.setSeed(seed) + } + + override def copy(): ClassLabelPairGenerator = new ClassLabelPairGenerator(numClasses) + } + + + /** + * Generator for a pair of real-valued labels. + * Pairs are useful for trees to make sure sibling leaf nodes make different predictions. + */ + private class RealLabelPairGenerator() extends RandomDataGenerator[Pair[Double, Double]] { + + private val rng = new java.util.Random() + + override def nextValue(): Pair[Double, Double] = + new Pair[Double, Double](rng.nextDouble(), rng.nextDouble()) + + override def setSeed(seed: Long): Unit = { + rng.setSeed(seed) + } + + override def copy(): RealLabelPairGenerator = new RealLabelPairGenerator() + } + + /** + * Creates a random decision tree structure. + * @param depth Depth of tree to build. Must be <= numFeatures. + * @param labelType Value 0 indicates regression. Integers >= 2 indicate numClasses for + * classification. + * @param featureArity Array of length numFeatures indicating feature type. + * Value 0 indicates continuous feature. + * Other values >= 2 indicate a categorical feature, + * where the value is the number of categories. + * @return root node of tree + */ + def randomBalancedDecisionTree( + depth: Int, + labelType: Int, + featureArity: Array[Int], + seed: Long): Node = { + require(depth >= 0, s"randomBalancedDecisionTree given depth < 0.") + val numFeatures = featureArity.length + require(depth <= numFeatures, + s"randomBalancedDecisionTree requires depth <= featureArity.size," + + s" but depth = $depth and featureArity.size = $numFeatures") + val isRegression = labelType == 0 + if (!isRegression) { + require(labelType >= 2, s"labelType must be >= 2 for classification. 0 indicates regression.") + } + + val rng = new scala.util.Random() + rng.setSeed(seed) + + val labelGenerator = if (isRegression) { + new RealLabelPairGenerator() + } else { + new ClassLabelPairGenerator(labelType) + } + labelGenerator.setSeed(rng.nextLong) + // We use a dummy impurityCalculator for all nodes. + val impurityCalculator = if (isRegression) { + ImpurityCalculator.getCalculator("variance", Array.fill[Double](3)(0.0)) + } else { + ImpurityCalculator.getCalculator("gini", Array.fill[Double](labelType)(0.0)) + } + + randomBalancedDecisionTreeHelper(depth, featureArity, impurityCalculator, + labelGenerator, Set.empty, rng) + } + + /** + * Create an internal node. Either create the leaf nodes beneath it, or recurse as needed. + * @param subtreeDepth Depth of subtree to build. Depth 0 means this is a leaf node. + * @param featureArity Indicates feature type. Value 0 indicates continuous feature. + * Other values >= 2 indicate a categorical feature, + * where the value is the number of categories. + * @param impurityCalculator Dummy impurity calculator to use at all tree nodes + * @param usedFeatures Features appearing in the path from the tree root to the node + * being constructed. + * @param labelGenerator Generates pairs of distinct labels. + * @return + */ + private def randomBalancedDecisionTreeHelper( + subtreeDepth: Int, + featureArity: Array[Int], + impurityCalculator: ImpurityCalculator, + labelGenerator: RandomDataGenerator[Pair[Double, Double]], + usedFeatures: Set[Int], + rng: scala.util.Random): Node = { + + if (subtreeDepth == 0) { + // This case only happens for a depth 0 tree. + return new LeafNode(prediction = 0.0, impurity = 0.0, impurityStats = impurityCalculator) + } + + val numFeatures = featureArity.length + // Should not happen. + assert(usedFeatures.size < numFeatures, s"randomBalancedDecisionTreeSplitNode ran out of " + + s"features for splits.") + + // Make node internal. + var feature: Int = rng.nextInt(numFeatures) + while (usedFeatures.contains(feature)) { + feature = rng.nextInt(numFeatures) + } + val split: Split = if (featureArity(feature) == 0) { + // continuous feature + new ContinuousSplit(featureIndex = feature, threshold = rng.nextDouble()) + } else { + // categorical feature + // Put nCatsSplit categories on left, and the rest on the right. + // nCatsSplit is in {1,...,arity-1}. + val nCatsSplit = rng.nextInt(featureArity(feature) - 1) + 1 + val splitCategories: Array[Double] = + rng.shuffle(Range(0,featureArity(feature)).toList).toArray.map(_.toDouble).take(nCatsSplit) + new CategoricalSplit(featureIndex = feature, + _leftCategories = splitCategories, numCategories = featureArity(feature)) + } + + val (leftChild: Node, rightChild: Node) = if (subtreeDepth == 1) { + // Add leaf nodes. Assign these jointly so they make different predictions. + val predictions = labelGenerator.nextValue() + val leftChild = new LeafNode(prediction = predictions._1, impurity = 0.0, + impurityStats = impurityCalculator) + val rightChild = new LeafNode(prediction = predictions._2, impurity = 0.0, + impurityStats = impurityCalculator) + (leftChild, rightChild) + } else { + val leftChild = randomBalancedDecisionTreeHelper(subtreeDepth - 1, featureArity, + impurityCalculator, labelGenerator, usedFeatures + feature, rng) + val rightChild = randomBalancedDecisionTreeHelper(subtreeDepth - 1, featureArity, + impurityCalculator, labelGenerator, usedFeatures + feature, rng) + (leftChild, rightChild) + } + new InternalNode(prediction = 0.0, impurity = 0.0, gain = 0.0, leftChild = leftChild, + rightChild = rightChild, split = split, impurityStats = impurityCalculator) + } +}