kyuubi/python/pyhive/tests/test_hive.py
Harry 06af125b9f
[KYUUBI #6281][PY] Enable hive test in python client
# 🔍 Description
## Issue References 🔗

This pull request enables running hive test cases in python client, however there's one trivial case not covered yet and two others require a proper container setup

## Types of changes 🔖

- [ ] Bugfix (non-breaking change which fixes an issue)
- [ ] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing functionality to change)

## Test Plan 🧪

#### Behavior Without This Pull Request ⚰️
Hive test disabled in #6343

#### Behavior With This Pull Request 🎉
Can cover hive test cases

#### Related Unit Tests
No

---

# Checklist 📝

- [x] This patch was not authored or co-authored using [Generative Tooling](https://www.apache.org/legal/generative-tooling.html)

**Be nice. Be informative.**

Closes #6381 from sudohainguyen/ci/hive.

Closes #6281

a861382b1 [Harry] [KYUUBI #6281][PY] Enable hive test in python client

Authored-by: Harry <quanghai.ng1512@gmail.com>
Signed-off-by: Cheng Pan <chengpan@apache.org>
2024-05-15 14:55:44 +08:00

247 lines
10 KiB
Python

"""Hive integration tests.
These rely on having a Hive+Hadoop cluster set up with HiveServer2 running.
They also require a tables created by make_test_tables.sh.
"""
from __future__ import absolute_import
from __future__ import unicode_literals
import contextlib
import datetime
import os
import socket
import subprocess
import time
import unittest
from decimal import Decimal
import mock
import pytest
import thrift.transport.TSocket
import thrift.transport.TTransport
import thrift_sasl
from thrift.transport.TTransport import TTransportException
from TCLIService import ttypes
from pyhive import hive
from pyhive.tests.dbapi_test_case import DBAPITestCase
from pyhive.tests.dbapi_test_case import with_cursor
_HOST = 'localhost'
class TestHive(unittest.TestCase, DBAPITestCase):
__test__ = True
def connect(self):
return hive.connect(host=_HOST, port=10000, configuration={'mapred.job.tracker': 'local'})
@with_cursor
def test_description(self, cursor):
cursor.execute('SELECT * FROM one_row')
desc = [('one_row.number_of_rows', 'INT_TYPE', None, None, None, None, True)]
self.assertEqual(cursor.description, desc)
@with_cursor
def test_complex(self, cursor):
cursor.execute('SELECT * FROM one_row_complex')
self.assertEqual(cursor.description, [
('one_row_complex.boolean', 'BOOLEAN_TYPE', None, None, None, None, True),
('one_row_complex.tinyint', 'TINYINT_TYPE', None, None, None, None, True),
('one_row_complex.smallint', 'SMALLINT_TYPE', None, None, None, None, True),
('one_row_complex.int', 'INT_TYPE', None, None, None, None, True),
('one_row_complex.bigint', 'BIGINT_TYPE', None, None, None, None, True),
('one_row_complex.float', 'FLOAT_TYPE', None, None, None, None, True),
('one_row_complex.double', 'DOUBLE_TYPE', None, None, None, None, True),
('one_row_complex.string', 'STRING_TYPE', None, None, None, None, True),
('one_row_complex.timestamp', 'TIMESTAMP_TYPE', None, None, None, None, True),
('one_row_complex.binary', 'BINARY_TYPE', None, None, None, None, True),
('one_row_complex.array', 'ARRAY_TYPE', None, None, None, None, True),
('one_row_complex.map', 'MAP_TYPE', None, None, None, None, True),
('one_row_complex.struct', 'STRUCT_TYPE', None, None, None, None, True),
('one_row_complex.union', 'UNION_TYPE', None, None, None, None, True),
('one_row_complex.decimal', 'DECIMAL_TYPE', None, None, None, None, True),
])
rows = cursor.fetchall()
expected = [(
True,
127,
32767,
2147483647,
9223372036854775807,
0.5,
0.25,
'a string',
datetime.datetime(1970, 1, 1, 0, 0),
b'123',
'[1,2]',
'{1:2,3:4}',
'{"a":1,"b":2}',
'{0:1}',
Decimal('0.1'),
)]
self.assertEqual(rows, expected)
# catch unicode/str
self.assertEqual(list(map(type, rows[0])), list(map(type, expected[0])))
@with_cursor
def test_async(self, cursor):
cursor.execute('SELECT * FROM one_row', async_=True)
unfinished_states = (
ttypes.TOperationState.INITIALIZED_STATE,
ttypes.TOperationState.RUNNING_STATE,
)
while cursor.poll().operationState in unfinished_states:
cursor.fetch_logs()
assert cursor.poll().operationState == ttypes.TOperationState.FINISHED_STATE
self.assertEqual(len(cursor.fetchall()), 1)
@with_cursor
def test_cancel(self, cursor):
# Need to do a JOIN to force a MR job. Without it, Hive optimizes the query to a fetch
# operator and prematurely declares the query done.
cursor.execute(
"SELECT reflect('java.lang.Thread', 'sleep', 1000L * 1000L * 1000L) "
"FROM one_row a JOIN one_row b",
async_=True
)
self.assertEqual(cursor.poll().operationState, ttypes.TOperationState.RUNNING_STATE)
cursor.cancel()
self.assertEqual(cursor.poll().operationState, ttypes.TOperationState.CANCELED_STATE)
def test_noops(self):
"""The DB-API specification requires that certain actions exist, even though they might not
be applicable."""
# Wohoo inflating coverage stats!
with contextlib.closing(self.connect()) as connection:
with contextlib.closing(connection.cursor()) as cursor:
self.assertEqual(cursor.rowcount, -1)
cursor.setinputsizes([])
cursor.setoutputsize(1, 'blah')
connection.commit()
@mock.patch('TCLIService.TCLIService.Client.OpenSession')
def test_open_failed(self, open_session):
open_session.return_value.serverProtocolVersion = \
ttypes.TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V1
self.assertRaises(hive.OperationalError, self.connect)
def test_escape(self):
# Hive thrift translates newlines into multiple rows. WTF.
bad_str = '''`~!@#$%^&*()_+-={}[]|\\;:'",./<>?\t '''
self.run_escape_case(bad_str)
@pytest.mark.skip(reason="Currently failing")
def test_newlines(self):
"""Verify that newlines are passed through correctly"""
cursor = self.connect().cursor()
orig = ' \r\n \r \n '
cursor.execute(
'SELECT %s FROM one_row',
(orig,)
)
result = cursor.fetchall()
self.assertEqual(result, [(orig,)])
@with_cursor
def test_no_result_set(self, cursor):
cursor.execute('USE default')
self.assertIsNone(cursor.description)
self.assertRaises(hive.ProgrammingError, cursor.fetchone)
@pytest.mark.skip(reason="Need a proper setup for ldap")
def test_ldap_connection(self):
rootdir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
orig_ldap = os.path.join(rootdir, 'scripts', 'conf', 'hive', 'hive-site-ldap.xml')
orig_none = os.path.join(rootdir, 'scripts', 'conf', 'hive', 'hive-site.xml')
des = os.path.join('/', 'etc', 'hive', 'conf', 'hive-site.xml')
try:
subprocess.check_call(['sudo', 'cp', orig_ldap, des])
_restart_hs2()
with contextlib.closing(hive.connect(
host=_HOST, username='existing', auth='LDAP', password='testpw')
) as connection:
with contextlib.closing(connection.cursor()) as cursor:
cursor.execute('SELECT * FROM one_row')
self.assertEqual(cursor.fetchall(), [(1,)])
self.assertRaisesRegexp(
TTransportException, 'Error validating the login',
lambda: hive.connect(
host=_HOST, username='existing', auth='LDAP', password='wrong')
)
finally:
subprocess.check_call(['sudo', 'cp', orig_none, des])
_restart_hs2()
def test_invalid_ldap_config(self):
"""password should be set if and only if using LDAP"""
self.assertRaisesRegexp(ValueError, 'Password.*LDAP',
lambda: hive.connect(_HOST, password=''))
self.assertRaisesRegexp(ValueError, 'Password.*LDAP',
lambda: hive.connect(_HOST, auth='LDAP'))
def test_invalid_kerberos_config(self):
"""kerberos_service_name should be set if and only if using KERBEROS"""
self.assertRaisesRegexp(ValueError, 'kerberos_service_name.*KERBEROS',
lambda: hive.connect(_HOST, kerberos_service_name=''))
self.assertRaisesRegexp(ValueError, 'kerberos_service_name.*KERBEROS',
lambda: hive.connect(_HOST, auth='KERBEROS'))
def test_invalid_transport(self):
"""transport and auth are incompatible"""
socket = thrift.transport.TSocket.TSocket('localhost', 10000)
transport = thrift.transport.TTransport.TBufferedTransport(socket)
self.assertRaisesRegexp(
ValueError, 'thrift_transport cannot be used with',
lambda: hive.connect(_HOST, thrift_transport=transport)
)
def test_custom_transport(self):
socket = thrift.transport.TSocket.TSocket('localhost', 10000)
sasl_auth = 'PLAIN'
transport = thrift_sasl.TSaslClientTransport(lambda: hive.get_installed_sasl(host='localhost', sasl_auth=sasl_auth, username='test_username', password='x'), sasl_auth, socket)
conn = hive.connect(thrift_transport=transport)
with contextlib.closing(conn):
with contextlib.closing(conn.cursor()) as cursor:
cursor.execute('SELECT * FROM one_row')
self.assertEqual(cursor.fetchall(), [(1,)])
@pytest.mark.skip(reason="Need a proper setup for custom auth")
def test_custom_connection(self):
rootdir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
orig_ldap = os.path.join(rootdir, 'scripts', 'conf', 'hive', 'hive-site-custom.xml')
orig_none = os.path.join(rootdir, 'scripts', 'conf', 'hive', 'hive-site.xml')
des = os.path.join('/', 'etc', 'hive', 'conf', 'hive-site.xml')
try:
subprocess.check_call(['sudo', 'cp', orig_ldap, des])
_restart_hs2()
with contextlib.closing(hive.connect(
host=_HOST, username='the-user', auth='CUSTOM', password='p4ssw0rd')
) as connection:
with contextlib.closing(connection.cursor()) as cursor:
cursor.execute('SELECT * FROM one_row')
self.assertEqual(cursor.fetchall(), [(1,)])
self.assertRaisesRegexp(
TTransportException, 'Error validating the login',
lambda: hive.connect(
host=_HOST, username='the-user', auth='CUSTOM', password='wrong')
)
finally:
subprocess.check_call(['sudo', 'cp', orig_none, des])
_restart_hs2()
def _restart_hs2():
subprocess.check_call(['sudo', 'service', 'hive-server2', 'restart'])
with contextlib.closing(socket.socket()) as s:
while s.connect_ex(('localhost', 10000)) != 0:
time.sleep(1)