diff --git a/src/main/resources/com/databricks/spark/sql/perf/mllib/config/mllib-large.yaml b/src/main/resources/com/databricks/spark/sql/perf/mllib/config/mllib-large.yaml index 1f4c3ee..3c40539 100644 --- a/src/main/resources/com/databricks/spark/sql/perf/mllib/config/mllib-large.yaml +++ b/src/main/resources/com/databricks/spark/sql/perf/mllib/config/mllib-large.yaml @@ -111,6 +111,11 @@ benchmarks: - name: regression.DecisionTreeRegression params: depth: [5, 10] + - name: regression.GBTRegression + params: + numFeatures: 2000 + depth: 5 + maxIter: 5 - name: regression.GLMRegression params: numExamples: 500000 diff --git a/src/main/resources/com/databricks/spark/sql/perf/mllib/config/mllib-small.yaml b/src/main/resources/com/databricks/spark/sql/perf/mllib/config/mllib-small.yaml index 8de1235..153d471 100644 --- a/src/main/resources/com/databricks/spark/sql/perf/mllib/config/mllib-small.yaml +++ b/src/main/resources/com/databricks/spark/sql/perf/mllib/config/mllib-small.yaml @@ -142,6 +142,12 @@ benchmarks: depth: 3 numClasses: 4 numFeatures: 5 + - name: regression.GBTRegression + params: + numExamples: 100 + numTestExamples: 10 + depth: 3 + maxIter: 3 - name: regression.GLMRegression params: numExamples: 100 diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/regression/GBTRegression.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/regression/GBTRegression.scala new file mode 100644 index 0000000..e78d2eb --- /dev/null +++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/regression/GBTRegression.scala @@ -0,0 +1,18 @@ +package com.databricks.spark.sql.perf.mllib.regression + +import org.apache.spark.ml.PipelineStage +import org.apache.spark.ml.regression.GBTRegressor + +import com.databricks.spark.sql.perf.mllib.OptionImplicits._ +import com.databricks.spark.sql.perf.mllib.{BenchmarkAlgorithm, MLBenchContext, + TreeOrForestRegressor} + +object GBTRegression extends BenchmarkAlgorithm with TreeOrForestRegressor { + override def getPipelineStage(ctx: MLBenchContext): PipelineStage = { + import ctx.params._ + new GBTRegressor() + .setMaxDepth(depth) + .setMaxIter(maxIter) + .setSeed(ctx.seed()) + } +} \ No newline at end of file