updated per code review. works in local tests

This commit is contained in:
Joseph K. Bradley 2016-07-01 17:39:28 -07:00
parent c2f0a35db4
commit 495e2716c4
5 changed files with 31 additions and 24 deletions

View File

@ -12,23 +12,15 @@ import com.databricks.spark.sql.perf.mllib.data.DataGenerator
abstract class TreeOrForestClassification extends BenchmarkAlgorithm
with TestFromTraining with TrainingSetFromTransformer with ScoringWithEvaluator {
def getFeatureArity(ctx: MLBenchContext): Array[Int] = {
val numFeatures = ctx.params.numFeatures
val fourthFeatures = numFeatures / 4
Array.fill[Int](fourthFeatures)(2) ++ // low-arity categorical
Array.fill[Int](fourthFeatures)(20) ++ // high-arity categorical
Array.fill[Int](numFeatures - 2 * fourthFeatures)(0) // continuous
}
override protected def initialData(ctx: MLBenchContext) = {
import ctx.params._
DataGenerator.generateMixedFeatures(ctx.sqlContext, numExamples, ctx.seed(), numPartitions,
getFeatureArity(ctx))
TreeOrForestClassification.getFeatureArity(ctx))
}
override protected def trueModel(ctx: MLBenchContext): Transformer = {
ModelBuilder.newDecisionTreeClassificationModel(ctx.params.depth, ctx.params.numClasses,
getFeatureArity(ctx), ctx.seed())
TreeOrForestClassification.getFeatureArity(ctx), ctx.seed())
}
override protected def evaluator(ctx: MLBenchContext): Evaluator =
@ -44,3 +36,23 @@ object DecisionTreeClassification extends TreeOrForestClassification {
.setSeed(ctx.seed())
}
}
object TreeOrForestClassification {
/**
* Get feature arity for tree and tree ensemble tests.
* Currently, this is hard-coded as:
* - 1/2 binary features
* - 1/2 high-arity (20-category) features
* - 1/2 continuous features
* @return Array of length numFeatures, where 0 indicates continuous feature and
* value > 0 indicates a categorical feature of that arity.
*/
def getFeatureArity(ctx: MLBenchContext): Array[Int] = {
val numFeatures = ctx.params.numFeatures
val fourthFeatures = numFeatures / 4
Array.fill[Int](fourthFeatures)(2) ++ // low-arity categorical
Array.fill[Int](fourthFeatures)(20) ++ // high-arity categorical
Array.fill[Int](numFeatures - 2 * fourthFeatures)(0) // continuous
}
}

View File

@ -12,25 +12,17 @@ import com.databricks.spark.sql.perf.mllib.data.DataGenerator
object GBTClassification extends BenchmarkAlgorithm
with TestFromTraining with TrainingSetFromTransformer with ScoringWithEvaluator {
def getFeatureArity(ctx: MLBenchContext): Array[Int] = {
val numFeatures = ctx.params.numFeatures
val fourthFeatures = numFeatures / 4
Array.fill[Int](fourthFeatures)(2) ++ // low-arity categorical
Array.fill[Int](fourthFeatures)(20) ++ // high-arity categorical
Array.fill[Int](numFeatures - 2 * fourthFeatures)(0) // continuous
}
override protected def initialData(ctx: MLBenchContext) = {
import ctx.params._
DataGenerator.generateMixedFeatures(ctx.sqlContext, numExamples, ctx.seed(), numPartitions,
getFeatureArity(ctx))
TreeOrForestClassification.getFeatureArity(ctx))
}
override protected def trueModel(ctx: MLBenchContext): Transformer = {
// We add +1 to the depth to make it more likely that many iterations of boosting are needed
// to model the true tree.
ModelBuilder.newDecisionTreeClassificationModel(ctx.params.depth + 1, ctx.params.numClasses,
getFeatureArity(ctx), ctx.seed())
TreeOrForestClassification.getFeatureArity(ctx), ctx.seed())
}
override def getEstimator(ctx: MLBenchContext): Estimator[_] = {

View File

@ -4,6 +4,7 @@ import org.apache.spark.ml.Estimator
import org.apache.spark.ml.classification.RandomForestClassifier
import com.databricks.spark.sql.perf.mllib._
import com.databricks.spark.sql.perf.mllib.OptionImplicits._
object RandomForestClassification extends TreeOrForestClassification {
@ -13,8 +14,8 @@ object RandomForestClassification extends TreeOrForestClassification {
// TODO: subsamplingRate, featureSubsetStrategy
// TODO: cacheNodeIds, checkpoint?
new RandomForestClassifier()
.setMaxDepth(depth.get)
.setNumTrees(maxIter.get)
.setMaxDepth(depth)
.setNumTrees(maxIter)
.setSeed(ctx.seed())
}
}

View File

@ -47,6 +47,7 @@ benchmarks:
depth: 3
numClasses: 4
numFeatures: 5
maxIter: 3
- name: classification.GBTClassification
params:
numExamples: 100

View File

@ -75,7 +75,7 @@ object TreeBuilder {
new Pair[Double, Double](left, right)
}
override def setSeed(seed: Long) {
override def setSeed(seed: Long): Unit = {
rng.setSeed(seed)
}
@ -94,7 +94,7 @@ object TreeBuilder {
override def nextValue(): Pair[Double, Double] =
new Pair[Double, Double](rng.nextDouble(), rng.nextDouble())
override def setSeed(seed: Long) {
override def setSeed(seed: Long): Unit = {
rng.setSeed(seed)
}
@ -135,6 +135,7 @@ object TreeBuilder {
} else {
new ClassLabelPairGenerator(labelType)
}
labelGenerator.setSeed(rng.nextLong)
// We use a dummy impurityCalculator for all nodes.
val impurityCalculator = if (isRegression) {
ImpurityCalculator.getCalculator("variance", Array.fill[Double](3)(0.0))