From be4459fe417229457eb64b6e35c325519f961b81 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Thu, 3 May 2018 04:45:58 +0800 Subject: [PATCH] Additional method test for some ML algos (#139) Add additional method test for some ML algos. In this PR, I add `associationRules` in `FPGrowth` and `findSynonyms`. After the design is accepted, I will add other methods later. Add an interface in `BenchmarkableAlgorithm`: ``` def testAdditionalMethods(ctx: MLBenchContext, model: Transformer): Map[String, () => _] ``` --- .../sql/perf/mllib/BenchmarkAlgorithm.scala | 12 ++++++++++- .../mllib/MLPipelineStageBenchmarkable.scala | 8 ++++++- .../sql/perf/mllib/feature/Word2Vec.scala | 21 ++++++++++++++++++- .../spark/sql/perf/mllib/fpm/FPGrowth.scala | 13 +++++++++++- .../databricks/spark/sql/perf/results.scala | 10 ++++++--- src/main/scala/configs/mllib-small.yaml | 1 + 6 files changed, 58 insertions(+), 7 deletions(-) diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/BenchmarkAlgorithm.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/BenchmarkAlgorithm.scala index 8b15f1a..dd48b9d 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/mllib/BenchmarkAlgorithm.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/BenchmarkAlgorithm.scala @@ -1,7 +1,6 @@ package com.databricks.spark.sql.perf.mllib import com.typesafe.scalalogging.slf4j.{LazyLogging => Logging} - import org.apache.spark.ml.attribute.{NominalAttribute, NumericAttribute} import org.apache.spark.ml.{Estimator, PipelineStage, Transformer} import org.apache.spark.ml.evaluation.Evaluator @@ -44,6 +43,17 @@ trait BenchmarkAlgorithm extends Logging { def name: String = { this.getClass.getCanonicalName.replace("$", "") } + + /** + * Test additional methods for some algorithms. + * + * @param transformer The transformer which includes additional methods. + * @return A map which key is the additional method name, and value is a function which runs + * the corresponding method. + */ + def testAdditionalMethods( + ctx: MLBenchContext, + transformer: Transformer): Map[String, () => _] = null } /** diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/MLPipelineStageBenchmarkable.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/MLPipelineStageBenchmarkable.scala index a1712f0..d598186 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/mllib/MLPipelineStageBenchmarkable.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/MLPipelineStageBenchmarkable.scala @@ -80,12 +80,18 @@ class MLPipelineStageBenchmarkable( s" s, Scored training dataset in ${scoreTrainTime.toMillis / 1000.0} s," + s" test dataset in ${scoreTestTime.toMillis / 1000.0} s") + val additionalTests = test.testAdditionalMethods(param, model).map { + tuple => + val (additionalMethodTime, _) = measureTime { tuple._2() } + tuple._1 -> additionalMethodTime.toMillis.toDouble + } val ml = MLResult( trainingTime = Some(trainingTime.toMillis), trainingMetric = Some(scoreTraining), testTime = Some(scoreTestTime.toMillis), - testMetric = Some(scoreTest / testDataCount.get)) + testMetric = Some(scoreTest / testDataCount.get), + additionalTests = additionalTests) BenchmarkResult( name = name, diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/feature/Word2Vec.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/feature/Word2Vec.scala index ca30dcf..3e1b995 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/mllib/feature/Word2Vec.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/feature/Word2Vec.scala @@ -1,10 +1,15 @@ package com.databricks.spark.sql.perf.mllib.feature +import scala.util.Random + import org.apache.spark.ml -import org.apache.spark.ml.PipelineStage +import org.apache.spark.ml.{PipelineStage, Transformer} +import org.apache.spark.ml.feature.Word2VecModel +import org.apache.spark.ml.linalg.Vectors import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions.{col, split} +import com.databricks.spark.sql.perf.MLResult import com.databricks.spark.sql.perf.mllib.{BenchmarkAlgorithm, MLBenchContext, TestFromTraining} import com.databricks.spark.sql.perf.mllib.OptionImplicits._ import com.databricks.spark.sql.perf.mllib.data.DataGenerator @@ -31,4 +36,18 @@ object Word2Vec extends BenchmarkAlgorithm with TestFromTraining { new ml.feature.Word2Vec().setInputCol("text") } + override def testAdditionalMethods( + ctx: MLBenchContext, + model: Transformer): Map[String, () => _] = { + import ctx.params._ + + val rng = new Random(ctx.seed()) + val word2vecModel = model.asInstanceOf[Word2VecModel] + val testWord = Vectors.dense(Array.fill(word2vecModel.getVectorSize)(rng.nextGaussian())) + + Map("findSynonyms" -> (() => { + word2vecModel.findSynonyms(testWord, numSynonymsToFind) + })) + } + } diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/fpm/FPGrowth.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/fpm/FPGrowth.scala index 35cf8a7..691bf5b 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/mllib/fpm/FPGrowth.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/fpm/FPGrowth.scala @@ -1,7 +1,8 @@ package com.databricks.spark.sql.perf.mllib.fpm import org.apache.spark.ml -import org.apache.spark.ml.PipelineStage +import org.apache.spark.ml.{PipelineStage, Transformer} +import org.apache.spark.ml.fpm.FPGrowthModel import org.apache.spark.sql.DataFrame import com.databricks.spark.sql.perf.mllib._ @@ -28,4 +29,14 @@ object FPGrowth extends BenchmarkAlgorithm with TestFromTraining { new ml.fpm.FPGrowth() .setItemsCol("items") } + + override def testAdditionalMethods( + ctx: MLBenchContext, + model: Transformer): Map[String, () => _] = { + + val fpModel = model.asInstanceOf[FPGrowthModel] + Map("associationRules" -> (() => { + fpModel.associationRules.count() + })) + } } diff --git a/src/main/scala/com/databricks/spark/sql/perf/results.scala b/src/main/scala/com/databricks/spark/sql/perf/results.scala index ab0fb24..487d7a6 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/results.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/results.scala @@ -131,6 +131,7 @@ class MLParams( val numClasses: Option[Int] = None, val numFeatures: Option[Int] = None, val numHashTables: Option[Int] = Some(1), + val numSynonymsToFind: Option[Int] = None, val numInputCols: Option[Int] = None, val numItems: Option[Int] = None, val numUsers: Option[Int] = None, @@ -173,6 +174,7 @@ class MLParams( numClasses: Option[Int] = numClasses, numFeatures: Option[Int] = numFeatures, numHashTables: Option[Int] = numHashTables, + numSynonymsToFind: Option[Int] = numSynonymsToFind, numInputCols: Option[Int] = numInputCols, numItems: Option[Int] = numItems, numUsers: Option[Int] = numUsers, @@ -188,8 +190,8 @@ class MLParams( elasticNetParam = elasticNetParam, family = family, featureArity = featureArity, itemSetSize = itemSetSize, k = k, link = link, maxIter = maxIter, numClasses = numClasses, numFeatures = numFeatures, numHashTables = numHashTables, - numInputCols = numInputCols, numItems = numItems, numUsers = numUsers, - optimizer = optimizer, regParam = regParam, + numInputCols = numInputCols, numItems = numItems, numSynonymsToFind = numSynonymsToFind, + numUsers = numUsers, optimizer = optimizer, regParam = regParam, rank = rank, smoothing = smoothing, tol = tol, vocabSize = vocabSize) } } @@ -207,9 +209,11 @@ object MLParams { * @param testTime (MLlib) Test time (for prediction on test set, or on training set if there * is no test set). * @param testMetric (MLlib) Test metric, such as accuracy + * @param additionalTests (MLlib) Additional methods test results. */ case class MLResult( trainingTime: Option[Double] = None, trainingMetric: Option[Double] = None, testTime: Option[Double] = None, - testMetric: Option[Double] = None) + testMetric: Option[Double] = None, + additionalTests: Map[String, Double]) diff --git a/src/main/scala/configs/mllib-small.yaml b/src/main/scala/configs/mllib-small.yaml index b0392d5..efbef73 100644 --- a/src/main/scala/configs/mllib-small.yaml +++ b/src/main/scala/configs/mllib-small.yaml @@ -120,6 +120,7 @@ benchmarks: numExamples: 100 vocabSize: 100 docLength: 10 + numSynonymsToFind: 3 - name: recommendation.ALS params: numExamples: 100