**Why are the changes needed:** Currently looking to connect to a HiveServer2 behind an NGINX proxy that is requiring mTLS communication. pyHive seems to lack the capability to establish an mTLS connection in applications such as Airflow directly communicating to the HiveServer2 instance. The change needed is to be able to pass in the parameters for a proper mTLS ssl context to be established. I believe that creating your own ssl_context object is the quickest and cleanest way to do so, leaving the responsibility of configuring it to further implementations and users. Also cuts down on code length. **How was this patch tested:** Corresponding pytest fixtures have been added, using the mock module to see if ssl_context object was properly accessed, or if the default one created in the Connection initialization was properly configured. Was not able to run pytest fixtures specifically, was lacking JDBC driver, first time contributing to open source, happy to run tests if provided guidance. Passed a clean build and test of the entire kyuubi project in local dev environment. **Was this patch authored or co-authored using generative AI tooling** Yes, Generated-by Cursor-AI with Claude Sonnet 3.5 agent Closes #6935 from alexio215/connection-class-ssl-context-param. Closes #6908 539b29962 [Cheng Pan] Update python/pyhive/tests/test_hive.py 14c607489 [Alex Wojtowicz] Simplified testing, following pattern of other tests, need proper SSL setup with nginx to test ssl_context fully b947f2454 [Alex Wojtowicz] Added exception handling since JDBC driver will not run in python tests 11f9002bf [Alex Wojtowicz] Passing in fully configured mock object before creating connection 009c5cf24 [Alex Wojtowicz] Added back doc string documentation e3280bcd8 [Alex Wojtowicz] Python testing 529de8a12 [Alex Wojtowicz] Added ssl_context object. If no obj is provided, then it continues to use default provided parameters Lead-authored-by: Alex Wojtowicz <awojtowi@akamai.com> Co-authored-by: Cheng Pan <pan3793@gmail.com> Signed-off-by: Cheng Pan <chengpan@apache.org>
268 lines
11 KiB
Python
268 lines
11 KiB
Python
"""Hive integration tests.
|
|
|
|
These rely on having a Hive+Hadoop cluster set up with HiveServer2 running.
|
|
They also require a tables created by make_test_tables.sh.
|
|
"""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import unicode_literals
|
|
|
|
import contextlib
|
|
import datetime
|
|
import os
|
|
import socket
|
|
import subprocess
|
|
import time
|
|
import unittest
|
|
from decimal import Decimal
|
|
import ssl
|
|
|
|
import mock
|
|
import pytest
|
|
import thrift.transport.TSocket
|
|
import thrift.transport.TTransport
|
|
import thrift_sasl
|
|
from thrift.transport.TTransport import TTransportException
|
|
|
|
from TCLIService import ttypes
|
|
from pyhive import hive
|
|
from pyhive.tests.dbapi_test_case import DBAPITestCase
|
|
from pyhive.tests.dbapi_test_case import with_cursor
|
|
|
|
_HOST = 'localhost'
|
|
|
|
|
|
class TestHive(unittest.TestCase, DBAPITestCase):
|
|
__test__ = True
|
|
|
|
def connect(self):
|
|
return hive.connect(host=_HOST, port=10000, configuration={'mapred.job.tracker': 'local'})
|
|
|
|
@with_cursor
|
|
def test_description(self, cursor):
|
|
cursor.execute('SELECT * FROM one_row')
|
|
|
|
desc = [('one_row.number_of_rows', 'INT_TYPE', None, None, None, None, True)]
|
|
self.assertEqual(cursor.description, desc)
|
|
|
|
@with_cursor
|
|
def test_complex(self, cursor):
|
|
cursor.execute('SELECT * FROM one_row_complex')
|
|
self.assertEqual(cursor.description, [
|
|
('one_row_complex.boolean', 'BOOLEAN_TYPE', None, None, None, None, True),
|
|
('one_row_complex.tinyint', 'TINYINT_TYPE', None, None, None, None, True),
|
|
('one_row_complex.smallint', 'SMALLINT_TYPE', None, None, None, None, True),
|
|
('one_row_complex.int', 'INT_TYPE', None, None, None, None, True),
|
|
('one_row_complex.bigint', 'BIGINT_TYPE', None, None, None, None, True),
|
|
('one_row_complex.float', 'FLOAT_TYPE', None, None, None, None, True),
|
|
('one_row_complex.double', 'DOUBLE_TYPE', None, None, None, None, True),
|
|
('one_row_complex.string', 'STRING_TYPE', None, None, None, None, True),
|
|
('one_row_complex.timestamp', 'TIMESTAMP_TYPE', None, None, None, None, True),
|
|
('one_row_complex.binary', 'BINARY_TYPE', None, None, None, None, True),
|
|
('one_row_complex.array', 'ARRAY_TYPE', None, None, None, None, True),
|
|
('one_row_complex.map', 'MAP_TYPE', None, None, None, None, True),
|
|
('one_row_complex.struct', 'STRUCT_TYPE', None, None, None, None, True),
|
|
('one_row_complex.union', 'UNION_TYPE', None, None, None, None, True),
|
|
('one_row_complex.decimal', 'DECIMAL_TYPE', None, None, None, None, True),
|
|
])
|
|
rows = cursor.fetchall()
|
|
expected = [(
|
|
True,
|
|
127,
|
|
32767,
|
|
2147483647,
|
|
9223372036854775807,
|
|
0.5,
|
|
0.25,
|
|
'a string',
|
|
datetime.datetime(1970, 1, 1, 0, 0),
|
|
b'123',
|
|
'[1,2]',
|
|
'{1:2,3:4}',
|
|
'{"a":1,"b":2}',
|
|
'{0:1}',
|
|
Decimal('0.1'),
|
|
)]
|
|
self.assertEqual(rows, expected)
|
|
# catch unicode/str
|
|
self.assertEqual(list(map(type, rows[0])), list(map(type, expected[0])))
|
|
|
|
@with_cursor
|
|
def test_async(self, cursor):
|
|
cursor.execute('SELECT * FROM one_row', async_=True)
|
|
unfinished_states = (
|
|
ttypes.TOperationState.INITIALIZED_STATE,
|
|
ttypes.TOperationState.RUNNING_STATE,
|
|
)
|
|
while cursor.poll().operationState in unfinished_states:
|
|
cursor.fetch_logs()
|
|
assert cursor.poll().operationState == ttypes.TOperationState.FINISHED_STATE
|
|
|
|
self.assertEqual(len(cursor.fetchall()), 1)
|
|
|
|
@with_cursor
|
|
def test_cancel(self, cursor):
|
|
# Need to do a JOIN to force a MR job. Without it, Hive optimizes the query to a fetch
|
|
# operator and prematurely declares the query done.
|
|
cursor.execute(
|
|
"SELECT reflect('java.lang.Thread', 'sleep', 1000L * 1000L * 1000L) "
|
|
"FROM one_row a JOIN one_row b",
|
|
async_=True
|
|
)
|
|
self.assertEqual(cursor.poll().operationState, ttypes.TOperationState.RUNNING_STATE)
|
|
cursor.cancel()
|
|
self.assertEqual(cursor.poll().operationState, ttypes.TOperationState.CANCELED_STATE)
|
|
|
|
def test_noops(self):
|
|
"""The DB-API specification requires that certain actions exist, even though they might not
|
|
be applicable."""
|
|
# Wohoo inflating coverage stats!
|
|
with contextlib.closing(self.connect()) as connection:
|
|
with contextlib.closing(connection.cursor()) as cursor:
|
|
self.assertEqual(cursor.rowcount, -1)
|
|
cursor.setinputsizes([])
|
|
cursor.setoutputsize(1, 'blah')
|
|
connection.commit()
|
|
|
|
@mock.patch('TCLIService.TCLIService.Client.OpenSession')
|
|
def test_open_failed(self, open_session):
|
|
open_session.return_value.serverProtocolVersion = \
|
|
ttypes.TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V1
|
|
self.assertRaises(hive.OperationalError, self.connect)
|
|
|
|
def test_escape(self):
|
|
# Hive thrift translates newlines into multiple rows. WTF.
|
|
bad_str = '''`~!@#$%^&*()_+-={}[]|\\;:'",./<>?\t '''
|
|
self.run_escape_case(bad_str)
|
|
|
|
@pytest.mark.skip(reason="Currently failing")
|
|
def test_newlines(self):
|
|
"""Verify that newlines are passed through correctly"""
|
|
cursor = self.connect().cursor()
|
|
orig = ' \r\n \r \n '
|
|
cursor.execute(
|
|
'SELECT %s FROM one_row',
|
|
(orig,)
|
|
)
|
|
result = cursor.fetchall()
|
|
self.assertEqual(result, [(orig,)])
|
|
|
|
@with_cursor
|
|
def test_no_result_set(self, cursor):
|
|
cursor.execute('USE default')
|
|
self.assertIsNone(cursor.description)
|
|
self.assertRaises(hive.ProgrammingError, cursor.fetchone)
|
|
|
|
@pytest.mark.skip(reason="Need a proper setup for ldap")
|
|
def test_ldap_connection(self):
|
|
rootdir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
|
orig_ldap = os.path.join(rootdir, 'scripts', 'conf', 'hive', 'hive-site-ldap.xml')
|
|
orig_none = os.path.join(rootdir, 'scripts', 'conf', 'hive', 'hive-site.xml')
|
|
des = os.path.join('/', 'etc', 'hive', 'conf', 'hive-site.xml')
|
|
try:
|
|
subprocess.check_call(['sudo', 'cp', orig_ldap, des])
|
|
_restart_hs2()
|
|
with contextlib.closing(hive.connect(
|
|
host=_HOST, username='existing', auth='LDAP', password='testpw')
|
|
) as connection:
|
|
with contextlib.closing(connection.cursor()) as cursor:
|
|
cursor.execute('SELECT * FROM one_row')
|
|
self.assertEqual(cursor.fetchall(), [(1,)])
|
|
|
|
self.assertRaisesRegexp(
|
|
TTransportException, 'Error validating the login',
|
|
lambda: hive.connect(
|
|
host=_HOST, username='existing', auth='LDAP', password='wrong')
|
|
)
|
|
|
|
finally:
|
|
subprocess.check_call(['sudo', 'cp', orig_none, des])
|
|
_restart_hs2()
|
|
|
|
def test_invalid_ldap_config(self):
|
|
"""password should be set if and only if using LDAP"""
|
|
self.assertRaisesRegexp(ValueError, 'Password.*LDAP',
|
|
lambda: hive.connect(_HOST, password=''))
|
|
self.assertRaisesRegexp(ValueError, 'Password.*LDAP',
|
|
lambda: hive.connect(_HOST, auth='LDAP'))
|
|
|
|
def test_invalid_kerberos_config(self):
|
|
"""kerberos_service_name should be set if and only if using KERBEROS"""
|
|
self.assertRaisesRegexp(ValueError, 'kerberos_service_name.*KERBEROS',
|
|
lambda: hive.connect(_HOST, kerberos_service_name=''))
|
|
self.assertRaisesRegexp(ValueError, 'kerberos_service_name.*KERBEROS',
|
|
lambda: hive.connect(_HOST, auth='KERBEROS'))
|
|
|
|
def test_invalid_transport(self):
|
|
"""transport and auth are incompatible"""
|
|
socket = thrift.transport.TSocket.TSocket('localhost', 10000)
|
|
transport = thrift.transport.TTransport.TBufferedTransport(socket)
|
|
self.assertRaisesRegexp(
|
|
ValueError, 'thrift_transport cannot be used with',
|
|
lambda: hive.connect(_HOST, thrift_transport=transport)
|
|
)
|
|
|
|
def test_custom_transport(self):
|
|
socket = thrift.transport.TSocket.TSocket('localhost', 10000)
|
|
sasl_auth = 'PLAIN'
|
|
|
|
transport = thrift_sasl.TSaslClientTransport(lambda: hive.get_installed_sasl(host='localhost', sasl_auth=sasl_auth, username='test_username', password='x'), sasl_auth, socket)
|
|
conn = hive.connect(thrift_transport=transport)
|
|
with contextlib.closing(conn):
|
|
with contextlib.closing(conn.cursor()) as cursor:
|
|
cursor.execute('SELECT * FROM one_row')
|
|
self.assertEqual(cursor.fetchall(), [(1,)])
|
|
|
|
@pytest.mark.skip(reason="Need a proper setup for custom auth")
|
|
def test_custom_connection(self):
|
|
rootdir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
|
orig_ldap = os.path.join(rootdir, 'scripts', 'conf', 'hive', 'hive-site-custom.xml')
|
|
orig_none = os.path.join(rootdir, 'scripts', 'conf', 'hive', 'hive-site.xml')
|
|
des = os.path.join('/', 'etc', 'hive', 'conf', 'hive-site.xml')
|
|
try:
|
|
subprocess.check_call(['sudo', 'cp', orig_ldap, des])
|
|
_restart_hs2()
|
|
with contextlib.closing(hive.connect(
|
|
host=_HOST, username='the-user', auth='CUSTOM', password='p4ssw0rd')
|
|
) as connection:
|
|
with contextlib.closing(connection.cursor()) as cursor:
|
|
cursor.execute('SELECT * FROM one_row')
|
|
self.assertEqual(cursor.fetchall(), [(1,)])
|
|
|
|
self.assertRaisesRegexp(
|
|
TTransportException, 'Error validating the login',
|
|
lambda: hive.connect(
|
|
host=_HOST, username='the-user', auth='CUSTOM', password='wrong')
|
|
)
|
|
|
|
finally:
|
|
subprocess.check_call(['sudo', 'cp', orig_none, des])
|
|
_restart_hs2()
|
|
|
|
@pytest.mark.skip(reason="Need a proper setup for SSL context testing")
|
|
def test_basic_ssl_context(self):
|
|
"""Test that connection works with a custom SSL context that mimics the default behavior."""
|
|
# Create an SSL context similar to what Connection creates by default
|
|
ssl_context = ssl.create_default_context()
|
|
ssl_context.check_hostname = False
|
|
ssl_context.verify_mode = ssl.CERT_NONE
|
|
|
|
# Connect using the same parameters as self.connect() but with our custom context
|
|
with contextlib.closing(hive.connect(
|
|
host=_HOST,
|
|
port=10000,
|
|
configuration={'mapred.job.tracker': 'local'},
|
|
ssl_context=ssl_context
|
|
)) as connection:
|
|
with contextlib.closing(connection.cursor()) as cursor:
|
|
# Use the same query pattern as other tests
|
|
cursor.execute('SELECT 1 FROM one_row')
|
|
self.assertEqual(cursor.fetchall(), [(1,)])
|
|
|
|
|
|
def _restart_hs2():
|
|
subprocess.check_call(['sudo', 'service', 'hive-server2', 'restart'])
|
|
with contextlib.closing(socket.socket()) as s:
|
|
while s.connect_ex(('localhost', 10000)) != 0:
|
|
time.sleep(1)
|