kyuubi/python/pyhive/hive.py
John Zhang c19d923b85
[KYUUBI #7048] Fix KeyError when parsing unknown Hive type_id in schema inspection
This patch adds try/except block to prevent `KeyError` when mapping unknown `type_id` in Hive schema parsing. Now, if a `type_id` is not recognized, `type_code` is set to `None` instead of raising an exception.

### Why are the changes needed?

Previously, when parsing Hive table schemas, the code attempts to map each `type_id` to a human-readable type name via `ttypes.TTypeId._VALUES_TO_NAMES[type_id]`. If Hive introduced an unknown or custom type (e.g. some might using an non-standard version or data pumping from a totally different data source like *Oracle* into *Hive* databases), a `KeyError` was raised, interrupting the entire SQL query process. This patch adds a `try/except` block so that unrecognized `type_id`s will set `type_code` to `None` instead of raising an error so that the downstream user can decided what to do instead of just an Exception. This makes schema inspection more robust and compatible with evolving Hive data types.

### How was this patch tested?

The patch was tested by running schema inspection on tables containing both standard and unknown/custom Hive column types. For known types, parsing behaves as before. For unknown types, the parser sets `type_code` to `None` without raising an exception, and the rest of the process completes successfully. No unit test was added since this is an edge case dependent on unreachable or custom Hive types, but was tested on typical use cases.

### Was this patch authored or co-authored using generative AI tooling?

No. 😂 It's a minor patch.

Closes #7048 from ZsgsDesign/patch-1.

Closes #7048

4d246d0ec [John Zhang] fix: handle KeyError when parsing Hive type_id mapping

Authored-by: John Zhang <zsgsdesign@gmail.com>
Signed-off-by: Kent Yao <yao@apache.org>
2025-04-29 10:41:16 +08:00

617 lines
23 KiB
Python

"""DB-API implementation backed by HiveServer2 (Thrift API)
See http://www.python.org/dev/peps/pep-0249/
Many docstrings in this file are based on the PEP, which is in the public domain.
"""
from __future__ import absolute_import
from __future__ import unicode_literals
import base64
import datetime
import re
from decimal import Decimal
from ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED, create_default_context
from TCLIService import TCLIService
from TCLIService import constants
from TCLIService import ttypes
from pyhive import common
from pyhive.common import DBAPITypeObject
# Make all exceptions visible in this module per DB-API
from pyhive.exc import * # noqa
from builtins import range
import contextlib
from future.utils import iteritems
import getpass
import logging
import sys
import thrift.transport.THttpClient
import thrift.protocol.TBinaryProtocol
import thrift.transport.TSocket
import thrift.transport.TTransport
# PEP 249 module globals
apilevel = '2.0'
threadsafety = 2 # Threads may share the module and connections.
paramstyle = 'pyformat' # Python extended format codes, e.g. ...WHERE name=%(name)s
_logger = logging.getLogger(__name__)
_TIMESTAMP_PATTERN = re.compile(r'(\d+-\d+-\d+ \d+:\d+:\d+(\.\d{,6})?)')
ssl_cert_parameter_map = {
"none": CERT_NONE,
"optional": CERT_OPTIONAL,
"required": CERT_REQUIRED,
}
def get_sasl_client(host, sasl_auth, service=None, username=None, password=None):
import sasl
sasl_client = sasl.Client()
sasl_client.setAttr('host', host)
if sasl_auth == 'GSSAPI':
sasl_client.setAttr('service', service)
elif sasl_auth == 'PLAIN':
sasl_client.setAttr('username', username)
sasl_client.setAttr('password', password)
else:
raise ValueError("sasl_auth only supports GSSAPI and PLAIN")
sasl_client.init()
return sasl_client
def get_pure_sasl_client(host, sasl_auth, service=None, username=None, password=None):
from pyhive.sasl_compat import PureSASLClient
if sasl_auth == 'GSSAPI':
sasl_kwargs = {'service': service}
elif sasl_auth == 'PLAIN':
sasl_kwargs = {'username': username, 'password': password}
else:
raise ValueError("sasl_auth only supports GSSAPI and PLAIN")
return PureSASLClient(host=host, **sasl_kwargs)
def get_installed_sasl(host, sasl_auth, service=None, username=None, password=None):
try:
return get_sasl_client(host=host, sasl_auth=sasl_auth, service=service, username=username, password=password)
# The sasl library is available
except ImportError:
# Fallback to pure-sasl library
return get_pure_sasl_client(host=host, sasl_auth=sasl_auth, service=service, username=username, password=password)
def _parse_timestamp(value):
if value:
match = _TIMESTAMP_PATTERN.match(value)
if match:
if match.group(2):
format = '%Y-%m-%d %H:%M:%S.%f'
# use the pattern to truncate the value
value = match.group()
else:
format = '%Y-%m-%d %H:%M:%S'
value = datetime.datetime.strptime(value, format)
else:
raise Exception(
'Cannot convert "{}" into a datetime'.format(value))
else:
value = None
return value
TYPES_CONVERTER = {"DECIMAL_TYPE": Decimal,
"TIMESTAMP_TYPE": _parse_timestamp}
class HiveParamEscaper(common.ParamEscaper):
def escape_string(self, item):
# backslashes and single quotes need to be escaped
# TODO verify against parser
# Need to decode UTF-8 because of old sqlalchemy.
# Newer SQLAlchemy checks dialect.supports_unicode_binds before encoding Unicode strings
# as byte strings. The old version always encodes Unicode as byte strings, which breaks
# string formatting here.
if isinstance(item, bytes):
item = item.decode('utf-8')
return "'{}'".format(
item
.replace('\\', '\\\\')
.replace("'", "\\'")
.replace('\r', '\\r')
.replace('\n', '\\n')
.replace('\t', '\\t')
)
_escaper = HiveParamEscaper()
def connect(*args, **kwargs):
"""Constructor for creating a connection to the database. See class :py:class:`Connection` for
arguments.
:returns: a :py:class:`Connection` object.
"""
return Connection(*args, **kwargs)
class Connection(object):
"""Wraps a Thrift session"""
def __init__(
self,
host=None,
port=None,
scheme=None,
username=None,
database='default',
auth=None,
configuration=None,
kerberos_service_name=None,
password=None,
check_hostname=None,
ssl_cert=None,
thrift_transport=None,
ssl_context=None
):
"""Connect to HiveServer2
:param host: What host HiveServer2 runs on
:param port: What port HiveServer2 runs on. Defaults to 10000.
:param auth: The value of hive.server2.authentication used by HiveServer2.
Defaults to ``NONE``.
:param configuration: A dictionary of Hive settings (functionally same as the `set` command)
:param kerberos_service_name: Use with auth='KERBEROS' only
:param password: Use with auth='LDAP' or auth='CUSTOM' only
:param thrift_transport: A ``TTransportBase`` for custom advanced usage.
Incompatible with host, port, auth, kerberos_service_name, and password.
:param ssl_context: A custom SSL context to use for HTTPS connections. If provided,
this overrides check_hostname and ssl_cert parameters.
The way to support LDAP and GSSAPI is originated from cloudera/Impyla:
https://github.com/cloudera/impyla/blob/255b07ed973d47a3395214ed92d35ec0615ebf62
/impala/_thrift_api.py#L152-L160
"""
if scheme in ("https", "http") and thrift_transport is None:
port = port or 1000
if scheme == "https":
if ssl_context is None:
ssl_context = create_default_context()
ssl_context.check_hostname = check_hostname == "true"
ssl_cert = ssl_cert or "none"
ssl_context.verify_mode = ssl_cert_parameter_map.get(ssl_cert, CERT_NONE)
thrift_transport = thrift.transport.THttpClient.THttpClient(
uri_or_host="{scheme}://{host}:{port}/cliservice/".format(
scheme=scheme, host=host, port=port
),
ssl_context=ssl_context,
)
if auth in ("BASIC", "NOSASL", "NONE", None):
# Always needs the Authorization header
self._set_authorization_header(thrift_transport, username, password)
elif auth == "KERBEROS" and kerberos_service_name:
self._set_kerberos_header(thrift_transport, kerberos_service_name, host)
else:
raise ValueError(
"Authentication is not valid use one of:"
"BASIC, NOSASL, KERBEROS, NONE"
)
host, port, auth, kerberos_service_name, password = (
None, None, None, None, None
)
username = username or getpass.getuser()
configuration = configuration or {}
if (password is not None) != (auth in ('LDAP', 'CUSTOM')):
raise ValueError("Password should be set if and only if in LDAP or CUSTOM mode; "
"Remove password or use one of those modes")
if (kerberos_service_name is not None) != (auth == 'KERBEROS'):
raise ValueError("kerberos_service_name should be set if and only if in KERBEROS mode")
if thrift_transport is not None:
has_incompatible_arg = (
host is not None
or port is not None
or auth is not None
or kerberos_service_name is not None
or password is not None
)
if has_incompatible_arg:
raise ValueError("thrift_transport cannot be used with "
"host/port/auth/kerberos_service_name/password")
if thrift_transport is not None:
self._transport = thrift_transport
else:
if port is None:
port = 10000
if auth is None:
auth = 'NONE'
socket = thrift.transport.TSocket.TSocket(host, port)
if auth == 'NOSASL':
# NOSASL corresponds to hive.server2.authentication=NOSASL in hive-site.xml
self._transport = thrift.transport.TTransport.TBufferedTransport(socket)
elif auth in ('LDAP', 'KERBEROS', 'NONE', 'CUSTOM'):
# Defer import so package dependency is optional
import thrift_sasl
if auth == 'KERBEROS':
# KERBEROS mode in hive.server2.authentication is GSSAPI in sasl library
sasl_auth = 'GSSAPI'
else:
sasl_auth = 'PLAIN'
if password is None:
# Password doesn't matter in NONE mode, just needs to be nonempty.
password = 'x'
self._transport = thrift_sasl.TSaslClientTransport(lambda: get_installed_sasl(host=host, sasl_auth=sasl_auth, service=kerberos_service_name, username=username, password=password), sasl_auth, socket)
else:
# All HS2 config options:
# https://cwiki.apache.org/confluence/display/Hive/Setting+Up+HiveServer2#SettingUpHiveServer2-Configuration
# PAM currently left to end user via thrift_transport option.
raise NotImplementedError(
"Only NONE, NOSASL, LDAP, KERBEROS, CUSTOM "
"authentication are supported, got {}".format(auth))
protocol = thrift.protocol.TBinaryProtocol.TBinaryProtocol(self._transport)
self._client = TCLIService.Client(protocol)
# oldest version that still contains features we care about
# "V6 uses binary type for binary payload (was string) and uses columnar result set"
protocol_version = ttypes.TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V6
try:
self._transport.open()
open_session_req = ttypes.TOpenSessionReq(
client_protocol=protocol_version,
configuration=configuration,
username=username,
)
response = self._client.OpenSession(open_session_req)
_check_status(response)
assert response.sessionHandle is not None, "Expected a session from OpenSession"
self._sessionHandle = response.sessionHandle
assert response.serverProtocolVersion == protocol_version, \
"Unable to handle protocol version {}".format(response.serverProtocolVersion)
with contextlib.closing(self.cursor()) as cursor:
cursor.execute('USE `{}`'.format(database))
except:
self._transport.close()
raise
@staticmethod
def _set_authorization_header(transport, username=None, password=None):
username = username or "user"
password = password or "pass"
auth_credentials = "{username}:{password}".format(
username=username, password=password
).encode("UTF-8")
auth_credentials_base64 = base64.standard_b64encode(auth_credentials).decode(
"UTF-8"
)
transport.setCustomHeaders(
{
"Authorization": "Basic {auth_credentials_base64}".format(
auth_credentials_base64=auth_credentials_base64
)
}
)
@staticmethod
def _set_kerberos_header(transport, kerberos_service_name, host):
import kerberos
__, krb_context = kerberos.authGSSClientInit(
service="{kerberos_service_name}@{host}".format(
kerberos_service_name=kerberos_service_name, host=host
)
)
kerberos.authGSSClientClean(krb_context, "")
kerberos.authGSSClientStep(krb_context, "")
auth_header = kerberos.authGSSClientResponse(krb_context)
transport.setCustomHeaders(
{
"Authorization": "Negotiate {auth_header}".format(
auth_header=auth_header
)
}
)
def __enter__(self):
"""Transport should already be opened by __init__"""
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Call close"""
self.close()
def close(self):
"""Close the underlying session and Thrift transport"""
req = ttypes.TCloseSessionReq(sessionHandle=self._sessionHandle)
response = self._client.CloseSession(req)
self._transport.close()
_check_status(response)
def commit(self):
"""Hive does not support transactions, so this does nothing."""
pass
def cursor(self, *args, **kwargs):
"""Return a new :py:class:`Cursor` object using the connection."""
return Cursor(self, *args, **kwargs)
@property
def client(self):
return self._client
@property
def sessionHandle(self):
return self._sessionHandle
def rollback(self):
raise NotSupportedError("Hive does not have transactions") # pragma: no cover
class Cursor(common.DBAPICursor):
"""These objects represent a database cursor, which is used to manage the context of a fetch
operation.
Cursors are not isolated, i.e., any changes done to the database by a cursor are immediately
visible by other cursors or connections.
"""
def __init__(self, connection, arraysize=1000):
self._operationHandle = None
super(Cursor, self).__init__()
self._arraysize = arraysize
self._connection = connection
def _reset_state(self):
"""Reset state about the previous query in preparation for running another query"""
super(Cursor, self)._reset_state()
self._description = None
if self._operationHandle is not None:
request = ttypes.TCloseOperationReq(self._operationHandle)
try:
response = self._connection.client.CloseOperation(request)
_check_status(response)
finally:
self._operationHandle = None
@property
def arraysize(self):
return self._arraysize
@arraysize.setter
def arraysize(self, value):
"""Array size cannot be None, and should be an integer"""
default_arraysize = 1000
try:
self._arraysize = int(value) or default_arraysize
except TypeError:
self._arraysize = default_arraysize
@property
def description(self):
"""This read-only attribute is a sequence of 7-item sequences.
Each of these sequences contains information describing one result column:
- name
- type_code
- display_size (None in current implementation)
- internal_size (None in current implementation)
- precision (None in current implementation)
- scale (None in current implementation)
- null_ok (always True in current implementation)
This attribute will be ``None`` for operations that do not return rows or if the cursor has
not had an operation invoked via the :py:meth:`execute` method yet.
The ``type_code`` can be interpreted by comparing it to the Type Objects specified in the
section below.
"""
if self._operationHandle is None or not self._operationHandle.hasResultSet:
return None
if self._description is None:
req = ttypes.TGetResultSetMetadataReq(self._operationHandle)
response = self._connection.client.GetResultSetMetadata(req)
_check_status(response)
columns = response.schema.columns
self._description = []
for col in columns:
primary_type_entry = col.typeDesc.types[0]
if primary_type_entry.primitiveEntry is None:
# All fancy stuff maps to string
type_code = ttypes.TTypeId._VALUES_TO_NAMES[ttypes.TTypeId.STRING_TYPE]
else:
type_id = primary_type_entry.primitiveEntry.type
try:
type_code = ttypes.TTypeId._VALUES_TO_NAMES[type_id]
except KeyError:
type_code = None
self._description.append((
col.columnName.decode('utf-8') if sys.version_info[0] == 2 else col.columnName,
type_code.decode('utf-8') if sys.version_info[0] == 2 else type_code,
None, None, None, None, True
))
return self._description
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
def close(self):
"""Close the operation handle"""
self._reset_state()
def execute(self, operation, parameters=None, **kwargs):
"""Prepare and execute a database operation (query or command).
Return values are not defined.
"""
# backward compatibility with Python < 3.7
for kw in ['async', 'async_']:
if kw in kwargs:
async_ = kwargs[kw]
break
else:
async_ = False
# Prepare statement
if parameters is None:
sql = operation
else:
sql = operation % _escaper.escape_args(parameters)
self._reset_state()
self._state = self._STATE_RUNNING
_logger.info('%s', sql)
req = ttypes.TExecuteStatementReq(self._connection.sessionHandle,
sql, runAsync=async_)
_logger.debug(req)
response = self._connection.client.ExecuteStatement(req)
_check_status(response)
self._operationHandle = response.operationHandle
def cancel(self):
req = ttypes.TCancelOperationReq(
operationHandle=self._operationHandle,
)
response = self._connection.client.CancelOperation(req)
_check_status(response)
def _fetch_more(self):
"""Send another TFetchResultsReq and update state"""
assert(self._state == self._STATE_RUNNING), "Should be running when in _fetch_more"
assert(self._operationHandle is not None), "Should have an op handle in _fetch_more"
if not self._operationHandle.hasResultSet:
raise ProgrammingError("No result set")
req = ttypes.TFetchResultsReq(
operationHandle=self._operationHandle,
orientation=ttypes.TFetchOrientation.FETCH_NEXT,
maxRows=self.arraysize,
)
response = self._connection.client.FetchResults(req)
_check_status(response)
schema = self.description
assert not response.results.rows, 'expected data in columnar format'
columns = [_unwrap_column(col, col_schema[1]) for col, col_schema in
zip(response.results.columns, schema)]
new_data = list(zip(*columns))
self._data += new_data
# response.hasMoreRows seems to always be False, so we instead check the number of rows
# https://github.com/apache/hive/blob/release-1.2.1/service/src/java/org/apache/hive/service/cli/thrift/ThriftCLIService.java#L678
# if not response.hasMoreRows:
if not new_data:
self._state = self._STATE_FINISHED
def poll(self, get_progress_update=True):
"""Poll for and return the raw status data provided by the Hive Thrift REST API.
:returns: ``ttypes.TGetOperationStatusResp``
:raises: ``ProgrammingError`` when no query has been started
.. note::
This is not a part of DB-API.
"""
if self._state == self._STATE_NONE:
raise ProgrammingError("No query yet")
req = ttypes.TGetOperationStatusReq(
operationHandle=self._operationHandle,
getProgressUpdate=get_progress_update,
)
response = self._connection.client.GetOperationStatus(req)
_check_status(response)
return response
def fetch_logs(self):
"""Retrieve the logs produced by the execution of the query.
Can be called multiple times to fetch the logs produced after the previous call.
:returns: list<str>
:raises: ``ProgrammingError`` when no query has been started
.. note::
This is not a part of DB-API.
"""
if self._state == self._STATE_NONE:
raise ProgrammingError("No query yet")
try: # Older Hive instances require logs to be retrieved using GetLog
req = ttypes.TGetLogReq(operationHandle=self._operationHandle)
logs = self._connection.client.GetLog(req).log.splitlines()
except ttypes.TApplicationException as e: # Otherwise, retrieve logs using newer method
if e.type != ttypes.TApplicationException.UNKNOWN_METHOD:
raise
logs = []
while True:
req = ttypes.TFetchResultsReq(
operationHandle=self._operationHandle,
orientation=ttypes.TFetchOrientation.FETCH_NEXT,
maxRows=self.arraysize,
fetchType=1 # 0: results, 1: logs
)
response = self._connection.client.FetchResults(req)
_check_status(response)
assert not response.results.rows, 'expected data in columnar format'
assert len(response.results.columns) == 1, response.results.columns
new_logs = _unwrap_column(response.results.columns[0])
logs += new_logs
if not new_logs:
break
return logs
#
# Type Objects and Constructors
#
for type_id in constants.PRIMITIVE_TYPES:
name = ttypes.TTypeId._VALUES_TO_NAMES[type_id]
setattr(sys.modules[__name__], name, DBAPITypeObject([name]))
#
# Private utilities
#
def _unwrap_column(col, type_=None):
"""Return a list of raw values from a TColumn instance."""
for attr, wrapper in iteritems(col.__dict__):
if wrapper is not None:
result = wrapper.values
nulls = wrapper.nulls # bit set describing what's null
assert isinstance(nulls, bytes)
for i, char in enumerate(nulls):
byte = ord(char) if sys.version_info[0] == 2 else char
for b in range(8):
if byte & (1 << b):
result[i * 8 + b] = None
converter = TYPES_CONVERTER.get(type_, None)
if converter and type_:
result = [converter(row) if row else row for row in result]
return result
raise DataError("Got empty column value {}".format(col)) # pragma: no cover
def _check_status(response):
"""Raise an OperationalError if the status is not success"""
_logger.debug(response)
if response.status.statusCode != ttypes.TStatusCode.SUCCESS_STATUS:
raise OperationalError(response)