[CELEBORN-1041] Improve the implementation for get the PartitionIdPassthrough class

### What changes were proposed in this pull request?
Currently, the code of get the contractor of `PartitionIdPassthrough` is very redundant.
We should improve the implementation.

### Why are the changes needed?
Improve the implementation for get the `PartitionIdPassthrough` class

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
New test cases.

Closes #1989 from beliefer/CELEBORN-1041.

Authored-by: Jiaan Geng <beliefer@163.com>
Signed-off-by: zky.zhoukeyong <zky.zhoukeyong@alibaba-inc.com>
This commit is contained in:
Jiaan Geng 2023-10-16 14:24:49 +08:00 committed by zky.zhoukeyong
parent 640007100a
commit cdcd30dc2f
2 changed files with 40 additions and 6 deletions

View File

@ -0,0 +1,32 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark
import org.apache.spark.util.VersionUtils
object SparkVersionUtil {
private val sparkMajorMinorVersion = VersionUtils.majorMinorVersion(SPARK_VERSION)
def isGreaterThan(major: Int, minor: Int): Boolean = {
sparkMajorMinorVersion match {
case (ma, _) if ma > major => true
case (ma, mi) if ma == major && mi > minor => true
case other => false
}
}
}

View File

@ -42,6 +42,7 @@ import org.apache.spark.Partitioner;
import org.apache.spark.ShuffleDependency;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkEnv;
import org.apache.spark.SparkVersionUtil;
import org.apache.spark.TaskContext;
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.executor.TaskMetrics;
@ -217,13 +218,14 @@ public abstract class CelebornShuffleWriterSuiteBase {
throws Exception {
final boolean useUnsafe = serializer instanceof UnsafeRowSerializer;
String PartitionIdPassthroughClazz;
if (SparkVersionUtil.isGreaterThan(3, 3)) {
PartitionIdPassthroughClazz = "org.apache.spark.PartitionIdPassthrough";
} else {
PartitionIdPassthroughClazz = "org.apache.spark.sql.execution.PartitionIdPassthrough";
}
DynConstructors.Ctor<Partitioner> partitionIdPassthroughCtor =
DynConstructors.builder()
// for Spark 3.3 and previous
.impl("org.apache.spark.sql.execution.PartitionIdPassthrough", int.class)
// for Spark 3.4
.impl("org.apache.spark.PartitionIdPassthrough", int.class)
.build();
DynConstructors.builder().impl(PartitionIdPassthroughClazz, int.class).build();
final Partitioner partitioner =
useUnsafe
? partitionIdPassthroughCtor.newInstance(numPartitions)