diff --git a/src/main/scala/yaooqinn/kyuubi/spark/SparkSessionCacheManager.scala b/src/main/scala/yaooqinn/kyuubi/spark/SparkSessionCacheManager.scala index b9a1b1ed2..92247994b 100644 --- a/src/main/scala/yaooqinn/kyuubi/spark/SparkSessionCacheManager.scala +++ b/src/main/scala/yaooqinn/kyuubi/spark/SparkSessionCacheManager.scala @@ -50,7 +50,7 @@ class SparkSessionCacheManager(conf: SparkConf) extends Logging { def getAndIncrease(user: String): Option[SparkSession] = { Some(userToSparkSession.get(user)) match { case Some((ss, times)) if !ss.sparkContext.isStopped => - info(s"SparkSession for [$user] is reused for ${times.incrementAndGet()} times.") + info(s"SparkSession for [$user] is reused for ${times.incrementAndGet()} time(s) after + 1") Some(ss) case _ => info(s"SparkSession for [$user] isn't cached, will create a new one.") @@ -62,7 +62,7 @@ class SparkSessionCacheManager(conf: SparkConf) extends Logging { Some(userToSparkSession.get(user)) match { case Some((ss, times)) if !ss.sparkContext.isStopped => userLatestLogout.put(user, System.currentTimeMillis()) - info(s"SparkSession for [$user] is reused for ${times.decrementAndGet()} times.") + info(s"SparkSession for [$user] is reused for ${times.decrementAndGet()} time(s) after -1") case _ => warn(s"SparkSession for [$user] was not found in the cache.") } diff --git a/src/main/scala/yaooqinn/kyuubi/spark/SparkSessionWithUGI.scala b/src/main/scala/yaooqinn/kyuubi/spark/SparkSessionWithUGI.scala index d53b37f1f..87c0502bf 100644 --- a/src/main/scala/yaooqinn/kyuubi/spark/SparkSessionWithUGI.scala +++ b/src/main/scala/yaooqinn/kyuubi/spark/SparkSessionWithUGI.scala @@ -22,7 +22,7 @@ import java.security.PrivilegedExceptionAction import java.util.concurrent.TimeUnit import scala.collection.mutable.{HashSet => MHSet} -import scala.concurrent.{Await, Promise, TimeoutException} +import scala.concurrent.{Await, Promise} import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.duration.Duration import scala.util.control.NonFatal @@ -43,6 +43,7 @@ class SparkSessionWithUGI(user: UserGroupInformation, conf: SparkConf) extends L def sparkSession: SparkSession = _sparkSession private[this] val promisedSparkContext = Promise[SparkContext]() private[this] var initialDatabase: Option[String] = None + private[this] var sparkException: Option[Throwable] = None private[this] def newContext(): Thread = { new Thread(s"Start-SparkContext-$userName") { @@ -52,7 +53,7 @@ class SparkSessionWithUGI(user: UserGroupInformation, conf: SparkConf) extends L new SparkContext(conf) } } catch { - case NonFatal(e) => throw e + case NonFatal(e) => sparkException = Some(e) } } } @@ -64,13 +65,18 @@ class SparkSessionWithUGI(user: UserGroupInformation, conf: SparkConf) extends L private[this] def stopContext(): Unit = { promisedSparkContext.future.map { sc => warn(s"Error occurred during initializing SparkContext for $userName, stopping") - sc.stop - System.setProperty("SPARK_YARN_MODE", "true") + try { + sc.stop + } catch { + case NonFatal(e) => error(s"Error Stopping $userName's SparkContext", e) + } finally { + System.setProperty("SPARK_YARN_MODE", "true") + } } } /** - * Setting configuration from connection strings before SparkConext init. + * Setting configuration from connection strings before SparkContext init. * @param sessionConf configurations for user connection string */ private[this] def configureSparkConf(sessionConf: Map[String, String]): Unit = { @@ -113,7 +119,8 @@ class SparkSessionWithUGI(user: UserGroupInformation, conf: SparkConf) extends L } private[this] def getOrCreate(sessionConf: Map[String, String]): Unit = synchronized { - var checkRound = math.max(conf.get(BACKEND_SESSION_WAIT_OTHER_TIMES.key).toInt, 15) + val totalRounds = math.max(conf.get(BACKEND_SESSION_WAIT_OTHER_TIMES.key).toInt, 15) + var checkRound = totalRounds val interval = conf.getTimeAsMs(BACKEND_SESSION_WAIT_OTHER_INTERVAL.key) // if user's sc is being constructed by another while (SparkSessionWithUGI.isPartiallyConstructed(userName)) { @@ -121,7 +128,7 @@ class SparkSessionWithUGI(user: UserGroupInformation, conf: SparkConf) extends L checkRound -= 1 if (checkRound <= 0) { throw new KyuubiSQLException(s"A partially constructed SparkContext for [$userName] " + - s"has last more than ${checkRound * interval} seconds") + s"has last more than ${totalRounds * interval / 1000} seconds") } info(s"A partially constructed SparkContext for [$userName], $checkRound times countdown.") } @@ -158,15 +165,11 @@ class SparkSessionWithUGI(user: UserGroupInformation, conf: SparkConf) extends L SparkSessionCacheManager.get.set(userName, _sparkSession) } catch { case ute: UndeclaredThrowableException => - ute.getCause match { - case te: TimeoutException => - stopContext() - throw new KyuubiSQLException( - s"Get SparkSession for [$userName] failed: " + te, "08S01", 1001, te) - case _ => - stopContext() - throw new KyuubiSQLException(ute.toString, "08S01", ute.getCause) - } + stopContext() + val ke = new KyuubiSQLException( + s"Get SparkSession for [$userName] failed: " + ute.getCause, "08S01", 1001, ute.getCause) + sparkException.foreach(ke.addSuppressed) + throw ke case e: Exception => stopContext() throw new KyuubiSQLException( @@ -187,16 +190,17 @@ class SparkSessionWithUGI(user: UserGroupInformation, conf: SparkConf) extends L @throws[KyuubiSQLException] def init(sessionConf: Map[String, String]): Unit = { + getOrCreate(sessionConf) try { - getOrCreate(sessionConf) initialDatabase.foreach { db => user.doAs(new PrivilegedExceptionAction[Unit] { override def run(): Unit = _sparkSession.sql(db) }) } } catch { - case ute: UndeclaredThrowableException => throw ute.getCause - case e: Exception => throw e + case ute: UndeclaredThrowableException => + SparkSessionCacheManager.get.decrease(userName) + throw ute.getCause } } } diff --git a/src/test/scala/yaooqinn/kyuubi/operation/KyuubiOperationSuite.scala b/src/test/scala/yaooqinn/kyuubi/operation/KyuubiOperationSuite.scala index cf8c2512a..a3defd608 100644 --- a/src/test/scala/yaooqinn/kyuubi/operation/KyuubiOperationSuite.scala +++ b/src/test/scala/yaooqinn/kyuubi/operation/KyuubiOperationSuite.scala @@ -32,7 +32,7 @@ import yaooqinn.kyuubi.session.{KyuubiSession, SessionManager} import yaooqinn.kyuubi.spark.SparkSessionWithUGI import yaooqinn.kyuubi.utils.ReflectUtils -class KyuubiOperationSuite extends SparkFunSuite with BeforeAndAfterEach { +class KyuubiOperationSuite extends SparkFunSuite { val conf = new SparkConf(loadDefaults = true).setAppName("operation test") KyuubiServer.setupCommonConfig(conf) diff --git a/src/test/scala/yaooqinn/kyuubi/spark/SparkSessionWithUGISuite.scala b/src/test/scala/yaooqinn/kyuubi/spark/SparkSessionWithUGISuite.scala new file mode 100644 index 000000000..c30f12a5e --- /dev/null +++ b/src/test/scala/yaooqinn/kyuubi/spark/SparkSessionWithUGISuite.scala @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package yaooqinn.kyuubi.spark + +import scala.concurrent.{Promise, TimeoutException} +import scala.concurrent.ExecutionContext.Implicits.global + +import org.apache.hadoop.security.UserGroupInformation +import org.apache.spark._ +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.analysis.NoSuchDatabaseException + +import yaooqinn.kyuubi.KyuubiSQLException +import yaooqinn.kyuubi.server.KyuubiServer +import yaooqinn.kyuubi.ui.KyuubiServerMonitor +import yaooqinn.kyuubi.utils.ReflectUtils + +class SparkSessionWithUGISuite extends SparkFunSuite { + + val user = UserGroupInformation.getCurrentUser + val conf = new SparkConf(loadDefaults = true).setAppName("spark session test") + KyuubiServer.setupCommonConfig(conf) + conf.remove(KyuubiSparkUtil.CATALOG_IMPL) + conf.setMaster("local") + val userName = user.getShortUserName + var spark: SparkSession = _ + + override protected def beforeAll(): Unit = { + val sc = ReflectUtils + .newInstance(classOf[SparkContext].getName, Seq(classOf[SparkConf]), Seq(conf)) + .asInstanceOf[SparkContext] + spark = ReflectUtils.newInstance( + classOf[SparkSession].getName, + Seq(classOf[SparkContext]), + Seq(sc)).asInstanceOf[SparkSession] + SparkSessionCacheManager.startCacheManager(conf) + SparkSessionCacheManager.get.set(userName, spark) + } + + protected override def afterAll(): Unit = { + SparkSessionCacheManager.get.stop() + spark.stop() + } + + test("test init failed with sc init failing") { + assert(!spark.sparkContext.isStopped) + val confClone = conf.clone().remove(KyuubiSparkUtil.MULTIPLE_CONTEXTS) + .set(KyuubiConf.BACKEND_SESSTION_INIT_TIMEOUT.key, "3") + val userName1 = "test1" + val ru = UserGroupInformation.createRemoteUser(userName1) + val sparkSessionWithUGI = new SparkSessionWithUGI(ru, confClone) + assert(!SparkSessionWithUGI.isPartiallyConstructed(userName1)) + val e = intercept[KyuubiSQLException](sparkSessionWithUGI.init(Map.empty)) + assert(e.getCause.isInstanceOf[TimeoutException]) + val se = e.getSuppressed.head + assert(se.isInstanceOf[SparkException]) + assert(se.getMessage.startsWith("Only one SparkContext")) + assert(sparkSessionWithUGI.sparkSession === null) + assert(System.getProperty("SPARK_YARN_MODE") === null) + assert(SparkSessionCacheManager.get.getAndIncrease(userName1).isEmpty) + } + + test("test init failed with no such database") { + val sparkSessionWithUGI = new SparkSessionWithUGI(user, conf) + intercept[NoSuchDatabaseException](sparkSessionWithUGI.init(Map("use:database" -> "fakedb"))) + assert(ReflectUtils.getFieldValue(sparkSessionWithUGI, + "yaooqinn$kyuubi$spark$SparkSessionWithUGI$$initialDatabase") === Some("use fakedb")) + assert(SparkSessionCacheManager.get.getAndIncrease(userName).nonEmpty) + } + + test("test init success with empty session conf") { + val sparkSessionWithUGI = new SparkSessionWithUGI(user, conf) + sparkSessionWithUGI.init(Map.empty) + assert(sparkSessionWithUGI.sparkSession.sparkContext.sparkUser === userName) + assert(sparkSessionWithUGI.userName === userName) + } + + test("test init success with spark properties") { + val sessionConf = Map("set:hivevar:spark.foo" -> "bar") + val sparkSessionWithUGI = new SparkSessionWithUGI(user, conf) + sparkSessionWithUGI.init(sessionConf) + assert(sparkSessionWithUGI.sparkSession.conf.get("spark.foo") === "bar") + } + + test("test init success with hive/hadoop/extra properties") { + val sessionConf = Map("set:hivevar:foo" -> "bar") + val sparkSessionWithUGI = new SparkSessionWithUGI(user, conf) + sparkSessionWithUGI.init(sessionConf) + assert(sparkSessionWithUGI.sparkSession.conf.get("spark.hadoop.foo") === "bar") + } + + test("test init with new spark context") { + val userName1 = "test" + val ru = UserGroupInformation.createRemoteUser(userName1) + val sessionConf = Map("set:hivevar:spark.foo" -> "bar", "set:hivevar:foo" -> "bar") + val sparkSessionWithUGI = new SparkSessionWithUGI(ru, conf) + sparkSessionWithUGI.init(sessionConf) + assert(sparkSessionWithUGI.sparkSession.conf.get("spark.foo") === "bar") + assert(sparkSessionWithUGI.sparkSession.conf.get("spark.hadoop.foo") === "bar") + assert(!sparkSessionWithUGI.sparkSession.sparkContext.getConf.contains(KyuubiSparkUtil.KEYTAB)) + assert(KyuubiServerMonitor.getListener(userName1).nonEmpty) + sparkSessionWithUGI.sparkSession.stop() + } + + test("testSetPartiallyConstructed") { + val confClone = conf.clone().set(KyuubiConf.BACKEND_SESSION_WAIT_OTHER_TIMES.key, "3") + SparkSessionWithUGI.setPartiallyConstructed(userName) + val sparkSessionWithUGI = new SparkSessionWithUGI(user, confClone) + val e = intercept[KyuubiSQLException](sparkSessionWithUGI.init(Map.empty)) + assert(e.getMessage.startsWith("A partially constructed SparkContext for")) + assert(e.getMessage.contains(userName)) + assert(e.getMessage.contains("has last more than 15 seconds")) + assert(SparkSessionWithUGI.isPartiallyConstructed(userName)) + assert(!SparkSessionWithUGI.isPartiallyConstructed("Kent Yao")) + SparkSessionWithUGI.setFullyConstructed(userName) + } + + test("test init failed with time out exception") { + // point to an non-exist cluster manager + val confClone = conf.clone().setMaster("spark://localhost:7077") + .set(KyuubiConf.BACKEND_SESSTION_INIT_TIMEOUT.key, "3") + val userName1 = "test" + val ru = UserGroupInformation.createRemoteUser(userName1) + val sparkSessionWithUGI = new SparkSessionWithUGI(ru, confClone) + assert(!SparkSessionWithUGI.isPartiallyConstructed(userName1)) + val e = intercept[KyuubiSQLException](sparkSessionWithUGI.init(Map.empty)) + assert(e.getCause.isInstanceOf[TimeoutException]) + assert(e.getMessage.startsWith("Get SparkSession")) + } + + test("testSetFullyConstructed") { + SparkSessionWithUGI.setPartiallyConstructed("Kent") + assert(SparkSessionWithUGI.isPartiallyConstructed("Kent")) + SparkSessionWithUGI.setFullyConstructed("Kent") + assert(!SparkSessionWithUGI.isPartiallyConstructed("Kent")) + } + + test("testIsPartiallyConstructed") { + assert(!SparkSessionWithUGI.isPartiallyConstructed(userName)) + } + + test("stop sparkcontext") { + val sparkSessionWithUGI = new SparkSessionWithUGI(user, conf) + sparkSessionWithUGI.init(Map.empty) + val promise = ReflectUtils.getFieldValue(sparkSessionWithUGI, + "yaooqinn$kyuubi$spark$SparkSessionWithUGI$$promisedSparkContext") + .asInstanceOf[Promise[SparkContext]] + val future = promise.future + ReflectUtils.invokeMethod(sparkSessionWithUGI, "stopContext") + future.foreach { sc => + assert(sc.isStopped) + } + } +}