From e9ef9788c2094aeb40c0f7d883b8c1cb0f852b74 Mon Sep 17 00:00:00 2001 From: ludatabricks <38018689+ludatabricks@users.noreply.github.com> Date: Wed, 27 Jun 2018 09:17:38 -0700 Subject: [PATCH] [ML-3844] Add GBTRegression benchmark (#156) * add GBTRegression benchmark * add GBTRegression benchmark --- .../sql/perf/mllib/config/mllib-large.yaml | 5 +++++ .../sql/perf/mllib/config/mllib-small.yaml | 6 ++++++ .../perf/mllib/regression/GBTRegression.scala | 18 ++++++++++++++++++ 3 files changed, 29 insertions(+) create mode 100644 src/main/scala/com/databricks/spark/sql/perf/mllib/regression/GBTRegression.scala diff --git a/src/main/resources/com/databricks/spark/sql/perf/mllib/config/mllib-large.yaml b/src/main/resources/com/databricks/spark/sql/perf/mllib/config/mllib-large.yaml index 1f4c3ee..3c40539 100644 --- a/src/main/resources/com/databricks/spark/sql/perf/mllib/config/mllib-large.yaml +++ b/src/main/resources/com/databricks/spark/sql/perf/mllib/config/mllib-large.yaml @@ -111,6 +111,11 @@ benchmarks: - name: regression.DecisionTreeRegression params: depth: [5, 10] + - name: regression.GBTRegression + params: + numFeatures: 2000 + depth: 5 + maxIter: 5 - name: regression.GLMRegression params: numExamples: 500000 diff --git a/src/main/resources/com/databricks/spark/sql/perf/mllib/config/mllib-small.yaml b/src/main/resources/com/databricks/spark/sql/perf/mllib/config/mllib-small.yaml index 8de1235..153d471 100644 --- a/src/main/resources/com/databricks/spark/sql/perf/mllib/config/mllib-small.yaml +++ b/src/main/resources/com/databricks/spark/sql/perf/mllib/config/mllib-small.yaml @@ -142,6 +142,12 @@ benchmarks: depth: 3 numClasses: 4 numFeatures: 5 + - name: regression.GBTRegression + params: + numExamples: 100 + numTestExamples: 10 + depth: 3 + maxIter: 3 - name: regression.GLMRegression params: numExamples: 100 diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/regression/GBTRegression.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/regression/GBTRegression.scala new file mode 100644 index 0000000..e78d2eb --- /dev/null +++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/regression/GBTRegression.scala @@ -0,0 +1,18 @@ +package com.databricks.spark.sql.perf.mllib.regression + +import org.apache.spark.ml.PipelineStage +import org.apache.spark.ml.regression.GBTRegressor + +import com.databricks.spark.sql.perf.mllib.OptionImplicits._ +import com.databricks.spark.sql.perf.mllib.{BenchmarkAlgorithm, MLBenchContext, + TreeOrForestRegressor} + +object GBTRegression extends BenchmarkAlgorithm with TreeOrForestRegressor { + override def getPipelineStage(ctx: MLBenchContext): PipelineStage = { + import ctx.params._ + new GBTRegressor() + .setMaxDepth(depth) + .setMaxIter(maxIter) + .setSeed(ctx.seed()) + } +} \ No newline at end of file