parent
15d9283473
commit
3786a8391e
@ -97,6 +97,11 @@ benchmarks:
|
||||
params:
|
||||
numExamples: 100
|
||||
featureArity: 10
|
||||
- name: feature.QuantileDiscretizer
|
||||
params:
|
||||
numExamples: 100
|
||||
bucketizerNumBuckets: 2
|
||||
relativeError: 0.001
|
||||
- name: feature.StringIndexer
|
||||
params:
|
||||
numExamples: 100
|
||||
|
||||
@ -0,0 +1,37 @@
|
||||
package com.databricks.spark.sql.perf.mllib.feature
|
||||
|
||||
import org.apache.spark.ml
|
||||
import org.apache.spark.ml.linalg.Vector
|
||||
import org.apache.spark.ml.PipelineStage
|
||||
import org.apache.spark.sql._
|
||||
|
||||
import com.databricks.spark.sql.perf.mllib.OptionImplicits._
|
||||
import com.databricks.spark.sql.perf.mllib.data.DataGenerator
|
||||
import com.databricks.spark.sql.perf.mllib.{BenchmarkAlgorithm, MLBenchContext, TestFromTraining}
|
||||
|
||||
/** Object for testing QuantileDiscretizer performance */
|
||||
object QuantileDiscretizer extends BenchmarkAlgorithm with TestFromTraining with UnaryTransformer {
|
||||
|
||||
override def trainingDataSet(ctx: MLBenchContext): DataFrame = {
|
||||
import ctx.params._
|
||||
import ctx.sqlContext.implicits._
|
||||
|
||||
DataGenerator.generateContinuousFeatures(
|
||||
ctx.sqlContext,
|
||||
numExamples,
|
||||
ctx.seed(),
|
||||
numPartitions,
|
||||
1
|
||||
).rdd.map { case Row(vec: Vector) =>
|
||||
vec(0) // extract the single generated double value for each row
|
||||
}.toDF(inputCol)
|
||||
}
|
||||
|
||||
override def getPipelineStage(ctx: MLBenchContext): PipelineStage = {
|
||||
import ctx.params._
|
||||
new ml.feature.QuantileDiscretizer()
|
||||
.setInputCol(inputCol)
|
||||
.setNumBuckets(bucketizerNumBuckets)
|
||||
.setRelativeError(relativeError)
|
||||
}
|
||||
}
|
||||
@ -137,6 +137,7 @@ class MLParams(
|
||||
val numUsers: Option[Int] = None,
|
||||
val optimizer: Option[String] = None,
|
||||
val regParam: Option[Double] = None,
|
||||
val relativeError: Option[Double] = Some(0.001),
|
||||
val rank: Option[Int] = None,
|
||||
val smoothing: Option[Double] = None,
|
||||
val tol: Option[Double] = None,
|
||||
@ -178,21 +179,42 @@ class MLParams(
|
||||
numInputCols: Option[Int] = numInputCols,
|
||||
numItems: Option[Int] = numItems,
|
||||
numUsers: Option[Int] = numUsers,
|
||||
vocabSize: Option[Int] = vocabSize,
|
||||
optimizer: Option[String] = optimizer,
|
||||
regParam: Option[Double] = regParam,
|
||||
relativeError: Option[Double] = relativeError,
|
||||
rank: Option[Int] = rank,
|
||||
smoothing: Option[Double] = smoothing,
|
||||
tol: Option[Double] = tol): MLParams = {
|
||||
new MLParams(randomSeed = randomSeed, numExamples = numExamples,
|
||||
numTestExamples = numTestExamples, numPartitions = numPartitions,
|
||||
bucketizerNumBuckets = bucketizerNumBuckets, depth = depth, docLength = docLength,
|
||||
elasticNetParam = elasticNetParam, family = family, featureArity = featureArity,
|
||||
itemSetSize = itemSetSize, k = k, link = link, maxIter = maxIter,
|
||||
numClasses = numClasses, numFeatures = numFeatures, numHashTables = numHashTables,
|
||||
numInputCols = numInputCols, numItems = numItems, numSynonymsToFind = numSynonymsToFind,
|
||||
numUsers = numUsers, optimizer = optimizer, regParam = regParam,
|
||||
rank = rank, smoothing = smoothing, tol = tol, vocabSize = vocabSize)
|
||||
tol: Option[Double] = tol,
|
||||
vocabSize: Option[Int] = vocabSize): MLParams = {
|
||||
new MLParams(
|
||||
randomSeed = randomSeed,
|
||||
numExamples = numExamples,
|
||||
numTestExamples = numTestExamples,
|
||||
numPartitions = numPartitions,
|
||||
bucketizerNumBuckets = bucketizerNumBuckets,
|
||||
depth = depth,
|
||||
docLength = docLength,
|
||||
elasticNetParam = elasticNetParam,
|
||||
family = family,
|
||||
featureArity = featureArity,
|
||||
itemSetSize = itemSetSize,
|
||||
k = k,
|
||||
link = link,
|
||||
maxIter = maxIter,
|
||||
numClasses = numClasses,
|
||||
numFeatures = numFeatures,
|
||||
numHashTables = numHashTables,
|
||||
numSynonymsToFind = numSynonymsToFind,
|
||||
numInputCols = numInputCols,
|
||||
numItems = numItems,
|
||||
numUsers = numUsers,
|
||||
optimizer = optimizer,
|
||||
regParam = regParam,
|
||||
relativeError = relativeError,
|
||||
rank = rank,
|
||||
smoothing = smoothing,
|
||||
tol = tol,
|
||||
vocabSize = vocabSize)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user