Merge pull request #75 from jkbradley/kmeans

Added kmeans test
This commit is contained in:
Timothy Hunter 2016-07-05 10:14:11 -07:00 committed by GitHub
commit 979ebd5d0f
3 changed files with 96 additions and 0 deletions

View File

@ -0,0 +1,29 @@
package com.databricks.spark.sql.perf.mllib.clustering
import org.apache.spark.ml
import org.apache.spark.ml.Estimator
import org.apache.spark.sql._
import com.databricks.spark.sql.perf.mllib.OptionImplicits._
import com.databricks.spark.sql.perf.mllib.data.DataGenerator
import com.databricks.spark.sql.perf.mllib.{BenchmarkAlgorithm, MLBenchContext, TestFromTraining}
object KMeans extends BenchmarkAlgorithm with TestFromTraining {
override def trainingDataSet(ctx: MLBenchContext): DataFrame = {
import ctx.params._
DataGenerator.generateGaussianMixtureData(ctx.sqlContext, k, numExamples, ctx.seed(),
numPartitions, numFeatures)
}
override def getEstimator(ctx: MLBenchContext): Estimator[_] = {
import ctx.params._
new ml.clustering.KMeans()
.setK(k)
.setSeed(randomSeed.toLong)
.setMaxIter(maxIter)
}
// TODO(?) add a scoring method here.
}

View File

@ -36,6 +36,23 @@ object DataGenerator {
new FeaturesGenerator(featureArity), numExamples, numPartitions, seed)
sql.createDataFrame(rdd.map(Tuple1.apply)).toDF("features")
}
/**
* Generate data from a Gaussian mixture model.
* @param numCenters Number of clusters in mixture
*/
def generateGaussianMixtureData(
sql: SQLContext,
numCenters: Int,
numExamples: Long,
seed: Long,
numPartitions: Int,
numFeatures: Int): DataFrame = {
val rdd: RDD[Vector] = RandomRDDs.randomRDD(sql.sparkContext,
new GaussianMixtureDataGenerator(numCenters, numFeatures, seed), numExamples, numPartitions,
seed)
sql.createDataFrame(rdd.map(Tuple1.apply)).toDF("features")
}
}
@ -78,3 +95,46 @@ class FeaturesGenerator(val featureArity: Array[Int])
override def copy(): FeaturesGenerator = new FeaturesGenerator(featureArity)
}
/**
* Generate data from a Gaussian mixture model.
*/
class GaussianMixtureDataGenerator(
val numCenters: Int,
val numFeatures: Int,
val seed: Long) extends RandomDataGenerator[Vector] {
private val rng = new java.util.Random(seed)
private val rng2 = new java.util.Random(seed + 24)
private val scale_factors = Array.fill(numCenters)(rng.nextInt(20) - 10)
// Have a random number of points around a cluster
private val concentrations: Seq[Double] = {
val rand = Array.fill(numCenters)(rng.nextDouble())
val randSum = rand.sum
val scaled = rand.map(x => x / randSum)
(1 to numCenters).map{i =>
scaled.slice(0, i).sum
}
}
private val centers = (0 until numCenters).map{i =>
Array.fill(numFeatures)((2 * rng.nextDouble() - 1) * scale_factors(i))
}
override def nextValue(): Vector = {
val pick_center_rand = rng2.nextDouble()
val center = centers(concentrations.indexWhere(p => pick_center_rand <= p))
Vectors.dense(Array.tabulate(numFeatures)(i => center(i) + rng2.nextGaussian()))
}
override def setSeed(seed: Long) {
rng.setSeed(seed)
rng2.setSeed(seed + 24)
}
override def copy(): GaussianMixtureDataGenerator =
new GaussianMixtureDataGenerator(numCenters, numFeatures, seed)
}

View File

@ -12,6 +12,7 @@ benchmarks:
numFeatures: 100
regParam: 0.1
tol: [0.2, 0.1]
maxIter: 10
- name: clustering.LDA
params:
numExamples: 10
@ -23,6 +24,12 @@ benchmarks:
optimizer:
- em
- online
- name: clustering.KMeans
params:
numExamples: 10
numTestExamples: 10
k: 5
maxIter: 10
- name: regression.GLMRegression
params:
numExamples: 100