fix #107 kill yarn application fast

This commit is contained in:
Kent Yao 2018-10-26 17:38:30 +08:00
parent 8bbc3d53ac
commit d8a4509405
5 changed files with 101 additions and 32 deletions

View File

@ -22,6 +22,7 @@ import java.util.HashMap
import java.util.concurrent.TimeUnit
import scala.collection.JavaConverters._
import scala.language.implicitConversions
import org.apache.spark.internal.config.{ConfigBuilder, ConfigEntry}
@ -365,4 +366,9 @@ object KyuubiConf {
(kv.getKey, kv.getValue.defaultValueString)
}.toMap
}
implicit def convertBooleanConf(config: ConfigEntry[Boolean]): String = config.key
implicit def convertIntConf(config: ConfigEntry[Int]): String = config.key
implicit def convertLongConf(config: ConfigEntry[Long]): String = config.key
implicit def convertStringConf(config: ConfigEntry[String]): String = config.key
}

View File

@ -112,21 +112,22 @@ private[kyuubi] class FrontendService private(name: String, beService: BackendSe
}
override def init(conf: SparkConf): Unit = synchronized {
this.conf = conf
hadoopConf = SparkHadoopUtil.get.newConfiguration(conf)
val serverHost = conf.get(FRONTEND_BIND_HOST.key)
val serverHost = conf.get(FRONTEND_BIND_HOST)
try {
if (serverHost.nonEmpty) {
serverIPAddress = InetAddress.getByName(serverHost)
} else {
serverIPAddress = InetAddress.getLocalHost
}
portNum = conf.get(FRONTEND_BIND_PORT.key).toInt
portNum = conf.get(FRONTEND_BIND_PORT).toInt
serverSocket = new ServerSocket(portNum, 1, serverIPAddress)
} catch {
case e: Exception => throw new ServiceException(e.getMessage + ": " + portNum, e)
}
portNum = serverSocket.getLocalPort
// conf.set(FRONTEND_BIND_PORT, portNum.toString)
// conf.set(FRONTEND_BIND_HOST, serverIPAddress.getCanonicalHostName)
super.init(conf)
}
@ -152,7 +153,7 @@ private[kyuubi] class FrontendService private(name: String, beService: BackendSe
def getServerIPAddress: InetAddress = serverIPAddress
private[this] def isKerberosAuthMode = {
conf.get(KyuubiConf.AUTHENTICATION_METHOD.key).equalsIgnoreCase(AuthType.KERBEROS.name)
conf.get(KyuubiConf.AUTHENTICATION_METHOD).equalsIgnoreCase(AuthType.KERBEROS.name)
}
private[this] def getUserName(req: TOpenSessionReq) = {
@ -184,7 +185,7 @@ private[kyuubi] class FrontendService private(name: String, beService: BackendSe
private[this] def getProxyUser(sessionConf: Map[String, String], ipAddress: String): String = {
Option(sessionConf).flatMap(_.get(KyuubiAuthFactory.HS2_PROXY_USER)) match {
case None => realUser
case Some(_) if !conf.get(FRONTEND_ALLOW_USER_SUBSTITUTION.key).toBoolean =>
case Some(_) if !conf.get(FRONTEND_ALLOW_USER_SUBSTITUTION).toBoolean =>
throw new KyuubiSQLException("Proxy user substitution is not allowed")
case Some(p) if !isKerberosAuthMode => p
case Some(p) => // Verify proxy user privilege of the realUser for the proxyUser
@ -223,7 +224,7 @@ private[kyuubi] class FrontendService private(name: String, beService: BackendSe
val ipAddress = getIpAddress
val protocol = getMinVersion(BackendService.SERVER_VERSION, req.getClient_protocol)
val sessionHandle =
if (conf.get(FRONTEND_ENABLE_DOAS.key).toBoolean && (userName != null)) {
if (conf.get(FRONTEND_ENABLE_DOAS).toBoolean && (userName != null)) {
beService.openSessionWithImpersonation(
protocol, userName, req.getPassword, ipAddress, req.getConfiguration.asScala.toMap, null)
} else {
@ -567,12 +568,12 @@ private[kyuubi] class FrontendService private(name: String, beService: BackendSe
override def run(): Unit = {
try {
// Server thread pool
val minThreads = conf.get(FRONTEND_MIN_WORKER_THREADS.key).toInt
val maxThreads = conf.get(FRONTEND_MAX_WORKER_THREADS.key).toInt
val minThreads = conf.get(FRONTEND_MIN_WORKER_THREADS).toInt
val maxThreads = conf.get(FRONTEND_MAX_WORKER_THREADS).toInt
val executorService = new ThreadPoolExecutor(
minThreads,
maxThreads,
conf.getTimeAsSeconds(FRONTEND_WORKER_KEEPALIVE_TIME.key),
conf.getTimeAsSeconds(FRONTEND_WORKER_KEEPALIVE_TIME),
TimeUnit.SECONDS,
new SynchronousQueue[Runnable],
new NamedThreadFactory(threadPoolName))
@ -584,9 +585,9 @@ private[kyuubi] class FrontendService private(name: String, beService: BackendSe
val tSocket = new TServerSocket(serverSocket)
// Server args
val maxMessageSize = conf.get(FRONTEND_MAX_MESSAGE_SIZE.key).toInt
val requestTimeout = conf.getTimeAsSeconds(FRONTEND_LOGIN_TIMEOUT.key).toInt
val beBackoffSlotLength = conf.getTimeAsMs(FRONTEND_LOGIN_BEBACKOFF_SLOT_LENGTH.key).toInt
val maxMessageSize = conf.get(FRONTEND_MAX_MESSAGE_SIZE).toInt
val requestTimeout = conf.getTimeAsSeconds(FRONTEND_LOGIN_TIMEOUT).toInt
val beBackoffSlotLength = conf.getTimeAsMs(FRONTEND_LOGIN_BEBACKOFF_SLOT_LENGTH).toInt
val args = new TThreadPoolServer.Args(tSocket)
.processorFactory(processorFactory)
.transportFactory(transportFactory)

View File

@ -17,7 +17,6 @@
package yaooqinn.kyuubi.spark
import java.security.PrivilegedExceptionAction
import java.util.concurrent.TimeUnit
import scala.collection.mutable.{HashSet => MHSet}
@ -36,7 +35,7 @@ import org.apache.spark.ui.KyuubiServerTab
import yaooqinn.kyuubi.{KyuubiSQLException, Logging}
import yaooqinn.kyuubi.author.AuthzHelper
import yaooqinn.kyuubi.ui.{KyuubiServerListener, KyuubiServerMonitor}
import yaooqinn.kyuubi.utils.ReflectUtils
import yaooqinn.kyuubi.utils.{KyuubiHadoopUtil, ReflectUtils}
class SparkSessionWithUGI(
user: UserGroupInformation,
@ -124,9 +123,9 @@ class SparkSessionWithUGI(
}
private[this] def getOrCreate(sessionConf: Map[String, String]): Unit = synchronized {
val totalRounds = math.max(conf.get(BACKEND_SESSION_WAIT_OTHER_TIMES.key).toInt, 15)
val totalRounds = math.max(conf.get(BACKEND_SESSION_WAIT_OTHER_TIMES).toInt, 15)
var checkRound = totalRounds
val interval = conf.getTimeAsMs(BACKEND_SESSION_WAIT_OTHER_INTERVAL.key)
val interval = conf.getTimeAsMs(BACKEND_SESSION_WAIT_OTHER_INTERVAL)
// if user's sc is being constructed by another
while (SparkSessionWithUGI.isPartiallyConstructed(userName)) {
wait(interval)
@ -151,25 +150,30 @@ class SparkSessionWithUGI(
private[this] def create(sessionConf: Map[String, String]): Unit = {
info(s"--------- Create new SparkSession for $userName ----------")
val appName = s"KyuubiSession[$userName]@" + conf.get(FRONTEND_BIND_HOST.key)
// kyuubi|user name|canonical host name| port
val appName = Seq(
"kyuubi", userName, conf.get(FRONTEND_BIND_HOST), conf.get(FRONTEND_BIND_PORT)).mkString("|")
conf.setAppName(appName)
configureSparkConf(sessionConf)
val totalWaitTime: Long = conf.getTimeAsSeconds(BACKEND_SESSTION_INIT_TIMEOUT.key)
val totalWaitTime: Long = conf.getTimeAsSeconds(BACKEND_SESSTION_INIT_TIMEOUT)
try {
user.doAs(new PrivilegedExceptionAction[Unit] {
override def run(): Unit = {
newContext().start()
val context =
Await.result(promisedSparkContext.future, Duration(totalWaitTime, TimeUnit.SECONDS))
_sparkSession = ReflectUtils.newInstance(
classOf[SparkSession].getName,
Seq(classOf[SparkContext]),
Seq(context)).asInstanceOf[SparkSession]
}
})
KyuubiHadoopUtil.doAs(user) {
newContext().start()
val context =
Await.result(promisedSparkContext.future, Duration(totalWaitTime, TimeUnit.SECONDS))
_sparkSession = ReflectUtils.newInstance(
classOf[SparkSession].getName,
Seq(classOf[SparkContext]),
Seq(context)).asInstanceOf[SparkSession]
}
cache.set(userName, _sparkSession)
} catch {
case e: Exception =>
if (conf.getOption("spark.master").contains("yarn")) {
KyuubiHadoopUtil.doAs(user) {
KyuubiHadoopUtil.killYarnAppByName(appName)
}
}
stopContext()
val ke = new KyuubiSQLException(
s"Get SparkSession for [$userName] failed", "08S01", 1001, findCause(e))
@ -193,9 +197,9 @@ class SparkSessionWithUGI(
try {
initialDatabase.foreach { db =>
user.doAs(new PrivilegedExceptionAction[Unit] {
override def run(): Unit = _sparkSession.sql(db)
})
KyuubiHadoopUtil.doAs(user) {
_sparkSession.sql(db)
}
}
} catch {
case e: Exception =>

View File

@ -0,0 +1,56 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package yaooqinn.kyuubi.utils
import java.security.PrivilegedExceptionAction
import scala.collection.JavaConverters._
import org.apache.hadoop.security.UserGroupInformation
import org.apache.hadoop.yarn.api.records.ApplicationReport
import org.apache.hadoop.yarn.client.api.YarnClient
import org.apache.hadoop.yarn.conf.YarnConfiguration
private[kyuubi] object KyuubiHadoopUtil {
// YarnClient is thread safe. Create once, share it across threads.
private lazy val yarnClient = {
val c = YarnClient.createYarnClient()
c.init(new YarnConfiguration())
c.start()
c
}
def killYarnApp(report: ApplicationReport): Unit = {
yarnClient.killApplication(report.getApplicationId)
}
def getApplications: Seq[ApplicationReport] = {
yarnClient.getApplications(Set("SPARK").asJava).asScala
}
def killYarnAppByName(appName: String): Unit = {
getApplications.filter(app => app.getName.equals(appName)).foreach(killYarnApp)
}
def doAs[T](user: UserGroupInformation)(f: => T): T = {
user.doAs(new PrivilegedExceptionAction[T] {
override def run(): T = f
})
}
}

View File

@ -20,6 +20,8 @@ package yaooqinn.kyuubi.spark
import scala.concurrent.{Promise, TimeoutException}
import scala.concurrent.ExecutionContext.Implicits.global
import com.github.sakserv.minicluster.impl.YarnLocalCluster
import com.github.sakserv.minicluster.impl.YarnLocalCluster.Builder
import org.apache.hadoop.security.UserGroupInformation
import org.apache.spark._
import org.apache.spark.sql.SparkSession