"""Presto integration tests. These rely on having a Presto+Hadoop cluster set up. They also require a tables created by make_test_tables.sh. """ from __future__ import absolute_import from __future__ import unicode_literals import contextlib import os from decimal import Decimal import requests import pytest from pyhive import exc from pyhive import presto from pyhive.tests.dbapi_test_case import DBAPITestCase from pyhive.tests.dbapi_test_case import with_cursor import mock import unittest import datetime _HOST = 'localhost' _PORT = '8080' class TestPresto(unittest.TestCase, DBAPITestCase): __test__ = True def connect(self): return presto.connect(host=_HOST, port=_PORT, source=self.id()) def test_bad_protocol(self): self.assertRaisesRegexp(ValueError, 'Protocol must be', lambda: presto.connect('localhost', protocol='nonsense').cursor()) def test_escape_args(self): escaper = presto.PrestoParamEscaper() self.assertEqual(escaper.escape_args((datetime.date(2020, 4, 17),)), ("date '2020-04-17'",)) self.assertEqual(escaper.escape_args((datetime.datetime(2020, 4, 17, 12, 0, 0, 123456),)), ("timestamp '2020-04-17 12:00:00.123'",)) @with_cursor def test_description(self, cursor): cursor.execute('SELECT 1 AS foobar FROM one_row') self.assertEqual(cursor.description, [('foobar', 'integer', None, None, None, None, True)]) self.assertIsNotNone(cursor.last_query_id) @with_cursor def test_complex(self, cursor): cursor.execute('SELECT * FROM one_row_complex') # TODO Presto drops the union field if os.environ.get('PRESTO') == '0.147': tinyint_type = 'integer' smallint_type = 'integer' float_type = 'double' else: # some later version made these map to more specific types tinyint_type = 'tinyint' smallint_type = 'smallint' float_type = 'real' self.assertEqual(cursor.description, [ ('boolean', 'boolean', None, None, None, None, True), ('tinyint', tinyint_type, None, None, None, None, True), ('smallint', smallint_type, None, None, None, None, True), ('int', 'integer', None, None, None, None, True), ('bigint', 'bigint', None, None, None, None, True), ('float', float_type, None, None, None, None, True), ('double', 'double', None, None, None, None, True), ('string', 'varchar', None, None, None, None, True), ('timestamp', 'timestamp', None, None, None, None, True), ('binary', 'varbinary', None, None, None, None, True), ('array', 'array(integer)', None, None, None, None, True), ('map', 'map(integer,integer)', None, None, None, None, True), ('struct', 'row(a integer,b integer)', None, None, None, None, True), # ('union', 'varchar', None, None, None, None, True), ('decimal', 'decimal(10,1)', None, None, None, None, True), ]) rows = cursor.fetchall() expected = [( True, 127, 32767, 2147483647, 9223372036854775807, 0.5, 0.25, 'a string', '1970-01-01 00:00:00.000', b'123', [1, 2], {"1": 2, "3": 4}, # Presto converts all keys to strings so that they're valid JSON [1, 2], # struct is returned as a list of elements # '{0:1}', Decimal('0.1'), )] self.assertEqual(rows, expected) # catch unicode/str self.assertEqual(list(map(type, rows[0])), list(map(type, expected[0]))) @with_cursor def test_cancel(self, cursor): cursor.execute( "SELECT a.a * rand(), b.a * rand()" "FROM many_rows a " "CROSS JOIN many_rows b " ) self.assertIn(cursor.poll()['stats']['state'], ( 'STARTING', 'PLANNING', 'RUNNING', 'WAITING_FOR_RESOURCES', 'QUEUED')) cursor.cancel() self.assertIsNotNone(cursor.last_query_id) self.assertIsNone(cursor.poll()) def test_noops(self): """The DB-API specification requires that certain actions exist, even though they might not be applicable.""" # Wohoo inflating coverage stats! connection = self.connect() cursor = connection.cursor() self.assertEqual(cursor.rowcount, -1) cursor.setinputsizes([]) cursor.setoutputsize(1, 'blah') self.assertIsNone(cursor.last_query_id) connection.commit() @mock.patch('requests.post') def test_non_200(self, post): cursor = self.connect().cursor() post.return_value.status_code = 404 self.assertRaises(exc.OperationalError, lambda: cursor.execute('show tables')) @with_cursor def test_poll(self, cursor): self.assertRaises(presto.ProgrammingError, cursor.poll) cursor.execute('SELECT * FROM one_row') while True: status = cursor.poll() if status is None: break self.assertIn('stats', status) def fail(*args, **kwargs): self.fail("Should not need requests.get after done polling") # pragma: no cover with mock.patch('requests.get', fail): self.assertEqual(cursor.fetchall(), [(1,)]) @with_cursor def test_set_session(self, cursor): id = None self.assertIsNone(cursor.last_query_id) cursor.execute("SET SESSION query_max_run_time = '1234m'") self.assertIsNotNone(cursor.last_query_id) id = cursor.last_query_id cursor.fetchall() self.assertEqual(id, cursor.last_query_id) cursor.execute('SHOW SESSION') self.assertIsNotNone(cursor.last_query_id) self.assertNotEqual(id, cursor.last_query_id) id = cursor.last_query_id rows = [r for r in cursor.fetchall() if r[0] == 'query_max_run_time'] self.assertEqual(len(rows), 1) session_prop = rows[0] self.assertEqual(session_prop[1], '1234m') self.assertEqual(id, cursor.last_query_id) cursor.execute('RESET SESSION query_max_run_time') self.assertIsNotNone(cursor.last_query_id) self.assertNotEqual(id, cursor.last_query_id) id = cursor.last_query_id cursor.fetchall() self.assertEqual(id, cursor.last_query_id) cursor.execute('SHOW SESSION') self.assertIsNotNone(cursor.last_query_id) self.assertNotEqual(id, cursor.last_query_id) id = cursor.last_query_id rows = [r for r in cursor.fetchall() if r[0] == 'query_max_run_time'] self.assertEqual(len(rows), 1) session_prop = rows[0] self.assertNotEqual(session_prop[1], '1234m') self.assertEqual(id, cursor.last_query_id) def test_set_session_in_constructor(self): conn = presto.connect( host=_HOST, source=self.id(), session_props={'query_max_run_time': '1234m'} ) with contextlib.closing(conn): with contextlib.closing(conn.cursor()) as cursor: cursor.execute('SHOW SESSION') rows = [r for r in cursor.fetchall() if r[0] == 'query_max_run_time'] assert len(rows) == 1 session_prop = rows[0] assert session_prop[1] == '1234m' cursor.execute('RESET SESSION query_max_run_time') cursor.fetchall() cursor.execute('SHOW SESSION') rows = [r for r in cursor.fetchall() if r[0] == 'query_max_run_time'] assert len(rows) == 1 session_prop = rows[0] assert session_prop[1] != '1234m' def test_invalid_protocol_config(self): """protocol should be https when passing password""" self.assertRaisesRegexp( ValueError, 'Protocol.*https.*password', lambda: presto.connect( host=_HOST, username='user', password='secret', protocol='http').cursor() ) def test_invalid_password_and_kwargs(self): """password and requests_kwargs authentication are incompatible""" self.assertRaisesRegexp( ValueError, 'Cannot use both', lambda: presto.connect( host=_HOST, username='user', password='secret', protocol='https', requests_kwargs={'auth': requests.auth.HTTPBasicAuth('user', 'secret')} ).cursor() ) def test_invalid_kwargs(self): """some kwargs are reserved""" self.assertRaisesRegexp( ValueError, 'Cannot override', lambda: presto.connect( host=_HOST, username='user', requests_kwargs={'url': 'test'} ).cursor() ) @pytest.mark.skip(reason='This test requires a proxy server running on localhost:9999') def test_requests_kwargs(self): connection = presto.connect( host=_HOST, port=_PORT, source=self.id(), requests_kwargs={'proxies': {'http': 'localhost:9999'}}, ) cursor = connection.cursor() self.assertRaises(requests.exceptions.ProxyError, lambda: cursor.execute('SELECT * FROM one_row')) def test_requests_session(self): with requests.Session() as session: connection = presto.connect( host=_HOST, port=_PORT, source=self.id(), requests_session=session ) cursor = connection.cursor() cursor.execute('SELECT * FROM one_row') self.assertEqual(cursor.fetchall(), [(1,)])