Merge pull request #79 from jkbradley/tree-test-fix

Fixed tree, forest, GBT tests by adding metadata to DataFrames
This commit is contained in:
Timothy Hunter 2016-07-11 10:42:19 -07:00 committed by GitHub
commit 8830bffd46
3 changed files with 68 additions and 14 deletions

View File

@ -1,8 +1,9 @@
package com.databricks.spark.sql.perf.mllib.classification
import org.apache.spark.ml.{Estimator, ModelBuilder, Transformer}
import org.apache.spark.ml.{Estimator, ModelBuilder, Transformer, TreeUtils}
import org.apache.spark.ml.classification.DecisionTreeClassifier
import org.apache.spark.ml.evaluation.{Evaluator, MulticlassClassificationEvaluator}
import org.apache.spark.sql.DataFrame
import com.databricks.spark.sql.perf.mllib.OptionImplicits._
import com.databricks.spark.sql.perf.mllib._
@ -12,15 +13,19 @@ import com.databricks.spark.sql.perf.mllib.data.DataGenerator
abstract class TreeOrForestClassification extends BenchmarkAlgorithm
with TestFromTraining with TrainingSetFromTransformer with ScoringWithEvaluator {
import TreeOrForestClassification.getFeatureArity
override protected def initialData(ctx: MLBenchContext) = {
import ctx.params._
DataGenerator.generateMixedFeatures(ctx.sqlContext, numExamples, ctx.seed(), numPartitions,
TreeOrForestClassification.getFeatureArity(ctx))
val featureArity: Array[Int] = getFeatureArity(ctx)
val data: DataFrame = DataGenerator.generateMixedFeatures(ctx.sqlContext, numExamples,
ctx.seed(), numPartitions, featureArity)
TreeUtils.setMetadata(data, "label", numClasses, "features", featureArity)
}
override protected def trueModel(ctx: MLBenchContext): Transformer = {
ModelBuilder.newDecisionTreeClassificationModel(ctx.params.depth, ctx.params.numClasses,
TreeOrForestClassification.getFeatureArity(ctx), ctx.seed())
getFeatureArity(ctx), ctx.seed())
}
override protected def evaluator(ctx: MLBenchContext): Evaluator =
@ -42,11 +47,12 @@ object TreeOrForestClassification {
/**
* Get feature arity for tree and tree ensemble tests.
* Currently, this is hard-coded as:
* - 1/2 binary features
* - 1/2 high-arity (20-category) features
* - 1/2 continuous features
* @return Array of length numFeatures, where 0 indicates continuous feature and
* value > 0 indicates a categorical feature of that arity.
* - 1/2 binary features
* - 1/2 high-arity (20-category) features
* - 1/2 continuous features
*
* @return Array of length numFeatures, where 0 indicates continuous feature and
* value > 0 indicates a categorical feature of that arity.
*/
def getFeatureArity(ctx: MLBenchContext): Array[Int] = {
val numFeatures = ctx.params.numFeatures

View File

@ -1,8 +1,9 @@
package com.databricks.spark.sql.perf.mllib.classification
import org.apache.spark.ml.{Estimator, ModelBuilder, Transformer}
import org.apache.spark.ml.{Estimator, ModelBuilder, Transformer, TreeUtils}
import org.apache.spark.ml.classification.GBTClassifier
import org.apache.spark.ml.evaluation.{Evaluator, MulticlassClassificationEvaluator}
import org.apache.spark.sql._
import com.databricks.spark.sql.perf.mllib._
import com.databricks.spark.sql.perf.mllib.OptionImplicits._
@ -12,17 +13,22 @@ import com.databricks.spark.sql.perf.mllib.data.DataGenerator
object GBTClassification extends BenchmarkAlgorithm
with TestFromTraining with TrainingSetFromTransformer with ScoringWithEvaluator {
import TreeOrForestClassification.getFeatureArity
override protected def initialData(ctx: MLBenchContext) = {
import ctx.params._
DataGenerator.generateMixedFeatures(ctx.sqlContext, numExamples, ctx.seed(), numPartitions,
TreeOrForestClassification.getFeatureArity(ctx))
val featureArity: Array[Int] = getFeatureArity(ctx)
val data: DataFrame = DataGenerator.generateMixedFeatures(ctx.sqlContext, numExamples,
ctx.seed(), numPartitions, featureArity)
TreeUtils.setMetadata(data, "label", numClasses, "features", featureArity)
}
override protected def trueModel(ctx: MLBenchContext): Transformer = {
import ctx.params._
// We add +1 to the depth to make it more likely that many iterations of boosting are needed
// to model the true tree.
ModelBuilder.newDecisionTreeClassificationModel(ctx.params.depth + 1, ctx.params.numClasses,
TreeOrForestClassification.getFeatureArity(ctx), ctx.seed())
ModelBuilder.newDecisionTreeClassificationModel(depth + 1, numClasses, getFeatureArity(ctx),
ctx.seed())
}
override def getEstimator(ctx: MLBenchContext): Estimator[_] = {

View File

@ -0,0 +1,42 @@
package org.apache.spark.ml
import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute}
import org.apache.spark.sql.DataFrame
object TreeUtils {
/**
* Set label metadata (particularly the number of classes) on a DataFrame.
*
* @param data Dataset. Categorical features and labels must already have 0-based indices.
* This must be non-empty.
* @param labelColName Name of the label column on which to set the metadata.
* @param numClasses Number of classes label can take. If 0, mark as continuous.
* @param featuresColName Name of the features column
* @param featureArity Array of length numFeatures, where 0 indicates continuous feature and
* value > 0 indicates a categorical feature of that arity.
* @return DataFrame with metadata
*/
def setMetadata(
data: DataFrame,
labelColName: String,
numClasses: Int,
featuresColName: String,
featureArity: Array[Int]): DataFrame = {
val labelAttribute = if (numClasses == 0) {
NumericAttribute.defaultAttr.withName(labelColName)
} else {
NominalAttribute.defaultAttr.withName(labelColName).withNumValues(numClasses)
}
val labelMetadata = labelAttribute.toMetadata()
val featuresAttributes = featureArity.zipWithIndex.map { case (arity: Int, feature: Int) =>
if (arity > 0) {
NominalAttribute.defaultAttr.withIndex(feature).withNumValues(arity)
} else {
NumericAttribute.defaultAttr.withIndex(feature)
}
}
val featuresMetadata = new AttributeGroup("features", featuresAttributes).toMetadata()
data.select(data(featuresColName).as(featuresColName, featuresMetadata),
data(labelColName).as(labelColName, labelMetadata))
}
}