# 🔍 Description ## Issue References 🔗 This pull request enables running hive test cases in python client, however there's one trivial case not covered yet and two others require a proper container setup ## Types of changes 🔖 - [ ] Bugfix (non-breaking change which fixes an issue) - [ ] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) ## Test Plan 🧪 #### Behavior Without This Pull Request ⚰️ Hive test disabled in #6343 #### Behavior With This Pull Request 🎉 Can cover hive test cases #### Related Unit Tests No --- # Checklist 📝 - [x] This patch was not authored or co-authored using [Generative Tooling](https://www.apache.org/legal/generative-tooling.html) **Be nice. Be informative.** Closes #6381 from sudohainguyen/ci/hive. Closes #6281 a861382b1 [Harry] [KYUUBI #6281][PY] Enable hive test in python client Authored-by: Harry <quanghai.ng1512@gmail.com> Signed-off-by: Cheng Pan <chengpan@apache.org>
247 lines
10 KiB
Python
247 lines
10 KiB
Python
"""Hive integration tests.
|
|
|
|
These rely on having a Hive+Hadoop cluster set up with HiveServer2 running.
|
|
They also require a tables created by make_test_tables.sh.
|
|
"""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import unicode_literals
|
|
|
|
import contextlib
|
|
import datetime
|
|
import os
|
|
import socket
|
|
import subprocess
|
|
import time
|
|
import unittest
|
|
from decimal import Decimal
|
|
|
|
import mock
|
|
import pytest
|
|
import thrift.transport.TSocket
|
|
import thrift.transport.TTransport
|
|
import thrift_sasl
|
|
from thrift.transport.TTransport import TTransportException
|
|
|
|
from TCLIService import ttypes
|
|
from pyhive import hive
|
|
from pyhive.tests.dbapi_test_case import DBAPITestCase
|
|
from pyhive.tests.dbapi_test_case import with_cursor
|
|
|
|
_HOST = 'localhost'
|
|
|
|
|
|
class TestHive(unittest.TestCase, DBAPITestCase):
|
|
__test__ = True
|
|
|
|
def connect(self):
|
|
return hive.connect(host=_HOST, port=10000, configuration={'mapred.job.tracker': 'local'})
|
|
|
|
@with_cursor
|
|
def test_description(self, cursor):
|
|
cursor.execute('SELECT * FROM one_row')
|
|
|
|
desc = [('one_row.number_of_rows', 'INT_TYPE', None, None, None, None, True)]
|
|
self.assertEqual(cursor.description, desc)
|
|
|
|
@with_cursor
|
|
def test_complex(self, cursor):
|
|
cursor.execute('SELECT * FROM one_row_complex')
|
|
self.assertEqual(cursor.description, [
|
|
('one_row_complex.boolean', 'BOOLEAN_TYPE', None, None, None, None, True),
|
|
('one_row_complex.tinyint', 'TINYINT_TYPE', None, None, None, None, True),
|
|
('one_row_complex.smallint', 'SMALLINT_TYPE', None, None, None, None, True),
|
|
('one_row_complex.int', 'INT_TYPE', None, None, None, None, True),
|
|
('one_row_complex.bigint', 'BIGINT_TYPE', None, None, None, None, True),
|
|
('one_row_complex.float', 'FLOAT_TYPE', None, None, None, None, True),
|
|
('one_row_complex.double', 'DOUBLE_TYPE', None, None, None, None, True),
|
|
('one_row_complex.string', 'STRING_TYPE', None, None, None, None, True),
|
|
('one_row_complex.timestamp', 'TIMESTAMP_TYPE', None, None, None, None, True),
|
|
('one_row_complex.binary', 'BINARY_TYPE', None, None, None, None, True),
|
|
('one_row_complex.array', 'ARRAY_TYPE', None, None, None, None, True),
|
|
('one_row_complex.map', 'MAP_TYPE', None, None, None, None, True),
|
|
('one_row_complex.struct', 'STRUCT_TYPE', None, None, None, None, True),
|
|
('one_row_complex.union', 'UNION_TYPE', None, None, None, None, True),
|
|
('one_row_complex.decimal', 'DECIMAL_TYPE', None, None, None, None, True),
|
|
])
|
|
rows = cursor.fetchall()
|
|
expected = [(
|
|
True,
|
|
127,
|
|
32767,
|
|
2147483647,
|
|
9223372036854775807,
|
|
0.5,
|
|
0.25,
|
|
'a string',
|
|
datetime.datetime(1970, 1, 1, 0, 0),
|
|
b'123',
|
|
'[1,2]',
|
|
'{1:2,3:4}',
|
|
'{"a":1,"b":2}',
|
|
'{0:1}',
|
|
Decimal('0.1'),
|
|
)]
|
|
self.assertEqual(rows, expected)
|
|
# catch unicode/str
|
|
self.assertEqual(list(map(type, rows[0])), list(map(type, expected[0])))
|
|
|
|
@with_cursor
|
|
def test_async(self, cursor):
|
|
cursor.execute('SELECT * FROM one_row', async_=True)
|
|
unfinished_states = (
|
|
ttypes.TOperationState.INITIALIZED_STATE,
|
|
ttypes.TOperationState.RUNNING_STATE,
|
|
)
|
|
while cursor.poll().operationState in unfinished_states:
|
|
cursor.fetch_logs()
|
|
assert cursor.poll().operationState == ttypes.TOperationState.FINISHED_STATE
|
|
|
|
self.assertEqual(len(cursor.fetchall()), 1)
|
|
|
|
@with_cursor
|
|
def test_cancel(self, cursor):
|
|
# Need to do a JOIN to force a MR job. Without it, Hive optimizes the query to a fetch
|
|
# operator and prematurely declares the query done.
|
|
cursor.execute(
|
|
"SELECT reflect('java.lang.Thread', 'sleep', 1000L * 1000L * 1000L) "
|
|
"FROM one_row a JOIN one_row b",
|
|
async_=True
|
|
)
|
|
self.assertEqual(cursor.poll().operationState, ttypes.TOperationState.RUNNING_STATE)
|
|
cursor.cancel()
|
|
self.assertEqual(cursor.poll().operationState, ttypes.TOperationState.CANCELED_STATE)
|
|
|
|
def test_noops(self):
|
|
"""The DB-API specification requires that certain actions exist, even though they might not
|
|
be applicable."""
|
|
# Wohoo inflating coverage stats!
|
|
with contextlib.closing(self.connect()) as connection:
|
|
with contextlib.closing(connection.cursor()) as cursor:
|
|
self.assertEqual(cursor.rowcount, -1)
|
|
cursor.setinputsizes([])
|
|
cursor.setoutputsize(1, 'blah')
|
|
connection.commit()
|
|
|
|
@mock.patch('TCLIService.TCLIService.Client.OpenSession')
|
|
def test_open_failed(self, open_session):
|
|
open_session.return_value.serverProtocolVersion = \
|
|
ttypes.TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V1
|
|
self.assertRaises(hive.OperationalError, self.connect)
|
|
|
|
def test_escape(self):
|
|
# Hive thrift translates newlines into multiple rows. WTF.
|
|
bad_str = '''`~!@#$%^&*()_+-={}[]|\\;:'",./<>?\t '''
|
|
self.run_escape_case(bad_str)
|
|
|
|
@pytest.mark.skip(reason="Currently failing")
|
|
def test_newlines(self):
|
|
"""Verify that newlines are passed through correctly"""
|
|
cursor = self.connect().cursor()
|
|
orig = ' \r\n \r \n '
|
|
cursor.execute(
|
|
'SELECT %s FROM one_row',
|
|
(orig,)
|
|
)
|
|
result = cursor.fetchall()
|
|
self.assertEqual(result, [(orig,)])
|
|
|
|
@with_cursor
|
|
def test_no_result_set(self, cursor):
|
|
cursor.execute('USE default')
|
|
self.assertIsNone(cursor.description)
|
|
self.assertRaises(hive.ProgrammingError, cursor.fetchone)
|
|
|
|
@pytest.mark.skip(reason="Need a proper setup for ldap")
|
|
def test_ldap_connection(self):
|
|
rootdir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
|
orig_ldap = os.path.join(rootdir, 'scripts', 'conf', 'hive', 'hive-site-ldap.xml')
|
|
orig_none = os.path.join(rootdir, 'scripts', 'conf', 'hive', 'hive-site.xml')
|
|
des = os.path.join('/', 'etc', 'hive', 'conf', 'hive-site.xml')
|
|
try:
|
|
subprocess.check_call(['sudo', 'cp', orig_ldap, des])
|
|
_restart_hs2()
|
|
with contextlib.closing(hive.connect(
|
|
host=_HOST, username='existing', auth='LDAP', password='testpw')
|
|
) as connection:
|
|
with contextlib.closing(connection.cursor()) as cursor:
|
|
cursor.execute('SELECT * FROM one_row')
|
|
self.assertEqual(cursor.fetchall(), [(1,)])
|
|
|
|
self.assertRaisesRegexp(
|
|
TTransportException, 'Error validating the login',
|
|
lambda: hive.connect(
|
|
host=_HOST, username='existing', auth='LDAP', password='wrong')
|
|
)
|
|
|
|
finally:
|
|
subprocess.check_call(['sudo', 'cp', orig_none, des])
|
|
_restart_hs2()
|
|
|
|
def test_invalid_ldap_config(self):
|
|
"""password should be set if and only if using LDAP"""
|
|
self.assertRaisesRegexp(ValueError, 'Password.*LDAP',
|
|
lambda: hive.connect(_HOST, password=''))
|
|
self.assertRaisesRegexp(ValueError, 'Password.*LDAP',
|
|
lambda: hive.connect(_HOST, auth='LDAP'))
|
|
|
|
def test_invalid_kerberos_config(self):
|
|
"""kerberos_service_name should be set if and only if using KERBEROS"""
|
|
self.assertRaisesRegexp(ValueError, 'kerberos_service_name.*KERBEROS',
|
|
lambda: hive.connect(_HOST, kerberos_service_name=''))
|
|
self.assertRaisesRegexp(ValueError, 'kerberos_service_name.*KERBEROS',
|
|
lambda: hive.connect(_HOST, auth='KERBEROS'))
|
|
|
|
def test_invalid_transport(self):
|
|
"""transport and auth are incompatible"""
|
|
socket = thrift.transport.TSocket.TSocket('localhost', 10000)
|
|
transport = thrift.transport.TTransport.TBufferedTransport(socket)
|
|
self.assertRaisesRegexp(
|
|
ValueError, 'thrift_transport cannot be used with',
|
|
lambda: hive.connect(_HOST, thrift_transport=transport)
|
|
)
|
|
|
|
def test_custom_transport(self):
|
|
socket = thrift.transport.TSocket.TSocket('localhost', 10000)
|
|
sasl_auth = 'PLAIN'
|
|
|
|
transport = thrift_sasl.TSaslClientTransport(lambda: hive.get_installed_sasl(host='localhost', sasl_auth=sasl_auth, username='test_username', password='x'), sasl_auth, socket)
|
|
conn = hive.connect(thrift_transport=transport)
|
|
with contextlib.closing(conn):
|
|
with contextlib.closing(conn.cursor()) as cursor:
|
|
cursor.execute('SELECT * FROM one_row')
|
|
self.assertEqual(cursor.fetchall(), [(1,)])
|
|
|
|
@pytest.mark.skip(reason="Need a proper setup for custom auth")
|
|
def test_custom_connection(self):
|
|
rootdir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
|
orig_ldap = os.path.join(rootdir, 'scripts', 'conf', 'hive', 'hive-site-custom.xml')
|
|
orig_none = os.path.join(rootdir, 'scripts', 'conf', 'hive', 'hive-site.xml')
|
|
des = os.path.join('/', 'etc', 'hive', 'conf', 'hive-site.xml')
|
|
try:
|
|
subprocess.check_call(['sudo', 'cp', orig_ldap, des])
|
|
_restart_hs2()
|
|
with contextlib.closing(hive.connect(
|
|
host=_HOST, username='the-user', auth='CUSTOM', password='p4ssw0rd')
|
|
) as connection:
|
|
with contextlib.closing(connection.cursor()) as cursor:
|
|
cursor.execute('SELECT * FROM one_row')
|
|
self.assertEqual(cursor.fetchall(), [(1,)])
|
|
|
|
self.assertRaisesRegexp(
|
|
TTransportException, 'Error validating the login',
|
|
lambda: hive.connect(
|
|
host=_HOST, username='the-user', auth='CUSTOM', password='wrong')
|
|
)
|
|
|
|
finally:
|
|
subprocess.check_call(['sudo', 'cp', orig_none, des])
|
|
_restart_hs2()
|
|
|
|
|
|
def _restart_hs2():
|
|
subprocess.check_call(['sudo', 'service', 'hive-server2', 'restart'])
|
|
with contextlib.closing(socket.socket()) as s:
|
|
while s.connect_ex(('localhost', 10000)) != 0:
|
|
time.sleep(1)
|