parent
a8acd53fdd
commit
5af9f6dfc2
@ -0,0 +1,34 @@
|
||||
package com.databricks.spark.sql.perf.mllib.feature
|
||||
|
||||
import org.apache.spark.ml
|
||||
import org.apache.spark.ml.PipelineStage
|
||||
import org.apache.spark.sql.DataFrame
|
||||
import org.apache.spark.sql.functions.{col, split}
|
||||
|
||||
import com.databricks.spark.sql.perf.mllib.{BenchmarkAlgorithm, MLBenchContext, TestFromTraining}
|
||||
import com.databricks.spark.sql.perf.mllib.OptionImplicits._
|
||||
import com.databricks.spark.sql.perf.mllib.data.DataGenerator
|
||||
|
||||
/** Object for testing Word2Vec performance */
|
||||
object Word2Vec extends BenchmarkAlgorithm with TestFromTraining {
|
||||
|
||||
override def trainingDataSet(ctx: MLBenchContext): DataFrame = {
|
||||
import ctx.params._
|
||||
|
||||
val df = DataGenerator.generateDoc(
|
||||
ctx.sqlContext,
|
||||
numExamples,
|
||||
ctx.seed(),
|
||||
numPartitions,
|
||||
vocabSize,
|
||||
docLength,
|
||||
"text"
|
||||
)
|
||||
df.select(split(col("text"), " ").as("text"))
|
||||
}
|
||||
|
||||
override def getPipelineStage(ctx: MLBenchContext): PipelineStage = {
|
||||
new ml.feature.Word2Vec().setInputCol("text")
|
||||
}
|
||||
|
||||
}
|
||||
@ -115,6 +115,11 @@ benchmarks:
|
||||
params:
|
||||
numExamples: 100
|
||||
numFeatures: 10
|
||||
- name: feature.Word2Vec
|
||||
params:
|
||||
numExamples: 100
|
||||
vocabSize: 100
|
||||
docLength: 10
|
||||
- name: recommendation.ALS
|
||||
params:
|
||||
numExamples: 100
|
||||
|
||||
Loading…
Reference in New Issue
Block a user