Update for Spark 3.0.0 compatibility (#191)

* Updating the build file to spark 3.0.0 and scala 2.12.10
* Fixing incompatibilities
* Adding default parameters to newer required functions
* Removing HiveTest
This commit is contained in:
Nico Poggi 2020-11-03 15:27:34 +01:00 committed by GitHub
parent 6b2bf9f9ad
commit d85f75bb38
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 13 additions and 16 deletions

View File

@ -5,16 +5,16 @@ name := "spark-sql-perf"
organization := "com.databricks"
scalaVersion := "2.11.12"
scalaVersion := "2.12.10"
crossScalaVersions := Seq("2.11.12","2.12.8")
crossScalaVersions := Seq("2.12.10", "2.11.12")
sparkPackageName := "databricks/spark-sql-perf"
// All Spark Packages need a license
licenses := Seq("Apache-2.0" -> url("http://opensource.org/licenses/Apache-2.0"))
sparkVersion := "2.4.0"
sparkVersion := "3.0.0"
sparkComponents ++= Seq("sql", "hive", "mllib")

View File

@ -335,7 +335,7 @@ object Benchmark {
.flatMap { query =>
try {
query.newDataFrame().queryExecution.logical.collect {
case UnresolvedRelation(t) => t.table
case r: UnresolvedRelation => r.tableName
}
} catch {
// ignore the queries that can't be parsed

View File

@ -77,7 +77,7 @@ trait Benchmarkable {
try {
result = doBenchmark(includeBreakdown, description, messages)
} catch {
case NonFatal(e) =>
case e: Throwable =>
logger.info(s"$that: failure in runBenchmark: $e")
println(s"$that: failure in runBenchmark: $e")
result = BenchmarkResult(

View File

@ -54,10 +54,7 @@ class Query(
}
lazy val tablesInvolved = buildDataFrame.queryExecution.logical collect {
case UnresolvedRelation(tableIdentifier) => {
// We are ignoring the database name.
tableIdentifier.table
}
case r: UnresolvedRelation => r.tableName
}
def newDataFrame() = buildDataFrame
@ -88,10 +85,11 @@ class Query(
val physicalOperators = (0 until depth).map(i => (i, queryExecution.executedPlan.p(i)))
val indexMap = physicalOperators.map { case (index, op) => (op, index) }.toMap
val timeMap = new mutable.HashMap[Int, Double]
val maxFields = 999 // Maximum number of fields that will be converted to strings
physicalOperators.reverse.map {
case (index, node) =>
messages += s"Breakdown: ${node.simpleString}"
messages += s"Breakdown: ${node.simpleString(maxFields)}"
val newNode = buildDataFrame.queryExecution.executedPlan.p(index)
val executionTime = measureTimeMs {
newNode.execute().foreach((row: Any) => Unit)
@ -104,7 +102,7 @@ class Query(
BreakdownResult(
node.nodeName,
node.simpleString.replaceAll("#\\d+", ""),
node.simpleString(maxFields).replaceAll("#\\d+", ""),
index,
childIndexes,
executionTime,

View File

@ -51,6 +51,7 @@ object NaiveBayes extends BenchmarkAlgorithm
// Initialize new Naive Bayes model
val pi = Vectors.dense(piArray)
val theta = new DenseMatrix(numClasses, numFeatures, thetaArray.flatten, true)
ModelBuilderSSP.newNaiveBayesModel(pi, theta)
}

View File

@ -59,7 +59,7 @@ object ModelBuilderSSP {
}
def newNaiveBayesModel(pi: Vector, theta: Matrix): NaiveBayesModel = {
val model = new NaiveBayesModel("naivebayes-uid", pi, theta)
val model = new NaiveBayesModel("naivebayes-uid", pi, theta, null)
model.set(model.modelType, "multinomial")
}
@ -160,9 +160,9 @@ object TreeBuilder {
labelGenerator.setSeed(rng.nextLong)
// We use a dummy impurityCalculator for all nodes.
val impurityCalculator = if (isRegression) {
ImpurityCalculator.getCalculator("variance", Array.fill[Double](3)(0.0))
ImpurityCalculator.getCalculator("variance", Array.fill[Double](3)(0.0), 0L)
} else {
ImpurityCalculator.getCalculator("gini", Array.fill[Double](labelType)(0.0))
ImpurityCalculator.getCalculator("gini", Array.fill[Double](labelType)(0.0), 0L)
}
randomBalancedDecisionTreeHelper(depth, featureArity, impurityCalculator,

View File

@ -1,11 +1,9 @@
package com.databricks.spark.sql.perf
import org.apache.spark.sql.hive.test.TestHive
import org.scalatest.FunSuite
class DatasetPerformanceSuite extends FunSuite {
ignore("run benchmark") {
TestHive // Init HiveContext
val benchmark = new DatasetPerformance() {
override val numLongs = 100
}