diff --git a/extensions/spark/kyuubi-spark-connector-hive/src/main/scala/org/apache/kyuubi/spark/connector/hive/HiveConnectorUtils.scala b/extensions/spark/kyuubi-spark-connector-hive/src/main/scala/org/apache/kyuubi/spark/connector/hive/HiveConnectorUtils.scala index f56aa977b..371d79abe 100644 --- a/extensions/spark/kyuubi-spark-connector-hive/src/main/scala/org/apache/kyuubi/spark/connector/hive/HiveConnectorUtils.scala +++ b/extensions/spark/kyuubi-spark-connector-hive/src/main/scala/org/apache/kyuubi/spark/connector/hive/HiveConnectorUtils.scala @@ -18,6 +18,7 @@ package org.apache.kyuubi.spark.connector.hive import java.lang.{Boolean => JBoolean, Long => JLong} +import java.net.URI import scala.util.Try @@ -25,12 +26,11 @@ import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.hive.ql.plan.{FileSinkDesc, TableDesc} import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTablePartition} import org.apache.spark.sql.connector.catalog.TableChange import org.apache.spark.sql.connector.catalog.TableChange._ import org.apache.spark.sql.execution.command.CommandUtils -import org.apache.spark.sql.execution.command.CommandUtils.{calculateMultipleLocationSizes, calculateSingleLocationSize} import org.apache.spark.sql.execution.datasources.{PartitionDirectory, PartitionedFile} import org.apache.spark.sql.hive.execution.HiveFileFormat import org.apache.spark.sql.internal.SQLConf @@ -82,7 +82,28 @@ object HiveConnectorUtils extends Logging { isSplitable: JBoolean, maxSplitBytes: JLong, partitionValues: InternalRow): Seq[PartitionedFile] = - Try { // SPARK-42821: 4.0.0-preview2 + Try { // SPARK-42821, SPARK-51185: Spark 4.0 + val fileStatusWithMetadataClz = DynClasses.builder() + .impl("org.apache.spark.sql.execution.datasources.FileStatusWithMetadata") + .buildChecked() + DynMethods + .builder("splitFiles") + .impl( + "org.apache.spark.sql.execution.PartitionedFileUtil", + fileStatusWithMetadataClz, + classOf[Path], + classOf[Boolean], + classOf[Long], + classOf[InternalRow]) + .buildChecked() + .invokeChecked[Seq[PartitionedFile]]( + null, + file, + filePath, + isSplitable, + maxSplitBytes, + partitionValues) + }.recover { case _: Exception => // SPARK-42821: 4.0.0-preview2 val fileStatusWithMetadataClz = DynClasses.builder() .impl("org.apache.spark.sql.execution.datasources.FileStatusWithMetadata") .buildChecked() @@ -192,6 +213,29 @@ object HiveConnectorUtils extends Logging { file.asInstanceOf[FileStatus].getPath }.get + private def calculateMultipleLocationSizes( + sparkSession: SparkSession, + tid: TableIdentifier, + paths: Seq[Option[URI]]): Seq[Long] = { + + val sparkSessionClz = DynClasses.builder() + .impl("org.apache.spark.sql.classic.SparkSession") // SPARK-49700 (4.0.0) + .impl("org.apache.spark.sql.SparkSession") + .build() + + val calculateMultipleLocationSizesMethod = + DynMethods.builder("calculateMultipleLocationSizes") + .impl( + CommandUtils.getClass, + sparkSessionClz, + classOf[TableIdentifier], + classOf[Seq[Option[URI]]]) + .buildChecked(CommandUtils) + + calculateMultipleLocationSizesMethod + .invokeChecked[Seq[Long]](sparkSession, tid, paths) + } + def calculateTotalSize( spark: SparkSession, catalogTable: CatalogTable, @@ -199,12 +243,11 @@ object HiveConnectorUtils extends Logging { val sessionState = spark.sessionState val startTime = System.nanoTime() val (totalSize, newPartitions) = if (catalogTable.partitionColumnNames.isEmpty) { - ( - calculateSingleLocationSize( - sessionState, - catalogTable.identifier, - catalogTable.storage.locationUri), - Seq()) + val tableSize = CommandUtils.calculateSingleLocationSize( + sessionState, + catalogTable.identifier, + catalogTable.storage.locationUri) + (tableSize, Seq()) } else { // Calculate table size as a sum of the visible partitions. See SPARK-21079 val partitions = hiveTableCatalog.listPartitions(catalogTable.identifier) diff --git a/extensions/spark/kyuubi-spark-connector-hive/src/main/scala/org/apache/kyuubi/spark/connector/hive/HiveTableCatalog.scala b/extensions/spark/kyuubi-spark-connector-hive/src/main/scala/org/apache/kyuubi/spark/connector/hive/HiveTableCatalog.scala index 91088d787..f72881f92 100644 --- a/extensions/spark/kyuubi-spark-connector-hive/src/main/scala/org/apache/kyuubi/spark/connector/hive/HiveTableCatalog.scala +++ b/extensions/spark/kyuubi-spark-connector-hive/src/main/scala/org/apache/kyuubi/spark/connector/hive/HiveTableCatalog.scala @@ -17,17 +17,19 @@ package org.apache.kyuubi.spark.connector.hive +import java.lang.{Boolean => JBoolean, Long => JLong} import java.net.URI import java.util import scala.collection.JavaConverters._ import scala.collection.mutable +import scala.util.Try import org.apache.hadoop.conf.Configuration import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.{SQLConfHelper, TableIdentifier} +import org.apache.spark.sql.catalyst.{CurrentUserContext, SQLConfHelper, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException, NoSuchDatabaseException, NoSuchNamespaceException, NoSuchTableException, TableAlreadyExistsException} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec @@ -47,6 +49,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.kyuubi.spark.connector.hive.HiveConnectorUtils.withSparkSQLConf import org.apache.kyuubi.spark.connector.hive.HiveTableCatalog.{getStorageFormatAndProvider, toCatalogDatabase, CatalogDatabaseHelper, IdentifierHelper, NamespaceHelper} import org.apache.kyuubi.spark.connector.hive.KyuubiHiveConnectorDelegationTokenProvider.metastoreTokenSignature +import org.apache.kyuubi.util.reflect.{DynClasses, DynConstructors} /** * A [[TableCatalog]] that wrap HiveExternalCatalog to as V2 CatalogPlugin instance to access Hive. @@ -100,6 +103,20 @@ class HiveTableCatalog(sparkSession: SparkSession) catalogName } + private def newHiveMetastoreCatalog(sparkSession: SparkSession): HiveMetastoreCatalog = { + val sparkSessionClz = DynClasses.builder() + .impl("org.apache.spark.sql.classic.SparkSession") // SPARK-49700 (4.0.0) + .impl("org.apache.spark.sql.SparkSession") + .buildChecked() + + val hiveMetastoreCatalogCtor = + DynConstructors.builder() + .impl("org.apache.spark.sql.hive.HiveMetastoreCatalog", sparkSessionClz) + .buildChecked[HiveMetastoreCatalog]() + + hiveMetastoreCatalogCtor.newInstanceChecked(sparkSession) + } + override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = { assert(catalogName == null, "The Hive table catalog is already initialed.") assert( @@ -110,7 +127,7 @@ class HiveTableCatalog(sparkSession: SparkSession) catalog = new HiveSessionCatalog( externalCatalogBuilder = () => externalCatalog, globalTempViewManagerBuilder = () => globalTempViewManager, - metastoreCatalog = new HiveMetastoreCatalog(sparkSession), + metastoreCatalog = newHiveMetastoreCatalog(sparkSession), functionRegistry = sessionState.functionRegistry, tableFunctionRegistry = sessionState.tableFunctionRegistry, hadoopConf = hadoopConf, @@ -166,6 +183,129 @@ class HiveTableCatalog(sparkSession: SparkSession) HiveTable(sparkSession, catalog.getTableMetadata(ident.asTableIdentifier), this) } + // scalastyle:off + private def newCatalogTable( + identifier: TableIdentifier, + tableType: CatalogTableType, + storage: CatalogStorageFormat, + schema: StructType, + provider: Option[String] = None, + partitionColumnNames: Seq[String] = Seq.empty, + bucketSpec: Option[BucketSpec] = None, + owner: String = Option(CurrentUserContext.CURRENT_USER.get()).getOrElse(""), + createTime: JLong = System.currentTimeMillis, + lastAccessTime: JLong = -1, + createVersion: String = "", + properties: Map[String, String] = Map.empty, + stats: Option[CatalogStatistics] = None, + viewText: Option[String] = None, + comment: Option[String] = None, + collation: Option[String] = None, + unsupportedFeatures: Seq[String] = Seq.empty, + tracksPartitionsInCatalog: JBoolean = false, + schemaPreservesCase: JBoolean = true, + ignoredProperties: Map[String, String] = Map.empty, + viewOriginalText: Option[String] = None): CatalogTable = { + // scalastyle:on + Try { // SPARK-50675 (4.0.0) + DynConstructors.builder() + .impl( + classOf[CatalogTable], + classOf[TableIdentifier], + classOf[CatalogTableType], + classOf[CatalogStorageFormat], + classOf[StructType], + classOf[Option[String]], + classOf[Seq[String]], + classOf[Option[BucketSpec]], + classOf[String], + classOf[Long], + classOf[Long], + classOf[String], + classOf[Map[String, String]], + classOf[Option[CatalogStatistics]], + classOf[Option[String]], + classOf[Option[String]], + classOf[Option[String]], + classOf[Seq[String]], + classOf[Boolean], + classOf[Boolean], + classOf[Map[String, String]], + classOf[Option[String]]) + .buildChecked() + .invokeChecked[CatalogTable]( + null, + identifier, + tableType, + storage, + schema, + provider, + partitionColumnNames, + bucketSpec, + owner, + createTime, + lastAccessTime, + createVersion, + properties, + stats, + viewText, + comment, + collation, + unsupportedFeatures, + tracksPartitionsInCatalog, + schemaPreservesCase, + ignoredProperties, + viewOriginalText) + }.recover { case _: Exception => // Spark 3.5 and previous + DynConstructors.builder() + .impl( + classOf[CatalogTable], + classOf[TableIdentifier], + classOf[CatalogTableType], + classOf[CatalogStorageFormat], + classOf[StructType], + classOf[Option[String]], + classOf[Seq[String]], + classOf[Option[BucketSpec]], + classOf[String], + classOf[Long], + classOf[Long], + classOf[String], + classOf[Map[String, String]], + classOf[Option[CatalogStatistics]], + classOf[Option[String]], + classOf[Option[String]], + classOf[Seq[String]], + classOf[Boolean], + classOf[Boolean], + classOf[Map[String, String]], + classOf[Option[String]]) + .buildChecked() + .invokeChecked[CatalogTable]( + null, + identifier, + tableType, + storage, + schema, + provider, + partitionColumnNames, + bucketSpec, + owner, + createTime, + lastAccessTime, + createVersion, + properties, + stats, + viewText, + comment, + unsupportedFeatures, + tracksPartitionsInCatalog, + schemaPreservesCase, + ignoredProperties, + viewOriginalText) + }.get + } + override def createTable( ident: Identifier, schema: StructType, @@ -190,7 +330,7 @@ class HiveTableCatalog(sparkSession: SparkSession) CatalogTableType.MANAGED } - val tableDesc = CatalogTable( + val tableDesc = newCatalogTable( identifier = ident.asTableIdentifier, tableType = tableType, storage = storage, diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/SparkPlanHelper.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/SparkPlanHelper.scala new file mode 100644 index 000000000..be0eb02a1 --- /dev/null +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/SparkPlanHelper.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.SparkSession + +import org.apache.kyuubi.util.reflect.DynMethods + +object SparkPlanHelper { + + private val sparkSessionMethod = DynMethods.builder("session") + .impl(classOf[SparkPlan]) + .buildChecked() + + def sparkSession(sparkPlan: SparkPlan): SparkSession = { + sparkSessionMethod.invokeChecked[SparkSession](sparkPlan) + } +} diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/KyuubiArrowConverters.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/KyuubiArrowConverters.scala index 04f4ede6c..e2fb55134 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/KyuubiArrowConverters.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/execution/arrow/KyuubiArrowConverters.scala @@ -32,7 +32,7 @@ import org.apache.spark.TaskContext import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.{InternalRow, SQLConfHelper} import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.execution.CollectLimitExec +import org.apache.spark.sql.execution.{CollectLimitExec, SparkPlanHelper} import org.apache.spark.sql.types._ import org.apache.spark.sql.util.ArrowUtils import org.apache.spark.util.Utils @@ -157,7 +157,7 @@ object KyuubiArrowConverters extends SQLConfHelper with Logging { val partsToScan = partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts)) - val sc = collectLimitExec.session.sparkContext + val sc = SparkPlanHelper.sparkSession(collectLimitExec).sparkContext val res = sc.runJob( childRDD, (it: Iterator[InternalRow]) => { diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala index 85cf2971e..73e7f7799 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala @@ -26,7 +26,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.catalyst.plans.logical.GlobalLimit import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils -import org.apache.spark.sql.execution.{CollectLimitExec, CommandResultExec, HiveResult, LocalTableScanExec, QueryExecution, SparkPlan, SQLExecution} +import org.apache.spark.sql.execution.{CollectLimitExec, CommandResultExec, HiveResult, LocalTableScanExec, QueryExecution, SparkPlan, SparkPlanHelper, SQLExecution} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec import org.apache.spark.sql.execution.arrow.KyuubiArrowConverters import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} @@ -83,8 +83,9 @@ object SparkDatasetHelper extends Logging { */ def toArrowBatchRdd(plan: SparkPlan): RDD[Array[Byte]] = { val schemaCaptured = plan.schema - val maxRecordsPerBatch = plan.session.sessionState.conf.arrowMaxRecordsPerBatch - val timeZoneId = plan.session.sessionState.conf.sessionLocalTimeZone + val spark = SparkPlanHelper.sparkSession(plan) + val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch + val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone // note that, we can't pass the lazy variable `maxBatchSize` directly, this is because input // arguments are serialized and sent to the executor side for execution. val maxBatchSizePerBatch = maxBatchSize @@ -169,8 +170,9 @@ object SparkDatasetHelper extends Logging { } private def doCollectLimit(collectLimit: CollectLimitExec): Array[Array[Byte]] = { - val timeZoneId = collectLimit.session.sessionState.conf.sessionLocalTimeZone - val maxRecordsPerBatch = collectLimit.session.sessionState.conf.arrowMaxRecordsPerBatch + val spark = SparkPlanHelper.sparkSession(collectLimit) + val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone + val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch val batches = KyuubiArrowConverters.takeAsArrowBatches( collectLimit, @@ -199,7 +201,7 @@ object SparkDatasetHelper extends Logging { } private def doCommandResultExec(commandResult: CommandResultExec): Array[Array[Byte]] = { - val spark = commandResult.session + val spark = SparkPlanHelper.sparkSession(commandResult) commandResult.longMetric("numOutputRows").add(commandResult.rows.size) sendDriverMetrics(spark.sparkContext, commandResult.metrics) KyuubiArrowConverters.toBatchIterator( @@ -212,7 +214,7 @@ object SparkDatasetHelper extends Logging { } private def doLocalTableScan(localTableScan: LocalTableScanExec): Array[Array[Byte]] = { - val spark = localTableScan.session + val spark = SparkPlanHelper.sparkSession(localTableScan) localTableScan.longMetric("numOutputRows").add(localTableScan.rows.size) sendDriverMetrics(spark.sparkContext, localTableScan.metrics) KyuubiArrowConverters.toBatchIterator(