Removes labels from tree data generation (#82)
* changes * removes labels * reset scala version * adding metadata * bumping spark release
This commit is contained in:
parent
685c50d9dc
commit
53091a1935
@ -14,7 +14,7 @@ sparkPackageName := "databricks/spark-sql-perf"
|
||||
// All Spark Packages need a license
|
||||
licenses := Seq("Apache-2.0" -> url("http://opensource.org/licenses/Apache-2.0"))
|
||||
|
||||
sparkVersion := "2.0.0"
|
||||
sparkVersion := "2.0.1"
|
||||
|
||||
sparkComponents ++= Seq("sql", "hive", "mllib")
|
||||
|
||||
|
||||
@ -2,6 +2,7 @@ package com.databricks.spark.sql.perf.mllib
|
||||
|
||||
import com.typesafe.scalalogging.slf4j.{LazyLogging => Logging}
|
||||
|
||||
import org.apache.spark.ml.attribute.{NominalAttribute, NumericAttribute}
|
||||
import org.apache.spark.ml.{Estimator, Transformer}
|
||||
import org.apache.spark.ml.evaluation.Evaluator
|
||||
import org.apache.spark.sql._
|
||||
@ -76,7 +77,22 @@ trait TrainingSetFromTransformer {
|
||||
final override def trainingDataSet(ctx: MLBenchContext): DataFrame = {
|
||||
val initial = initialData(ctx)
|
||||
val model = trueModel(ctx)
|
||||
model.transform(initial).select(col("features"), col("prediction").as("label"))
|
||||
val fCol = col("features")
|
||||
// Special case for the trees: we need to set the number of labels.
|
||||
// numClasses is set? We will add the number of classes to the final column.
|
||||
val lCol = ctx.params.numClasses match {
|
||||
case Some(numClasses) =>
|
||||
val labelAttribute = if (numClasses == 0) {
|
||||
NumericAttribute.defaultAttr.withName("label")
|
||||
} else {
|
||||
NominalAttribute.defaultAttr.withName("label").withNumValues(numClasses)
|
||||
}
|
||||
val labelMetadata = labelAttribute.toMetadata()
|
||||
col("prediction").as("label", labelMetadata)
|
||||
case None =>
|
||||
col("prediction").as("label")
|
||||
}
|
||||
model.transform(initial).select(fCol, lCol)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -20,7 +20,7 @@ abstract class TreeOrForestClassification extends BenchmarkAlgorithm
|
||||
val featureArity: Array[Int] = getFeatureArity(ctx)
|
||||
val data: DataFrame = DataGenerator.generateMixedFeatures(ctx.sqlContext, numExamples,
|
||||
ctx.seed(), numPartitions, featureArity)
|
||||
TreeUtils.setMetadata(data, "label", numClasses, "features", featureArity)
|
||||
TreeUtils.setMetadata(data, "features", featureArity)
|
||||
}
|
||||
|
||||
override protected def trueModel(ctx: MLBenchContext): Transformer = {
|
||||
|
||||
@ -20,7 +20,7 @@ object GBTClassification extends BenchmarkAlgorithm
|
||||
val featureArity: Array[Int] = getFeatureArity(ctx)
|
||||
val data: DataFrame = DataGenerator.generateMixedFeatures(ctx.sqlContext, numExamples,
|
||||
ctx.seed(), numPartitions, featureArity)
|
||||
TreeUtils.setMetadata(data, "label", numClasses, "features", featureArity)
|
||||
TreeUtils.setMetadata(data, "features", featureArity)
|
||||
}
|
||||
|
||||
override protected def trueModel(ctx: MLBenchContext): Transformer = {
|
||||
|
||||
@ -60,7 +60,7 @@ benchmarks:
|
||||
numExamples: 100
|
||||
numTestExamples: 10
|
||||
depth: 3
|
||||
numClasses: 4
|
||||
numClasses: 2
|
||||
numFeatures: 5
|
||||
maxIter: 3
|
||||
- name: regression.LinearRegression
|
||||
|
||||
@ -9,8 +9,6 @@ object TreeUtils {
|
||||
*
|
||||
* @param data Dataset. Categorical features and labels must already have 0-based indices.
|
||||
* This must be non-empty.
|
||||
* @param labelColName Name of the label column on which to set the metadata.
|
||||
* @param numClasses Number of classes label can take. If 0, mark as continuous.
|
||||
* @param featuresColName Name of the features column
|
||||
* @param featureArity Array of length numFeatures, where 0 indicates continuous feature and
|
||||
* value > 0 indicates a categorical feature of that arity.
|
||||
@ -18,16 +16,8 @@ object TreeUtils {
|
||||
*/
|
||||
def setMetadata(
|
||||
data: DataFrame,
|
||||
labelColName: String,
|
||||
numClasses: Int,
|
||||
featuresColName: String,
|
||||
featureArity: Array[Int]): DataFrame = {
|
||||
val labelAttribute = if (numClasses == 0) {
|
||||
NumericAttribute.defaultAttr.withName(labelColName)
|
||||
} else {
|
||||
NominalAttribute.defaultAttr.withName(labelColName).withNumValues(numClasses)
|
||||
}
|
||||
val labelMetadata = labelAttribute.toMetadata()
|
||||
val featuresAttributes = featureArity.zipWithIndex.map { case (arity: Int, feature: Int) =>
|
||||
if (arity > 0) {
|
||||
NominalAttribute.defaultAttr.withIndex(feature).withNumValues(arity)
|
||||
@ -36,7 +26,6 @@ object TreeUtils {
|
||||
}
|
||||
}
|
||||
val featuresMetadata = new AttributeGroup("features", featuresAttributes).toMetadata()
|
||||
data.select(data(featuresColName).as(featuresColName, featuresMetadata),
|
||||
data(labelColName).as(labelColName, labelMetadata))
|
||||
data.select(data(featuresColName).as(featuresColName, featuresMetadata))
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user