From d85f75bb38a11b10562d2fed7e3ee8e78ee2f97d Mon Sep 17 00:00:00 2001 From: Nico Poggi Date: Tue, 3 Nov 2020 15:27:34 +0100 Subject: [PATCH] Update for Spark 3.0.0 compatibility (#191) * Updating the build file to spark 3.0.0 and scala 2.12.10 * Fixing incompatibilities * Adding default parameters to newer required functions * Removing HiveTest --- build.sbt | 6 +++--- .../com/databricks/spark/sql/perf/Benchmark.scala | 2 +- .../com/databricks/spark/sql/perf/Benchmarkable.scala | 2 +- .../scala/com/databricks/spark/sql/perf/Query.scala | 10 ++++------ .../sql/perf/mllib/classification/NaiveBayes.scala | 1 + .../scala/org/apache/spark/ml/ModelBuilderSSP.scala | 6 +++--- .../spark/sql/perf/DatasetPerformanceSuite.scala | 2 -- 7 files changed, 13 insertions(+), 16 deletions(-) diff --git a/build.sbt b/build.sbt index 13a091c..3cc42a5 100644 --- a/build.sbt +++ b/build.sbt @@ -5,16 +5,16 @@ name := "spark-sql-perf" organization := "com.databricks" -scalaVersion := "2.11.12" +scalaVersion := "2.12.10" -crossScalaVersions := Seq("2.11.12","2.12.8") +crossScalaVersions := Seq("2.12.10", "2.11.12") sparkPackageName := "databricks/spark-sql-perf" // All Spark Packages need a license licenses := Seq("Apache-2.0" -> url("http://opensource.org/licenses/Apache-2.0")) -sparkVersion := "2.4.0" +sparkVersion := "3.0.0" sparkComponents ++= Seq("sql", "hive", "mllib") diff --git a/src/main/scala/com/databricks/spark/sql/perf/Benchmark.scala b/src/main/scala/com/databricks/spark/sql/perf/Benchmark.scala index 7d85595..ebb4935 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/Benchmark.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/Benchmark.scala @@ -335,7 +335,7 @@ object Benchmark { .flatMap { query => try { query.newDataFrame().queryExecution.logical.collect { - case UnresolvedRelation(t) => t.table + case r: UnresolvedRelation => r.tableName } } catch { // ignore the queries that can't be parsed diff --git a/src/main/scala/com/databricks/spark/sql/perf/Benchmarkable.scala b/src/main/scala/com/databricks/spark/sql/perf/Benchmarkable.scala index 9df2a1d..24efef7 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/Benchmarkable.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/Benchmarkable.scala @@ -77,7 +77,7 @@ trait Benchmarkable { try { result = doBenchmark(includeBreakdown, description, messages) } catch { - case NonFatal(e) => + case e: Throwable => logger.info(s"$that: failure in runBenchmark: $e") println(s"$that: failure in runBenchmark: $e") result = BenchmarkResult( diff --git a/src/main/scala/com/databricks/spark/sql/perf/Query.scala b/src/main/scala/com/databricks/spark/sql/perf/Query.scala index 16cd907..babc63f 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/Query.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/Query.scala @@ -54,10 +54,7 @@ class Query( } lazy val tablesInvolved = buildDataFrame.queryExecution.logical collect { - case UnresolvedRelation(tableIdentifier) => { - // We are ignoring the database name. - tableIdentifier.table - } + case r: UnresolvedRelation => r.tableName } def newDataFrame() = buildDataFrame @@ -88,10 +85,11 @@ class Query( val physicalOperators = (0 until depth).map(i => (i, queryExecution.executedPlan.p(i))) val indexMap = physicalOperators.map { case (index, op) => (op, index) }.toMap val timeMap = new mutable.HashMap[Int, Double] + val maxFields = 999 // Maximum number of fields that will be converted to strings physicalOperators.reverse.map { case (index, node) => - messages += s"Breakdown: ${node.simpleString}" + messages += s"Breakdown: ${node.simpleString(maxFields)}" val newNode = buildDataFrame.queryExecution.executedPlan.p(index) val executionTime = measureTimeMs { newNode.execute().foreach((row: Any) => Unit) @@ -104,7 +102,7 @@ class Query( BreakdownResult( node.nodeName, - node.simpleString.replaceAll("#\\d+", ""), + node.simpleString(maxFields).replaceAll("#\\d+", ""), index, childIndexes, executionTime, diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/NaiveBayes.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/NaiveBayes.scala index cb527ce..6d648f5 100644 --- a/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/NaiveBayes.scala +++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/classification/NaiveBayes.scala @@ -51,6 +51,7 @@ object NaiveBayes extends BenchmarkAlgorithm // Initialize new Naive Bayes model val pi = Vectors.dense(piArray) val theta = new DenseMatrix(numClasses, numFeatures, thetaArray.flatten, true) + ModelBuilderSSP.newNaiveBayesModel(pi, theta) } diff --git a/src/main/scala/org/apache/spark/ml/ModelBuilderSSP.scala b/src/main/scala/org/apache/spark/ml/ModelBuilderSSP.scala index 511b645..fa66e00 100644 --- a/src/main/scala/org/apache/spark/ml/ModelBuilderSSP.scala +++ b/src/main/scala/org/apache/spark/ml/ModelBuilderSSP.scala @@ -59,7 +59,7 @@ object ModelBuilderSSP { } def newNaiveBayesModel(pi: Vector, theta: Matrix): NaiveBayesModel = { - val model = new NaiveBayesModel("naivebayes-uid", pi, theta) + val model = new NaiveBayesModel("naivebayes-uid", pi, theta, null) model.set(model.modelType, "multinomial") } @@ -160,9 +160,9 @@ object TreeBuilder { labelGenerator.setSeed(rng.nextLong) // We use a dummy impurityCalculator for all nodes. val impurityCalculator = if (isRegression) { - ImpurityCalculator.getCalculator("variance", Array.fill[Double](3)(0.0)) + ImpurityCalculator.getCalculator("variance", Array.fill[Double](3)(0.0), 0L) } else { - ImpurityCalculator.getCalculator("gini", Array.fill[Double](labelType)(0.0)) + ImpurityCalculator.getCalculator("gini", Array.fill[Double](labelType)(0.0), 0L) } randomBalancedDecisionTreeHelper(depth, featureArity, impurityCalculator, diff --git a/src/test/scala/com/databricks/spark/sql/perf/DatasetPerformanceSuite.scala b/src/test/scala/com/databricks/spark/sql/perf/DatasetPerformanceSuite.scala index 3ea68ea..f0b936d 100644 --- a/src/test/scala/com/databricks/spark/sql/perf/DatasetPerformanceSuite.scala +++ b/src/test/scala/com/databricks/spark/sql/perf/DatasetPerformanceSuite.scala @@ -1,11 +1,9 @@ package com.databricks.spark.sql.perf -import org.apache.spark.sql.hive.test.TestHive import org.scalatest.FunSuite class DatasetPerformanceSuite extends FunSuite { ignore("run benchmark") { - TestHive // Init HiveContext val benchmark = new DatasetPerformance() { override val numLongs = 100 }