diff --git a/kyuubi-common/src/main/scala/org/apache/kyuubi/service/authentication/PlainSASLHelper.scala b/kyuubi-common/src/main/scala/org/apache/kyuubi/service/authentication/PlainSASLHelper.scala index 2d880a344..25b3e376a 100644 --- a/kyuubi-common/src/main/scala/org/apache/kyuubi/service/authentication/PlainSASLHelper.scala +++ b/kyuubi-common/src/main/scala/org/apache/kyuubi/service/authentication/PlainSASLHelper.scala @@ -18,9 +18,10 @@ package org.apache.kyuubi.service.authentication import java.security.Security +import java.util import java.util.Collections import javax.security.auth.callback.{Callback, CallbackHandler, NameCallback, PasswordCallback, UnsupportedCallbackException} -import javax.security.sasl.AuthorizeCallback +import javax.security.sasl.{AuthorizeCallback, Sasl} import org.apache.kyuubi.config.KyuubiConf import org.apache.kyuubi.service.authentication.AuthMethods.AuthMethod @@ -79,17 +80,32 @@ object PlainSASLHelper { conf: KyuubiConf, transportFactory: Option[TSaslServerTransport.Factory] = None, isServer: Boolean = true): TTransportFactory = { - val saslFactory = transportFactory.getOrElse(new TSaslServerTransport.Factory()) - try { - val handler = new PlainServerCallbackHandler(authTypeStr, conf, isServer) - val props = new java.util.HashMap[String, String] - saslFactory.addServerDefinition("PLAIN", authTypeStr, null, props, handler) - } catch { - case e: NoSuchElementException => - throw new IllegalArgumentException( - s"Illegal authentication type $authTypeStr for plain transport", - e) + val handler = + try { + new PlainServerCallbackHandler(authTypeStr, conf, isServer) + } catch { + case _: NoSuchElementException => + throw new IllegalArgumentException( + s"Illegal authentication type $authTypeStr for plain transport") + } + val saslFactory = transportFactory.getOrElse { + val _factory = new TSaslServerTransport.Factory() + _factory.setSaslServerFactory { d => + if (d.mechanism == "PLAIN") { + // [KYUUBI #7142]: There may be multiple SaslServer classes registered for PLAIN + // mechanism, we should not use JDK Sasl.createSaslServer to avoid picking up the + // unexpected SaslServer implementation. + val kyuubiFactory = new PlainSASLServer.SaslPlainServerFactory() + kyuubiFactory.createSaslServer(d.mechanism, d.protocol, d.serverName, d.props, d.cbh) + } else { + Sasl.createSaslServer(d.mechanism, d.protocol, d.serverName, d.props, d.cbh) + } + } + _factory } + val props = new util.HashMap[String, String] + props.put("org.apache.kyuubi.service.name", if (isServer) "SERVER" else "ENGINE") + saslFactory.addServerDefinition("PLAIN", authTypeStr, null, props, handler) saslFactory } diff --git a/kyuubi-common/src/main/scala/org/apache/kyuubi/service/authentication/PlainSASLServer.scala b/kyuubi-common/src/main/scala/org/apache/kyuubi/service/authentication/PlainSASLServer.scala index 737a6d8cd..068b5b1d8 100644 --- a/kyuubi-common/src/main/scala/org/apache/kyuubi/service/authentication/PlainSASLServer.scala +++ b/kyuubi-common/src/main/scala/org/apache/kyuubi/service/authentication/PlainSASLServer.scala @@ -19,6 +19,7 @@ package org.apache.kyuubi.service.authentication import java.io.IOException import java.security.Provider +import java.util import javax.security.auth.callback.{Callback, CallbackHandler, NameCallback, PasswordCallback, UnsupportedCallbackException} import javax.security.sasl.{AuthorizeCallback, SaslException, SaslServer, SaslServerFactory} @@ -37,7 +38,7 @@ class PlainSASLServer( try { // parse the response // message = [authzid] UTF8NUL authcid UTF8NUL passwd' - val tokenList = new java.util.ArrayDeque[String] + val tokenList = new util.ArrayDeque[String] val messageToken: StringBuilder = new StringBuilder response.foreach { case 0 => @@ -109,9 +110,9 @@ object PlainSASLServer { mechanism: String, protocol: String, serverName: String, - props: java.util.Map[String, _], + props: util.Map[String, _], cbh: CallbackHandler): SaslServer = mechanism match { - case PLAIN_METHOD => + case PLAIN_METHOD if props.containsKey("org.apache.kyuubi.service.name") => try { new PlainSASLServer(cbh, AuthMethods.withName(protocol)) } catch { @@ -121,7 +122,7 @@ object PlainSASLServer { case _ => null } - override def getMechanismNames(props: java.util.Map[String, _]): Array[String] = { + override def getMechanismNames(props: util.Map[String, _]): Array[String] = { Array(PLAIN_METHOD) } } diff --git a/kyuubi-common/src/test/scala/org/apache/kyuubi/service/authentication/PlainSASLServerSuite.scala b/kyuubi-common/src/test/scala/org/apache/kyuubi/service/authentication/PlainSASLServerSuite.scala index a7f4b9535..6e1a86a51 100644 --- a/kyuubi-common/src/test/scala/org/apache/kyuubi/service/authentication/PlainSASLServerSuite.scala +++ b/kyuubi-common/src/test/scala/org/apache/kyuubi/service/authentication/PlainSASLServerSuite.scala @@ -17,6 +17,7 @@ package org.apache.kyuubi.service.authentication +import java.util import java.util.Collections import javax.security.auth.callback.{Callback, CallbackHandler} import javax.security.sasl.{AuthorizeCallback, SaslException} @@ -28,8 +29,10 @@ class PlainSASLServerSuite extends KyuubiFunSuite { test("SaslPlainServerFactory") { val plainServerFactory = new SaslPlainServerFactory() - val map = Collections.emptyMap[String, String]() - assert(plainServerFactory.getMechanismNames(map) === + val invalidProps = Collections.emptyMap[String, String]() + val props = new util.HashMap[String, String]() + props.put("org.apache.kyuubi.service.name", "TEST") + assert(plainServerFactory.getMechanismNames(props) === Array(PlainSASLServer.PLAIN_METHOD)) val ch = new CallbackHandler { override def handle(callbacks: Array[Callback]): Unit = callbacks.foreach { @@ -38,16 +41,23 @@ class PlainSASLServerSuite extends KyuubiFunSuite { } } - val s1 = plainServerFactory.createSaslServer("INVALID", "", "", map, ch) + val s1 = plainServerFactory.createSaslServer("INVALID", "", "", props, ch) assert(s1 === null) - val s2 = plainServerFactory.createSaslServer(PlainSASLServer.PLAIN_METHOD, "", "", map, ch) + val s2 = plainServerFactory.createSaslServer(PlainSASLServer.PLAIN_METHOD, "", "", props, ch) assert(s2 === null) + val s3 = plainServerFactory.createSaslServer( + PlainSASLServer.PLAIN_METHOD, + AuthMethods.NONE.toString, + "", + invalidProps, + ch) + assert(s3 === null) val server = plainServerFactory.createSaslServer( PlainSASLServer.PLAIN_METHOD, AuthMethods.NONE.toString, "KYUUBI", - map, + props, ch) assert(server.getMechanismName === PlainSASLServer.PLAIN_METHOD) assert(!server.isComplete) @@ -78,7 +88,7 @@ class PlainSASLServerSuite extends KyuubiFunSuite { "PLAIN", "NONE", "KYUUBI", - map, + props, _ => {}) val e6 = intercept[SaslException](server2.evaluateResponse(res4.map(_.toByte))) assert(e6.getMessage === "Error validating the login") diff --git a/pom.xml b/pom.xml index e3696c848..a1715243e 100644 --- a/pom.xml +++ b/pom.xml @@ -174,7 +174,7 @@ 4.13.2 3.5.2 6.13.5 - 0.5.0 + 0.6.0 kyuubi-relocated-zookeeper-34 6.0.5 2.24.3