Compare commits

...

10 Commits

Author SHA1 Message Date
Your Name
304fdaf81a 增加单表参数 2026-01-07 12:47:26 +08:00
Frank Luan
28d88190f6
Update Spark repository for sbt (#206)
https://dl.bintray.com/spark-packages/maven is a dead link now. According to https://github.com/databricks/sbt-spark-package/issues/50, this is the new link: https://repos.spark-packages.org
2021-12-13 19:20:13 +09:00
Yuming Wang
ca4ccea3dd
Add a convenient class to generate TPC-DS data (#196)
How to use it:
```
build/sbt "test:runMain com.databricks.spark.sql.perf.tpcds.GenTPCDSData -d /root/tmp/tpcds-kit/tools -s 5 -l /root/tmp/tpcds5g -f parquet"
```

```
[root@spark-3267648 spark-sql-perf]# build/sbt "test:runMain com.databricks.spark.sql.perf.tpcds.GenTPCDSData --help"
[info] Running com.databricks.spark.sql.perf.tpcds.GenTPCDSData --help
[info] Usage: Gen-TPC-DS-data [options]
[info]
[info]   -m, --master <value>     the Spark master to use, default to local[*]
[info]   -d, --dsdgenDir <value>  location of dsdgen
[info]   -s, --scaleFactor <value>
[info]                            scaleFactor defines the size of the dataset to generate (in GB)
[info]   -l, --location <value>   root directory of location to create data in
[info]   -f, --format <value>     valid spark format, Parquet, ORC ...
[info]   -i, --useDoubleForDecimal <value>
[info]                            true to replace DecimalType with DoubleType
[info]   -e, --useStringForDate <value>
[info]                            true to replace DateType with StringType
[info]   -o, --overwrite <value>  overwrite the data that is already there
[info]   -p, --partitionTables <value>
[info]                            create the partitioned fact tables
[info]   -c, --clusterByPartitionColumns <value>
[info]                            shuffle to get partitions coalesced into single files
[info]   -v, --filterOutNullPartitionValues <value>
[info]                            true to filter out the partition with NULL key value
[info]   -t, --tableFilter <value>
[info]                            "" means generate all tables
[info]   -n, --numPartitions <value>
[info]                            how many dsdgen partitions to run - number of input tasks.
[info]   --help                   prints this usage text
```
2021-03-30 21:19:36 +09:00
Yuming Wang
65785a8a04
Fix Travis CI JDK installation (#195)
* Replace oraclejdk8 with openjdk8
* Update .travis.yml
2021-01-28 17:28:46 +01:00
Nico Poggi
d85f75bb38
Update for Spark 3.0.0 compatibility (#191)
* Updating the build file to spark 3.0.0 and scala 2.12.10
* Fixing incompatibilities
* Adding default parameters to newer required functions
* Removing HiveTest
2020-11-03 15:27:34 +01:00
Guo Chenzhao
6b2bf9f9ad Fix files truncating according to maxRecordPerFile (#180)
* Fix files truncating according to maxRecordPerFile

* toDouble
2019-05-29 23:20:01 +08:00
Nico Poggi
3f92a094cc
Bumping version to 0.5.1-SNAPSHOT (spark 3, scala 2.12, log4j ) (#168) 2019-01-29 10:00:54 +01:00
Luca Canali
e1e1365a87 Updates for Spark 3.0 and Scala 2.12 compatibility (#176)
* Refactor deprecated `getOrCreate()` in spark 3
* Compile with scala 2.12
* Updated usage related to obsolete/deprecated features
* remove use of scala-logging replaced by using slf4j directly
2019-01-29 09:58:52 +01:00
Bago Amirbekian
85bbfd4ca2 [ML-5437] Build with spark-2.4.0 and resolve build issues (#174)
We made some changes to related to new APIs in spark2.4. These APIs were reverted because they were breaking changes so we need to revert our changes.
2018-11-09 16:21:22 -08:00
Nico Poggi
d44caec277
Revert "Update Scala Logging to officially supported one " (#172)
Reverts #157 due to library errors when the previous was is in the classpath already (i.e., in databricks) and not bringing any noted improvements or needed fixes. Exception:
java.lang.InstantiationError: com.typesafe.scalalogging.Logger
This reverts commit 56f7348.
2018-10-19 17:33:34 +02:00
23 changed files with 208 additions and 130 deletions

View File

@ -1,11 +1,12 @@
language: scala
scala:
- 2.11.8
- 2.12.10
sudo: false
dist: trusty
jdk:
oraclejdk8
cache:
directories:
- $HOME/.ivy2
env:
- DBC_USERNAME="" DBC_PASSWORD="" DBC_URL=""
- DBC_USERNAME="" DBC_PASSWORD="" DBC_URL=""

View File

@ -34,7 +34,8 @@ The first run of `bin/run` will build the library.
## Build
Use `sbt package` or `sbt assembly` to build the library jar.
Use `sbt package` or `sbt assembly` to build the library jar.
Use `sbt +package` to build for scala 2.11 and 2.12.
## Local performance tests
The framework contains twelve benchmarks that can be executed in local mode. They are organized into three classes and target different components and functions of Spark:
@ -66,31 +67,11 @@ TPCDS kit needs to be installed on all cluster executor nodes under the same pat
It can be found [here](https://github.com/databricks/tpcds-kit).
```
import com.databricks.spark.sql.perf.tpcds.TPCDSTables
// Set:
val rootDir = ... // root directory of location to create data in.
val databaseName = ... // name of database to create.
val scaleFactor = ... // scaleFactor defines the size of the dataset to generate (in GB).
val format = ... // valid spark format like parquet "parquet".
// Run:
val tables = new TPCDSTables(sqlContext,
dsdgenDir = "/tmp/tpcds-kit/tools", // location of dsdgen
scaleFactor = scaleFactor,
useDoubleForDecimal = false, // true to replace DecimalType with DoubleType
useStringForDate = false) // true to replace DateType with StringType
tables.genData(
location = rootDir,
format = format,
overwrite = true, // overwrite the data that is already there
partitionTables = true, // create the partitioned fact tables
clusterByPartitionColumns = true, // shuffle to get partitions coalesced into single files.
filterOutNullPartitionValues = false, // true to filter out the partition with NULL key value
tableFilter = "", // "" means generate all tables
numPartitions = 100) // how many dsdgen partitions to run - number of input tasks.
// Generate the data
build/sbt "test:runMain com.databricks.spark.sql.perf.tpcds.GenTPCDSData -d <dsdgenDir> -s <scaleFactor> -l <location> -f <format>"
```
```
// Create the specified database
sql(s"create database $databaseName")
// Create metastore tables in a specified database for your data.

View File

@ -5,16 +5,16 @@ name := "spark-sql-perf"
organization := "com.databricks"
scalaVersion := "2.11.8"
scalaVersion := "2.12.10"
crossScalaVersions := Seq("2.11.8")
crossScalaVersions := Seq("2.12.10")
sparkPackageName := "databricks/spark-sql-perf"
// All Spark Packages need a license
licenses := Seq("Apache-2.0" -> url("http://opensource.org/licenses/Apache-2.0"))
sparkVersion := "2.4.0-SNAPSHOT"
sparkVersion := "3.0.0"
sparkComponents ++= Seq("sql", "hive", "mllib")
@ -32,19 +32,13 @@ initialCommands in console :=
|import sqlContext.implicits._
""".stripMargin
libraryDependencies += "org.slf4j" % "slf4j-api" % "1.7.5"
libraryDependencies += "com.github.scopt" %% "scopt" % "3.7.1"
libraryDependencies += "com.github.scopt" %% "scopt" % "3.3.0"
libraryDependencies += "com.twitter" %% "util-jvm" % "6.45.0" % "provided"
libraryDependencies += "com.twitter" %% "util-jvm" % "6.23.0" % "provided"
libraryDependencies += "org.scalatest" %% "scalatest" % "3.0.5" % "test"
libraryDependencies += "org.scalatest" %% "scalatest" % "2.2.1" % "test"
libraryDependencies += "org.yaml" % "snakeyaml" % "1.17"
libraryDependencies += "com.typesafe.scala-logging" %% "scala-logging" % "3.9.0"
resolvers += "Apache Development Snapshot Repository" at "https://repository.apache.org/content/repositories/snapshots"
libraryDependencies += "org.yaml" % "snakeyaml" % "1.23"
fork := true

View File

@ -1,2 +1,2 @@
// This file should only contain the version of sbt to use.
sbt.version=0.13.8
sbt.version=0.13.18

View File

@ -1,6 +1,6 @@
// You may use this file to add plugin dependencies for sbt.
resolvers += "Spark Packages repo" at "https://dl.bintray.com/spark-packages/maven/"
resolvers += "Spark Packages repo" at "https://repos.spark-packages.org/"
resolvers += "sonatype-releases" at "https://oss.sonatype.org/content/repositories/releases/"
@ -14,4 +14,4 @@ addSbtPlugin("com.databricks" %% "sbt-databricks" % "0.1.3")
addSbtPlugin("me.lessis" % "bintray-sbt" % "0.3.0")
addSbtPlugin("com.jsuereth" % "sbt-pgp" % "1.0.0")
addSbtPlugin("com.jsuereth" % "sbt-pgp" % "1.0.0")

View File

@ -25,7 +25,7 @@ import scala.util.{Success, Try, Failure => SFailure}
import scala.util.control.NonFatal
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Dataset, DataFrame, SQLContext}
import org.apache.spark.sql.{Dataset, DataFrame, SQLContext, SparkSession}
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.SparkContext
@ -42,7 +42,7 @@ abstract class Benchmark(
import Benchmark._
def this() = this(SQLContext.getOrCreate(SparkContext.getOrCreate()))
def this() = this(SparkSession.builder.getOrCreate().sqlContext)
val resultsLocation =
sqlContext.getAllConfs.getOrElse(
@ -335,7 +335,7 @@ object Benchmark {
.flatMap { query =>
try {
query.newDataFrame().queryExecution.logical.collect {
case UnresolvedRelation(t) => t.table
case r: UnresolvedRelation => r.tableName
}
} catch {
// ignore the queries that can't be parsed
@ -476,14 +476,14 @@ object Benchmark {
/** Returns results from an actively running experiment. */
def getCurrentResults() = {
val tbl = sqlContext.createDataFrame(currentResults)
tbl.registerTempTable("currentResults")
tbl.createOrReplaceTempView("currentResults")
tbl
}
/** Returns full iterations from an actively running experiment. */
def getCurrentRuns() = {
val tbl = sqlContext.createDataFrame(currentRuns)
tbl.registerTempTable("currentRuns")
tbl.createOrReplaceTempView("currentRuns")
tbl
}

View File

@ -18,23 +18,25 @@ package com.databricks.spark.sql.perf
import java.util.UUID
import com.typesafe.scalalogging.{LazyLogging => Logging}
import org.slf4j.LoggerFactory
import scala.concurrent.duration._
import scala.collection.mutable.ArrayBuffer
import scala.util.control.NonFatal
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.{SQLContext,SparkSession}
import org.apache.spark.{SparkEnv, SparkContext}
/** A trait to describe things that can be benchmarked. */
trait Benchmarkable extends Logging {
@transient protected[this] val sqlContext = SQLContext.getOrCreate(SparkContext.getOrCreate())
@transient protected[this] val sparkContext = sqlContext.sparkContext
trait Benchmarkable {
@transient protected[this] val sqlSession = SparkSession.builder.getOrCreate()
@transient protected[this] val sqlContext = sqlSession.sqlContext
@transient protected[this] val sparkContext = sqlSession.sparkContext
val name: String
protected val executionMode: ExecutionMode
lazy val logger = LoggerFactory.getLogger(this.getClass.getName)
final def benchmark(
includeBreakdown: Boolean,
@ -75,7 +77,7 @@ trait Benchmarkable extends Logging {
try {
result = doBenchmark(includeBreakdown, description, messages)
} catch {
case NonFatal(e) =>
case e: Throwable =>
logger.info(s"$that: failure in runBenchmark: $e")
println(s"$that: failure in runBenchmark: $e")
result = BenchmarkResult(

View File

@ -104,7 +104,7 @@ package object cpu {
}
val counts = cpuLogs.groupBy($"stack").agg(count($"*")).collect().flatMap {
case Row(stackLines: Seq[String], count: Long) => stackLines.map(toStackElement) -> count :: Nil
case Row(stackLines: Array[String], count: Long) => stackLines.toSeq.map(toStackElement) -> count :: Nil
case other => println(s"Failed to parse $other"); Nil
}.toMap
val profile = new com.twitter.jvm.CpuProfile(counts, com.twitter.util.Duration.fromSeconds(10), cpuLogs.count().toInt, 0)

View File

@ -54,10 +54,7 @@ class Query(
}
lazy val tablesInvolved = buildDataFrame.queryExecution.logical collect {
case UnresolvedRelation(tableIdentifier) => {
// We are ignoring the database name.
tableIdentifier.table
}
case r: UnresolvedRelation => r.tableName
}
def newDataFrame() = buildDataFrame
@ -88,10 +85,11 @@ class Query(
val physicalOperators = (0 until depth).map(i => (i, queryExecution.executedPlan.p(i)))
val indexMap = physicalOperators.map { case (index, op) => (op, index) }.toMap
val timeMap = new mutable.HashMap[Int, Double]
val maxFields = 999 // Maximum number of fields that will be converted to strings
physicalOperators.reverse.map {
case (index, node) =>
messages += s"Breakdown: ${node.simpleString}"
messages += s"Breakdown: ${node.simpleString(maxFields)}"
val newNode = buildDataFrame.queryExecution.executedPlan.p(index)
val executionTime = measureTimeMs {
newNode.execute().foreach((row: Any) => Unit)
@ -104,7 +102,7 @@ class Query(
BreakdownResult(
node.nodeName,
node.simpleString.replaceAll("#\\d+", ""),
node.simpleString(maxFields).replaceAll("#\\d+", ""),
index,
childIndexes,
executionTime,
@ -122,7 +120,7 @@ class Query(
val executionTime = measureTimeMs {
executionMode match {
case ExecutionMode.CollectResults => dataFrame.collect()
case ExecutionMode.ForeachResults => dataFrame.foreach { row => Unit }
case ExecutionMode.ForeachResults => dataFrame.foreach { _ => ():Unit }
case ExecutionMode.WriteParquet(location) =>
dataFrame.write.parquet(s"$location/$name.parquet")
case ExecutionMode.HashResults =>

View File

@ -18,7 +18,7 @@ package com.databricks.spark.sql.perf
import java.net.InetAddress
import java.io.File
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.{SQLContext, SparkSession}
import org.apache.spark.sql.functions._
import org.apache.spark.{SparkContext, SparkConf}
import scala.util.Try
@ -70,8 +70,9 @@ object RunBenchmark {
.setMaster(config.master)
.setAppName(getClass.getName)
val sc = SparkContext.getOrCreate(conf)
val sqlContext = SQLContext.getOrCreate(sc)
val sparkSession = SparkSession.builder.config(conf).getOrCreate()
val sc = sparkSession.sparkContext
val sqlContext = sparkSession.sqlContext
import sqlContext.implicits._
sqlContext.setConf("spark.sql.perf.results",

View File

@ -222,7 +222,7 @@ abstract class Tables(sqlContext: SQLContext, scaleFactor: String,
log.info(s"Data has $numRows rows clustered $clusterByPartitionColumns for $maxRecordPerFile")
if (maxRecordPerFile > 0 && numRows > maxRecordPerFile) {
val numFiles = ((numRows)/maxRecordPerFile).ceil.toInt
val numFiles = (numRows.toDouble/maxRecordPerFile).ceil.toInt
println(s"Coalescing into $numFiles files")
log.info(s"Coalescing into $numFiles files")
data.coalesce(numFiles).write

View File

@ -1,6 +1,5 @@
package com.databricks.spark.sql.perf.mllib
import com.typesafe.scalalogging.{LazyLogging => Logging}
import org.apache.spark.ml.attribute.{NominalAttribute, NumericAttribute}
import org.apache.spark.ml.{Estimator, PipelineStage, Transformer}
import org.apache.spark.ml.evaluation.Evaluator
@ -21,7 +20,7 @@ import com.databricks.spark.sql.perf._
*
* It is assumed that the implementation is going to be an object.
*/
trait BenchmarkAlgorithm extends Logging {
trait BenchmarkAlgorithm {
def trainingDataSet(ctx: MLBenchContext): DataFrame

View File

@ -2,7 +2,7 @@ package com.databricks.spark.sql.perf.mllib
import com.databricks.spark.sql.perf.mllib.classification.LogisticRegression
import org.apache.spark.SparkContext
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.{SQLContext,SparkSession}
import com.databricks.spark.sql.perf.{MLParams}
import OptionImplicits._
@ -27,8 +27,9 @@ object MLBenchmarks {
)
)
val context = SparkContext.getOrCreate()
val sqlContext: SQLContext = SQLContext.getOrCreate(context)
val sparkSession = SparkSession.builder.getOrCreate()
val sqlContext: SQLContext = sparkSession.sqlContext
val context = sqlContext.sparkContext
def benchmarkObjects: Seq[MLPipelineStageBenchmarkable] = benchmarks.map { mlb =>
new MLPipelineStageBenchmarkable(mlb.params, mlb.benchmark, sqlContext)

View File

@ -4,7 +4,7 @@ package com.databricks.spark.sql.perf.mllib
import scala.io.Source
import scala.language.implicitConversions
import com.typesafe.scalalogging.{LazyLogging => Logging}
import org.slf4j.LoggerFactory
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.{SparkConf, SparkContext}
@ -18,7 +18,7 @@ class MLLib(sqlContext: SQLContext)
def this() = this(SQLContext.getOrCreate(SparkContext.getOrCreate()))
}
object MLLib extends Logging {
object MLLib {
/**
* Runs a set of preprogrammed experiments and blocks on completion.
@ -26,6 +26,9 @@ object MLLib extends Logging {
* @param runConfig a configuration that is av
* @return
*/
lazy val logger = LoggerFactory.getLogger(this.getClass.getName)
def runDefault(runConfig: RunConfig): DataFrame = {
val ml = new MLLib()
val benchmarks = MLBenchmarks.benchmarkObjects

View File

@ -3,8 +3,6 @@ package com.databricks.spark.sql.perf.mllib
import scala.collection.mutable.ArrayBuffer
import scala.util.control.NonFatal
import com.typesafe.scalalogging.{LazyLogging => Logging}
import org.apache.spark.ml.{Estimator, Transformer}
import org.apache.spark.sql._
import org.apache.spark.{SparkContext, SparkEnv}
@ -15,7 +13,7 @@ class MLPipelineStageBenchmarkable(
params: MLParams,
test: BenchmarkAlgorithm,
sqlContext: SQLContext)
extends Benchmarkable with Serializable with Logging {
extends Benchmarkable with Serializable {
import MLPipelineStageBenchmarkable._

View File

@ -51,6 +51,7 @@ object NaiveBayes extends BenchmarkAlgorithm
// Initialize new Naive Bayes model
val pi = Vectors.dense(piArray)
val theta = new DenseMatrix(numClasses, numFeatures, thetaArray.flatten, true)
ModelBuilderSSP.newNaiveBayesModel(pi, theta)
}

View File

@ -0,0 +1,121 @@
/*
* Copyright 2015 Databricks Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.databricks.spark.sql.perf.tpcds
import org.apache.spark.sql.SparkSession
case class GenTPCDSDataConfig(
master: String = "local[*]",
dsdgenDir: String = null,
scaleFactor: String = null,
location: String = null,
format: String = null,
useDoubleForDecimal: Boolean = false,
useStringForDate: Boolean = false,
overwrite: Boolean = false,
partitionTables: Boolean = true,
clusterByPartitionColumns: Boolean = true,
filterOutNullPartitionValues: Boolean = true,
tableFilter: String = "",
numPartitions: Int = 100)
/**
* Gen TPCDS data.
* To run this:
* {{{
* build/sbt "test:runMain <this class> -d <dsdgenDir> -s <scaleFactor> -l <location> -f <format>"
* }}}
*/
object GenTPCDSData {
def main(args: Array[String]): Unit = {
val parser = new scopt.OptionParser[GenTPCDSDataConfig]("Gen-TPC-DS-data") {
opt[String]('m', "master")
.action { (x, c) => c.copy(master = x) }
.text("the Spark master to use, default to local[*]")
opt[String]('d', "dsdgenDir")
.action { (x, c) => c.copy(dsdgenDir = x) }
.text("location of dsdgen")
.required()
opt[String]('s', "scaleFactor")
.action((x, c) => c.copy(scaleFactor = x))
.text("scaleFactor defines the size of the dataset to generate (in GB)")
opt[String]('l', "location")
.action((x, c) => c.copy(location = x))
.text("root directory of location to create data in")
opt[String]('f', "format")
.action((x, c) => c.copy(format = x))
.text("valid spark format, Parquet, ORC ...")
opt[Boolean]('i', "useDoubleForDecimal")
.action((x, c) => c.copy(useDoubleForDecimal = x))
.text("true to replace DecimalType with DoubleType")
opt[Boolean]('e', "useStringForDate")
.action((x, c) => c.copy(useStringForDate = x))
.text("true to replace DateType with StringType")
opt[Boolean]('o', "overwrite")
.action((x, c) => c.copy(overwrite = x))
.text("overwrite the data that is already there")
opt[Boolean]('p', "partitionTables")
.action((x, c) => c.copy(partitionTables = x))
.text("create the partitioned fact tables")
opt[Boolean]('c', "clusterByPartitionColumns")
.action((x, c) => c.copy(clusterByPartitionColumns = x))
.text("shuffle to get partitions coalesced into single files")
opt[Boolean]('v', "filterOutNullPartitionValues")
.action((x, c) => c.copy(filterOutNullPartitionValues = x))
.text("true to filter out the partition with NULL key value")
opt[String]('t', "tableFilter")
.action((x, c) => c.copy(tableFilter = x))
.text("\"\" means generate all tables")
opt[Int]('n', "numPartitions")
.action((x, c) => c.copy(numPartitions = x))
.text("how many dsdgen partitions to run - number of input tasks.")
help("help")
.text("prints this usage text")
}
parser.parse(args, GenTPCDSDataConfig()) match {
case Some(config) =>
run(config)
case None =>
System.exit(1)
}
}
private def run(config: GenTPCDSDataConfig) {
val spark = SparkSession
.builder()
.appName(getClass.getName)
.master(config.master)
.getOrCreate()
val tables = new TPCDSTables(spark.sqlContext,
dsdgenDir = config.dsdgenDir,
scaleFactor = config.scaleFactor,
useDoubleForDecimal = config.useDoubleForDecimal,
useStringForDate = config.useStringForDate)
tables.genData(
location = config.location,
format = config.format,
overwrite = config.overwrite,
partitionTables = config.partitionTables,
clusterByPartitionColumns = config.clusterByPartitionColumns,
filterOutNullPartitionValues = config.filterOutNullPartitionValues,
tableFilter = config.tableFilter,
numPartitions = config.numPartitions)
}
}

View File

@ -17,10 +17,9 @@
package com.databricks.spark.sql.perf.tpcds
import scala.collection.mutable
import com.databricks.spark.sql.perf._
import org.apache.spark.SparkContext
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.{SQLContext, SparkSession}
/**
* TPC-DS benchmark's dataset.
@ -35,7 +34,7 @@ class TPCDS(@transient sqlContext: SQLContext)
with Tpcds_2_4_Queries
with Serializable {
def this() = this(SQLContext.getOrCreate(SparkContext.getOrCreate()))
def this() = this(SparkSession.builder.getOrCreate().sqlContext)
/*
def setupBroadcast(skipTables: Seq[String] = Seq("store_sales", "customer")) = {

View File

@ -539,4 +539,6 @@ class TPCDSTables(
'web_gmt_offset .decimal(5,2),
'web_tax_percentage .decimal(5,2))
).map(_.convertTypes())
}

View File

@ -45,7 +45,6 @@ object ModelBuilderSSP {
s" but was given $numClasses")
val rootNode = TreeBuilder.randomBalancedDecisionTree(depth = depth, labelType = numClasses,
featureArity = featureArity, seed = seed)
.asInstanceOf[ClassificationNode]
new DecisionTreeClassificationModel(rootNode, numFeatures = featureArity.length,
numClasses = numClasses)
}
@ -56,12 +55,11 @@ object ModelBuilderSSP {
seed: Long): DecisionTreeRegressionModel = {
val rootNode = TreeBuilder.randomBalancedDecisionTree(depth = depth, labelType = 0,
featureArity = featureArity, seed = seed)
.asInstanceOf[RegressionNode]
new DecisionTreeRegressionModel(rootNode, numFeatures = featureArity.length)
}
def newNaiveBayesModel(pi: Vector, theta: Matrix): NaiveBayesModel = {
val model = new NaiveBayesModel("naivebayes-uid", pi, theta)
val model = new NaiveBayesModel("naivebayes-uid", pi, theta, null)
model.set(model.modelType, "multinomial")
}
@ -162,50 +160,17 @@ object TreeBuilder {
labelGenerator.setSeed(rng.nextLong)
// We use a dummy impurityCalculator for all nodes.
val impurityCalculator = if (isRegression) {
ImpurityCalculator.getCalculator("variance", Array.fill[Double](3)(0.0))
ImpurityCalculator.getCalculator("variance", Array.fill[Double](3)(0.0), 0L)
} else {
ImpurityCalculator.getCalculator("gini", Array.fill[Double](labelType)(0.0))
ImpurityCalculator.getCalculator("gini", Array.fill[Double](labelType)(0.0), 0L)
}
randomBalancedDecisionTreeHelper(isRegression, depth, featureArity, impurityCalculator,
randomBalancedDecisionTreeHelper(depth, featureArity, impurityCalculator,
labelGenerator, Set.empty, rng)
}
private def createLeafNode(
isRegression: Boolean,
prediction: Double,
impurity: Double,
impurityStats: ImpurityCalculator): LeafNode = {
if (isRegression) {
new RegressionLeafNode(prediction, impurity, impurityStats)
} else {
new ClassificationLeafNode(prediction, impurity, impurityStats)
}
}
private def createInternalNode(
isRegression: Boolean,
prediction: Double,
impurity: Double,
gain: Double,
leftChild: Node,
rightChild: Node,
split: Split,
impurityStats: ImpurityCalculator): InternalNode = {
if (isRegression) {
new RegressionInternalNode(prediction, impurity, gain,
leftChild.asInstanceOf[RegressionNode], rightChild.asInstanceOf[RegressionNode],
split, impurityStats)
} else {
new ClassificationInternalNode(prediction, impurity, gain,
leftChild.asInstanceOf[ClassificationNode], rightChild.asInstanceOf[ClassificationNode],
split, impurityStats)
}
}
/**
* Create an internal node. Either create the leaf nodes beneath it, or recurse as needed.
* @param isRegression Whether the tree is a regressor or not (classifier)
* @param subtreeDepth Depth of subtree to build. Depth 0 means this is a leaf node.
* @param featureArity Indicates feature type. Value 0 indicates continuous feature.
* Other values >= 2 indicate a categorical feature,
@ -217,7 +182,6 @@ object TreeBuilder {
* @return
*/
private def randomBalancedDecisionTreeHelper(
isRegression: Boolean,
subtreeDepth: Int,
featureArity: Array[Int],
impurityCalculator: ImpurityCalculator,
@ -227,7 +191,7 @@ object TreeBuilder {
if (subtreeDepth == 0) {
// This case only happens for a depth 0 tree.
createLeafNode(isRegression, prediction = 0.0, impurity = 0.0, impurityCalculator)
return new LeafNode(prediction = 0.0, impurity = 0.0, impurityStats = impurityCalculator)
}
val numFeatures = featureArity.length
@ -257,20 +221,19 @@ object TreeBuilder {
val (leftChild: Node, rightChild: Node) = if (subtreeDepth == 1) {
// Add leaf nodes. Assign these jointly so they make different predictions.
val predictions = labelGenerator.nextValue()
val leftChild = createLeafNode(isRegression, prediction = predictions._1, impurity = 0.0,
val leftChild = new LeafNode(prediction = predictions._1, impurity = 0.0,
impurityStats = impurityCalculator)
val rightChild = createLeafNode(isRegression, prediction = predictions._2, impurity = 0.0,
val rightChild = new LeafNode(prediction = predictions._2, impurity = 0.0,
impurityStats = impurityCalculator)
(leftChild, rightChild)
} else {
val leftChild = randomBalancedDecisionTreeHelper(isRegression, subtreeDepth - 1, featureArity,
val leftChild = randomBalancedDecisionTreeHelper(subtreeDepth - 1, featureArity,
impurityCalculator, labelGenerator, usedFeatures + feature, rng)
val rightChild = randomBalancedDecisionTreeHelper(isRegression, subtreeDepth - 1, featureArity,
val rightChild = randomBalancedDecisionTreeHelper(subtreeDepth - 1, featureArity,
impurityCalculator, labelGenerator, usedFeatures + feature, rng)
(leftChild, rightChild)
}
createInternalNode(isRegression, prediction = 0.0, impurity = 0.0, gain = 0.0,
leftChild = leftChild, rightChild = rightChild, split = split,
impurityStats = impurityCalculator)
new InternalNode(prediction = 0.0, impurity = 0.0, gain = 0.0, leftChild = leftChild,
rightChild = rightChild, split = split, impurityStats = impurityCalculator)
}
}

View File

@ -1,11 +1,9 @@
package com.databricks.spark.sql.perf
import org.apache.spark.sql.hive.test.TestHive
import org.scalatest.FunSuite
class DatasetPerformanceSuite extends FunSuite {
ignore("run benchmark") {
TestHive // Init HiveContext
val benchmark = new DatasetPerformance() {
override val numLongs = 100
}

View File

@ -2,18 +2,34 @@ package com.databricks.spark.sql.perf.mllib
import org.scalatest.{BeforeAndAfterAll, FunSuite}
import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.{Row, SparkSession}
class MLLibSuite extends FunSuite with BeforeAndAfterAll {
private var sparkSession: SparkSession = _
var savedLevels: Map[String, Level] = _
override def beforeAll(): Unit = {
super.beforeAll()
sparkSession = SparkSession.builder.master("local[2]").appName("MLlib QA").getOrCreate()
// Travis limits the size of the log file produced by a build. Because we do run a small
// version of all the ML benchmarks in this suite, we produce a ton of logs. Here we set the
// log level to ERROR, just for this suite, to avoid displeasing travis.
savedLevels = Seq("akka", "org", "com.databricks").map { name =>
val logger = Logger.getLogger(name)
val curLevel = logger.getLevel
logger.setLevel(Level.ERROR)
name -> curLevel
}.toMap
}
override def afterAll(): Unit = {
savedLevels.foreach { case (name, level) =>
Logger.getLogger(name).setLevel(level)
}
try {
if (sparkSession != null) {
sparkSession.stop()

View File

@ -1 +1 @@
version in ThisBuild := "0.5.0-SNAPSHOT"
version in ThisBuild := "0.5.1-SNAPSHOT"