diff --git a/README.md b/README.md
index 3b10fafe3..03f58aa85 100644
--- a/README.md
+++ b/README.md
@@ -39,16 +39,16 @@ RSS Worker's slot count is decided by `rss.worker.numSlots` or`rss.worker.flush.
RSS worker's slot count decreases when a partition is allocated and increments when a partition is freed.
## Build
-RSS supports Spark2.x(>=2.4.0) and Spark3.x(>=3.1.0).
+RSS supports Spark2.x(>=2.4.0) and Spark3.x(>=3.0.1).
Build for Spark 2
`
-./dev/make-distribution.sh -Pspark-2 -Dspark.version=[spark.version default 2.4.5]
+./dev/make-distribution.sh -Pspark-2
`
Build for Spark 3
`
-./dev/make-distribution.sh -Pspark-3 -Dspark.version=[spark.version default 3.1.2]
+./dev/make-distribution.sh -Pspark-3
`
package rss-${project.version}-bin-release.tgz will be generated.
diff --git a/client-spark/shuffle-manager-3/src/main/scala/org/apache/spark/shuffle/rss/RssShuffleManager.scala b/client-spark/shuffle-manager-3/src/main/scala/org/apache/spark/shuffle/rss/RssShuffleManager.scala
index 5c6fbe3c2..4b1a4c24a 100644
--- a/client-spark/shuffle-manager-3/src/main/scala/org/apache/spark/shuffle/rss/RssShuffleManager.scala
+++ b/client-spark/shuffle-manager-3/src/main/scala/org/apache/spark/shuffle/rss/RssShuffleManager.scala
@@ -87,6 +87,29 @@ class RssShuffleManager(conf: SparkConf) extends ShuffleManager with Logging {
}
}
+ override def unregisterShuffle(shuffleId: Int): Boolean = {
+ if (sortShuffleIds.contains(shuffleId)) {
+ sortShuffleManager.unregisterShuffle(shuffleId)
+ } else {
+ newAppId match {
+ case Some(id) => rssShuffleClient.exists(_.unregisterShuffle(id, shuffleId, isDriver))
+ case None => true
+ }
+ }
+ }
+
+ override def shuffleBlockResolver: ShuffleBlockResolver = {
+ sortShuffleManager.shuffleBlockResolver
+ }
+
+ override def stop(): Unit = {
+ rssShuffleClient.foreach(_.shutDown())
+ lifecycleManager.foreach(_.stop())
+ if (sortShuffleManager != null) {
+ sortShuffleManager.stop()
+ }
+ }
+
override def getWriter[K, V](
handle: ShuffleHandle,
mapId: Long,
@@ -112,8 +135,10 @@ class RssShuffleManager(conf: SparkConf) extends ShuffleManager with Logging {
}
}
- // remove override for compatibility
- override def getReader[K, C](
+ /**
+ * Interface for Spark3.1 and higher
+ */
+ def getReader[K, C](
handle: ShuffleHandle,
startMapIndex: Int,
endMapIndex: Int,
@@ -129,32 +154,63 @@ class RssShuffleManager(conf: SparkConf) extends ShuffleManager with Logging {
endPartition,
context,
essConf)
- case _ => sortShuffleManager.getReader(handle, startMapIndex, endMapIndex,
- startPartition, endPartition, context, metrics)
+ case _ =>
+ RssShuffleManager.invokeGetReaderMethod(
+ sortShuffleManagerName,
+ "getReader",
+ sortShuffleManager,
+ handle,
+ startMapIndex,
+ endMapIndex,
+ startPartition,
+ endPartition,
+ context,
+ metrics)
}
}
- override def unregisterShuffle(shuffleId: Int): Boolean = {
- if (sortShuffleIds.contains(shuffleId)) {
- sortShuffleManager.unregisterShuffle(shuffleId)
- } else {
- newAppId match {
- case Some(id) => rssShuffleClient.exists(_.unregisterShuffle(id, shuffleId, isDriver))
- case None => true
- }
+ /**
+ * Interface for Spark3.0 and higher
+ */
+ def getReader[K, C](
+ handle: ShuffleHandle,
+ startPartition: Int,
+ endPartition: Int,
+ context: TaskContext,
+ metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
+ handle match {
+ case _: RssShuffleHandle[K@unchecked, C@unchecked, _] =>
+ new RssShuffleReader(
+ handle.asInstanceOf[RssShuffleHandle[K, _, C]],
+ startPartition,
+ endPartition,
+ context,
+ essConf)
+ case _ =>
+ RssShuffleManager.invokeGetReaderMethod(
+ sortShuffleManagerName,
+ "getReader",
+ sortShuffleManager,
+ handle,
+ -1,
+ -1,
+ startPartition,
+ endPartition,
+ context,
+ metrics)
}
}
- override def shuffleBlockResolver: ShuffleBlockResolver = {
- sortShuffleManager.shuffleBlockResolver
- }
-
- override def stop(): Unit = {
- rssShuffleClient.foreach(_.shutDown())
- lifecycleManager.foreach(_.stop())
- if (sortShuffleManager != null) {
- sortShuffleManager.stop()
- }
+ def getReaderForRange[K, C](
+ handle: ShuffleHandle,
+ startMapIndex: Int,
+ endMapIndex: Int,
+ startPartition: Int,
+ endPartition: Int,
+ context: TaskContext,
+ metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
+ throw new UnsupportedOperationException("Currently RSS do NOT support skew join Optimization," +
+ "Please set spark.sql.adaptive.skewJoin.enabled to false")
}
}
@@ -201,6 +257,50 @@ object RssShuffleManager {
}
}
}
+
+ // Invoke and return getReader method of SortShuffleManager
+ def invokeGetReaderMethod[K, C](
+ className: String,
+ methodName: String,
+ sortShuffleManager: SortShuffleManager,
+ handle: ShuffleHandle,
+ startMapIndex: Int = 0,
+ endMapIndex: Int = Int.MaxValue,
+ startPartition: Int,
+ endPartition: Int,
+ context: TaskContext,
+ metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
+ val cls = Utils.classForName(className)
+ try {
+ val method = cls.getMethod(methodName, classOf[ShuffleHandle], Integer.TYPE, Integer.TYPE,
+ Integer.TYPE, Integer.TYPE, classOf[TaskContext], classOf[ShuffleReadMetricsReporter])
+ method.invoke(
+ sortShuffleManager,
+ handle,
+ Integer.valueOf(startMapIndex),
+ Integer.valueOf(endMapIndex),
+ Integer.valueOf(startPartition),
+ Integer.valueOf(endPartition),
+ context,
+ metrics).asInstanceOf[ShuffleReader[K, C]]
+ } catch {
+ case _: NoSuchMethodException =>
+ try {
+ val method = cls.getMethod(methodName, classOf[ShuffleHandle], Integer.TYPE, Integer.TYPE,
+ classOf[TaskContext], classOf[ShuffleReadMetricsReporter])
+ method.invoke(
+ sortShuffleManager,
+ handle,
+ Integer.valueOf(startPartition),
+ Integer.valueOf(endPartition),
+ context,
+ metrics).asInstanceOf[ShuffleReader[K, C]]
+ } catch {
+ case e: NoSuchMethodException =>
+ throw new Exception("Get getReader method failed.", e)
+ }
+ }
+ }
}
class RssShuffleHandle[K, V, C](
diff --git a/pom.xml b/pom.xml
index 4e30a2540..a74e99340 100644
--- a/pom.xml
+++ b/pom.xml
@@ -510,7 +510,7 @@
2.12.10
2.12
- 3.1.2
+ 3.0.1
shuffle-manager-3