From cdcd30dc2fb3eb5affaa170a1ab45e9d29d7a2ee Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Mon, 16 Oct 2023 14:24:49 +0800 Subject: [PATCH] [CELEBORN-1041] Improve the implementation for get the PartitionIdPassthrough class ### What changes were proposed in this pull request? Currently, the code of get the contractor of `PartitionIdPassthrough` is very redundant. We should improve the implementation. ### Why are the changes needed? Improve the implementation for get the `PartitionIdPassthrough` class ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? New test cases. Closes #1989 from beliefer/CELEBORN-1041. Authored-by: Jiaan Geng Signed-off-by: zky.zhoukeyong --- .../org/apache/spark/SparkVersionUtil.scala | 32 +++++++++++++++++++ .../CelebornShuffleWriterSuiteBase.java | 14 ++++---- 2 files changed, 40 insertions(+), 6 deletions(-) create mode 100644 client-spark/spark-3/src/main/scala/org/apache/spark/SparkVersionUtil.scala diff --git a/client-spark/spark-3/src/main/scala/org/apache/spark/SparkVersionUtil.scala b/client-spark/spark-3/src/main/scala/org/apache/spark/SparkVersionUtil.scala new file mode 100644 index 000000000..8699c2bff --- /dev/null +++ b/client-spark/spark-3/src/main/scala/org/apache/spark/SparkVersionUtil.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.apache.spark.util.VersionUtils + +object SparkVersionUtil { + private val sparkMajorMinorVersion = VersionUtils.majorMinorVersion(SPARK_VERSION) + + def isGreaterThan(major: Int, minor: Int): Boolean = { + sparkMajorMinorVersion match { + case (ma, _) if ma > major => true + case (ma, mi) if ma == major && mi > minor => true + case other => false + } + } +} diff --git a/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/CelebornShuffleWriterSuiteBase.java b/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/CelebornShuffleWriterSuiteBase.java index 95b4257a8..d8a3d4986 100644 --- a/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/CelebornShuffleWriterSuiteBase.java +++ b/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/CelebornShuffleWriterSuiteBase.java @@ -42,6 +42,7 @@ import org.apache.spark.Partitioner; import org.apache.spark.ShuffleDependency; import org.apache.spark.SparkConf; import org.apache.spark.SparkEnv; +import org.apache.spark.SparkVersionUtil; import org.apache.spark.TaskContext; import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.executor.TaskMetrics; @@ -217,13 +218,14 @@ public abstract class CelebornShuffleWriterSuiteBase { throws Exception { final boolean useUnsafe = serializer instanceof UnsafeRowSerializer; + String PartitionIdPassthroughClazz; + if (SparkVersionUtil.isGreaterThan(3, 3)) { + PartitionIdPassthroughClazz = "org.apache.spark.PartitionIdPassthrough"; + } else { + PartitionIdPassthroughClazz = "org.apache.spark.sql.execution.PartitionIdPassthrough"; + } DynConstructors.Ctor partitionIdPassthroughCtor = - DynConstructors.builder() - // for Spark 3.3 and previous - .impl("org.apache.spark.sql.execution.PartitionIdPassthrough", int.class) - // for Spark 3.4 - .impl("org.apache.spark.PartitionIdPassthrough", int.class) - .build(); + DynConstructors.builder().impl(PartitionIdPassthroughClazz, int.class).build(); final Partitioner partitioner = useUnsafe ? partitionIdPassthroughCtor.newInstance(numPartitions)