diff --git a/build.sbt b/build.sbt index 609635d..332bc7c 100644 --- a/build.sbt +++ b/build.sbt @@ -14,7 +14,7 @@ sparkPackageName := "databricks/spark-sql-perf" // All Spark Packages need a license licenses := Seq("Apache-2.0" -> url("http://opensource.org/licenses/Apache-2.0")) -sparkVersion := "2.0.0" +sparkVersion := "2.0.1" sparkComponents ++= Seq("sql", "hive", "mllib") diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/BenchmarkAlgorithm.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/BenchmarkAlgorithm.scala index a17e193..858f911 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/mllib/BenchmarkAlgorithm.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/BenchmarkAlgorithm.scala @@ -2,6 +2,7 @@ package com.databricks.spark.sql.perf.mllib import com.typesafe.scalalogging.slf4j.{LazyLogging => Logging} +import org.apache.spark.ml.attribute.{NominalAttribute, NumericAttribute} import org.apache.spark.ml.{Estimator, Transformer} import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.sql._ @@ -76,7 +77,22 @@ trait TrainingSetFromTransformer { final override def trainingDataSet(ctx: MLBenchContext): DataFrame = { val initial = initialData(ctx) val model = trueModel(ctx) - model.transform(initial).select(col("features"), col("prediction").as("label")) + val fCol = col("features") + // Special case for the trees: we need to set the number of labels. + // numClasses is set? We will add the number of classes to the final column. + val lCol = ctx.params.numClasses match { + case Some(numClasses) => + val labelAttribute = if (numClasses == 0) { + NumericAttribute.defaultAttr.withName("label") + } else { + NominalAttribute.defaultAttr.withName("label").withNumValues(numClasses) + } + val labelMetadata = labelAttribute.toMetadata() + col("prediction").as("label", labelMetadata) + case None => + col("prediction").as("label") + } + model.transform(initial).select(fCol, lCol) } } diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/DecisionTreeClassification.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/DecisionTreeClassification.scala index 45ce7f8..47cf4c4 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/DecisionTreeClassification.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/DecisionTreeClassification.scala @@ -20,7 +20,7 @@ abstract class TreeOrForestClassification extends BenchmarkAlgorithm val featureArity: Array[Int] = getFeatureArity(ctx) val data: DataFrame = DataGenerator.generateMixedFeatures(ctx.sqlContext, numExamples, ctx.seed(), numPartitions, featureArity) - TreeUtils.setMetadata(data, "label", numClasses, "features", featureArity) + TreeUtils.setMetadata(data, "features", featureArity) } override protected def trueModel(ctx: MLBenchContext): Transformer = { diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/GBTClassification.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/GBTClassification.scala index dfd172d..547a050 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/GBTClassification.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/GBTClassification.scala @@ -20,7 +20,7 @@ object GBTClassification extends BenchmarkAlgorithm val featureArity: Array[Int] = getFeatureArity(ctx) val data: DataFrame = DataGenerator.generateMixedFeatures(ctx.sqlContext, numExamples, ctx.seed(), numPartitions, featureArity) - TreeUtils.setMetadata(data, "label", numClasses, "features", featureArity) + TreeUtils.setMetadata(data, "features", featureArity) } override protected def trueModel(ctx: MLBenchContext): Transformer = { diff --git a/src/main/scala/configs/mllib-small.yaml b/src/main/scala/configs/mllib-small.yaml index 6beeedd..3e574b4 100644 --- a/src/main/scala/configs/mllib-small.yaml +++ b/src/main/scala/configs/mllib-small.yaml @@ -60,7 +60,7 @@ benchmarks: numExamples: 100 numTestExamples: 10 depth: 3 - numClasses: 4 + numClasses: 2 numFeatures: 5 maxIter: 3 - name: regression.LinearRegression diff --git a/src/main/scala/org/apache/spark/ml/TreeUtils.scala b/src/main/scala/org/apache/spark/ml/TreeUtils.scala index 1bd3c12..badef4f 100644 --- a/src/main/scala/org/apache/spark/ml/TreeUtils.scala +++ b/src/main/scala/org/apache/spark/ml/TreeUtils.scala @@ -9,8 +9,6 @@ object TreeUtils { * * @param data Dataset. Categorical features and labels must already have 0-based indices. * This must be non-empty. - * @param labelColName Name of the label column on which to set the metadata. - * @param numClasses Number of classes label can take. If 0, mark as continuous. * @param featuresColName Name of the features column * @param featureArity Array of length numFeatures, where 0 indicates continuous feature and * value > 0 indicates a categorical feature of that arity. @@ -18,16 +16,8 @@ object TreeUtils { */ def setMetadata( data: DataFrame, - labelColName: String, - numClasses: Int, featuresColName: String, featureArity: Array[Int]): DataFrame = { - val labelAttribute = if (numClasses == 0) { - NumericAttribute.defaultAttr.withName(labelColName) - } else { - NominalAttribute.defaultAttr.withName(labelColName).withNumValues(numClasses) - } - val labelMetadata = labelAttribute.toMetadata() val featuresAttributes = featureArity.zipWithIndex.map { case (arity: Int, feature: Int) => if (arity > 0) { NominalAttribute.defaultAttr.withIndex(feature).withNumValues(arity) @@ -36,7 +26,6 @@ object TreeUtils { } } val featuresMetadata = new AttributeGroup("features", featuresAttributes).toMetadata() - data.select(data(featuresColName).as(featuresColName, featuresMetadata), - data(labelColName).as(labelColName, labelMetadata)) + data.select(data(featuresColName).as(featuresColName, featuresMetadata)) } }