diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/DecisionTreeClassification.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/DecisionTreeClassification.scala index 7ba333a..45ce7f8 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/DecisionTreeClassification.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/DecisionTreeClassification.scala @@ -1,8 +1,9 @@ package com.databricks.spark.sql.perf.mllib.classification -import org.apache.spark.ml.{Estimator, ModelBuilder, Transformer} +import org.apache.spark.ml.{Estimator, ModelBuilder, Transformer, TreeUtils} import org.apache.spark.ml.classification.DecisionTreeClassifier import org.apache.spark.ml.evaluation.{Evaluator, MulticlassClassificationEvaluator} +import org.apache.spark.sql.DataFrame import com.databricks.spark.sql.perf.mllib.OptionImplicits._ import com.databricks.spark.sql.perf.mllib._ @@ -12,15 +13,19 @@ import com.databricks.spark.sql.perf.mllib.data.DataGenerator abstract class TreeOrForestClassification extends BenchmarkAlgorithm with TestFromTraining with TrainingSetFromTransformer with ScoringWithEvaluator { + import TreeOrForestClassification.getFeatureArity + override protected def initialData(ctx: MLBenchContext) = { import ctx.params._ - DataGenerator.generateMixedFeatures(ctx.sqlContext, numExamples, ctx.seed(), numPartitions, - TreeOrForestClassification.getFeatureArity(ctx)) + val featureArity: Array[Int] = getFeatureArity(ctx) + val data: DataFrame = DataGenerator.generateMixedFeatures(ctx.sqlContext, numExamples, + ctx.seed(), numPartitions, featureArity) + TreeUtils.setMetadata(data, "label", numClasses, "features", featureArity) } override protected def trueModel(ctx: MLBenchContext): Transformer = { ModelBuilder.newDecisionTreeClassificationModel(ctx.params.depth, ctx.params.numClasses, - TreeOrForestClassification.getFeatureArity(ctx), ctx.seed()) + getFeatureArity(ctx), ctx.seed()) } override protected def evaluator(ctx: MLBenchContext): Evaluator = @@ -42,11 +47,12 @@ object TreeOrForestClassification { /** * Get feature arity for tree and tree ensemble tests. * Currently, this is hard-coded as: - * - 1/2 binary features - * - 1/2 high-arity (20-category) features - * - 1/2 continuous features - * @return Array of length numFeatures, where 0 indicates continuous feature and - * value > 0 indicates a categorical feature of that arity. + * - 1/2 binary features + * - 1/2 high-arity (20-category) features + * - 1/2 continuous features + * + * @return Array of length numFeatures, where 0 indicates continuous feature and + * value > 0 indicates a categorical feature of that arity. */ def getFeatureArity(ctx: MLBenchContext): Array[Int] = { val numFeatures = ctx.params.numFeatures diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/GBTClassification.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/GBTClassification.scala index acdb105..dfd172d 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/GBTClassification.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/GBTClassification.scala @@ -1,8 +1,9 @@ package com.databricks.spark.sql.perf.mllib.classification -import org.apache.spark.ml.{Estimator, ModelBuilder, Transformer} +import org.apache.spark.ml.{Estimator, ModelBuilder, Transformer, TreeUtils} import org.apache.spark.ml.classification.GBTClassifier import org.apache.spark.ml.evaluation.{Evaluator, MulticlassClassificationEvaluator} +import org.apache.spark.sql._ import com.databricks.spark.sql.perf.mllib._ import com.databricks.spark.sql.perf.mllib.OptionImplicits._ @@ -12,17 +13,22 @@ import com.databricks.spark.sql.perf.mllib.data.DataGenerator object GBTClassification extends BenchmarkAlgorithm with TestFromTraining with TrainingSetFromTransformer with ScoringWithEvaluator { + import TreeOrForestClassification.getFeatureArity + override protected def initialData(ctx: MLBenchContext) = { import ctx.params._ - DataGenerator.generateMixedFeatures(ctx.sqlContext, numExamples, ctx.seed(), numPartitions, - TreeOrForestClassification.getFeatureArity(ctx)) + val featureArity: Array[Int] = getFeatureArity(ctx) + val data: DataFrame = DataGenerator.generateMixedFeatures(ctx.sqlContext, numExamples, + ctx.seed(), numPartitions, featureArity) + TreeUtils.setMetadata(data, "label", numClasses, "features", featureArity) } override protected def trueModel(ctx: MLBenchContext): Transformer = { + import ctx.params._ // We add +1 to the depth to make it more likely that many iterations of boosting are needed // to model the true tree. - ModelBuilder.newDecisionTreeClassificationModel(ctx.params.depth + 1, ctx.params.numClasses, - TreeOrForestClassification.getFeatureArity(ctx), ctx.seed()) + ModelBuilder.newDecisionTreeClassificationModel(depth + 1, numClasses, getFeatureArity(ctx), + ctx.seed()) } override def getEstimator(ctx: MLBenchContext): Estimator[_] = { diff --git a/src/main/scala/org/apache/spark/ml/TreeUtils.scala b/src/main/scala/org/apache/spark/ml/TreeUtils.scala new file mode 100644 index 0000000..1bd3c12 --- /dev/null +++ b/src/main/scala/org/apache/spark/ml/TreeUtils.scala @@ -0,0 +1,42 @@ +package org.apache.spark.ml + +import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute} +import org.apache.spark.sql.DataFrame + +object TreeUtils { + /** + * Set label metadata (particularly the number of classes) on a DataFrame. + * + * @param data Dataset. Categorical features and labels must already have 0-based indices. + * This must be non-empty. + * @param labelColName Name of the label column on which to set the metadata. + * @param numClasses Number of classes label can take. If 0, mark as continuous. + * @param featuresColName Name of the features column + * @param featureArity Array of length numFeatures, where 0 indicates continuous feature and + * value > 0 indicates a categorical feature of that arity. + * @return DataFrame with metadata + */ + def setMetadata( + data: DataFrame, + labelColName: String, + numClasses: Int, + featuresColName: String, + featureArity: Array[Int]): DataFrame = { + val labelAttribute = if (numClasses == 0) { + NumericAttribute.defaultAttr.withName(labelColName) + } else { + NominalAttribute.defaultAttr.withName(labelColName).withNumValues(numClasses) + } + val labelMetadata = labelAttribute.toMetadata() + val featuresAttributes = featureArity.zipWithIndex.map { case (arity: Int, feature: Int) => + if (arity > 0) { + NominalAttribute.defaultAttr.withIndex(feature).withNumValues(arity) + } else { + NumericAttribute.defaultAttr.withIndex(feature) + } + } + val featuresMetadata = new AttributeGroup("features", featuresAttributes).toMetadata() + data.select(data(featuresColName).as(featuresColName, featuresMetadata), + data(labelColName).as(labelColName, labelMetadata)) + } +}