# 🔍 Description ## Issue References 🔗 This pull request fixes #6485 ## Describe Your Solution 🔧 Ignore uppercase and lowercase letters in table names when using regular expressions to match. ## Types of changes 🔖 - [x] Bugfix (non-breaking change which fixes an issue) - [ ] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) ## Test Plan 🧪 Added unit tests when table names have capital letters. --- # Checklist 📝 - [x] This patch was not authored or co-authored using [Generative Tooling](https://www.apache.org/legal/generative-tooling.html) **Be nice. Be informative.** Closes #6605 from BruceWong96/fix-presto-regex. Closes #6485 06f737f24 [Bruce Wong] Fix typos 93071754a [Bruce Wong] Added unit tests for table names with both upper and lower case letters 9837030a1 [Bruce Wong] fix table not found Authored-by: Bruce Wong <603334301@qq.com> Signed-off-by: Cheng Pan <chengpan@apache.org>
224 lines
7.8 KiB
Python
224 lines
7.8 KiB
Python
"""Integration between SQLAlchemy and Presto.
|
|
|
|
Some code based on
|
|
https://github.com/zzzeek/sqlalchemy/blob/rel_0_5/lib/sqlalchemy/databases/sqlite.py
|
|
which is released under the MIT license.
|
|
"""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import unicode_literals
|
|
|
|
import re
|
|
import sqlalchemy
|
|
from sqlalchemy import exc
|
|
from sqlalchemy import types
|
|
from sqlalchemy import util
|
|
# TODO shouldn't use mysql type
|
|
from sqlalchemy.sql import text
|
|
try:
|
|
from sqlalchemy.databases import mysql
|
|
mysql_tinyinteger = mysql.MSTinyInteger
|
|
except ImportError:
|
|
# Required for SQLAlchemy>=2.0
|
|
from sqlalchemy.dialects import mysql
|
|
mysql_tinyinteger = mysql.base.MSTinyInteger
|
|
from sqlalchemy.engine import default
|
|
from sqlalchemy.sql import compiler
|
|
from sqlalchemy.sql.compiler import SQLCompiler
|
|
|
|
from pyhive import presto
|
|
from pyhive.common import UniversalSet
|
|
|
|
sqlalchemy_version = float(re.search(r"^([\d]+\.[\d]+)\..+", sqlalchemy.__version__).group(1))
|
|
|
|
class PrestoIdentifierPreparer(compiler.IdentifierPreparer):
|
|
# Just quote everything to make things simpler / easier to upgrade
|
|
reserved_words = UniversalSet()
|
|
|
|
|
|
_type_map = {
|
|
'boolean': types.Boolean,
|
|
'tinyint': mysql_tinyinteger,
|
|
'smallint': types.SmallInteger,
|
|
'integer': types.Integer,
|
|
'bigint': types.BigInteger,
|
|
'real': types.Float,
|
|
'double': types.Float,
|
|
'varchar': types.String,
|
|
'timestamp': types.TIMESTAMP,
|
|
'date': types.DATE,
|
|
'varbinary': types.VARBINARY,
|
|
}
|
|
|
|
|
|
class PrestoCompiler(SQLCompiler):
|
|
def visit_char_length_func(self, fn, **kw):
|
|
return 'length{}'.format(self.function_argspec(fn, **kw))
|
|
|
|
|
|
class PrestoTypeCompiler(compiler.GenericTypeCompiler):
|
|
def visit_CLOB(self, type_, **kw):
|
|
raise ValueError("Presto does not support the CLOB column type.")
|
|
|
|
def visit_NCLOB(self, type_, **kw):
|
|
raise ValueError("Presto does not support the NCLOB column type.")
|
|
|
|
def visit_DATETIME(self, type_, **kw):
|
|
raise ValueError("Presto does not support the DATETIME column type.")
|
|
|
|
def visit_FLOAT(self, type_, **kw):
|
|
return 'DOUBLE'
|
|
|
|
def visit_TEXT(self, type_, **kw):
|
|
if type_.length:
|
|
return 'VARCHAR({:d})'.format(type_.length)
|
|
else:
|
|
return 'VARCHAR'
|
|
|
|
|
|
class PrestoDialect(default.DefaultDialect):
|
|
name = 'presto'
|
|
driver = 'rest'
|
|
paramstyle = 'pyformat'
|
|
preparer = PrestoIdentifierPreparer
|
|
statement_compiler = PrestoCompiler
|
|
supports_alter = False
|
|
supports_pk_autoincrement = False
|
|
supports_default_values = False
|
|
supports_empty_insert = False
|
|
supports_multivalues_insert = True
|
|
supports_unicode_statements = True
|
|
supports_unicode_binds = True
|
|
supports_statement_cache = False
|
|
returns_unicode_strings = True
|
|
description_encoding = None
|
|
supports_native_boolean = True
|
|
type_compiler = PrestoTypeCompiler
|
|
|
|
@classmethod
|
|
def dbapi(cls):
|
|
return presto
|
|
|
|
@classmethod
|
|
def import_dbapi(cls):
|
|
return presto
|
|
|
|
def create_connect_args(self, url):
|
|
db_parts = (url.database or 'hive').split('/')
|
|
kwargs = {
|
|
'host': url.host,
|
|
'port': url.port or 8080,
|
|
'username': url.username,
|
|
'password': url.password
|
|
}
|
|
kwargs.update(url.query)
|
|
if len(db_parts) == 1:
|
|
kwargs['catalog'] = db_parts[0]
|
|
elif len(db_parts) == 2:
|
|
kwargs['catalog'] = db_parts[0]
|
|
kwargs['schema'] = db_parts[1]
|
|
else:
|
|
raise ValueError("Unexpected database format {}".format(url.database))
|
|
return [], kwargs
|
|
|
|
def get_schema_names(self, connection, **kw):
|
|
return [row.Schema for row in connection.execute(text('SHOW SCHEMAS'))]
|
|
|
|
def _get_table_columns(self, connection, table_name, schema):
|
|
full_table = self.identifier_preparer.quote_identifier(table_name)
|
|
if schema:
|
|
full_table = self.identifier_preparer.quote_identifier(schema) + '.' + full_table
|
|
try:
|
|
return connection.execute(text('SHOW COLUMNS FROM {}'.format(full_table)))
|
|
except (presto.DatabaseError, exc.DatabaseError) as e:
|
|
# Normally SQLAlchemy should wrap this exception in sqlalchemy.exc.DatabaseError, which
|
|
# it successfully does in the Hive version. The difference with Presto is that this
|
|
# error is raised when fetching the cursor's description rather than the initial execute
|
|
# call. SQLAlchemy doesn't handle this. Thus, we catch the unwrapped
|
|
# presto.DatabaseError here.
|
|
# Does the table exist?
|
|
msg = (
|
|
e.args[0].get('message') if e.args and isinstance(e.args[0], dict)
|
|
else e.args[0] if e.args and isinstance(e.args[0], str)
|
|
else None
|
|
)
|
|
regex = r"Table\ \'.*{}\'\ does\ not\ exist".format(re.escape(table_name))
|
|
if msg and re.search(regex, msg, re.IGNORECASE):
|
|
raise exc.NoSuchTableError(table_name)
|
|
else:
|
|
raise
|
|
|
|
def has_table(self, connection, table_name, schema=None, **kw):
|
|
try:
|
|
self._get_table_columns(connection, table_name, schema)
|
|
return True
|
|
except exc.NoSuchTableError:
|
|
return False
|
|
|
|
def get_columns(self, connection, table_name, schema=None, **kw):
|
|
rows = self._get_table_columns(connection, table_name, schema)
|
|
result = []
|
|
for row in rows:
|
|
try:
|
|
coltype = _type_map[row.Type]
|
|
except KeyError:
|
|
util.warn("Did not recognize type '%s' of column '%s'" % (row.Type, row.Column))
|
|
coltype = types.NullType
|
|
result.append({
|
|
'name': row.Column,
|
|
'type': coltype,
|
|
# newer Presto no longer includes this column
|
|
'nullable': getattr(row, 'Null', True),
|
|
'default': None,
|
|
})
|
|
return result
|
|
|
|
def get_foreign_keys(self, connection, table_name, schema=None, **kw):
|
|
# Hive has no support for foreign keys.
|
|
return []
|
|
|
|
def get_pk_constraint(self, connection, table_name, schema=None, **kw):
|
|
# Hive has no support for primary keys.
|
|
return []
|
|
|
|
def get_indexes(self, connection, table_name, schema=None, **kw):
|
|
rows = self._get_table_columns(connection, table_name, schema)
|
|
col_names = []
|
|
for row in rows:
|
|
part_key = 'Partition Key'
|
|
# Presto puts this information in one of 3 places depending on version
|
|
# - a boolean column named "Partition Key"
|
|
# - a string in the "Comment" column
|
|
# - a string in the "Extra" column
|
|
if sqlalchemy_version >= 1.4:
|
|
row = row._mapping
|
|
is_partition_key = (
|
|
(part_key in row and row[part_key])
|
|
or row['Comment'].startswith(part_key)
|
|
or ('Extra' in row and 'partition key' in row['Extra'])
|
|
)
|
|
if is_partition_key:
|
|
col_names.append(row['Column'])
|
|
if col_names:
|
|
return [{'name': 'partition', 'column_names': col_names, 'unique': False}]
|
|
else:
|
|
return []
|
|
|
|
def get_table_names(self, connection, schema=None, **kw):
|
|
query = 'SHOW TABLES'
|
|
if schema:
|
|
query += ' FROM ' + self.identifier_preparer.quote_identifier(schema)
|
|
return [row.Table for row in connection.execute(text(query))]
|
|
|
|
def do_rollback(self, dbapi_connection):
|
|
# No transactions for Presto
|
|
pass
|
|
|
|
def _check_unicode_returns(self, connection, additional_tests=None):
|
|
# requests gives back Unicode strings
|
|
return True
|
|
|
|
def _check_unicode_description(self, connection):
|
|
# requests gives back Unicode strings
|
|
return True
|