diff --git a/dev/kyuubi-extension-spark-3-1/src/test/scala/org/apache/spark/sql/KyuubiExtensionSuite.scala b/dev/kyuubi-extension-spark-3-1/src/test/scala/org/apache/spark/sql/KyuubiExtensionSuite.scala index 00d47cabe..85d17dfeb 100644 --- a/dev/kyuubi-extension-spark-3-1/src/test/scala/org/apache/spark/sql/KyuubiExtensionSuite.scala +++ b/dev/kyuubi-extension-spark-3-1/src/test/scala/org/apache/spark/sql/KyuubiExtensionSuite.scala @@ -21,67 +21,16 @@ import scala.collection.mutable.Set import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Multiply} import org.apache.spark.sql.catalyst.plans.logical.RepartitionByExpression -import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, CustomShuffleReaderExec, QueryStageExec} +import org.apache.spark.sql.execution.adaptive.{CustomShuffleReaderExec, QueryStageExec} import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, ShuffleExchangeLike} import org.apache.spark.sql.hive.HiveUtils import org.apache.spark.sql.hive.execution.OptimizedCreateHiveTableAsSelectCommand -import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} -import org.apache.spark.sql.test.SQLTestData.TestData -import org.apache.spark.sql.test.SQLTestUtils -import org.apache.spark.util.Utils +import org.apache.spark.sql.internal.SQLConf import org.apache.kyuubi.sql.{FinalStageConfigIsolation, KyuubiSQLConf} import org.apache.kyuubi.sql.watchdog.MaxHivePartitionExceedException -class KyuubiExtensionSuite extends QueryTest with SQLTestUtils with AdaptiveSparkPlanHelper { - - var _spark: SparkSession = _ - override def spark: SparkSession = _spark - - protected override def beforeAll(): Unit = { - _spark = SparkSession.builder() - .master("local[1]") - .config(StaticSQLConf.SPARK_SESSION_EXTENSIONS.key, - "org.apache.kyuubi.sql.KyuubiSparkSQLExtension") - .config(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true") - .config("spark.hadoop.hive.exec.dynamic.partition.mode", "nonstrict") - .config("spark.hadoop.hive.metastore.client.capability.check", "false") - .config("spark.ui.enabled", "false") - .enableHiveSupport() - .getOrCreate() - setupData() - super.beforeAll() - } - - protected override def afterAll(): Unit = { - super.afterAll() - cleanupData() - if (_spark != null) { - _spark.stop() - } - Utils.deleteRecursively(new java.io.File("spark-warehouse")) - Utils.deleteRecursively(new java.io.File("metastore_db")) - } - - private def setupData(): Unit = { - val self = _spark - import self.implicits._ - spark.sparkContext.parallelize( - (1 to 100).map(i => TestData(i, i.toString)), 10) - .toDF("c1", "c2").createOrReplaceTempView("t1") - spark.sparkContext.parallelize( - (1 to 10).map(i => TestData(i, i.toString)), 5) - .toDF("c1", "c2").createOrReplaceTempView("t2") - spark.sparkContext.parallelize( - (1 to 50).map(i => TestData(i, i.toString)), 2) - .toDF("c1", "c2").createOrReplaceTempView("t3") - } - - private def cleanupData(): Unit = { - spark.sql("DROP VIEW IF EXISTS t1") - spark.sql("DROP VIEW IF EXISTS t2") - spark.sql("DROP VIEW IF EXISTS t3") - } +class KyuubiExtensionSuite extends KyuubiSparkSQLExtensionTest { test("check repartition exists") { def check(df: DataFrame): Unit = { diff --git a/dev/kyuubi-extension-spark-3-1/src/test/scala/org/apache/spark/sql/KyuubiSparkSQLExtensionTest.scala b/dev/kyuubi-extension-spark-3-1/src/test/scala/org/apache/spark/sql/KyuubiSparkSQLExtensionTest.scala new file mode 100644 index 000000000..d24e3fe62 --- /dev/null +++ b/dev/kyuubi-extension-spark-3-1/src/test/scala/org/apache/spark/sql/KyuubiSparkSQLExtensionTest.scala @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql + +import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.apache.spark.SparkConf +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} +import org.apache.spark.sql.test.SQLTestData.TestData +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.util.Utils + +trait KyuubiSparkSQLExtensionTest extends QueryTest + with SQLTestUtils + with AdaptiveSparkPlanHelper { + private var _spark: Option[SparkSession] = None + protected def spark: SparkSession = _spark.getOrElse{ + throw new RuntimeException("test spark session don't initial before using it.") + } + + protected override def beforeAll(): Unit = { + if (_spark.isEmpty) { + _spark = Option(SparkSession.builder() + .master("local[1]") + .config(sparkConf) + .enableHiveSupport() + .getOrCreate()) + } + setupData() + super.beforeAll() + } + + protected override def afterAll(): Unit = { + super.afterAll() + cleanupData() + _spark.foreach(_.stop) + } + + private def setupData(): Unit = { + val self = spark + import self.implicits._ + spark.sparkContext.parallelize( + (1 to 100).map(i => TestData(i, i.toString)), 10) + .toDF("c1", "c2").createOrReplaceTempView("t1") + spark.sparkContext.parallelize( + (1 to 10).map(i => TestData(i, i.toString)), 5) + .toDF("c1", "c2").createOrReplaceTempView("t2") + spark.sparkContext.parallelize( + (1 to 50).map(i => TestData(i, i.toString)), 2) + .toDF("c1", "c2").createOrReplaceTempView("t3") + } + + private def cleanupData(): Unit = { + spark.sql("DROP VIEW IF EXISTS t1") + spark.sql("DROP VIEW IF EXISTS t2") + spark.sql("DROP VIEW IF EXISTS t3") + } + + def sparkConf(): SparkConf = { + val basePath = Utils.createTempDir() + "/" + getClass.getCanonicalName + val metastorePath = basePath + "/metastore_db" + val warehousePath = basePath + "/warehouse" + new SparkConf() + .set(StaticSQLConf.SPARK_SESSION_EXTENSIONS.key, + "org.apache.kyuubi.sql.KyuubiSparkSQLExtension") + .set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true") + .set("spark.hadoop.hive.exec.dynamic.partition.mode", "nonstrict") + .set("spark.hadoop.hive.metastore.client.capability.check", "false") + .set(ConfVars.METASTORECONNECTURLKEY.varname, + s"jdbc:derby:;databaseName=$metastorePath;create=true") + .set(StaticSQLConf.WAREHOUSE_PATH, warehousePath) + .set("spark.ui.enabled", "false") + } +} diff --git a/dev/kyuubi-extension-spark-3-1/src/test/scala/org/apache/spark/sql/ZorderSuite.scala b/dev/kyuubi-extension-spark-3-1/src/test/scala/org/apache/spark/sql/ZorderSuite.scala index f7e72156f..e5ab0712c 100644 --- a/dev/kyuubi-extension-spark-3-1/src/test/scala/org/apache/spark/sql/ZorderSuite.scala +++ b/dev/kyuubi-extension-spark-3-1/src/test/scala/org/apache/spark/sql/ZorderSuite.scala @@ -21,50 +21,17 @@ import org.apache.spark.SparkConf import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Expression, ExpressionEvalHelper, Literal, NullsLast, SortOrder} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation, Project, Sort} -import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.command.CreateDataSourceTableAsSelectCommand import org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.execution.{CreateHiveTableAsSelectCommand, InsertIntoHiveTable, OptimizedCreateHiveTableAsSelectCommand} -import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils import org.apache.kyuubi.sql.{KyuubiSQLConf, KyuubiSQLExtensionException} import org.apache.kyuubi.sql.zorder.Zorder -trait ZorderSuite extends QueryTest - with SQLTestUtils - with AdaptiveSparkPlanHelper - with ExpressionEvalHelper { - - var _spark: SparkSession = _ - override def spark: SparkSession = _spark - - protected override def beforeAll(): Unit = { - _spark = SparkSession.builder() - .master("local[1]") - .config(StaticSQLConf.SPARK_SESSION_EXTENSIONS.key, - "org.apache.kyuubi.sql.KyuubiSparkSQLExtension") - .config(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true") - .config("spark.hadoop.hive.exec.dynamic.partition.mode", "nonstrict") - .config("spark.hadoop.hive.metastore.client.capability.check", "false") - .config("spark.ui.enabled", "false") - .config(sparkConf) - .enableHiveSupport() - .getOrCreate() - super.beforeAll() - } - - protected override def afterAll(): Unit = { - super.afterAll() - if (_spark != null) { - _spark.stop() - } - Utils.deleteRecursively(new java.io.File("spark-warehouse")) - Utils.deleteRecursively(new java.io.File("metastore_db")) - } +trait ZorderSuite extends KyuubiSparkSQLExtensionTest with ExpressionEvalHelper { test("optimize unpartitioned table") { withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { @@ -431,13 +398,11 @@ trait ZorderSuite extends QueryTest Row(2, 1) :: Row(0, 2) :: Row(1, 2) :: Row(2, 2) :: Nil checkSort(input2, expected2) } - - def sparkConf(): SparkConf } class ZorderWithCodegenEnabledSuite extends ZorderSuite { override def sparkConf(): SparkConf = { - val conf = new SparkConf() + val conf = super.sparkConf conf.set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true") conf } @@ -445,7 +410,7 @@ class ZorderWithCodegenEnabledSuite extends ZorderSuite { class ZorderWithCodegenDisabledSuite extends ZorderSuite { override def sparkConf(): SparkConf = { - val conf = new SparkConf() + val conf = super.sparkConf conf.set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "false") conf.set(SQLConf.CODEGEN_FACTORY_MODE.key, "NO_CODEGEN") conf