[ML-3844] Add GBTRegression benchmark (#156)
* add GBTRegression benchmark * add GBTRegression benchmark
This commit is contained in:
parent
e8aa132bb8
commit
e9ef9788c2
@ -111,6 +111,11 @@ benchmarks:
|
||||
- name: regression.DecisionTreeRegression
|
||||
params:
|
||||
depth: [5, 10]
|
||||
- name: regression.GBTRegression
|
||||
params:
|
||||
numFeatures: 2000
|
||||
depth: 5
|
||||
maxIter: 5
|
||||
- name: regression.GLMRegression
|
||||
params:
|
||||
numExamples: 500000
|
||||
|
||||
@ -142,6 +142,12 @@ benchmarks:
|
||||
depth: 3
|
||||
numClasses: 4
|
||||
numFeatures: 5
|
||||
- name: regression.GBTRegression
|
||||
params:
|
||||
numExamples: 100
|
||||
numTestExamples: 10
|
||||
depth: 3
|
||||
maxIter: 3
|
||||
- name: regression.GLMRegression
|
||||
params:
|
||||
numExamples: 100
|
||||
|
||||
@ -0,0 +1,18 @@
|
||||
package com.databricks.spark.sql.perf.mllib.regression
|
||||
|
||||
import org.apache.spark.ml.PipelineStage
|
||||
import org.apache.spark.ml.regression.GBTRegressor
|
||||
|
||||
import com.databricks.spark.sql.perf.mllib.OptionImplicits._
|
||||
import com.databricks.spark.sql.perf.mllib.{BenchmarkAlgorithm, MLBenchContext,
|
||||
TreeOrForestRegressor}
|
||||
|
||||
object GBTRegression extends BenchmarkAlgorithm with TreeOrForestRegressor {
|
||||
override def getPipelineStage(ctx: MLBenchContext): PipelineStage = {
|
||||
import ctx.params._
|
||||
new GBTRegressor()
|
||||
.setMaxDepth(depth)
|
||||
.setMaxIter(maxIter)
|
||||
.setSeed(ctx.seed())
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user