"""Hive integration tests. These rely on having a Hive+Hadoop cluster set up with HiveServer2 running. They also require a tables created by make_test_tables.sh. """ from __future__ import absolute_import from __future__ import unicode_literals import contextlib import datetime import os import socket import subprocess import time import unittest from decimal import Decimal import ssl import mock import pytest import thrift.transport.TSocket import thrift.transport.TTransport import thrift_sasl from thrift.transport.TTransport import TTransportException from TCLIService import ttypes from pyhive import hive from pyhive.tests.dbapi_test_case import DBAPITestCase from pyhive.tests.dbapi_test_case import with_cursor _HOST = 'localhost' class TestHive(unittest.TestCase, DBAPITestCase): __test__ = True def connect(self): return hive.connect(host=_HOST, port=10000, configuration={'mapred.job.tracker': 'local'}) @with_cursor def test_description(self, cursor): cursor.execute('SELECT * FROM one_row') desc = [('one_row.number_of_rows', 'INT_TYPE', None, None, None, None, True)] self.assertEqual(cursor.description, desc) @with_cursor def test_complex(self, cursor): cursor.execute('SELECT * FROM one_row_complex') self.assertEqual(cursor.description, [ ('one_row_complex.boolean', 'BOOLEAN_TYPE', None, None, None, None, True), ('one_row_complex.tinyint', 'TINYINT_TYPE', None, None, None, None, True), ('one_row_complex.smallint', 'SMALLINT_TYPE', None, None, None, None, True), ('one_row_complex.int', 'INT_TYPE', None, None, None, None, True), ('one_row_complex.bigint', 'BIGINT_TYPE', None, None, None, None, True), ('one_row_complex.float', 'FLOAT_TYPE', None, None, None, None, True), ('one_row_complex.double', 'DOUBLE_TYPE', None, None, None, None, True), ('one_row_complex.string', 'STRING_TYPE', None, None, None, None, True), ('one_row_complex.timestamp', 'TIMESTAMP_TYPE', None, None, None, None, True), ('one_row_complex.binary', 'BINARY_TYPE', None, None, None, None, True), ('one_row_complex.array', 'ARRAY_TYPE', None, None, None, None, True), ('one_row_complex.map', 'MAP_TYPE', None, None, None, None, True), ('one_row_complex.struct', 'STRUCT_TYPE', None, None, None, None, True), ('one_row_complex.union', 'UNION_TYPE', None, None, None, None, True), ('one_row_complex.decimal', 'DECIMAL_TYPE', None, None, None, None, True), ]) rows = cursor.fetchall() expected = [( True, 127, 32767, 2147483647, 9223372036854775807, 0.5, 0.25, 'a string', datetime.datetime(1970, 1, 1, 0, 0), b'123', '[1,2]', '{1:2,3:4}', '{"a":1,"b":2}', '{0:1}', Decimal('0.1'), )] self.assertEqual(rows, expected) # catch unicode/str self.assertEqual(list(map(type, rows[0])), list(map(type, expected[0]))) @with_cursor def test_async(self, cursor): cursor.execute('SELECT * FROM one_row', async_=True) unfinished_states = ( ttypes.TOperationState.INITIALIZED_STATE, ttypes.TOperationState.RUNNING_STATE, ) while cursor.poll().operationState in unfinished_states: cursor.fetch_logs() assert cursor.poll().operationState == ttypes.TOperationState.FINISHED_STATE self.assertEqual(len(cursor.fetchall()), 1) @with_cursor def test_cancel(self, cursor): # Need to do a JOIN to force a MR job. Without it, Hive optimizes the query to a fetch # operator and prematurely declares the query done. cursor.execute( "SELECT reflect('java.lang.Thread', 'sleep', 1000L * 1000L * 1000L) " "FROM one_row a JOIN one_row b", async_=True ) self.assertEqual(cursor.poll().operationState, ttypes.TOperationState.RUNNING_STATE) cursor.cancel() self.assertEqual(cursor.poll().operationState, ttypes.TOperationState.CANCELED_STATE) def test_noops(self): """The DB-API specification requires that certain actions exist, even though they might not be applicable.""" # Wohoo inflating coverage stats! with contextlib.closing(self.connect()) as connection: with contextlib.closing(connection.cursor()) as cursor: self.assertEqual(cursor.rowcount, -1) cursor.setinputsizes([]) cursor.setoutputsize(1, 'blah') connection.commit() @mock.patch('TCLIService.TCLIService.Client.OpenSession') def test_open_failed(self, open_session): open_session.return_value.serverProtocolVersion = \ ttypes.TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V1 self.assertRaises(hive.OperationalError, self.connect) def test_escape(self): # Hive thrift translates newlines into multiple rows. WTF. bad_str = '''`~!@#$%^&*()_+-={}[]|\\;:'",./<>?\t ''' self.run_escape_case(bad_str) @pytest.mark.skip(reason="Currently failing") def test_newlines(self): """Verify that newlines are passed through correctly""" cursor = self.connect().cursor() orig = ' \r\n \r \n ' cursor.execute( 'SELECT %s FROM one_row', (orig,) ) result = cursor.fetchall() self.assertEqual(result, [(orig,)]) @with_cursor def test_no_result_set(self, cursor): cursor.execute('USE default') self.assertIsNone(cursor.description) self.assertRaises(hive.ProgrammingError, cursor.fetchone) @pytest.mark.skip(reason="Need a proper setup for ldap") def test_ldap_connection(self): rootdir = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) orig_ldap = os.path.join(rootdir, 'scripts', 'conf', 'hive', 'hive-site-ldap.xml') orig_none = os.path.join(rootdir, 'scripts', 'conf', 'hive', 'hive-site.xml') des = os.path.join('/', 'etc', 'hive', 'conf', 'hive-site.xml') try: subprocess.check_call(['sudo', 'cp', orig_ldap, des]) _restart_hs2() with contextlib.closing(hive.connect( host=_HOST, username='existing', auth='LDAP', password='testpw') ) as connection: with contextlib.closing(connection.cursor()) as cursor: cursor.execute('SELECT * FROM one_row') self.assertEqual(cursor.fetchall(), [(1,)]) self.assertRaisesRegexp( TTransportException, 'Error validating the login', lambda: hive.connect( host=_HOST, username='existing', auth='LDAP', password='wrong') ) finally: subprocess.check_call(['sudo', 'cp', orig_none, des]) _restart_hs2() def test_invalid_ldap_config(self): """password should be set if and only if using LDAP""" self.assertRaisesRegexp(ValueError, 'Password.*LDAP', lambda: hive.connect(_HOST, password='')) self.assertRaisesRegexp(ValueError, 'Password.*LDAP', lambda: hive.connect(_HOST, auth='LDAP')) def test_invalid_kerberos_config(self): """kerberos_service_name should be set if and only if using KERBEROS""" self.assertRaisesRegexp(ValueError, 'kerberos_service_name.*KERBEROS', lambda: hive.connect(_HOST, kerberos_service_name='')) self.assertRaisesRegexp(ValueError, 'kerberos_service_name.*KERBEROS', lambda: hive.connect(_HOST, auth='KERBEROS')) def test_invalid_transport(self): """transport and auth are incompatible""" socket = thrift.transport.TSocket.TSocket('localhost', 10000) transport = thrift.transport.TTransport.TBufferedTransport(socket) self.assertRaisesRegexp( ValueError, 'thrift_transport cannot be used with', lambda: hive.connect(_HOST, thrift_transport=transport) ) def test_custom_transport(self): socket = thrift.transport.TSocket.TSocket('localhost', 10000) sasl_auth = 'PLAIN' transport = thrift_sasl.TSaslClientTransport(lambda: hive.get_installed_sasl(host='localhost', sasl_auth=sasl_auth, username='test_username', password='x'), sasl_auth, socket) conn = hive.connect(thrift_transport=transport) with contextlib.closing(conn): with contextlib.closing(conn.cursor()) as cursor: cursor.execute('SELECT * FROM one_row') self.assertEqual(cursor.fetchall(), [(1,)]) @pytest.mark.skip(reason="Need a proper setup for custom auth") def test_custom_connection(self): rootdir = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) orig_ldap = os.path.join(rootdir, 'scripts', 'conf', 'hive', 'hive-site-custom.xml') orig_none = os.path.join(rootdir, 'scripts', 'conf', 'hive', 'hive-site.xml') des = os.path.join('/', 'etc', 'hive', 'conf', 'hive-site.xml') try: subprocess.check_call(['sudo', 'cp', orig_ldap, des]) _restart_hs2() with contextlib.closing(hive.connect( host=_HOST, username='the-user', auth='CUSTOM', password='p4ssw0rd') ) as connection: with contextlib.closing(connection.cursor()) as cursor: cursor.execute('SELECT * FROM one_row') self.assertEqual(cursor.fetchall(), [(1,)]) self.assertRaisesRegexp( TTransportException, 'Error validating the login', lambda: hive.connect( host=_HOST, username='the-user', auth='CUSTOM', password='wrong') ) finally: subprocess.check_call(['sudo', 'cp', orig_none, des]) _restart_hs2() @pytest.mark.skip(reason="Need a proper setup for SSL context testing") def test_basic_ssl_context(self): """Test that connection works with a custom SSL context that mimics the default behavior.""" # Create an SSL context similar to what Connection creates by default ssl_context = ssl.create_default_context() ssl_context.check_hostname = False ssl_context.verify_mode = ssl.CERT_NONE # Connect using the same parameters as self.connect() but with our custom context with contextlib.closing(hive.connect( host=_HOST, port=10000, configuration={'mapred.job.tracker': 'local'}, ssl_context=ssl_context )) as connection: with contextlib.closing(connection.cursor()) as cursor: # Use the same query pattern as other tests cursor.execute('SELECT 1 FROM one_row') self.assertEqual(cursor.fetchall(), [(1,)]) def _restart_hs2(): subprocess.check_call(['sudo', 'service', 'hive-server2', 'restart']) with contextlib.closing(socket.socket()) as s: while s.connect_ex(('localhost', 10000)) != 0: time.sleep(1)