diff --git a/kyuubi-common/src/main/scala/org/apache/kyuubi/service/authentication/PlainSASLHelper.scala b/kyuubi-common/src/main/scala/org/apache/kyuubi/service/authentication/PlainSASLHelper.scala
index 2d880a344..25b3e376a 100644
--- a/kyuubi-common/src/main/scala/org/apache/kyuubi/service/authentication/PlainSASLHelper.scala
+++ b/kyuubi-common/src/main/scala/org/apache/kyuubi/service/authentication/PlainSASLHelper.scala
@@ -18,9 +18,10 @@
package org.apache.kyuubi.service.authentication
import java.security.Security
+import java.util
import java.util.Collections
import javax.security.auth.callback.{Callback, CallbackHandler, NameCallback, PasswordCallback, UnsupportedCallbackException}
-import javax.security.sasl.AuthorizeCallback
+import javax.security.sasl.{AuthorizeCallback, Sasl}
import org.apache.kyuubi.config.KyuubiConf
import org.apache.kyuubi.service.authentication.AuthMethods.AuthMethod
@@ -79,17 +80,32 @@ object PlainSASLHelper {
conf: KyuubiConf,
transportFactory: Option[TSaslServerTransport.Factory] = None,
isServer: Boolean = true): TTransportFactory = {
- val saslFactory = transportFactory.getOrElse(new TSaslServerTransport.Factory())
- try {
- val handler = new PlainServerCallbackHandler(authTypeStr, conf, isServer)
- val props = new java.util.HashMap[String, String]
- saslFactory.addServerDefinition("PLAIN", authTypeStr, null, props, handler)
- } catch {
- case e: NoSuchElementException =>
- throw new IllegalArgumentException(
- s"Illegal authentication type $authTypeStr for plain transport",
- e)
+ val handler =
+ try {
+ new PlainServerCallbackHandler(authTypeStr, conf, isServer)
+ } catch {
+ case _: NoSuchElementException =>
+ throw new IllegalArgumentException(
+ s"Illegal authentication type $authTypeStr for plain transport")
+ }
+ val saslFactory = transportFactory.getOrElse {
+ val _factory = new TSaslServerTransport.Factory()
+ _factory.setSaslServerFactory { d =>
+ if (d.mechanism == "PLAIN") {
+ // [KYUUBI #7142]: There may be multiple SaslServer classes registered for PLAIN
+ // mechanism, we should not use JDK Sasl.createSaslServer to avoid picking up the
+ // unexpected SaslServer implementation.
+ val kyuubiFactory = new PlainSASLServer.SaslPlainServerFactory()
+ kyuubiFactory.createSaslServer(d.mechanism, d.protocol, d.serverName, d.props, d.cbh)
+ } else {
+ Sasl.createSaslServer(d.mechanism, d.protocol, d.serverName, d.props, d.cbh)
+ }
+ }
+ _factory
}
+ val props = new util.HashMap[String, String]
+ props.put("org.apache.kyuubi.service.name", if (isServer) "SERVER" else "ENGINE")
+ saslFactory.addServerDefinition("PLAIN", authTypeStr, null, props, handler)
saslFactory
}
diff --git a/kyuubi-common/src/main/scala/org/apache/kyuubi/service/authentication/PlainSASLServer.scala b/kyuubi-common/src/main/scala/org/apache/kyuubi/service/authentication/PlainSASLServer.scala
index 737a6d8cd..068b5b1d8 100644
--- a/kyuubi-common/src/main/scala/org/apache/kyuubi/service/authentication/PlainSASLServer.scala
+++ b/kyuubi-common/src/main/scala/org/apache/kyuubi/service/authentication/PlainSASLServer.scala
@@ -19,6 +19,7 @@ package org.apache.kyuubi.service.authentication
import java.io.IOException
import java.security.Provider
+import java.util
import javax.security.auth.callback.{Callback, CallbackHandler, NameCallback, PasswordCallback, UnsupportedCallbackException}
import javax.security.sasl.{AuthorizeCallback, SaslException, SaslServer, SaslServerFactory}
@@ -37,7 +38,7 @@ class PlainSASLServer(
try {
// parse the response
// message = [authzid] UTF8NUL authcid UTF8NUL passwd'
- val tokenList = new java.util.ArrayDeque[String]
+ val tokenList = new util.ArrayDeque[String]
val messageToken: StringBuilder = new StringBuilder
response.foreach {
case 0 =>
@@ -109,9 +110,9 @@ object PlainSASLServer {
mechanism: String,
protocol: String,
serverName: String,
- props: java.util.Map[String, _],
+ props: util.Map[String, _],
cbh: CallbackHandler): SaslServer = mechanism match {
- case PLAIN_METHOD =>
+ case PLAIN_METHOD if props.containsKey("org.apache.kyuubi.service.name") =>
try {
new PlainSASLServer(cbh, AuthMethods.withName(protocol))
} catch {
@@ -121,7 +122,7 @@ object PlainSASLServer {
case _ => null
}
- override def getMechanismNames(props: java.util.Map[String, _]): Array[String] = {
+ override def getMechanismNames(props: util.Map[String, _]): Array[String] = {
Array(PLAIN_METHOD)
}
}
diff --git a/kyuubi-common/src/test/scala/org/apache/kyuubi/service/authentication/PlainSASLServerSuite.scala b/kyuubi-common/src/test/scala/org/apache/kyuubi/service/authentication/PlainSASLServerSuite.scala
index a7f4b9535..6e1a86a51 100644
--- a/kyuubi-common/src/test/scala/org/apache/kyuubi/service/authentication/PlainSASLServerSuite.scala
+++ b/kyuubi-common/src/test/scala/org/apache/kyuubi/service/authentication/PlainSASLServerSuite.scala
@@ -17,6 +17,7 @@
package org.apache.kyuubi.service.authentication
+import java.util
import java.util.Collections
import javax.security.auth.callback.{Callback, CallbackHandler}
import javax.security.sasl.{AuthorizeCallback, SaslException}
@@ -28,8 +29,10 @@ class PlainSASLServerSuite extends KyuubiFunSuite {
test("SaslPlainServerFactory") {
val plainServerFactory = new SaslPlainServerFactory()
- val map = Collections.emptyMap[String, String]()
- assert(plainServerFactory.getMechanismNames(map) ===
+ val invalidProps = Collections.emptyMap[String, String]()
+ val props = new util.HashMap[String, String]()
+ props.put("org.apache.kyuubi.service.name", "TEST")
+ assert(plainServerFactory.getMechanismNames(props) ===
Array(PlainSASLServer.PLAIN_METHOD))
val ch = new CallbackHandler {
override def handle(callbacks: Array[Callback]): Unit = callbacks.foreach {
@@ -38,16 +41,23 @@ class PlainSASLServerSuite extends KyuubiFunSuite {
}
}
- val s1 = plainServerFactory.createSaslServer("INVALID", "", "", map, ch)
+ val s1 = plainServerFactory.createSaslServer("INVALID", "", "", props, ch)
assert(s1 === null)
- val s2 = plainServerFactory.createSaslServer(PlainSASLServer.PLAIN_METHOD, "", "", map, ch)
+ val s2 = plainServerFactory.createSaslServer(PlainSASLServer.PLAIN_METHOD, "", "", props, ch)
assert(s2 === null)
+ val s3 = plainServerFactory.createSaslServer(
+ PlainSASLServer.PLAIN_METHOD,
+ AuthMethods.NONE.toString,
+ "",
+ invalidProps,
+ ch)
+ assert(s3 === null)
val server = plainServerFactory.createSaslServer(
PlainSASLServer.PLAIN_METHOD,
AuthMethods.NONE.toString,
"KYUUBI",
- map,
+ props,
ch)
assert(server.getMechanismName === PlainSASLServer.PLAIN_METHOD)
assert(!server.isComplete)
@@ -78,7 +88,7 @@ class PlainSASLServerSuite extends KyuubiFunSuite {
"PLAIN",
"NONE",
"KYUUBI",
- map,
+ props,
_ => {})
val e6 = intercept[SaslException](server2.evaluateResponse(res4.map(_.toByte)))
assert(e6.getMessage === "Error validating the login")
diff --git a/pom.xml b/pom.xml
index e3696c848..a1715243e 100644
--- a/pom.xml
+++ b/pom.xml
@@ -174,7 +174,7 @@
4.13.2
3.5.2
6.13.5
- 0.5.0
+ 0.6.0
kyuubi-relocated-zookeeper-34
6.0.5
2.24.3