rename hive related class names in case of class conflicts

This commit is contained in:
Kent Yao 2018-01-15 18:58:24 +08:00
parent 3b9ce670b9
commit 66f86172a6
13 changed files with 54 additions and 532 deletions

View File

@ -19,14 +19,9 @@ package org.apache.hive.service.auth;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.UnknownHostException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import javax.net.ssl.SSLServerSocket;
import javax.security.auth.login.LoginException;
import javax.security.sasl.Sasl;
@ -40,26 +35,20 @@ import org.apache.hadoop.hive.shims.ShimLoader;
import org.apache.hadoop.hive.thrift.DBTokenStore;
import org.apache.hadoop.hive.thrift.HadoopThriftAuthBridge;
import org.apache.hadoop.hive.thrift.HadoopThriftAuthBridge.Server.ServerMode;
import org.apache.hadoop.security.SecurityUtil;
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hadoop.security.authorize.ProxyUsers;
import org.apache.hive.service.cli.HiveSQLException;
import org.apache.hive.service.cli.thrift.TCLIService;
import org.apache.thrift.TProcessorFactory;
import org.apache.thrift.transport.TSSLTransportFactory;
import org.apache.thrift.transport.TServerSocket;
import org.apache.thrift.transport.TSocket;
import org.apache.thrift.transport.TTransport;
import org.apache.thrift.transport.TTransportException;
import org.apache.thrift.transport.TTransportFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* This class helps in some aspects of authentication. It creates the proper Thrift classes for the
* given configuration as well as helps with authenticating requests.
*/
public class HiveAuthFactory {
public class KyuubiAuthFactory {
public enum AuthTypes {
NOSASL("NOSASL"),
NONE("NONE"),
@ -85,7 +74,7 @@ public class HiveAuthFactory {
public static final String HS2_PROXY_USER = "hive.server2.proxy.user";
public HiveAuthFactory(HiveConf conf) throws TTransportException {
public KyuubiAuthFactory(HiveConf conf) throws TTransportException {
this.conf = conf;
authTypeStr = conf.getVar(HiveConf.ConfVars.HIVE_SERVER2_AUTHENTICATION);
@ -134,15 +123,15 @@ public class HiveAuthFactory {
throw new LoginException(e.getMessage());
}
} else if (authTypeStr.equalsIgnoreCase(AuthTypes.NONE.getAuthName())) {
transportFactory = PlainSaslHelper.getPlainTransportFactory(authTypeStr);
transportFactory = KyuubiPlainSaslHelper.getPlainTransportFactory(authTypeStr);
} else if (authTypeStr.equalsIgnoreCase(AuthTypes.LDAP.getAuthName())) {
transportFactory = PlainSaslHelper.getPlainTransportFactory(authTypeStr);
transportFactory = KyuubiPlainSaslHelper.getPlainTransportFactory(authTypeStr);
} else if (authTypeStr.equalsIgnoreCase(AuthTypes.PAM.getAuthName())) {
transportFactory = PlainSaslHelper.getPlainTransportFactory(authTypeStr);
transportFactory = KyuubiPlainSaslHelper.getPlainTransportFactory(authTypeStr);
} else if (authTypeStr.equalsIgnoreCase(AuthTypes.NOSASL.getAuthName())) {
transportFactory = new TTransportFactory();
} else if (authTypeStr.equalsIgnoreCase(AuthTypes.CUSTOM.getAuthName())) {
transportFactory = PlainSaslHelper.getPlainTransportFactory(authTypeStr);
transportFactory = KyuubiPlainSaslHelper.getPlainTransportFactory(authTypeStr);
} else {
throw new LoginException("Unsupported authentication type " + authTypeStr);
}
@ -157,9 +146,9 @@ public class HiveAuthFactory {
*/
public TProcessorFactory getAuthProcFactory(TCLIService.Iface service) throws LoginException {
if (authTypeStr.equalsIgnoreCase(AuthTypes.KERBEROS.getAuthName())) {
return KerberosSaslHelper.getKerberosProcessorFactory(saslServer, service);
return KyuubiKerberosSaslHelper.getKerberosProcessorFactory(saslServer, service);
} else {
return PlainSaslHelper.getPlainProcessorFactory(service);
return KyuubiPlainSaslHelper.getPlainProcessorFactory(service);
}
}

View File

@ -24,14 +24,14 @@ import org.apache.thrift.TProcessor;
import org.apache.thrift.TProcessorFactory;
import org.apache.thrift.transport.TTransport;
public final class KerberosSaslHelper {
public final class KyuubiKerberosSaslHelper {
public static TProcessorFactory getKerberosProcessorFactory(Server saslServer,
TCLIService.Iface service) {
return new CLIServiceProcessorFactory(saslServer, service);
}
private KerberosSaslHelper() {
private KyuubiKerberosSaslHelper() {
throw new UnsupportedOperationException("Can't initialize class");
}

View File

@ -41,7 +41,7 @@ import org.apache.thrift.transport.TSaslServerTransport;
import org.apache.thrift.transport.TTransport;
import org.apache.thrift.transport.TTransportFactory;
public final class PlainSaslHelper {
public final class KyuubiPlainSaslHelper {
public static TProcessorFactory getPlainProcessorFactory(TCLIService.Iface service) {
return new SQLPlainProcessorFactory(service);
@ -64,13 +64,7 @@ public final class PlainSaslHelper {
return saslFactory;
}
public static TTransport getPlainTransport(String username, String password,
TTransport underlyingTransport) throws SaslException {
return new TSaslClientTransport("PLAIN", null, null, null, new HashMap<String, String>(),
new PlainCallbackHandler(username, password), underlyingTransport);
}
private PlainSaslHelper() {
private KyuubiPlainSaslHelper() {
throw new UnsupportedOperationException("Can't initialize class");
}

View File

@ -20,11 +20,6 @@ package org.apache.hive.service.cli;
import java.util.List;
import java.util.Map;
import org.apache.hive.service.auth.HiveAuthFactory;
public interface ICLIService {
void closeSession(SessionHandle sessionHandle)

View File

@ -1,460 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark
import java.io.File
import java.net.Socket
import java.util.concurrent.ConcurrentHashMap
import scala.collection.mutable
import scala.util.Properties
import com.google.common.collect.MapMaker
import org.apache.hadoop.security.UserGroupInformation
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.api.python.PythonWorkerFactory
import org.apache.spark.broadcast.BroadcastManager
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._
import org.apache.spark.memory.{MemoryManager, StaticMemoryManager, UnifiedMemoryManager}
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.network.netty.NettyBlockTransferService
import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef, RpcEnv}
import org.apache.spark.scheduler.{LiveListenerBus, OutputCommitCoordinator}
import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinatorEndpoint
import org.apache.spark.security.CryptoStreamUtils
import org.apache.spark.serializer.{JavaSerializer, Serializer, SerializerManager}
import org.apache.spark.shuffle.ShuffleManager
import org.apache.spark.storage._
import org.apache.spark.util.{RpcUtils, Utils}
/**
* :: DeveloperApi ::
* Holds all the runtime environment objects for a running Spark instance (either master or worker),
* including the serializer, RpcEnv, block manager, map output tracker, etc. Currently
* Spark code finds the SparkEnv through a global variable, so all the threads can access the same
* SparkEnv. It can be accessed by SparkEnv.get (e.g. after creating a SparkContext).
*
* NOTE: This is not intended for external use. This is exposed for Shark and may be made private
* in a future release.
*/
@DeveloperApi
class SparkEnv (
val executorId: String,
private[spark] val rpcEnv: RpcEnv,
val serializer: Serializer,
val closureSerializer: Serializer,
val serializerManager: SerializerManager,
val mapOutputTracker: MapOutputTracker,
val shuffleManager: ShuffleManager,
val broadcastManager: BroadcastManager,
val blockManager: BlockManager,
val securityManager: SecurityManager,
val metricsSystem: MetricsSystem,
val memoryManager: MemoryManager,
val outputCommitCoordinator: OutputCommitCoordinator,
val conf: SparkConf) extends Logging {
private[spark] var isStopped = false
private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]()
// A general, soft-reference map for metadata needed during HadoopRDD split computation
// (e.g., HadoopFileRDD uses this to cache JobConfs and InputFormats).
private[spark] val hadoopJobMetadata = new MapMaker().softValues().makeMap[String, Any]()
private[spark] var driverTmpDir: Option[String] = None
private[spark] def stop() {
if (!isStopped) {
isStopped = true
pythonWorkers.values.foreach(_.stop())
mapOutputTracker.stop()
shuffleManager.stop()
broadcastManager.stop()
blockManager.stop()
blockManager.master.stop()
metricsSystem.stop()
outputCommitCoordinator.stop()
rpcEnv.shutdown()
rpcEnv.awaitTermination()
// If we only stop sc, but the driver process still run as a services then we need to delete
// the tmp dir, if not, it will create too many tmp dirs.
// We only need to delete the tmp dir create by driver
driverTmpDir match {
case Some(path) =>
try {
Utils.deleteRecursively(new File(path))
} catch {
case e: Exception =>
logWarning(s"Exception while deleting Spark temp dir: $path", e)
}
case None => // We just need to delete tmp dir created by driver, so do nothing on executor
}
}
}
private[spark]
def createPythonWorker(pythonExec: String, envVars: Map[String, String]): java.net.Socket = {
synchronized {
val key = (pythonExec, envVars)
pythonWorkers.getOrElseUpdate(key, new PythonWorkerFactory(pythonExec, envVars)).create()
}
}
private[spark]
def destroyPythonWorker(pythonExec: String, envVars: Map[String, String], worker: Socket) {
synchronized {
val key = (pythonExec, envVars)
pythonWorkers.get(key).foreach(_.stopWorker(worker))
}
}
private[spark]
def releasePythonWorker(pythonExec: String, envVars: Map[String, String], worker: Socket) {
synchronized {
val key = (pythonExec, envVars)
pythonWorkers.get(key).foreach(_.releaseWorker(worker))
}
}
}
object SparkEnv extends Logging {
type USER = String
@volatile private var env = new ConcurrentHashMap[USER, SparkEnv]()
private[spark] val driverSystemName = "sparkDriver"
private[spark] val executorSystemName = "sparkExecutor"
private[this] def getCurrentUserName = UserGroupInformation.getCurrentUser.getShortUserName
def set(e: SparkEnv): Unit = {
if (e == null) {
env.remove(getCurrentUserName)
} else {
env.put(getCurrentUserName, e)
}
}
/**
* Returns the SparkEnv.
*/
def get: SparkEnv = {
env.get(getCurrentUserName)
}
/**
* Create a SparkEnv for the driver.
*/
private[spark] def createDriverEnv(
conf: SparkConf,
isLocal: Boolean,
listenerBus: LiveListenerBus,
numCores: Int,
mockOutputCommitCoordinator: Option[OutputCommitCoordinator] = None): SparkEnv = {
assert(conf.contains(DRIVER_HOST_ADDRESS),
s"${DRIVER_HOST_ADDRESS.key} is not set on the driver!")
assert(conf.contains("spark.driver.port"), "spark.driver.port is not set on the driver!")
val bindAddress = conf.get(DRIVER_BIND_ADDRESS)
val advertiseAddress = conf.get(DRIVER_HOST_ADDRESS)
val port = conf.get("spark.driver.port").toInt
val ioEncryptionKey = if (conf.get(IO_ENCRYPTION_ENABLED)) {
Some(CryptoStreamUtils.createKey(conf))
} else {
None
}
create(
conf,
SparkContext.DRIVER_IDENTIFIER,
bindAddress,
advertiseAddress,
port,
isLocal,
numCores,
ioEncryptionKey,
listenerBus = listenerBus,
mockOutputCommitCoordinator = mockOutputCommitCoordinator
)
}
/**
* Create a SparkEnv for an executor.
* In coarse-grained mode, the executor provides an RpcEnv that is already instantiated.
*/
private[spark] def createExecutorEnv(
conf: SparkConf,
executorId: String,
hostname: String,
port: Int,
numCores: Int,
ioEncryptionKey: Option[Array[Byte]],
isLocal: Boolean): SparkEnv = {
val env = create(
conf,
executorId,
hostname,
hostname,
port,
isLocal,
numCores,
ioEncryptionKey
)
SparkEnv.set(env)
env
}
/**
* Helper method to create a SparkEnv for a driver or an executor.
*/
private def create(
conf: SparkConf,
executorId: String,
bindAddress: String,
advertiseAddress: String,
port: Int,
isLocal: Boolean,
numUsableCores: Int,
ioEncryptionKey: Option[Array[Byte]],
listenerBus: LiveListenerBus = null,
mockOutputCommitCoordinator: Option[OutputCommitCoordinator] = None): SparkEnv = {
val isDriver = executorId == SparkContext.DRIVER_IDENTIFIER
// Listener bus is only used on the driver
if (isDriver) {
assert(listenerBus != null, "Attempted to create driver SparkEnv with null listener bus!")
}
val securityManager = new SecurityManager(conf, ioEncryptionKey)
ioEncryptionKey.foreach { _ =>
if (!securityManager.isSaslEncryptionEnabled()) {
logWarning("I/O encryption enabled without RPC encryption: keys will be visible on the " +
"wire.")
}
}
val systemName = if (isDriver) driverSystemName else executorSystemName
val rpcEnv = RpcEnv.create(systemName, bindAddress, advertiseAddress, port, conf,
securityManager, clientMode = !isDriver)
// Figure out which port RpcEnv actually bound to in case the original port is 0 or occupied.
// In the non-driver case, the RPC env's address may be null since it may not be listening
// for incoming connections.
if (isDriver) {
conf.set("spark.driver.port", rpcEnv.address.port.toString)
} else if (rpcEnv.address != null) {
conf.set("spark.executor.port", rpcEnv.address.port.toString)
logInfo(s"Setting spark.executor.port to: ${rpcEnv.address.port.toString}")
}
// Create an instance of the class with the given name, possibly initializing it with our conf
def instantiateClass[T](className: String): T = {
val cls = Utils.classForName(className)
// Look for a constructor taking a SparkConf and a boolean isDriver, then one taking just
// SparkConf, then one taking no arguments
try {
cls.getConstructor(classOf[SparkConf], java.lang.Boolean.TYPE)
.newInstance(conf, new java.lang.Boolean(isDriver))
.asInstanceOf[T]
} catch {
case _: NoSuchMethodException =>
try {
cls.getConstructor(classOf[SparkConf]).newInstance(conf).asInstanceOf[T]
} catch {
case _: NoSuchMethodException =>
cls.getConstructor().newInstance().asInstanceOf[T]
}
}
}
// Create an instance of the class named by the given SparkConf property, or defaultClassName
// if the property is not set, possibly initializing it with our conf
def instantiateClassFromConf[T](propertyName: String, defaultClassName: String): T = {
instantiateClass[T](conf.get(propertyName, defaultClassName))
}
val serializer = instantiateClassFromConf[Serializer](
"spark.serializer", "org.apache.spark.serializer.JavaSerializer")
logDebug(s"Using serializer: ${serializer.getClass}")
val serializerManager = new SerializerManager(serializer, conf, ioEncryptionKey)
val closureSerializer = new JavaSerializer(conf)
def registerOrLookupEndpoint(
name: String, endpointCreator: => RpcEndpoint):
RpcEndpointRef = {
if (isDriver) {
logInfo("Registering " + name)
rpcEnv.setupEndpoint(name, endpointCreator)
} else {
RpcUtils.makeDriverRef(name, conf, rpcEnv)
}
}
val broadcastManager = new BroadcastManager(isDriver, conf, securityManager)
val mapOutputTracker = if (isDriver) {
new MapOutputTrackerMaster(conf, broadcastManager, isLocal)
} else {
new MapOutputTrackerWorker(conf)
}
// Have to assign trackerEndpoint after initialization as MapOutputTrackerEndpoint
// requires the MapOutputTracker itself
mapOutputTracker.trackerEndpoint = registerOrLookupEndpoint(MapOutputTracker.ENDPOINT_NAME,
new MapOutputTrackerMasterEndpoint(
rpcEnv, mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], conf))
// Let the user specify short names for shuffle managers
val shortShuffleMgrNames = Map(
"sort" -> classOf[org.apache.spark.shuffle.sort.SortShuffleManager].getName,
"tungsten-sort" -> classOf[org.apache.spark.shuffle.sort.SortShuffleManager].getName)
val shuffleMgrName = conf.get("spark.shuffle.manager", "sort")
val shuffleMgrClass = shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase, shuffleMgrName)
val shuffleManager = instantiateClass[ShuffleManager](shuffleMgrClass)
val useLegacyMemoryManager = conf.getBoolean("spark.memory.useLegacyMode", false)
val memoryManager: MemoryManager =
if (useLegacyMemoryManager) {
new StaticMemoryManager(conf, numUsableCores)
} else {
UnifiedMemoryManager(conf, numUsableCores)
}
val blockManagerPort = if (isDriver) {
conf.get(DRIVER_BLOCK_MANAGER_PORT)
} else {
conf.get(BLOCK_MANAGER_PORT)
}
val blockTransferService =
new NettyBlockTransferService(conf, securityManager, bindAddress, advertiseAddress,
blockManagerPort, numUsableCores)
val blockManagerMaster = new BlockManagerMaster(registerOrLookupEndpoint(
BlockManagerMaster.DRIVER_ENDPOINT_NAME,
new BlockManagerMasterEndpoint(rpcEnv, isLocal, conf, listenerBus)),
conf, isDriver)
// NB: blockManager is not valid until initialize() is called later.
val blockManager = new BlockManager(executorId, rpcEnv, blockManagerMaster,
serializerManager, conf, memoryManager, mapOutputTracker, shuffleManager,
blockTransferService, securityManager, numUsableCores)
val metricsSystem = if (isDriver) {
// Don't start metrics system right now for Driver.
// We need to wait for the task scheduler to give us an app ID.
// Then we can start the metrics system.
MetricsSystem.createMetricsSystem("driver", conf, securityManager)
} else {
// We need to set the executor ID before the MetricsSystem is created because sources and
// sinks specified in the metrics configuration file will want to incorporate this executor's
// ID into the metrics they report.
conf.set("spark.executor.id", executorId)
val ms = MetricsSystem.createMetricsSystem("executor", conf, securityManager)
ms.start()
ms
}
val outputCommitCoordinator = mockOutputCommitCoordinator.getOrElse {
new OutputCommitCoordinator(conf, isDriver)
}
val outputCommitCoordinatorRef = registerOrLookupEndpoint("OutputCommitCoordinator",
new OutputCommitCoordinatorEndpoint(rpcEnv, outputCommitCoordinator))
outputCommitCoordinator.coordinatorRef = Some(outputCommitCoordinatorRef)
val envInstance = new SparkEnv(
executorId,
rpcEnv,
serializer,
closureSerializer,
serializerManager,
mapOutputTracker,
shuffleManager,
broadcastManager,
blockManager,
securityManager,
metricsSystem,
memoryManager,
outputCommitCoordinator,
conf)
// Add a reference to tmp dir created by driver, we will delete this tmp dir when stop() is
// called, and we only need to do it for driver. Because driver may run as a service, and if we
// don't delete this tmp dir when sc is stopped, then will create too many tmp dirs.
if (isDriver) {
val sparkFilesDir = Utils.createTempDir(Utils.getLocalDir(conf), "userFiles").getAbsolutePath
envInstance.driverTmpDir = Some(sparkFilesDir)
}
envInstance
}
/**
* Return a map representation of jvm information, Spark properties, system properties, and
* class paths. Map keys define the category, and map values represent the corresponding
* attributes as a sequence of KV pairs. This is used mainly for SparkListenerEnvironmentUpdate.
*/
private[spark]
def environmentDetails(
conf: SparkConf,
schedulingMode: String,
addedJars: Seq[String],
addedFiles: Seq[String]): Map[String, Seq[(String, String)]] = {
import Properties._
val jvmInformation = Seq(
("Java Version", s"$javaVersion ($javaVendor)"),
("Java Home", javaHome),
("Scala Version", versionString)
).sorted
// Spark properties
// This includes the scheduling mode whether or not it is configured (used by SparkUI)
val schedulerMode =
if (!conf.contains("spark.scheduler.mode")) {
Seq(("spark.scheduler.mode", schedulingMode))
} else {
Seq[(String, String)]()
}
val sparkProperties = (conf.getAll ++ schedulerMode).sorted
// System properties that are not java classpaths
val systemProperties = Utils.getSystemProperties.toSeq
val otherProperties = systemProperties.filter { case (k, _) =>
k != "java.class.path" && !k.startsWith("spark.")
}.sorted
// Class paths including all added jars and files
val classPathEntries = javaClassPath
.split(File.pathSeparator)
.filterNot(_.isEmpty)
.map((_, "System Classpath"))
val addedJarsAndFiles = (addedJars ++ addedFiles).map((_, "Added By User"))
val classPaths = (addedJarsAndFiles ++ classPathEntries).sorted
Map[String, Seq[(String, String)]](
"JVM Information" -> jvmInformation,
"Spark Properties" -> sparkProperties,
"System Properties" -> otherProperties,
"Classpath Entries" -> classPaths)
}
}

View File

@ -45,7 +45,7 @@ object SparkUtils {
Utils.getCurrentUserName()
}
def getSparkClassLoader(): ClassLoader = {
def getContextOrSparkClassLoader(): ClassLoader = {
Utils.getContextOrSparkClassLoader
}

View File

@ -26,7 +26,7 @@ import scala.util.{Failure, Try}
import org.apache.hadoop.hive.conf.HiveConf
import org.apache.hadoop.hive.conf.HiveConf.ConfVars
import org.apache.hadoop.hive.ql.session.SessionState
import org.apache.hive.service.auth.{HiveAuthFactory, TSetIpAddressProcessor}
import org.apache.hive.service.auth.{KyuubiAuthFactory, TSetIpAddressProcessor}
import org.apache.hive.service.cli._
import org.apache.hive.service.cli.thrift._
import org.apache.hive.service.server.ThreadFactoryWithGarbageCleanup
@ -48,7 +48,7 @@ private[kyuubi] class FrontendService private(name: String, beService: BackendSe
extends AbstractService(name) with TCLIService.Iface with Runnable with Logging {
private[this] var hiveConf: HiveConf = _
private[this] var hiveAuthFactory: HiveAuthFactory = _
private[this] var hiveAuthFactory: KyuubiAuthFactory = _
private[this] val OK_STATUS = new TStatus(TStatusCode.SUCCESS_STATUS)
@ -60,7 +60,7 @@ private[kyuubi] class FrontendService private(name: String, beService: BackendSe
private[this] var portNum = 0
private[this] var serverIPAddress: InetAddress = _
private[this] val threadPoolName = "KyuubiServer-Handler-Pool"
private[this] val threadPoolName = classOf[KyuubiServer].getSimpleName + "-Handler-Pool"
private[this] var minWorkerThreads = 0
private[this] var maxWorkerThreads = 0
private[this] var workerKeepAliveTime = 0L
@ -160,7 +160,7 @@ private[kyuubi] class FrontendService private(name: String, beService: BackendSe
private[this] def isKerberosAuthMode = {
hiveConf.getVar(ConfVars.HIVE_SERVER2_AUTHENTICATION)
.equalsIgnoreCase(HiveAuthFactory.AuthTypes.KERBEROS.toString)
.equalsIgnoreCase(KyuubiAuthFactory.AuthTypes.KERBEROS.toString)
}
private[this] def getUserName(req: TOpenSessionReq) = {
@ -191,8 +191,8 @@ private[kyuubi] class FrontendService private(name: String, beService: BackendSe
@throws[HiveSQLException]
private[this] def getProxyUser(sessionConf: JMap[String, String], ipAddress: String): String = {
var proxyUser: String = null
if (sessionConf != null && sessionConf.containsKey(HiveAuthFactory.HS2_PROXY_USER)) {
proxyUser = sessionConf.get(HiveAuthFactory.HS2_PROXY_USER)
if (sessionConf != null && sessionConf.containsKey(KyuubiAuthFactory.HS2_PROXY_USER)) {
proxyUser = sessionConf.get(KyuubiAuthFactory.HS2_PROXY_USER)
}
if (proxyUser == null) {
return realUser
@ -202,12 +202,12 @@ private[kyuubi] class FrontendService private(name: String, beService: BackendSe
throw new HiveSQLException("Proxy user substitution is not allowed")
}
// If there's no authentication, then directly substitute the user
if (HiveAuthFactory.AuthTypes.NONE.toString
if (KyuubiAuthFactory.AuthTypes.NONE.toString
.equalsIgnoreCase(hiveConf.getVar(ConfVars.HIVE_SERVER2_AUTHENTICATION))) {
return proxyUser
}
// Verify proxy user privilege of the realUser for the proxyUser
HiveAuthFactory.verifyProxyAccess(realUser, proxyUser, ipAddress, hiveConf)
KyuubiAuthFactory.verifyProxyAccess(realUser, proxyUser, ipAddress, hiveConf)
proxyUser
}
@ -549,10 +549,10 @@ private[kyuubi] class FrontendService private(name: String, beService: BackendSe
new ThreadFactoryWithGarbageCleanup(threadPoolName))
// Thrift configs
hiveAuthFactory = new HiveAuthFactory(hiveConf)
hiveAuthFactory = new KyuubiAuthFactory(hiveConf)
val transportFactory = hiveAuthFactory.getAuthTransFactory
val processorFactory = hiveAuthFactory.getAuthProcFactory(this)
val serverSocket: TServerSocket = HiveAuthFactory.getServerSocket(serverHost, portNum)
val serverSocket: TServerSocket = KyuubiAuthFactory.getServerSocket(serverHost, portNum)
val sslVersionBlacklist = new JList[String]
for (sslVersion <- hiveConf.getVar(ConfVars.HIVE_SSL_PROTOCOL_BLACKLIST).split(",")) {
sslVersionBlacklist.add(sslVersion)

View File

@ -100,10 +100,10 @@ object KyuubiServer extends Logging {
* @param conf the default [[SparkConf]]
*/
private[this] def setupCommonConfig(conf: SparkConf): Unit = {
if (!conf.getBoolean("spark.driver.userClassPathFirst", defaultValue = false)) {
error("SET spark.driver.userClassPathFirst to true")
System.exit(-1)
}
// if (!conf.getBoolean("spark.driver.userClassPathFirst", defaultValue = false)) {
// error("SET spark.driver.userClassPathFirst to true")
// System.exit(-1)
// }
// overwrite later for each SparkC
conf.set("spark.app.name", classOf[KyuubiServer].getSimpleName)
// avoid max port retries reached

View File

@ -107,8 +107,10 @@ private[kyuubi] class KyuubiSession(
sessionUGI.doAs(new PrivilegedExceptionAction[Unit] {
override def run(): Unit = {
val sc = new SparkContext(conf)
_sparkSession = ReflectUtils.instantiateClass(classOf[SparkSession].getName,
Seq(classOf[SparkContext]), Seq(sc)).asInstanceOf[SparkSession]
_sparkSession = ReflectUtils.instantiateClass(
classOf[SparkSession].getName,
Seq(classOf[SparkContext]),
Seq(sc)).asInstanceOf[SparkSession]
}
})

View File

@ -50,7 +50,7 @@ private[kyuubi] class SessionManager private(
private[this] val userToSparkSession =
new ConcurrentHashMap[String, (SparkSession, AtomicInteger)]
private[this] val userSparkContextBeingConstruct = new HashSet[String]()
private[this] var backgroundOperationPool: ThreadPoolExecutor = _
private[this] var execPool: ThreadPoolExecutor = _
private[this] var isOperationLogEnabled = false
private[this] var operationLogRootDir: File = _
private[this] var checkInterval: Long = _
@ -67,12 +67,12 @@ private[kyuubi] class SessionManager private(
if (conf.get(KYUUBI_LOGGING_OPERATION_ENABLED.key).toBoolean) {
initOperationLogRootDir()
}
createBackgroundOperationPool()
createExecPool()
addService(operationManager)
super.init(conf)
}
private[this] def createBackgroundOperationPool(): Unit = {
private[this] def createExecPool(): Unit = {
val poolSize = conf.get(KYUUBI_ASYNC_EXEC_THREADS.key).toInt
info("Background operation thread pool size: " + poolSize)
val poolQueueSize = conf.get(KYUUBI_ASYNC_EXEC_WAIT_QUEUE_SIZE.key).toInt
@ -80,7 +80,7 @@ private[kyuubi] class SessionManager private(
val keepAliveTime = conf.getTimeAsSeconds(KYUUBI_EXEC_KEEPALIVE_TIME.key)
info("Background operation thread keepalive time: " + keepAliveTime + " seconds")
val threadPoolName = classOf[KyuubiServer].getSimpleName + "-Background-Pool"
backgroundOperationPool =
execPool =
new ThreadPoolExecutor(
poolSize,
poolSize,
@ -88,7 +88,7 @@ private[kyuubi] class SessionManager private(
TimeUnit.SECONDS,
new LinkedBlockingQueue[Runnable](poolQueueSize),
new ThreadFactoryWithGarbageCleanup(threadPoolName))
backgroundOperationPool.allowCoreThreadTimeOut(true)
execPool.allowCoreThreadTimeOut(true)
checkInterval = conf.getTimeAsMs(KYUUBI_SESSION_CHECK_INTERVAL.key)
sessionTimeout = conf.getTimeAsMs(KYUUBI_IDLE_SESSION_TIMEOUT.key)
checkOperation = conf.get(KYUUBI_IDLE_SESSION_CHECK_OPERATION.key).toBoolean
@ -164,7 +164,7 @@ private[kyuubi] class SessionManager private(
}
}
}
backgroundOperationPool.execute(timeoutChecker)
execPool.execute(timeoutChecker)
}
/**
@ -191,7 +191,7 @@ private[kyuubi] class SessionManager private(
}
}
}
backgroundOperationPool.execute(sessionCleaner)
execPool.execute(sessionCleaner)
}
private[this] def sleepInterval(interval: Long): Unit = {
@ -275,17 +275,17 @@ private[kyuubi] class SessionManager private(
override def stop(): Unit = {
super.stop()
shutdown = true
if (backgroundOperationPool != null) {
backgroundOperationPool.shutdown()
if (execPool != null) {
execPool.shutdown()
val timeout = conf.getTimeAsSeconds(KYUUBI_ASYNC_EXEC_SHUTDOWN_TIMEOUT.key)
try {
backgroundOperationPool.awaitTermination(timeout, TimeUnit.SECONDS)
execPool.awaitTermination(timeout, TimeUnit.SECONDS)
} catch {
case e: InterruptedException =>
warn("KYUUBI_ASYNC_EXEC_SHUTDOWN_TIMEOUT = " + timeout +
" seconds has been exceeded. RUNNING background operations will be shut down", e)
}
backgroundOperationPool = null
execPool = null
}
cleanupLoggingRootDir()
userToSparkSession.asScala.values.foreach { kv => kv._1.stop() }
@ -306,7 +306,7 @@ private[kyuubi] class SessionManager private(
def getOpenSessionCount: Int = handleToSession.size
def submitBackgroundOperation(r: Runnable): Future[_] = backgroundOperationPool.submit(r)
def submitBackgroundOperation(r: Runnable): Future[_] = execPool.submit(r)
def getExistSparkSession(user: String): Option[(SparkSession, AtomicInteger)] = {
Some(userToSparkSession.get(user))

View File

@ -22,7 +22,6 @@ import scala.util.{Failure, Success, Try}
import yaooqinn.kyuubi.Logging
object ReflectUtils extends Logging {
/**
* Init a class via Reflection
* @param className class name
@ -33,15 +32,16 @@ object ReflectUtils extends Logging {
def instantiateClass(
className: String,
argTypes: Seq[Class[_]],
params: Seq[AnyRef]): Any = {
params: Seq[AnyRef],
classLoader: ClassLoader = Thread.currentThread().getContextClassLoader): Any = {
require(className != null, "class name could not be null!")
try {
if (argTypes!= null && argTypes.nonEmpty) {
require(argTypes.length == params.length, "each params should have a class type!")
this.getClass.getClassLoader.loadClass(className).getConstructor(argTypes: _*)
classLoader.loadClass(className).getConstructor(argTypes: _*)
.newInstance(params: _*)
} else {
this.getClass.getClassLoader.loadClass(className).getConstructor().newInstance()
classLoader.loadClass(className).getConstructor().newInstance()
}
} catch {
case e: Exception => throw e
@ -53,8 +53,10 @@ object ReflectUtils extends Logging {
* @param className class name
* @return
*/
def instantiateClass(className: String): Any = {
instantiateClass(className, Seq.empty, Seq.empty)
def instantiateClassByName(
className: String,
classLoader: ClassLoader = Thread.currentThread().getContextClassLoader): Any = {
instantiateClass(className, Seq.empty, Seq.empty, classLoader)
}
/**

View File

@ -35,7 +35,7 @@ class SparkContextReflectionSuite extends SparkFunSuite {
test("SparkContext initialization with this()") {
intercept[InvocationTargetException](ReflectUtils
.instantiateClass(classOf[SparkContext].getName)
.instantiateClassByName(classOf[SparkContext].getName)
.asInstanceOf[SparkContext])
}

View File

@ -29,7 +29,7 @@ class ReflectUtilsSuite extends SparkFunSuite {
test("reflect utils init class without param") {
try {
val testClassInstance =
ReflectUtils.instantiateClass(classOf[TestClass0].getName)
ReflectUtils.instantiateClassByName(classOf[TestClass0].getName)
assert(testClassInstance.asInstanceOf[TestClass0].isInstanceOf[TestClass0])
} catch {
case e: Exception => throw e
@ -62,7 +62,7 @@ class ReflectUtilsSuite extends SparkFunSuite {
test("reflect utils fail init class not exist ") {
intercept[ClassNotFoundException](
ReflectUtils.instantiateClass("yaooqinn.kyuubi.NonExistTestClass"))
ReflectUtils.instantiateClassByName("yaooqinn.kyuubi.NonExistTestClass"))
}
test("find class by name") {