updated per code review. works in local tests
This commit is contained in:
parent
c2f0a35db4
commit
495e2716c4
@ -12,23 +12,15 @@ import com.databricks.spark.sql.perf.mllib.data.DataGenerator
|
||||
abstract class TreeOrForestClassification extends BenchmarkAlgorithm
|
||||
with TestFromTraining with TrainingSetFromTransformer with ScoringWithEvaluator {
|
||||
|
||||
def getFeatureArity(ctx: MLBenchContext): Array[Int] = {
|
||||
val numFeatures = ctx.params.numFeatures
|
||||
val fourthFeatures = numFeatures / 4
|
||||
Array.fill[Int](fourthFeatures)(2) ++ // low-arity categorical
|
||||
Array.fill[Int](fourthFeatures)(20) ++ // high-arity categorical
|
||||
Array.fill[Int](numFeatures - 2 * fourthFeatures)(0) // continuous
|
||||
}
|
||||
|
||||
override protected def initialData(ctx: MLBenchContext) = {
|
||||
import ctx.params._
|
||||
DataGenerator.generateMixedFeatures(ctx.sqlContext, numExamples, ctx.seed(), numPartitions,
|
||||
getFeatureArity(ctx))
|
||||
TreeOrForestClassification.getFeatureArity(ctx))
|
||||
}
|
||||
|
||||
override protected def trueModel(ctx: MLBenchContext): Transformer = {
|
||||
ModelBuilder.newDecisionTreeClassificationModel(ctx.params.depth, ctx.params.numClasses,
|
||||
getFeatureArity(ctx), ctx.seed())
|
||||
TreeOrForestClassification.getFeatureArity(ctx), ctx.seed())
|
||||
}
|
||||
|
||||
override protected def evaluator(ctx: MLBenchContext): Evaluator =
|
||||
@ -44,3 +36,23 @@ object DecisionTreeClassification extends TreeOrForestClassification {
|
||||
.setSeed(ctx.seed())
|
||||
}
|
||||
}
|
||||
|
||||
object TreeOrForestClassification {
|
||||
|
||||
/**
|
||||
* Get feature arity for tree and tree ensemble tests.
|
||||
* Currently, this is hard-coded as:
|
||||
* - 1/2 binary features
|
||||
* - 1/2 high-arity (20-category) features
|
||||
* - 1/2 continuous features
|
||||
* @return Array of length numFeatures, where 0 indicates continuous feature and
|
||||
* value > 0 indicates a categorical feature of that arity.
|
||||
*/
|
||||
def getFeatureArity(ctx: MLBenchContext): Array[Int] = {
|
||||
val numFeatures = ctx.params.numFeatures
|
||||
val fourthFeatures = numFeatures / 4
|
||||
Array.fill[Int](fourthFeatures)(2) ++ // low-arity categorical
|
||||
Array.fill[Int](fourthFeatures)(20) ++ // high-arity categorical
|
||||
Array.fill[Int](numFeatures - 2 * fourthFeatures)(0) // continuous
|
||||
}
|
||||
}
|
||||
|
||||
@ -12,25 +12,17 @@ import com.databricks.spark.sql.perf.mllib.data.DataGenerator
|
||||
object GBTClassification extends BenchmarkAlgorithm
|
||||
with TestFromTraining with TrainingSetFromTransformer with ScoringWithEvaluator {
|
||||
|
||||
def getFeatureArity(ctx: MLBenchContext): Array[Int] = {
|
||||
val numFeatures = ctx.params.numFeatures
|
||||
val fourthFeatures = numFeatures / 4
|
||||
Array.fill[Int](fourthFeatures)(2) ++ // low-arity categorical
|
||||
Array.fill[Int](fourthFeatures)(20) ++ // high-arity categorical
|
||||
Array.fill[Int](numFeatures - 2 * fourthFeatures)(0) // continuous
|
||||
}
|
||||
|
||||
override protected def initialData(ctx: MLBenchContext) = {
|
||||
import ctx.params._
|
||||
DataGenerator.generateMixedFeatures(ctx.sqlContext, numExamples, ctx.seed(), numPartitions,
|
||||
getFeatureArity(ctx))
|
||||
TreeOrForestClassification.getFeatureArity(ctx))
|
||||
}
|
||||
|
||||
override protected def trueModel(ctx: MLBenchContext): Transformer = {
|
||||
// We add +1 to the depth to make it more likely that many iterations of boosting are needed
|
||||
// to model the true tree.
|
||||
ModelBuilder.newDecisionTreeClassificationModel(ctx.params.depth + 1, ctx.params.numClasses,
|
||||
getFeatureArity(ctx), ctx.seed())
|
||||
TreeOrForestClassification.getFeatureArity(ctx), ctx.seed())
|
||||
}
|
||||
|
||||
override def getEstimator(ctx: MLBenchContext): Estimator[_] = {
|
||||
|
||||
@ -4,6 +4,7 @@ import org.apache.spark.ml.Estimator
|
||||
import org.apache.spark.ml.classification.RandomForestClassifier
|
||||
|
||||
import com.databricks.spark.sql.perf.mllib._
|
||||
import com.databricks.spark.sql.perf.mllib.OptionImplicits._
|
||||
|
||||
|
||||
object RandomForestClassification extends TreeOrForestClassification {
|
||||
@ -13,8 +14,8 @@ object RandomForestClassification extends TreeOrForestClassification {
|
||||
// TODO: subsamplingRate, featureSubsetStrategy
|
||||
// TODO: cacheNodeIds, checkpoint?
|
||||
new RandomForestClassifier()
|
||||
.setMaxDepth(depth.get)
|
||||
.setNumTrees(maxIter.get)
|
||||
.setMaxDepth(depth)
|
||||
.setNumTrees(maxIter)
|
||||
.setSeed(ctx.seed())
|
||||
}
|
||||
}
|
||||
|
||||
@ -47,6 +47,7 @@ benchmarks:
|
||||
depth: 3
|
||||
numClasses: 4
|
||||
numFeatures: 5
|
||||
maxIter: 3
|
||||
- name: classification.GBTClassification
|
||||
params:
|
||||
numExamples: 100
|
||||
|
||||
@ -75,7 +75,7 @@ object TreeBuilder {
|
||||
new Pair[Double, Double](left, right)
|
||||
}
|
||||
|
||||
override def setSeed(seed: Long) {
|
||||
override def setSeed(seed: Long): Unit = {
|
||||
rng.setSeed(seed)
|
||||
}
|
||||
|
||||
@ -94,7 +94,7 @@ object TreeBuilder {
|
||||
override def nextValue(): Pair[Double, Double] =
|
||||
new Pair[Double, Double](rng.nextDouble(), rng.nextDouble())
|
||||
|
||||
override def setSeed(seed: Long) {
|
||||
override def setSeed(seed: Long): Unit = {
|
||||
rng.setSeed(seed)
|
||||
}
|
||||
|
||||
@ -135,6 +135,7 @@ object TreeBuilder {
|
||||
} else {
|
||||
new ClassLabelPairGenerator(labelType)
|
||||
}
|
||||
labelGenerator.setSeed(rng.nextLong)
|
||||
// We use a dummy impurityCalculator for all nodes.
|
||||
val impurityCalculator = if (isRegression) {
|
||||
ImpurityCalculator.getCalculator("variance", Array.fill[Double](3)(0.0))
|
||||
|
||||
Loading…
Reference in New Issue
Block a user