added kmeans test
This commit is contained in:
parent
3d3443791c
commit
9d11a601c3
@ -0,0 +1,29 @@
|
||||
package com.databricks.spark.sql.perf.mllib.clustering
|
||||
|
||||
import org.apache.spark.ml
|
||||
import org.apache.spark.ml.Estimator
|
||||
import org.apache.spark.sql._
|
||||
|
||||
import com.databricks.spark.sql.perf.mllib.OptionImplicits._
|
||||
import com.databricks.spark.sql.perf.mllib.data.DataGenerator
|
||||
import com.databricks.spark.sql.perf.mllib.{BenchmarkAlgorithm, MLBenchContext, TestFromTraining}
|
||||
|
||||
|
||||
object KMeans extends BenchmarkAlgorithm with TestFromTraining {
|
||||
|
||||
override def trainingDataSet(ctx: MLBenchContext): DataFrame = {
|
||||
import ctx.params._
|
||||
DataGenerator.generateGaussianMixtureData(ctx.sqlContext, k, numExamples, ctx.seed(),
|
||||
numPartitions, numFeatures)
|
||||
}
|
||||
|
||||
override def getEstimator(ctx: MLBenchContext): Estimator[_] = {
|
||||
import ctx.params._
|
||||
new ml.clustering.KMeans()
|
||||
.setK(k)
|
||||
.setSeed(randomSeed.toLong)
|
||||
.setMaxIter(maxIter)
|
||||
}
|
||||
|
||||
// TODO(?) add a scoring method here.
|
||||
}
|
||||
@ -36,6 +36,23 @@ object DataGenerator {
|
||||
new FeaturesGenerator(featureArity), numExamples, numPartitions, seed)
|
||||
sql.createDataFrame(rdd.map(Tuple1.apply)).toDF("features")
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate data from a Gaussian mixture model.
|
||||
* @param numCenters Number of clusters in mixture
|
||||
*/
|
||||
def generateGaussianMixtureData(
|
||||
sql: SQLContext,
|
||||
numCenters: Int,
|
||||
numExamples: Long,
|
||||
seed: Long,
|
||||
numPartitions: Int,
|
||||
numFeatures: Int): DataFrame = {
|
||||
val rdd: RDD[Vector] = RandomRDDs.randomRDD(sql.sparkContext,
|
||||
new GaussianMixtureDataGenerator(numCenters, numFeatures, seed), numExamples, numPartitions,
|
||||
seed)
|
||||
sql.createDataFrame(rdd.map(Tuple1.apply)).toDF("features")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -78,3 +95,46 @@ class FeaturesGenerator(val featureArity: Array[Int])
|
||||
|
||||
override def copy(): FeaturesGenerator = new FeaturesGenerator(featureArity)
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Generate data from a Gaussian mixture model.
|
||||
*/
|
||||
class GaussianMixtureDataGenerator(
|
||||
val numCenters: Int,
|
||||
val numFeatures: Int,
|
||||
val seed: Long) extends RandomDataGenerator[Vector] {
|
||||
|
||||
private val rng = new java.util.Random(seed)
|
||||
private val rng2 = new java.util.Random(seed + 24)
|
||||
private val scale_factors = Array.fill(numCenters)(rng.nextInt(20) - 10)
|
||||
|
||||
// Have a random number of points around a cluster
|
||||
private val concentrations: Seq[Double] = {
|
||||
val rand = Array.fill(numCenters)(rng.nextDouble())
|
||||
val randSum = rand.sum
|
||||
val scaled = rand.map(x => x / randSum)
|
||||
|
||||
(1 to numCenters).map{i =>
|
||||
scaled.slice(0, i).sum
|
||||
}
|
||||
}
|
||||
|
||||
private val centers = (0 until numCenters).map{i =>
|
||||
Array.fill(numFeatures)((2 * rng.nextDouble() - 1) * scale_factors(i))
|
||||
}
|
||||
|
||||
override def nextValue(): Vector = {
|
||||
val pick_center_rand = rng2.nextDouble()
|
||||
val center = centers(concentrations.indexWhere(p => pick_center_rand <= p))
|
||||
Vectors.dense(Array.tabulate(numFeatures)(i => center(i) + rng2.nextGaussian()))
|
||||
}
|
||||
|
||||
override def setSeed(seed: Long) {
|
||||
rng.setSeed(seed)
|
||||
rng2.setSeed(seed + 24)
|
||||
}
|
||||
|
||||
override def copy(): GaussianMixtureDataGenerator =
|
||||
new GaussianMixtureDataGenerator(numCenters, numFeatures, seed)
|
||||
}
|
||||
|
||||
@ -12,6 +12,7 @@ benchmarks:
|
||||
numFeatures: 100
|
||||
regParam: 0.1
|
||||
tol: [0.2, 0.1]
|
||||
maxIter: 10
|
||||
- name: clustering.LDA
|
||||
params:
|
||||
numExamples: 10
|
||||
@ -23,6 +24,12 @@ benchmarks:
|
||||
optimizer:
|
||||
- em
|
||||
- online
|
||||
- name: clustering.KMeans
|
||||
params:
|
||||
numExamples: 10
|
||||
numTestExamples: 10
|
||||
k: 5
|
||||
maxIter: 10
|
||||
- name: regression.GLMRegression
|
||||
params:
|
||||
numExamples: 100
|
||||
|
||||
Loading…
Reference in New Issue
Block a user