Fixed tree, forest, GBT tests by adding metadata to DataFrames
This commit is contained in:
parent
1fcc366cec
commit
51469a34d6
@ -1,8 +1,9 @@
|
||||
package com.databricks.spark.sql.perf.mllib.classification
|
||||
|
||||
import org.apache.spark.ml.{Estimator, ModelBuilder, Transformer}
|
||||
import org.apache.spark.ml.{Estimator, ModelBuilder, Transformer, TreeUtils}
|
||||
import org.apache.spark.ml.classification.DecisionTreeClassifier
|
||||
import org.apache.spark.ml.evaluation.{Evaluator, MulticlassClassificationEvaluator}
|
||||
import org.apache.spark.sql.DataFrame
|
||||
|
||||
import com.databricks.spark.sql.perf.mllib.OptionImplicits._
|
||||
import com.databricks.spark.sql.perf.mllib._
|
||||
@ -12,15 +13,19 @@ import com.databricks.spark.sql.perf.mllib.data.DataGenerator
|
||||
abstract class TreeOrForestClassification extends BenchmarkAlgorithm
|
||||
with TestFromTraining with TrainingSetFromTransformer with ScoringWithEvaluator {
|
||||
|
||||
import TreeOrForestClassification.getFeatureArity
|
||||
|
||||
override protected def initialData(ctx: MLBenchContext) = {
|
||||
import ctx.params._
|
||||
DataGenerator.generateMixedFeatures(ctx.sqlContext, numExamples, ctx.seed(), numPartitions,
|
||||
TreeOrForestClassification.getFeatureArity(ctx))
|
||||
val featureArity: Array[Int] = getFeatureArity(ctx)
|
||||
val data: DataFrame = DataGenerator.generateMixedFeatures(ctx.sqlContext, numExamples,
|
||||
ctx.seed(), numPartitions, featureArity)
|
||||
TreeUtils.setMetadata(data, "label", numClasses, "features", featureArity)
|
||||
}
|
||||
|
||||
override protected def trueModel(ctx: MLBenchContext): Transformer = {
|
||||
ModelBuilder.newDecisionTreeClassificationModel(ctx.params.depth, ctx.params.numClasses,
|
||||
TreeOrForestClassification.getFeatureArity(ctx), ctx.seed())
|
||||
getFeatureArity(ctx), ctx.seed())
|
||||
}
|
||||
|
||||
override protected def evaluator(ctx: MLBenchContext): Evaluator =
|
||||
@ -42,11 +47,12 @@ object TreeOrForestClassification {
|
||||
/**
|
||||
* Get feature arity for tree and tree ensemble tests.
|
||||
* Currently, this is hard-coded as:
|
||||
* - 1/2 binary features
|
||||
* - 1/2 high-arity (20-category) features
|
||||
* - 1/2 continuous features
|
||||
* @return Array of length numFeatures, where 0 indicates continuous feature and
|
||||
* value > 0 indicates a categorical feature of that arity.
|
||||
* - 1/2 binary features
|
||||
* - 1/2 high-arity (20-category) features
|
||||
* - 1/2 continuous features
|
||||
*
|
||||
* @return Array of length numFeatures, where 0 indicates continuous feature and
|
||||
* value > 0 indicates a categorical feature of that arity.
|
||||
*/
|
||||
def getFeatureArity(ctx: MLBenchContext): Array[Int] = {
|
||||
val numFeatures = ctx.params.numFeatures
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
package com.databricks.spark.sql.perf.mllib.classification
|
||||
|
||||
import org.apache.spark.ml.{Estimator, ModelBuilder, Transformer}
|
||||
import org.apache.spark.ml.{Estimator, ModelBuilder, Transformer, TreeUtils}
|
||||
import org.apache.spark.ml.classification.GBTClassifier
|
||||
import org.apache.spark.ml.evaluation.{Evaluator, MulticlassClassificationEvaluator}
|
||||
import org.apache.spark.sql._
|
||||
|
||||
import com.databricks.spark.sql.perf.mllib._
|
||||
import com.databricks.spark.sql.perf.mllib.OptionImplicits._
|
||||
@ -12,17 +13,22 @@ import com.databricks.spark.sql.perf.mllib.data.DataGenerator
|
||||
object GBTClassification extends BenchmarkAlgorithm
|
||||
with TestFromTraining with TrainingSetFromTransformer with ScoringWithEvaluator {
|
||||
|
||||
import TreeOrForestClassification.getFeatureArity
|
||||
|
||||
override protected def initialData(ctx: MLBenchContext) = {
|
||||
import ctx.params._
|
||||
DataGenerator.generateMixedFeatures(ctx.sqlContext, numExamples, ctx.seed(), numPartitions,
|
||||
TreeOrForestClassification.getFeatureArity(ctx))
|
||||
val featureArity: Array[Int] = getFeatureArity(ctx)
|
||||
val data: DataFrame = DataGenerator.generateMixedFeatures(ctx.sqlContext, numExamples,
|
||||
ctx.seed(), numPartitions, featureArity)
|
||||
TreeUtils.setMetadata(data, "label", numClasses, "features", featureArity)
|
||||
}
|
||||
|
||||
override protected def trueModel(ctx: MLBenchContext): Transformer = {
|
||||
import ctx.params._
|
||||
// We add +1 to the depth to make it more likely that many iterations of boosting are needed
|
||||
// to model the true tree.
|
||||
ModelBuilder.newDecisionTreeClassificationModel(ctx.params.depth + 1, ctx.params.numClasses,
|
||||
TreeOrForestClassification.getFeatureArity(ctx), ctx.seed())
|
||||
ModelBuilder.newDecisionTreeClassificationModel(depth + 1, numClasses, getFeatureArity(ctx),
|
||||
ctx.seed())
|
||||
}
|
||||
|
||||
override def getEstimator(ctx: MLBenchContext): Estimator[_] = {
|
||||
|
||||
42
src/main/scala/org/apache/spark/ml/TreeUtils.scala
Normal file
42
src/main/scala/org/apache/spark/ml/TreeUtils.scala
Normal file
@ -0,0 +1,42 @@
|
||||
package org.apache.spark.ml
|
||||
|
||||
import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute}
|
||||
import org.apache.spark.sql.DataFrame
|
||||
|
||||
object TreeUtils {
|
||||
/**
|
||||
* Set label metadata (particularly the number of classes) on a DataFrame.
|
||||
*
|
||||
* @param data Dataset. Categorical features and labels must already have 0-based indices.
|
||||
* This must be non-empty.
|
||||
* @param labelColName Name of the label column on which to set the metadata.
|
||||
* @param numClasses Number of classes label can take. If 0, mark as continuous.
|
||||
* @param featuresColName Name of the features column
|
||||
* @param featureArity Array of length numFeatures, where 0 indicates continuous feature and
|
||||
* value > 0 indicates a categorical feature of that arity.
|
||||
* @return DataFrame with metadata
|
||||
*/
|
||||
def setMetadata(
|
||||
data: DataFrame,
|
||||
labelColName: String,
|
||||
numClasses: Int,
|
||||
featuresColName: String,
|
||||
featureArity: Array[Int]): DataFrame = {
|
||||
val labelAttribute = if (numClasses == 0) {
|
||||
NumericAttribute.defaultAttr.withName(labelColName)
|
||||
} else {
|
||||
NominalAttribute.defaultAttr.withName(labelColName).withNumValues(numClasses)
|
||||
}
|
||||
val labelMetadata = labelAttribute.toMetadata()
|
||||
val featuresAttributes = featureArity.zipWithIndex.map { case (arity: Int, feature: Int) =>
|
||||
if (arity > 0) {
|
||||
NominalAttribute.defaultAttr.withIndex(feature).withNumValues(arity)
|
||||
} else {
|
||||
NumericAttribute.defaultAttr.withIndex(feature)
|
||||
}
|
||||
}
|
||||
val featuresMetadata = new AttributeGroup("features", featuresAttributes).toMetadata()
|
||||
data.select(data(featuresColName).as(featuresColName, featuresMetadata),
|
||||
data(labelColName).as(labelColName, labelMetadata))
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user