From 33a1e55366ba9ecb601577a4e69a050c5c5d406b Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 29 Jun 2016 17:06:27 -0700 Subject: [PATCH 1/5] partly done adding decision tree tests --- .../DecisionTreeClassification.scala | 45 ++++ .../databricks/spark/sql/perf/results.scala | 4 +- .../org/apache/spark/ml/ModelBuilder.scala | 198 +++++++++++++++++- 3 files changed, 244 insertions(+), 3 deletions(-) create mode 100644 src/main/scala/com/databricks/spark/sql/perf/mllib/classification/DecisionTreeClassification.scala diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/DecisionTreeClassification.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/DecisionTreeClassification.scala new file mode 100644 index 0000000..4eebd74 --- /dev/null +++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/DecisionTreeClassification.scala @@ -0,0 +1,45 @@ +package com.databricks.spark.sql.perf.mllib.classification + +import org.apache.spark.ml.ModelBuilder +import org.apache.spark.ml.Transformer +import org.apache.spark.ml.classification.DecisionTreeClassifier +import org.apache.spark.ml.evaluation.{MulticlassClassificationEvaluator, Evaluator} +import org.apache.spark.sql.DataFrame + +import com.databricks.spark.sql.perf.mllib.OptionImplicits._ +import com.databricks.spark.sql.perf.mllib._ +import com.databricks.spark.sql.perf.mllib.data.DataGenerator + + +object DecisionTreeClassification extends BenchmarkAlgorithm + with TestFromTraining with TrainingSetFromTransformer with ScoringWithEvaluator { + + override protected def initialData(ctx: MLBenchContext) = { + import ctx.params._ + DataGenerator.generateFeatures(ctx.sqlContext, numExamples, ctx.seed(), numPartitions, + numFeatures) + } + + override protected def trueModel(ctx: MLBenchContext): Transformer = { + //val rng = ctx.newGenerator() + val numFeatures = ctx.params.numFeatures.get + val fourthFeatures = numFeatures / 4 + val featureArity: Array[Int] = + Array.fill[Int](fourthFeatures)(2) ++ // low-arity categorical + Array.fill[Int](fourthFeatures)(20) ++ // high-arity categorical + Array.fill[Int](numFeatures - 2 * fourthFeatures)(0) // continuous + ModelBuilder.newDecisionTreeClassificationModel(ctx.params.depth.get, ctx.params.numClasses.get, + featureArity, ctx.seed()) + } + + override def train(ctx: MLBenchContext, trainingSet: DataFrame): Transformer = { + logger.info(s"$this: train: trainingSet=${trainingSet.schema}") + import ctx.params._ + new DecisionTreeClassifier() + .setMaxDepth(depth.get) + .fit(trainingSet) + } + + override protected def evaluator(ctx: MLBenchContext): Evaluator = + new MulticlassClassificationEvaluator() +} diff --git a/src/main/scala/com/databricks/spark/sql/perf/results.scala b/src/main/scala/com/databricks/spark/sql/perf/results.scala index 1effa9c..62a2435 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/results.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/results.scala @@ -113,13 +113,15 @@ case class MLParams( numTestExamples: Option[Long] = None, numPartitions: Option[Int] = None, // *** Specialized and sorted by name *** + depth: Option[Int] = None, elasticNetParam: Option[Double] = None, family: Option[String] = None, - link: Option[String] = None, k: Option[Int] = None, ldaDocLength: Option[Int] = None, ldaNumVocabulary: Option[Int] = None, + link: Option[String] = None, maxIter: Option[Int] = None, + numClasses: Option[Int] = None, numFeatures: Option[Int] = None, optimizer: Option[String] = None, regParam: Option[Double] = None, diff --git a/src/main/scala/org/apache/spark/ml/ModelBuilder.scala b/src/main/scala/org/apache/spark/ml/ModelBuilder.scala index 089376a..0466ba6 100644 --- a/src/main/scala/org/apache/spark/ml/ModelBuilder.scala +++ b/src/main/scala/org/apache/spark/ml/ModelBuilder.scala @@ -1,8 +1,13 @@ package org.apache.spark.ml -import org.apache.spark.ml.classification.LogisticRegressionModel +import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, LogisticRegressionModel} import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.regression.GeneralizedLinearRegressionModel +import org.apache.spark.ml.regression.DecisionTreeRegressionModel +import org.apache.spark.ml.tree._ +import org.apache.spark.mllib.random.RandomDataGenerator +import org.apache.spark.mllib.tree.impurity.ImpurityCalculator + /** * Helper for creating MLlib models which have private constructors. @@ -19,4 +24,193 @@ object ModelBuilder { coefficients: Vector, intercept: Double): GeneralizedLinearRegressionModel = new GeneralizedLinearRegressionModel("glr-uid", coefficients, intercept) -} \ No newline at end of file + + def newDecisionTreeClassificationModel( + depth: Int, + numClasses: Int, + featureArity: Array[Int], + seed: Long): DecisionTreeClassificationModel = { + require(numClasses >= 2, s"DecisionTreeClassificationModel requires numClasses >= 2," + + s" but was given $numClasses") + val rootNode = TreeBuilder.randomBalancedDecisionTree(depth = depth, labelType = numClasses, + featureArity = featureArity, seed = seed) + new DecisionTreeClassificationModel(rootNode, numFeatures = featureArity.length, + numClasses = numClasses) + } + + def newDecisionTreeRegressionModel( + depth: Int, + featureArity: Array[Int], + seed: Long): DecisionTreeRegressionModel = { + val rootNode = TreeBuilder.randomBalancedDecisionTree(depth = depth, labelType = 0, + featureArity = featureArity, seed = seed) + new DecisionTreeRegressionModel(rootNode, numFeatures = featureArity.length) + } +} + +/** + * Helpers for creating random decision trees. + */ +object TreeBuilder { + + /** + * Generator for a pair of distinct class labels from the set {0,...,numClasses-1}. + * Pairs are useful for trees to make sure sibling leaf nodes make different predictions. + * @param numClasses Number of classes. + */ + private class ClassLabelPairGenerator(val numClasses: Int) + extends RandomDataGenerator[Pair[Double, Double]] { + + require(numClasses >= 2, + s"ClassLabelPairGenerator given label numClasses = $numClasses, but numClasses should be >= 2.") + + private val rng = new java.util.Random() + + override def nextValue(): Pair[Double, Double] = { + val left = rng.nextInt(numClasses) + var right = rng.nextInt(numClasses) + while (right == left) { + right = rng.nextInt(numClasses) + } + new Pair[Double, Double](left, right) + } + + override def setSeed(seed: Long) { + rng.setSeed(seed) + } + + override def copy(): ClassLabelPairGenerator = new ClassLabelPairGenerator(numClasses) + } + + + /** + * Generator for a pair of real-valued labels. + * Pairs are useful for trees to make sure sibling leaf nodes make different predictions. + */ + private class RealLabelPairGenerator() extends RandomDataGenerator[Pair[Double, Double]] { + + private val rng = new java.util.Random() + + override def nextValue(): Pair[Double, Double] = + new Pair[Double, Double](rng.nextDouble(), rng.nextDouble()) + + override def setSeed(seed: Long) { + rng.setSeed(seed) + } + + override def copy(): RealLabelPairGenerator = new RealLabelPairGenerator() + } + + /** + * Creates a random decision tree structure. + * @param depth Depth of tree to build. Must be <= numFeatures. + * @param labelType Value 0 indicates regression. Integers >= 2 indicate numClasses for + * classification. + * @param featureArity Array of length numFeatures indicating feature type. + * Value 0 indicates continuous feature. + * Other values >= 2 indicate a categorical feature, + * where the value is the number of categories. + * @return root node of tree + */ + def randomBalancedDecisionTree( + depth: Int, + labelType: Int, + featureArity: Array[Int], + seed: Long): Node = { + require(depth >= 0, s"randomBalancedDecisionTree given depth < 0.") + val numFeatures = featureArity.length + require(depth <= numFeatures, + s"randomBalancedDecisionTree requires depth <= featureArity.size," + + s" but depth = $depth and featureArity.size = $numFeatures") + val isRegression = labelType == 0 + if (!isRegression) { + require(labelType >= 2, s"labelType must be >= 2 for classification. 0 indicates regression.") + } + + val rng = new scala.util.Random() + rng.setSeed(seed) + + val labelGenerator = if (isRegression) { + new RealLabelPairGenerator() + } else { + new ClassLabelPairGenerator(labelType) + } + // We use a dummy impurityCalculator for all nodes. + val impurityCalculator = if (isRegression) { + ImpurityCalculator.getCalculator("variance", Array.fill[Double](3)(0.0)) + } else { + ImpurityCalculator.getCalculator("gini", Array.fill[Double](labelType)(0.0)) + } + + randomBalancedDecisionTreeHelper(depth, featureArity, impurityCalculator, + labelGenerator, Set.empty, rng) + } + + /** + * Create an internal node. Either create the leaf nodes beneath it, or recurse as needed. + * @param subtreeDepth Depth of subtree to build. Depth 0 means this is a leaf node. + * @param featureArity Indicates feature type. Value 0 indicates continuous feature. + * Other values >= 2 indicate a categorical feature, + * where the value is the number of categories. + * @param impurityCalculator Dummy impurity calculator to use at all tree nodes + * @param usedFeatures Features appearing in the path from the tree root to the node + * being constructed. + * @param labelGenerator Generates pairs of distinct labels. + * @return + */ + private def randomBalancedDecisionTreeHelper( + subtreeDepth: Int, + featureArity: Array[Int], + impurityCalculator: ImpurityCalculator, + labelGenerator: RandomDataGenerator[Pair[Double, Double]], + usedFeatures: Set[Int], + rng: scala.util.Random): Node = { + + if (subtreeDepth == 0) { + // This case only happens for a depth 0 tree. + return new LeafNode(prediction = 0.0, impurity = 0.0, impurityStats = impurityCalculator) + } + + val numFeatures = featureArity.length + // Should not happen. + assert(usedFeatures.size < numFeatures, s"randomBalancedDecisionTreeSplitNode ran out of " + + s"features for splits.") + + // Make node internal. + var feature: Int = rng.nextInt(numFeatures) + while (usedFeatures.contains(feature)) { + feature = rng.nextInt(numFeatures) + } + val split: Split = if (featureArity(feature) == 0) { + // continuous feature + new ContinuousSplit(featureIndex = feature, threshold = rng.nextDouble()) + } else { + // categorical feature + // Put nCatsSplit categories on left, and the rest on the right. + // nCatsSplit is in {1,...,arity-1}. + val nCatsSplit = rng.nextInt(featureArity(feature) - 1) + 1 + val splitCategories: Array[Double] = + rng.shuffle(Range(0,featureArity(feature))).map(_.toDouble).toArray.take(nCatsSplit) + new CategoricalSplit(featureIndex = feature, + _leftCategories = splitCategories, numCategories = featureArity(feature)) + } + + val (leftChild: Node, rightChild: Node) = if (subtreeDepth == 1) { + // Add leaf nodes. Assign these jointly so they make different predictions. + val predictions = labelGenerator.nextValue() + val leftChild = new LeafNode(prediction = predictions._1, impurity = 0.0, + impurityStats = impurityCalculator) + val rightChild = new LeafNode(prediction = predictions._2, impurity = 0.0, + impurityStats = impurityCalculator) + (leftChild, rightChild) + } else { + val leftChild = randomBalancedDecisionTreeHelper(subtreeDepth - 1, featureArity, + impurityCalculator, labelGenerator, usedFeatures + feature, rng) + val rightChild = randomBalancedDecisionTreeHelper(subtreeDepth - 1, featureArity, + impurityCalculator, labelGenerator, usedFeatures + feature, rng) + (leftChild, rightChild) + } + new InternalNode(prediction = 0.0, impurity = 0.0, gain = 0.0, leftChild = leftChild, + rightChild = rightChild, split = split, impurityStats = impurityCalculator) + } +} From ecf2eedbb8fe6a7b4fda4b971adb3d84cccee335 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Thu, 30 Jun 2016 10:38:24 -0700 Subject: [PATCH 2/5] Added decision tree, forest, GBT tests --- .../sql/perf/mllib/BenchmarkAlgorithm.scala | 12 +-- .../mllib/MLTransformerBenchmarkable.scala | 11 ++- .../DecisionTreeClassification.scala | 49 +++++------ .../classification/GBTClassification.scala | 48 +++++++++++ .../classification/LogisticRegression.scala | 15 ++-- .../RandomForestClassification.scala | 20 +++++ .../spark/sql/perf/mllib/clustering/LDA.scala | 19 ++--- .../sql/perf/mllib/data/data_generation.scala | 83 +++++++------------ .../perf/mllib/regression/GLMRegression.scala | 14 ++-- .../org/apache/spark/ml/ModelBuilder.scala | 2 +- 10 files changed, 157 insertions(+), 116 deletions(-) create mode 100644 src/main/scala/com/databricks/spark/sql/perf/mllib/classification/GBTClassification.scala create mode 100644 src/main/scala/com/databricks/spark/sql/perf/mllib/classification/RandomForestClassification.scala diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/BenchmarkAlgorithm.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/BenchmarkAlgorithm.scala index ad714e5..52415b4 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/mllib/BenchmarkAlgorithm.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/BenchmarkAlgorithm.scala @@ -2,8 +2,8 @@ package com.databricks.spark.sql.perf.mllib import com.typesafe.scalalogging.slf4j.Logging -import org.apache.spark.ml.Transformer -import org.apache.spark.ml.evaluation.{MulticlassClassificationEvaluator, Evaluator} +import org.apache.spark.ml.{Estimator, Transformer} +import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.sql._ import org.apache.spark.sql.functions._ @@ -25,10 +25,10 @@ trait BenchmarkAlgorithm extends Logging { def testDataSet(ctx: MLBenchContext): DataFrame - @throws[Exception]("if training fails") - def train( - ctx: MLBenchContext, - trainingSet: DataFrame): Transformer + /** + * Create an [[Estimator]] with params set from the given [[MLBenchContext]]. + */ + def getEstimator(ctx: MLBenchContext): Estimator[_] /** * The unnormalized score of the training procedure on a dataset. The normalization is diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/MLTransformerBenchmarkable.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/MLTransformerBenchmarkable.scala index 8794c07..61886fe 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/mllib/MLTransformerBenchmarkable.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/MLTransformerBenchmarkable.scala @@ -2,10 +2,12 @@ package com.databricks.spark.sql.perf.mllib import com.databricks.spark.sql.perf._ import com.typesafe.scalalogging.slf4j.Logging -import org.apache.spark.sql._ +import org.apache.spark.sql._ import scala.collection.mutable.ArrayBuffer +import org.apache.spark.ml.Transformer + class MLTransformerBenchmarkable( params: MLParams, test: BenchmarkAlgorithm, @@ -44,7 +46,12 @@ class MLTransformerBenchmarkable( description: String, messages: ArrayBuffer[String]): BenchmarkResult = { try { - val (trainingTime, model) = measureTime(test.train(param, trainingData)) + val (trainingTime, model: Transformer) = measureTime { + logger.info(s"$this: train: trainingSet=${trainingData.schema}") + val estimator = test.getEstimator(param) + estimator.fit(trainingData) + //test.train(param, trainingData) + } logger.info(s"model: $model") val (_, scoreTraining) = measureTime { test.score(param, trainingData, model) diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/DecisionTreeClassification.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/DecisionTreeClassification.scala index 4eebd74..2543284 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/DecisionTreeClassification.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/DecisionTreeClassification.scala @@ -1,45 +1,46 @@ package com.databricks.spark.sql.perf.mllib.classification -import org.apache.spark.ml.ModelBuilder -import org.apache.spark.ml.Transformer +import org.apache.spark.ml.{Estimator, ModelBuilder, Transformer} import org.apache.spark.ml.classification.DecisionTreeClassifier -import org.apache.spark.ml.evaluation.{MulticlassClassificationEvaluator, Evaluator} -import org.apache.spark.sql.DataFrame +import org.apache.spark.ml.evaluation.{Evaluator, MulticlassClassificationEvaluator} import com.databricks.spark.sql.perf.mllib.OptionImplicits._ import com.databricks.spark.sql.perf.mllib._ import com.databricks.spark.sql.perf.mllib.data.DataGenerator -object DecisionTreeClassification extends BenchmarkAlgorithm +abstract class TreeOrForestClassification extends BenchmarkAlgorithm with TestFromTraining with TrainingSetFromTransformer with ScoringWithEvaluator { + def getFeatureArity(ctx: MLBenchContext): Array[Int] = { + val numFeatures = ctx.params.numFeatures + val fourthFeatures = numFeatures / 4 + Array.fill[Int](fourthFeatures)(2) ++ // low-arity categorical + Array.fill[Int](fourthFeatures)(20) ++ // high-arity categorical + Array.fill[Int](numFeatures - 2 * fourthFeatures)(0) // continuous + } + override protected def initialData(ctx: MLBenchContext) = { import ctx.params._ - DataGenerator.generateFeatures(ctx.sqlContext, numExamples, ctx.seed(), numPartitions, - numFeatures) + DataGenerator.generateMixedFeatures(ctx.sqlContext, numExamples, ctx.seed(), numPartitions, + getFeatureArity(ctx)) } override protected def trueModel(ctx: MLBenchContext): Transformer = { - //val rng = ctx.newGenerator() - val numFeatures = ctx.params.numFeatures.get - val fourthFeatures = numFeatures / 4 - val featureArity: Array[Int] = - Array.fill[Int](fourthFeatures)(2) ++ // low-arity categorical - Array.fill[Int](fourthFeatures)(20) ++ // high-arity categorical - Array.fill[Int](numFeatures - 2 * fourthFeatures)(0) // continuous - ModelBuilder.newDecisionTreeClassificationModel(ctx.params.depth.get, ctx.params.numClasses.get, - featureArity, ctx.seed()) - } - - override def train(ctx: MLBenchContext, trainingSet: DataFrame): Transformer = { - logger.info(s"$this: train: trainingSet=${trainingSet.schema}") - import ctx.params._ - new DecisionTreeClassifier() - .setMaxDepth(depth.get) - .fit(trainingSet) + ModelBuilder.newDecisionTreeClassificationModel(ctx.params.depth, ctx.params.numClasses, + getFeatureArity(ctx), ctx.seed()) } override protected def evaluator(ctx: MLBenchContext): Evaluator = new MulticlassClassificationEvaluator() } + +object DecisionTreeClassification extends TreeOrForestClassification { + + override def getEstimator(ctx: MLBenchContext): Estimator[_] = { + import ctx.params._ + new DecisionTreeClassifier() + .setMaxDepth(depth) + .setSeed(ctx.seed()) + } +} diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/GBTClassification.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/GBTClassification.scala new file mode 100644 index 0000000..3e62587 --- /dev/null +++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/GBTClassification.scala @@ -0,0 +1,48 @@ +package com.databricks.spark.sql.perf.mllib.classification + +import org.apache.spark.ml.{Estimator, ModelBuilder, Transformer} +import org.apache.spark.ml.classification.GBTClassifier +import org.apache.spark.ml.evaluation.{Evaluator, MulticlassClassificationEvaluator} + +import com.databricks.spark.sql.perf.mllib._ +import com.databricks.spark.sql.perf.mllib.OptionImplicits._ +import com.databricks.spark.sql.perf.mllib.data.DataGenerator + + +object GBTClassification extends BenchmarkAlgorithm + with TestFromTraining with TrainingSetFromTransformer with ScoringWithEvaluator { + + def getFeatureArity(ctx: MLBenchContext): Array[Int] = { + val numFeatures = ctx.params.numFeatures + val fourthFeatures = numFeatures / 4 + Array.fill[Int](fourthFeatures)(2) ++ // low-arity categorical + Array.fill[Int](fourthFeatures)(20) ++ // high-arity categorical + Array.fill[Int](numFeatures - 2 * fourthFeatures)(0) // continuous + } + + override protected def initialData(ctx: MLBenchContext) = { + import ctx.params._ + DataGenerator.generateMixedFeatures(ctx.sqlContext, numExamples, ctx.seed(), numPartitions, + getFeatureArity(ctx)) + } + + override protected def trueModel(ctx: MLBenchContext): Transformer = { + // We add +1 to the depth to make it more likely that many iterations of boosting are needed + // to model the true tree. + ModelBuilder.newDecisionTreeClassificationModel(ctx.params.depth + 1, ctx.params.numClasses, + getFeatureArity(ctx), ctx.seed()) + } + + override def getEstimator(ctx: MLBenchContext): Estimator[_] = { + import ctx.params._ + // TODO: subsamplingRate, featureSubsetStrategy + // TODO: cacheNodeIds, checkpoint? + new GBTClassifier() + .setMaxDepth(depth) + .setMaxIter(maxIter) + .setSeed(ctx.seed()) + } + + override protected def evaluator(ctx: MLBenchContext): Evaluator = + new MulticlassClassificationEvaluator() +} diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/LogisticRegression.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/LogisticRegression.scala index 284905d..edb3a68 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/LogisticRegression.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/LogisticRegression.scala @@ -3,19 +3,19 @@ package com.databricks.spark.sql.perf.mllib.classification import com.databricks.spark.sql.perf.mllib.OptionImplicits._ import com.databricks.spark.sql.perf.mllib._ import com.databricks.spark.sql.perf.mllib.data.DataGenerator -import org.apache.spark.ml.evaluation.{MulticlassClassificationEvaluator, Evaluator} -import org.apache.spark.ml.{Transformer, ModelBuilder} +import org.apache.spark.ml.evaluation.{Evaluator, MulticlassClassificationEvaluator} +import org.apache.spark.ml.{Estimator, ModelBuilder, Transformer} import org.apache.spark.ml import org.apache.spark.ml.linalg.Vectors -import org.apache.spark.sql.DataFrame + object LogisticRegression extends BenchmarkAlgorithm with TestFromTraining with TrainingSetFromTransformer with ScoringWithEvaluator { override protected def initialData(ctx: MLBenchContext) = { import ctx.params._ - DataGenerator.generateFeatures( + DataGenerator.generateContinuousFeatures( ctx.sqlContext, numExamples, ctx.seed(), @@ -32,15 +32,12 @@ object LogisticRegression extends BenchmarkAlgorithm ModelBuilder.newLogisticRegressionModel(coefficients, intercept) } - override def train(ctx: MLBenchContext, - trainingSet: DataFrame): Transformer = { - logger.info(s"$this: train: trainingSet=${trainingSet.schema}") + override def getEstimator(ctx: MLBenchContext): Estimator[_] = { import ctx.params._ - val lr = new ml.classification.LogisticRegression() + new ml.classification.LogisticRegression() .setTol(tol) .setMaxIter(maxIter) .setRegParam(regParam) - lr.fit(trainingSet) } override protected def evaluator(ctx: MLBenchContext): Evaluator = diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/RandomForestClassification.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/RandomForestClassification.scala new file mode 100644 index 0000000..b4352fe --- /dev/null +++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/RandomForestClassification.scala @@ -0,0 +1,20 @@ +package com.databricks.spark.sql.perf.mllib.classification + +import org.apache.spark.ml.Estimator +import org.apache.spark.ml.classification.RandomForestClassifier + +import com.databricks.spark.sql.perf.mllib._ + + +object RandomForestClassification extends TreeOrForestClassification { + + override def getEstimator(ctx: MLBenchContext): Estimator[_] = { + import ctx.params._ + // TODO: subsamplingRate, featureSubsetStrategy + // TODO: cacheNodeIds, checkpoint? + new RandomForestClassifier() + .setMaxDepth(depth.get) + .setNumTrees(maxIter.get) + .setSeed(ctx.seed()) + } +} diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/clustering/LDA.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/clustering/LDA.scala index 3fe5688..dbccf3f 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/mllib/clustering/LDA.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/clustering/LDA.scala @@ -1,13 +1,14 @@ package com.databricks.spark.sql.perf.mllib.clustering -import com.databricks.spark.sql.perf.mllib.{MLBenchContext, TestFromTraining, BenchmarkAlgorithm} +import com.databricks.spark.sql.perf.mllib.{BenchmarkAlgorithm, MLBenchContext, TestFromTraining} import com.databricks.spark.sql.perf.mllib.OptionImplicits._ import org.apache.commons.math3.random.Well19937c -import org.apache.spark.ml.Transformer + +import org.apache.spark.ml.Estimator import org.apache.spark.ml import org.apache.spark.rdd.RDD import org.apache.spark.sql._ -import org.apache.spark.ml.linalg.{Vectors, Vector} +import org.apache.spark.ml.linalg.{Vector, Vectors} import scala.collection.mutable.{HashMap => MHashMap} object LDA extends BenchmarkAlgorithm with TestFromTraining { @@ -40,15 +41,13 @@ object LDA extends BenchmarkAlgorithm with TestFromTraining { ctx.sqlContext.createDataFrame(data).toDF("docIndex", "features") } - override def train(ctx: MLBenchContext, - trainingSet: DataFrame): Transformer = { + override def getEstimator(ctx: MLBenchContext): Estimator[_] = { import ctx.params._ new ml.clustering.LDA() - .setK(k) - .setSeed(randomSeed.toLong) - .setMaxIter(maxIter) - .setOptimizer(optimizer) - .fit(trainingSet) + .setK(k) + .setSeed(randomSeed.toLong) + .setMaxIter(maxIter) + .setOptimizer(optimizer) } // TODO(?) add a scoring method here. diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/data/data_generation.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/data/data_generation.scala index ca36704..d34a828 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/mllib/data/data_generation.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/data/data_generation.scala @@ -3,82 +3,65 @@ package com.databricks.spark.sql.perf.mllib.data import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.linalg.Vector import org.apache.spark.mllib.random._ -import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD import org.apache.spark.sql.{SQLContext, DataFrame} + object DataGenerator { - def generateFeatures( + def generateContinuousFeatures( sql: SQLContext, numExamples: Long, seed: Long, numPartitions: Int, numFeatures: Int): DataFrame = { - val categoricalArities = Array.empty[Int] + val featureArity = Array.fill[Int](numFeatures)(0) val rdd: RDD[Vector] = RandomRDDs.randomRDD(sql.sparkContext, - new FeaturesGenerator(categoricalArities, numFeatures), - numExamples, numPartitions, seed) + new FeaturesGenerator(featureArity), numExamples, numPartitions, seed) sql.createDataFrame(rdd.map(Tuple1.apply)).toDF("features") } -} -class BinaryLabeledDataGenerator( - private val numFeatures: Int, - private val threshold: Double) extends RandomDataGenerator[LabeledPoint] { - - private val rng = new java.util.Random() - - override def nextValue(): LabeledPoint = { - val y = if (rng.nextDouble() < threshold) 0.0 else 1.0 - val x = Array.fill[Double](numFeatures) { - if (rng.nextDouble() < threshold) 0.0 else 1.0 - } - ??? -// LabeledPoint(y, Vectors.dense(x)) + def generateMixedFeatures( + sql: SQLContext, + numExamples: Long, + seed: Long, + numPartitions: Int, + featureArity: Array[Int]): DataFrame = { + val rdd: RDD[Vector] = RandomRDDs.randomRDD(sql.sparkContext, + new FeaturesGenerator(featureArity), numExamples, numPartitions, seed) + sql.createDataFrame(rdd.map(Tuple1.apply)).toDF("features") } - - override def setSeed(seed: Long) { - rng.setSeed(seed) - } - - override def copy(): BinaryLabeledDataGenerator = - new BinaryLabeledDataGenerator(numFeatures, threshold) - } /** * Generator for a feature vector which can include a mix of categorical and continuous features. - * @param categoricalArities Specifies the number of categories for each categorical feature. - * @param numContinuous Number of continuous features. Feature values are in range [0,1]. + * @param featureArity Length numFeatures, where 0 indicates continuous feature and > 0 + * indicates a categorical feature of that arity. */ -class FeaturesGenerator(val categoricalArities: Array[Int], val numContinuous: Int) +class FeaturesGenerator(val featureArity: Array[Int]) extends RandomDataGenerator[Vector] { - categoricalArities.foreach { arity => - require(arity >= 2, s"FeaturesGenerator given categorical arity = $arity, " + - s"but arity should be >= 2.") + featureArity.foreach { arity => + require(arity >= 0, s"FeaturesGenerator given categorical arity = $arity, " + + s"but arity should be >= 0.") } - val numFeatures = categoricalArities.length + numContinuous + val numFeatures = featureArity.length private val rng = new java.util.Random() /** - * Generates vector with categorical features first, and continuous features in [0,1] second. + * Generates vector with features in the order given by [[featureArity]] */ override def nextValue(): Vector = { - // Feature ordering matches getCategoricalFeaturesInfo. val arr = new Array[Double](numFeatures) var j = 0 - while (j < categoricalArities.length) { - arr(j) = rng.nextInt(categoricalArities(j)) - j += 1 - } - while (j < numFeatures) { - // Generating some centered data - arr(j) = 2 * rng.nextDouble() - 1 + while (j < featureArity.length) { + if (featureArity(j) == 0) + arr(j) = 2 * rng.nextDouble() - 1 // centered uniform data + else + arr(j) = rng.nextInt(featureArity(j)) j += 1 } Vectors.dense(arr) @@ -88,15 +71,5 @@ class FeaturesGenerator(val categoricalArities: Array[Int], val numContinuous: I rng.setSeed(seed) } - override def copy(): FeaturesGenerator = new FeaturesGenerator(categoricalArities, numContinuous) - - /** - * @return categoricalFeaturesInfo Map storing arity of categorical features. - * E.g., an entry (n -> k) indicates that feature n is categorical - * with k categories indexed from 0: {0, 1, ..., k-1}. - */ - def getCategoricalFeaturesInfo: Map[Int, Int] = { - // Categorical features are indexed from 0 because of the implementation of nextValue(). - categoricalArities.zipWithIndex.map(_.swap).toMap - } -} \ No newline at end of file + override def copy(): FeaturesGenerator = new FeaturesGenerator(featureArity) +} diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/regression/GLMRegression.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/regression/GLMRegression.scala index 6797e50..36d0e36 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/mllib/regression/GLMRegression.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/regression/GLMRegression.scala @@ -3,11 +3,11 @@ package com.databricks.spark.sql.perf.mllib.regression import com.databricks.spark.sql.perf.mllib.OptionImplicits._ import com.databricks.spark.sql.perf.mllib._ import com.databricks.spark.sql.perf.mllib.data.DataGenerator + import org.apache.spark.ml.evaluation.{Evaluator, RegressionEvaluator} import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.regression.GeneralizedLinearRegression -import org.apache.spark.ml.{ModelBuilder, Transformer} -import org.apache.spark.sql._ +import org.apache.spark.ml.{Estimator, ModelBuilder, Transformer} object GLMRegression extends BenchmarkAlgorithm with TestFromTraining with @@ -15,7 +15,7 @@ object GLMRegression extends BenchmarkAlgorithm with TestFromTraining with override protected def initialData(ctx: MLBenchContext) = { import ctx.params._ - DataGenerator.generateFeatures( + DataGenerator.generateContinuousFeatures( ctx.sqlContext, numExamples, ctx.seed(), @@ -36,18 +36,14 @@ object GLMRegression extends BenchmarkAlgorithm with TestFromTraining with m } - override def train( - ctx: MLBenchContext, - trainingSet: DataFrame): Transformer = { - logger.info(s"$this: train: trainingSet=${trainingSet.schema}") + override def getEstimator(ctx: MLBenchContext): Estimator[_] = { import ctx.params._ - val glr = new GeneralizedLinearRegression() + new GeneralizedLinearRegression() .setLink(link) .setFamily(family) .setRegParam(regParam) .setMaxIter(maxIter) .setTol(tol) - glr.fit(trainingSet) } override protected def evaluator(ctx: MLBenchContext): Evaluator = diff --git a/src/main/scala/org/apache/spark/ml/ModelBuilder.scala b/src/main/scala/org/apache/spark/ml/ModelBuilder.scala index 0466ba6..bbae1ff 100644 --- a/src/main/scala/org/apache/spark/ml/ModelBuilder.scala +++ b/src/main/scala/org/apache/spark/ml/ModelBuilder.scala @@ -190,7 +190,7 @@ object TreeBuilder { // nCatsSplit is in {1,...,arity-1}. val nCatsSplit = rng.nextInt(featureArity(feature) - 1) + 1 val splitCategories: Array[Double] = - rng.shuffle(Range(0,featureArity(feature))).map(_.toDouble).toArray.take(nCatsSplit) + rng.shuffle(Range(0,featureArity(feature)).toList).toArray.map(_.toDouble).take(nCatsSplit) new CategoricalSplit(featureIndex = feature, _leftCategories = splitCategories, numCategories = featureArity(feature)) } From c15d083fe7e7172ac2fb1e2d5b6b7e6ed3be877b Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Thu, 30 Jun 2016 10:45:15 -0700 Subject: [PATCH 3/5] cleanups --- .../sql/perf/mllib/MLTransformerBenchmarkable.scala | 1 - .../spark/sql/perf/mllib/clustering/LDA.scala | 11 +++++++---- .../{data_generation.scala => dataGeneration.scala} | 5 +++++ .../sql/perf/mllib/regression/GLMRegression.scala | 8 ++++---- 4 files changed, 16 insertions(+), 9 deletions(-) rename src/main/scala/com/databricks/spark/sql/perf/mllib/data/{data_generation.scala => dataGeneration.scala} (89%) diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/MLTransformerBenchmarkable.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/MLTransformerBenchmarkable.scala index 61886fe..57c051b 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/mllib/MLTransformerBenchmarkable.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/MLTransformerBenchmarkable.scala @@ -50,7 +50,6 @@ class MLTransformerBenchmarkable( logger.info(s"$this: train: trainingSet=${trainingData.schema}") val estimator = test.getEstimator(param) estimator.fit(trainingData) - //test.train(param, trainingData) } logger.info(s"model: $model") val (_, scoreTraining) = measureTime { diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/clustering/LDA.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/clustering/LDA.scala index dbccf3f..a6daf4b 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/mllib/clustering/LDA.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/clustering/LDA.scala @@ -1,7 +1,7 @@ package com.databricks.spark.sql.perf.mllib.clustering -import com.databricks.spark.sql.perf.mllib.{BenchmarkAlgorithm, MLBenchContext, TestFromTraining} -import com.databricks.spark.sql.perf.mllib.OptionImplicits._ +import scala.collection.mutable.{HashMap => MHashMap} + import org.apache.commons.math3.random.Well19937c import org.apache.spark.ml.Estimator @@ -9,7 +9,10 @@ import org.apache.spark.ml import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.ml.linalg.{Vector, Vectors} -import scala.collection.mutable.{HashMap => MHashMap} + +import com.databricks.spark.sql.perf.mllib.{BenchmarkAlgorithm, MLBenchContext, TestFromTraining} +import com.databricks.spark.sql.perf.mllib.OptionImplicits._ + object LDA extends BenchmarkAlgorithm with TestFromTraining { // The LDA model is package private, no need to expose it. @@ -51,4 +54,4 @@ object LDA extends BenchmarkAlgorithm with TestFromTraining { } // TODO(?) add a scoring method here. -} \ No newline at end of file +} diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/data/data_generation.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/data/dataGeneration.scala similarity index 89% rename from src/main/scala/com/databricks/spark/sql/perf/mllib/data/data_generation.scala rename to src/main/scala/com/databricks/spark/sql/perf/mllib/data/dataGeneration.scala index d34a828..2461fbc 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/mllib/data/data_generation.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/data/dataGeneration.scala @@ -21,6 +21,11 @@ object DataGenerator { sql.createDataFrame(rdd.map(Tuple1.apply)).toDF("features") } + /** + * Generate a mix of continuous and categorical features. + * @param featureArity Array of length numFeatures, where 0 indicates a continuous feature and + * a value > 0 indicates a categorical feature with that arity. + */ def generateMixedFeatures( sql: SQLContext, numExamples: Long, diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/regression/GLMRegression.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/regression/GLMRegression.scala index 36d0e36..aea75fd 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/mllib/regression/GLMRegression.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/regression/GLMRegression.scala @@ -1,14 +1,14 @@ package com.databricks.spark.sql.perf.mllib.regression -import com.databricks.spark.sql.perf.mllib.OptionImplicits._ -import com.databricks.spark.sql.perf.mllib._ -import com.databricks.spark.sql.perf.mllib.data.DataGenerator - import org.apache.spark.ml.evaluation.{Evaluator, RegressionEvaluator} import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.regression.GeneralizedLinearRegression import org.apache.spark.ml.{Estimator, ModelBuilder, Transformer} +import com.databricks.spark.sql.perf.mllib.OptionImplicits._ +import com.databricks.spark.sql.perf.mllib._ +import com.databricks.spark.sql.perf.mllib.data.DataGenerator + object GLMRegression extends BenchmarkAlgorithm with TestFromTraining with TrainingSetFromTransformer with ScoringWithEvaluator { From 813bd8ad59e17d65a33ed4d37612dcd7626325eb Mon Sep 17 00:00:00 2001 From: Timothy Hunter Date: Fri, 1 Jul 2016 10:34:42 -0700 Subject: [PATCH 4/5] adding more experiments --- src/main/scala/configs/mllib-small.yaml | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/src/main/scala/configs/mllib-small.yaml b/src/main/scala/configs/mllib-small.yaml index 5d4fac7..e600554 100644 --- a/src/main/scala/configs/mllib-small.yaml +++ b/src/main/scala/configs/mllib-small.yaml @@ -33,3 +33,25 @@ benchmarks: tol: 0.0 maxIter: 10 regParam: 0.1 + - name: classification.DecisionTreeClassification + params: + numExamples: 100 + numTestExamples: 10 + depth: 3 + numClasses: 4 + numFeatures: 5 + - name: classification.RandomForestClassification + params: + numExamples: 100 + numTestExamples: 10 + depth: 3 + numClasses: 4 + numFeatures: 5 + - name: classification.GBTClassification + params: + numExamples: 100 + numTestExamples: 10 + depth: 3 + numClasses: 4 + numFeatures: 5 + maxIter: 3 From 495e2716c42ef07afb643ea1bd1be2ea76b38ef6 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Fri, 1 Jul 2016 17:39:28 -0700 Subject: [PATCH 5/5] updated per code review. works in local tests --- .../DecisionTreeClassification.scala | 32 +++++++++++++------ .../classification/GBTClassification.scala | 12 ++----- .../RandomForestClassification.scala | 5 +-- src/main/scala/configs/mllib-small.yaml | 1 + .../org/apache/spark/ml/ModelBuilder.scala | 5 +-- 5 files changed, 31 insertions(+), 24 deletions(-) diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/DecisionTreeClassification.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/DecisionTreeClassification.scala index 2543284..7ba333a 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/DecisionTreeClassification.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/DecisionTreeClassification.scala @@ -12,23 +12,15 @@ import com.databricks.spark.sql.perf.mllib.data.DataGenerator abstract class TreeOrForestClassification extends BenchmarkAlgorithm with TestFromTraining with TrainingSetFromTransformer with ScoringWithEvaluator { - def getFeatureArity(ctx: MLBenchContext): Array[Int] = { - val numFeatures = ctx.params.numFeatures - val fourthFeatures = numFeatures / 4 - Array.fill[Int](fourthFeatures)(2) ++ // low-arity categorical - Array.fill[Int](fourthFeatures)(20) ++ // high-arity categorical - Array.fill[Int](numFeatures - 2 * fourthFeatures)(0) // continuous - } - override protected def initialData(ctx: MLBenchContext) = { import ctx.params._ DataGenerator.generateMixedFeatures(ctx.sqlContext, numExamples, ctx.seed(), numPartitions, - getFeatureArity(ctx)) + TreeOrForestClassification.getFeatureArity(ctx)) } override protected def trueModel(ctx: MLBenchContext): Transformer = { ModelBuilder.newDecisionTreeClassificationModel(ctx.params.depth, ctx.params.numClasses, - getFeatureArity(ctx), ctx.seed()) + TreeOrForestClassification.getFeatureArity(ctx), ctx.seed()) } override protected def evaluator(ctx: MLBenchContext): Evaluator = @@ -44,3 +36,23 @@ object DecisionTreeClassification extends TreeOrForestClassification { .setSeed(ctx.seed()) } } + +object TreeOrForestClassification { + + /** + * Get feature arity for tree and tree ensemble tests. + * Currently, this is hard-coded as: + * - 1/2 binary features + * - 1/2 high-arity (20-category) features + * - 1/2 continuous features + * @return Array of length numFeatures, where 0 indicates continuous feature and + * value > 0 indicates a categorical feature of that arity. + */ + def getFeatureArity(ctx: MLBenchContext): Array[Int] = { + val numFeatures = ctx.params.numFeatures + val fourthFeatures = numFeatures / 4 + Array.fill[Int](fourthFeatures)(2) ++ // low-arity categorical + Array.fill[Int](fourthFeatures)(20) ++ // high-arity categorical + Array.fill[Int](numFeatures - 2 * fourthFeatures)(0) // continuous + } +} diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/GBTClassification.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/GBTClassification.scala index 3e62587..acdb105 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/GBTClassification.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/GBTClassification.scala @@ -12,25 +12,17 @@ import com.databricks.spark.sql.perf.mllib.data.DataGenerator object GBTClassification extends BenchmarkAlgorithm with TestFromTraining with TrainingSetFromTransformer with ScoringWithEvaluator { - def getFeatureArity(ctx: MLBenchContext): Array[Int] = { - val numFeatures = ctx.params.numFeatures - val fourthFeatures = numFeatures / 4 - Array.fill[Int](fourthFeatures)(2) ++ // low-arity categorical - Array.fill[Int](fourthFeatures)(20) ++ // high-arity categorical - Array.fill[Int](numFeatures - 2 * fourthFeatures)(0) // continuous - } - override protected def initialData(ctx: MLBenchContext) = { import ctx.params._ DataGenerator.generateMixedFeatures(ctx.sqlContext, numExamples, ctx.seed(), numPartitions, - getFeatureArity(ctx)) + TreeOrForestClassification.getFeatureArity(ctx)) } override protected def trueModel(ctx: MLBenchContext): Transformer = { // We add +1 to the depth to make it more likely that many iterations of boosting are needed // to model the true tree. ModelBuilder.newDecisionTreeClassificationModel(ctx.params.depth + 1, ctx.params.numClasses, - getFeatureArity(ctx), ctx.seed()) + TreeOrForestClassification.getFeatureArity(ctx), ctx.seed()) } override def getEstimator(ctx: MLBenchContext): Estimator[_] = { diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/RandomForestClassification.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/RandomForestClassification.scala index b4352fe..3aff023 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/RandomForestClassification.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/RandomForestClassification.scala @@ -4,6 +4,7 @@ import org.apache.spark.ml.Estimator import org.apache.spark.ml.classification.RandomForestClassifier import com.databricks.spark.sql.perf.mllib._ +import com.databricks.spark.sql.perf.mllib.OptionImplicits._ object RandomForestClassification extends TreeOrForestClassification { @@ -13,8 +14,8 @@ object RandomForestClassification extends TreeOrForestClassification { // TODO: subsamplingRate, featureSubsetStrategy // TODO: cacheNodeIds, checkpoint? new RandomForestClassifier() - .setMaxDepth(depth.get) - .setNumTrees(maxIter.get) + .setMaxDepth(depth) + .setNumTrees(maxIter) .setSeed(ctx.seed()) } } diff --git a/src/main/scala/configs/mllib-small.yaml b/src/main/scala/configs/mllib-small.yaml index e600554..2fb1e6a 100644 --- a/src/main/scala/configs/mllib-small.yaml +++ b/src/main/scala/configs/mllib-small.yaml @@ -47,6 +47,7 @@ benchmarks: depth: 3 numClasses: 4 numFeatures: 5 + maxIter: 3 - name: classification.GBTClassification params: numExamples: 100 diff --git a/src/main/scala/org/apache/spark/ml/ModelBuilder.scala b/src/main/scala/org/apache/spark/ml/ModelBuilder.scala index bbae1ff..7d0143c 100644 --- a/src/main/scala/org/apache/spark/ml/ModelBuilder.scala +++ b/src/main/scala/org/apache/spark/ml/ModelBuilder.scala @@ -75,7 +75,7 @@ object TreeBuilder { new Pair[Double, Double](left, right) } - override def setSeed(seed: Long) { + override def setSeed(seed: Long): Unit = { rng.setSeed(seed) } @@ -94,7 +94,7 @@ object TreeBuilder { override def nextValue(): Pair[Double, Double] = new Pair[Double, Double](rng.nextDouble(), rng.nextDouble()) - override def setSeed(seed: Long) { + override def setSeed(seed: Long): Unit = { rng.setSeed(seed) } @@ -135,6 +135,7 @@ object TreeBuilder { } else { new ClassLabelPairGenerator(labelType) } + labelGenerator.setSeed(rng.nextLong) // We use a dummy impurityCalculator for all nodes. val impurityCalculator = if (isRegression) { ImpurityCalculator.getCalculator("variance", Array.fill[Double](3)(0.0))