From db08a699489c9b0259579d7ff7fd6bf3496ca3a2 Mon Sep 17 00:00:00 2001 From: Federico Caselli Date: Thu, 14 Oct 2021 21:45:57 +0200 Subject: [PATCH] rearchitect reflection for batched performance Rearchitected the schema reflection API to allow some dialects to make use of high performing batch queries to reflect the schemas of many tables at once using much fewer queries. The new performance features are targeted first at the PostgreSQL and Oracle backends, and may be applied to any dialect that makes use of SELECT queries against system catalog tables to reflect tables (currently this omits the MySQL and SQLite dialects which instead make use of parsing the "CREATE TABLE" statement, however these dialects do not have a pre-existing performance issue with reflection. MS SQL Server is still a TODO). The new API is backwards compatible with the previous system, and should require no changes to third party dialects to retain compatibility; third party dialects can also opt into the new system by implementing batched queries for schema reflection. Along with this change is an updated reflection API that is fully :pep:`484` typed, features many new methods and some changes. Fixes: #4379 Change-Id: I897ec09843543aa7012bcdce758792ed3d415d08 --- README.unittests.rst | 5 + doc/build/changelog/unreleased_20/4379.rst | 69 + .../changelog/unreleased_20/oracle_views.rst | 11 + .../unreleased_20/postgresql_version.rst | 5 + doc/build/tutorial/metadata.rst | 2 +- lib/sqlalchemy/dialects/mssql/base.py | 172 +- lib/sqlalchemy/dialects/mysql/base.py | 47 +- lib/sqlalchemy/dialects/oracle/base.py | 2240 ++++++++------ lib/sqlalchemy/dialects/oracle/cx_oracle.py | 4 +- lib/sqlalchemy/dialects/oracle/dictionary.py | 495 ++++ lib/sqlalchemy/dialects/oracle/provision.py | 54 +- lib/sqlalchemy/dialects/oracle/types.py | 233 ++ .../dialects/postgresql/__init__.py | 31 +- .../dialects/postgresql/_psycopg_common.py | 18 + lib/sqlalchemy/dialects/postgresql/asyncpg.py | 5 + lib/sqlalchemy/dialects/postgresql/base.py | 2624 ++++++++--------- lib/sqlalchemy/dialects/postgresql/pg8000.py | 7 + .../dialects/postgresql/pg_catalog.py | 292 ++ lib/sqlalchemy/dialects/postgresql/types.py | 485 +++ lib/sqlalchemy/dialects/sqlite/base.py | 161 +- lib/sqlalchemy/engine/__init__.py | 2 + lib/sqlalchemy/engine/default.py | 131 +- lib/sqlalchemy/engine/interfaces.py | 417 ++- lib/sqlalchemy/engine/reflection.py | 1475 +++++++-- lib/sqlalchemy/sql/base.py | 2 +- lib/sqlalchemy/sql/cache_key.py | 38 + lib/sqlalchemy/sql/schema.py | 42 +- lib/sqlalchemy/testing/assertions.py | 30 +- lib/sqlalchemy/testing/plugin/pytestplugin.py | 16 +- lib/sqlalchemy/testing/provision.py | 70 +- lib/sqlalchemy/testing/requirements.py | 46 + lib/sqlalchemy/testing/schema.py | 15 +- .../testing/suite/test_reflection.py | 1544 ++++++++-- lib/sqlalchemy/testing/suite/test_sequence.py | 33 +- lib/sqlalchemy/testing/util.py | 39 +- lib/sqlalchemy/util/topological.py | 10 +- lib/sqlalchemy/util/typing.py | 2 + pyproject.toml | 13 - setup.cfg | 7 +- test/dialect/mysql/test_reflection.py | 18 +- test/dialect/oracle/test_reflection.py | 554 +++- test/dialect/postgresql/test_dialect.py | 20 +- test/dialect/postgresql/test_reflection.py | 441 ++- test/dialect/postgresql/test_types.py | 35 +- test/dialect/test_sqlite.py | 31 +- test/engine/test_reflection.py | 57 +- test/perf/many_table_reflection.py | 617 ++++ test/requirements.py | 80 +- 48 files changed, 9522 insertions(+), 3223 deletions(-) create mode 100644 doc/build/changelog/unreleased_20/4379.rst create mode 100644 doc/build/changelog/unreleased_20/oracle_views.rst create mode 100644 doc/build/changelog/unreleased_20/postgresql_version.rst create mode 100644 lib/sqlalchemy/dialects/oracle/dictionary.py create mode 100644 lib/sqlalchemy/dialects/oracle/types.py create mode 100644 lib/sqlalchemy/dialects/postgresql/pg_catalog.py create mode 100644 lib/sqlalchemy/dialects/postgresql/types.py create mode 100644 test/perf/many_table_reflection.py diff --git a/README.unittests.rst b/README.unittests.rst index 0e59d8e923..23f7292ccd 100644 --- a/README.unittests.rst +++ b/README.unittests.rst @@ -310,8 +310,13 @@ be used with pytest by using ``--db docker_mssql``. >> sqlplus system/tiger@//localhost/XEPDB1 < 0) - - super(NUMBER, self).__init__( - precision=precision, scale=scale, asdecimal=asdecimal - ) - - def adapt(self, impltype): - ret = super(NUMBER, self).adapt(impltype) - # leave a hint for the DBAPI handler - ret._is_oracle_number = True - return ret - - @property - def _type_affinity(self): - if bool(self.scale and self.scale > 0): - return sqltypes.Numeric - else: - return sqltypes.Integer - - -class FLOAT(sqltypes.FLOAT): - """Oracle FLOAT. - - This is the same as :class:`_sqltypes.FLOAT` except that - an Oracle-specific :paramref:`_oracle.FLOAT.binary_precision` - parameter is accepted, and - the :paramref:`_sqltypes.Float.precision` parameter is not accepted. - - Oracle FLOAT types indicate precision in terms of "binary precision", which - defaults to 126. For a REAL type, the value is 63. This parameter does not - cleanly map to a specific number of decimal places but is roughly - equivalent to the desired number of decimal places divided by 0.3103. - - .. versionadded:: 2.0 - - """ - - __visit_name__ = "FLOAT" - - def __init__( - self, - binary_precision=None, - asdecimal=False, - decimal_return_scale=None, - ): - r""" - Construct a FLOAT - - :param binary_precision: Oracle binary precision value to be rendered - in DDL. This may be approximated to the number of decimal characters - using the formula "decimal precision = 0.30103 * binary precision". - The default value used by Oracle for FLOAT / DOUBLE PRECISION is 126. - - :param asdecimal: See :paramref:`_sqltypes.Float.asdecimal` - - :param decimal_return_scale: See - :paramref:`_sqltypes.Float.decimal_return_scale` - - """ - super().__init__( - asdecimal=asdecimal, decimal_return_scale=decimal_return_scale - ) - self.binary_precision = binary_precision - - -class BINARY_DOUBLE(sqltypes.Float): - __visit_name__ = "BINARY_DOUBLE" - - -class BINARY_FLOAT(sqltypes.Float): - __visit_name__ = "BINARY_FLOAT" - - -class BFILE(sqltypes.LargeBinary): - __visit_name__ = "BFILE" - - -class LONG(sqltypes.Text): - __visit_name__ = "LONG" - - -class _OracleDateLiteralRender: - def _literal_processor_datetime(self, dialect): - def process(value): - if value is not None: - if getattr(value, "microsecond", None): - value = ( - f"""TO_TIMESTAMP""" - f"""('{value.isoformat().replace("T", " ")}', """ - """'YYYY-MM-DD HH24:MI:SS.FF')""" - ) - else: - value = ( - f"""TO_DATE""" - f"""('{value.isoformat().replace("T", " ")}', """ - """'YYYY-MM-DD HH24:MI:SS')""" - ) - return value - - return process - - def _literal_processor_date(self, dialect): - def process(value): - if value is not None: - if getattr(value, "microsecond", None): - value = ( - f"""TO_TIMESTAMP""" - f"""('{value.isoformat().split("T")[0]}', """ - """'YYYY-MM-DD')""" - ) - else: - value = ( - f"""TO_DATE""" - f"""('{value.isoformat().split("T")[0]}', """ - """'YYYY-MM-DD')""" - ) - return value - - return process - - -class DATE(_OracleDateLiteralRender, sqltypes.DateTime): - """Provide the oracle DATE type. - - This type has no special Python behavior, except that it subclasses - :class:`_types.DateTime`; this is to suit the fact that the Oracle - ``DATE`` type supports a time value. - - .. versionadded:: 0.9.4 - - """ - - __visit_name__ = "DATE" - - def literal_processor(self, dialect): - return self._literal_processor_datetime(dialect) - - def _compare_type_affinity(self, other): - return other._type_affinity in (sqltypes.DateTime, sqltypes.Date) - - -class _OracleDate(_OracleDateLiteralRender, sqltypes.Date): - def literal_processor(self, dialect): - return self._literal_processor_date(dialect) - - -class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval): - __visit_name__ = "INTERVAL" - - def __init__(self, day_precision=None, second_precision=None): - """Construct an INTERVAL. - - Note that only DAY TO SECOND intervals are currently supported. - This is due to a lack of support for YEAR TO MONTH intervals - within available DBAPIs. - - :param day_precision: the day precision value. this is the number of - digits to store for the day field. Defaults to "2" - :param second_precision: the second precision value. this is the - number of digits to store for the fractional seconds field. - Defaults to "6". - - """ - self.day_precision = day_precision - self.second_precision = second_precision - - @classmethod - def _adapt_from_generic_interval(cls, interval): - return INTERVAL( - day_precision=interval.day_precision, - second_precision=interval.second_precision, - ) - - @property - def _type_affinity(self): - return sqltypes.Interval - - def as_generic(self, allow_nulltype=False): - return sqltypes.Interval( - native=True, - second_precision=self.second_precision, - day_precision=self.day_precision, - ) - - -class ROWID(sqltypes.TypeEngine): - """Oracle ROWID type. - - When used in a cast() or similar, generates ROWID. - - """ - - __visit_name__ = "ROWID" - - -class _OracleBoolean(sqltypes.Boolean): - def get_dbapi_type(self, dbapi): - return dbapi.NUMBER - - colspecs = { sqltypes.Boolean: _OracleBoolean, sqltypes.Interval: INTERVAL, @@ -1541,6 +1349,13 @@ class OracleExecutionContext(default.DefaultExecutionContext): type_, ) + def pre_exec(self): + if self.statement and "_oracle_dblink" in self.execution_options: + self.statement = self.statement.replace( + dictionary.DB_LINK_PLACEHOLDER, + self.execution_options["_oracle_dblink"], + ) + class OracleDialect(default.DefaultDialect): name = "oracle" @@ -1675,6 +1490,10 @@ class OracleDialect(default.DefaultDialect): # it may work also on versions before the 18 return self.server_version_info and self.server_version_info >= (18,) + @property + def _supports_except_all(self): + return self.server_version_info and self.server_version_info >= (21,) + def do_release_savepoint(self, connection, name): # Oracle does not support RELEASE SAVEPOINT pass @@ -1700,45 +1519,99 @@ class OracleDialect(default.DefaultDialect): except: return "READ COMMITTED" - def has_table(self, connection, table_name, schema=None): + def _execute_reflection( + self, connection, query, dblink, returns_long, params=None + ): + if dblink and not dblink.startswith("@"): + dblink = f"@{dblink}" + execution_options = { + # handle db links + "_oracle_dblink": dblink or "", + # override any schema translate map + "schema_translate_map": None, + } + + if dblink and returns_long: + # Oracle seems to error with + # "ORA-00997: illegal use of LONG datatype" when returning + # LONG columns via a dblink in a query with bind params + # This type seems to be very hard to cast into something else + # so it seems easier to just use bind param in this case + def visit_bindparam(bindparam): + bindparam.literal_execute = True + + query = visitors.cloned_traverse( + query, {}, {"bindparam": visit_bindparam} + ) + return connection.execute( + query, params, execution_options=execution_options + ) + + @util.memoized_property + def _has_table_query(self): + # materialized views are returned by all_tables + tables = ( + select( + dictionary.all_tables.c.table_name, + dictionary.all_tables.c.owner, + ) + .union_all( + select( + dictionary.all_views.c.view_name.label("table_name"), + dictionary.all_views.c.owner, + ) + ) + .subquery("tables_and_views") + ) + + query = select(tables.c.table_name).where( + tables.c.table_name == bindparam("table_name"), + tables.c.owner == bindparam("owner"), + ) + return query + + @reflection.cache + def has_table( + self, connection, table_name, schema=None, dblink=None, **kw + ): + """Supported kw arguments are: ``dblink`` to reflect via a db link.""" self._ensure_has_table_connection(connection) if not schema: schema = self.default_schema_name - cursor = connection.execute( - sql.text( - """SELECT table_name FROM all_tables - WHERE table_name = CAST(:name AS VARCHAR2(128)) - AND owner = CAST(:schema_name AS VARCHAR2(128)) - UNION ALL - SELECT view_name FROM all_views - WHERE view_name = CAST(:name AS VARCHAR2(128)) - AND owner = CAST(:schema_name AS VARCHAR2(128)) - """ - ), - dict( - name=self.denormalize_name(table_name), - schema_name=self.denormalize_name(schema), - ), + params = { + "table_name": self.denormalize_name(table_name), + "owner": self.denormalize_name(schema), + } + cursor = self._execute_reflection( + connection, + self._has_table_query, + dblink, + returns_long=False, + params=params, ) - return cursor.first() is not None + return bool(cursor.scalar()) - def has_sequence(self, connection, sequence_name, schema=None): + @reflection.cache + def has_sequence( + self, connection, sequence_name, schema=None, dblink=None, **kw + ): + """Supported kw arguments are: ``dblink`` to reflect via a db link.""" if not schema: schema = self.default_schema_name - cursor = connection.execute( - sql.text( - "SELECT sequence_name FROM all_sequences " - "WHERE sequence_name = :name AND " - "sequence_owner = :schema_name" - ), - dict( - name=self.denormalize_name(sequence_name), - schema_name=self.denormalize_name(schema), - ), + + query = select(dictionary.all_sequences.c.sequence_name).where( + dictionary.all_sequences.c.sequence_name + == self.denormalize_name(sequence_name), + dictionary.all_sequences.c.sequence_owner + == self.denormalize_name(schema), ) - return cursor.first() is not None + + cursor = self._execute_reflection( + connection, query, dblink, returns_long=False + ) + return bool(cursor.scalar()) def _get_default_schema_name(self, connection): return self.normalize_name( @@ -1747,329 +1620,633 @@ class OracleDialect(default.DefaultDialect): ).scalar() ) - def _resolve_synonym( - self, - connection, - desired_owner=None, - desired_synonym=None, - desired_table=None, + @reflection.flexi_cache( + ("schema", InternalTraversal.dp_string), + ("filter_names", InternalTraversal.dp_string_list), + ("dblink", InternalTraversal.dp_string), + ) + def _get_synonyms(self, connection, schema, filter_names, dblink, **kw): + owner = self.denormalize_name(schema or self.default_schema_name) + + has_filter_names, params = self._prepare_filter_names(filter_names) + query = select( + dictionary.all_synonyms.c.synonym_name, + dictionary.all_synonyms.c.table_name, + dictionary.all_synonyms.c.table_owner, + dictionary.all_synonyms.c.db_link, + ).where(dictionary.all_synonyms.c.owner == owner) + if has_filter_names: + query = query.where( + dictionary.all_synonyms.c.synonym_name.in_( + params["filter_names"] + ) + ) + result = self._execute_reflection( + connection, query, dblink, returns_long=False + ).mappings() + return result.all() + + @lru_cache() + def _all_objects_query( + self, owner, scope, kind, has_filter_names, has_mat_views ): - """search for a local synonym matching the given desired owner/name. - - if desired_owner is None, attempts to locate a distinct owner. - - returns the actual name, owner, dblink name, and synonym name if - found. - """ - - q = ( - "SELECT owner, table_owner, table_name, db_link, " - "synonym_name FROM all_synonyms WHERE " + query = ( + select(dictionary.all_objects.c.object_name) + .select_from(dictionary.all_objects) + .where(dictionary.all_objects.c.owner == owner) ) - clauses = [] - params = {} - if desired_synonym: - clauses.append( - "synonym_name = CAST(:synonym_name AS VARCHAR2(128))" + + # NOTE: materialized views are listed in all_objects twice; + # once as MATERIALIZE VIEW and once as TABLE + if kind is ObjectKind.ANY: + # materilaized view are listed also as tables so there is no + # need to add them to the in_. + query = query.where( + dictionary.all_objects.c.object_type.in_(("TABLE", "VIEW")) ) - params["synonym_name"] = desired_synonym - if desired_owner: - clauses.append("owner = CAST(:desired_owner AS VARCHAR2(128))") - params["desired_owner"] = desired_owner - if desired_table: - clauses.append("table_name = CAST(:tname AS VARCHAR2(128))") - params["tname"] = desired_table - - q += " AND ".join(clauses) - - result = connection.execution_options(future_result=True).execute( - sql.text(q), params - ) - if desired_owner: - row = result.mappings().first() - if row: - return ( - row["table_name"], - row["table_owner"], - row["db_link"], - row["synonym_name"], - ) - else: - return None, None, None, None else: - rows = result.mappings().all() - if len(rows) > 1: - raise AssertionError( - "There are multiple tables visible to the schema, you " - "must specify owner" - ) - elif len(rows) == 1: - row = rows[0] - return ( - row["table_name"], - row["table_owner"], - row["db_link"], - row["synonym_name"], - ) - else: - return None, None, None, None + object_type = [] + if ObjectKind.VIEW in kind: + object_type.append("VIEW") + if ( + ObjectKind.MATERIALIZED_VIEW in kind + and ObjectKind.TABLE not in kind + ): + # materilaized view are listed also as tables so there is no + # need to add them to the in_ if also selecting tables. + object_type.append("MATERIALIZED VIEW") + if ObjectKind.TABLE in kind: + object_type.append("TABLE") + if has_mat_views and ObjectKind.MATERIALIZED_VIEW not in kind: + # materialized view are listed also as tables, + # so they need to be filtered out + # EXCEPT ALL / MINUS profiles as faster than using + # NOT EXISTS or NOT IN with a subquery, but it's in + # general faster to get the mat view names and exclude + # them only when needed + query = query.where( + dictionary.all_objects.c.object_name.not_in( + bindparam("mat_views") + ) + ) + query = query.where( + dictionary.all_objects.c.object_type.in_(object_type) + ) - @reflection.cache - def _prepare_reflection_args( - self, - connection, - table_name, - schema=None, - resolve_synonyms=False, - dblink="", - **kw, - ): + # handles scope + if scope is ObjectScope.DEFAULT: + query = query.where(dictionary.all_objects.c.temporary == "N") + elif scope is ObjectScope.TEMPORARY: + query = query.where(dictionary.all_objects.c.temporary == "Y") - if resolve_synonyms: - actual_name, owner, dblink, synonym = self._resolve_synonym( - connection, - desired_owner=self.denormalize_name(schema), - desired_synonym=self.denormalize_name(table_name), + if has_filter_names: + query = query.where( + dictionary.all_objects.c.object_name.in_( + bindparam("filter_names") + ) ) - else: - actual_name, owner, dblink, synonym = None, None, None, None - if not actual_name: - actual_name = self.denormalize_name(table_name) - - if dblink: - # using user_db_links here since all_db_links appears - # to have more restricted permissions. - # https://docs.oracle.com/cd/B28359_01/server.111/b28310/ds_admin005.htm - # will need to hear from more users if we are doing - # the right thing here. See [ticket:2619] - owner = connection.scalar( - sql.text( - "SELECT username FROM user_db_links " "WHERE db_link=:link" - ), - dict(link=dblink), + return query + + @reflection.flexi_cache( + ("schema", InternalTraversal.dp_string), + ("scope", InternalTraversal.dp_plain_obj), + ("kind", InternalTraversal.dp_plain_obj), + ("filter_names", InternalTraversal.dp_string_list), + ("dblink", InternalTraversal.dp_string), + ) + def _get_all_objects( + self, connection, schema, scope, kind, filter_names, dblink, **kw + ): + owner = self.denormalize_name(schema or self.default_schema_name) + + has_filter_names, params = self._prepare_filter_names(filter_names) + has_mat_views = False + if ( + ObjectKind.TABLE in kind + and ObjectKind.MATERIALIZED_VIEW not in kind + ): + # see note in _all_objects_query + mat_views = self.get_materialized_view_names( + connection, schema, dblink, _normalize=False, **kw ) - dblink = "@" + dblink - elif not owner: - owner = self.denormalize_name(schema or self.default_schema_name) + if mat_views: + params["mat_views"] = mat_views + has_mat_views = True + + query = self._all_objects_query( + owner, scope, kind, has_filter_names, has_mat_views + ) - return (actual_name, owner, dblink or "", synonym) + result = self._execute_reflection( + connection, query, dblink, returns_long=False, params=params + ).scalars() - @reflection.cache - def get_schema_names(self, connection, **kw): - s = "SELECT username FROM all_users ORDER BY username" - cursor = connection.exec_driver_sql(s) - return [self.normalize_name(row[0]) for row in cursor] + return result.all() + + def _handle_synonyms_decorator(fn): + @wraps(fn) + def wrapper(self, *args, **kwargs): + return self._handle_synonyms(fn, *args, **kwargs) + + return wrapper + + def _handle_synonyms(self, fn, connection, *args, **kwargs): + if not kwargs.get("oracle_resolve_synonyms", False): + return fn(self, connection, *args, **kwargs) + + original_kw = kwargs.copy() + schema = kwargs.pop("schema", None) + result = self._get_synonyms( + connection, + schema=schema, + filter_names=kwargs.pop("filter_names", None), + dblink=kwargs.pop("dblink", None), + info_cache=kwargs.get("info_cache", None), + ) + + dblinks_owners = defaultdict(dict) + for row in result: + key = row["db_link"], row["table_owner"] + tn = self.normalize_name(row["table_name"]) + dblinks_owners[key][tn] = row["synonym_name"] + + if not dblinks_owners: + # No synonym, do the plain thing + return fn(self, connection, *args, **original_kw) + + data = {} + for (dblink, table_owner), mapping in dblinks_owners.items(): + call_kw = { + **original_kw, + "schema": table_owner, + "dblink": self.normalize_name(dblink), + "filter_names": mapping.keys(), + } + call_result = fn(self, connection, *args, **call_kw) + for (_, tn), value in call_result: + synonym_name = self.normalize_name(mapping[tn]) + data[(schema, synonym_name)] = value + return data.items() @reflection.cache - def get_table_names(self, connection, schema=None, **kw): - schema = self.denormalize_name(schema or self.default_schema_name) + def get_schema_names(self, connection, dblink=None, **kw): + """Supported kw arguments are: ``dblink`` to reflect via a db link.""" + query = select(dictionary.all_users.c.username).order_by( + dictionary.all_users.c.username + ) + result = self._execute_reflection( + connection, query, dblink, returns_long=False + ).scalars() + return [self.normalize_name(row) for row in result] + @reflection.cache + def get_table_names(self, connection, schema=None, dblink=None, **kw): + """Supported kw arguments are: ``dblink`` to reflect via a db link.""" # note that table_names() isn't loading DBLINKed or synonym'ed tables if schema is None: schema = self.default_schema_name - sql_str = "SELECT table_name FROM all_tables WHERE " + den_schema = self.denormalize_name(schema) + if kw.get("oracle_resolve_synonyms", False): + tables = ( + select( + dictionary.all_tables.c.table_name, + dictionary.all_tables.c.owner, + dictionary.all_tables.c.iot_name, + dictionary.all_tables.c.duration, + dictionary.all_tables.c.tablespace_name, + ) + .union_all( + select( + dictionary.all_synonyms.c.synonym_name.label( + "table_name" + ), + dictionary.all_synonyms.c.owner, + dictionary.all_tables.c.iot_name, + dictionary.all_tables.c.duration, + dictionary.all_tables.c.tablespace_name, + ) + .select_from(dictionary.all_tables) + .join( + dictionary.all_synonyms, + and_( + dictionary.all_tables.c.table_name + == dictionary.all_synonyms.c.table_name, + dictionary.all_tables.c.owner + == func.coalesce( + dictionary.all_synonyms.c.table_owner, + dictionary.all_synonyms.c.owner, + ), + ), + ) + ) + .subquery("available_tables") + ) + else: + tables = dictionary.all_tables + + query = select(tables.c.table_name) if self.exclude_tablespaces: - sql_str += ( - "nvl(tablespace_name, 'no tablespace') " - "NOT IN (%s) AND " - % (", ".join(["'%s'" % ts for ts in self.exclude_tablespaces])) + query = query.where( + func.coalesce( + tables.c.tablespace_name, "no tablespace" + ).not_in(self.exclude_tablespaces) ) - sql_str += ( - "OWNER = :owner " "AND IOT_NAME IS NULL " "AND DURATION IS NULL" + query = query.where( + tables.c.owner == den_schema, + tables.c.iot_name.is_(null()), + tables.c.duration.is_(null()), ) - cursor = connection.execute(sql.text(sql_str), dict(owner=schema)) - return [self.normalize_name(row[0]) for row in cursor] + # remove materialized views + mat_query = select( + dictionary.all_mviews.c.mview_name.label("table_name") + ).where(dictionary.all_mviews.c.owner == den_schema) + + query = ( + query.except_all(mat_query) + if self._supports_except_all + else query.except_(mat_query) + ) + + result = self._execute_reflection( + connection, query, dblink, returns_long=False + ).scalars() + return [self.normalize_name(row) for row in result] @reflection.cache - def get_temp_table_names(self, connection, **kw): + def get_temp_table_names(self, connection, dblink=None, **kw): + """Supported kw arguments are: ``dblink`` to reflect via a db link.""" schema = self.denormalize_name(self.default_schema_name) - sql_str = "SELECT table_name FROM all_tables WHERE " + query = select(dictionary.all_tables.c.table_name) if self.exclude_tablespaces: - sql_str += ( - "nvl(tablespace_name, 'no tablespace') " - "NOT IN (%s) AND " - % (", ".join(["'%s'" % ts for ts in self.exclude_tablespaces])) + query = query.where( + func.coalesce( + dictionary.all_tables.c.tablespace_name, "no tablespace" + ).not_in(self.exclude_tablespaces) ) - sql_str += ( - "OWNER = :owner " - "AND IOT_NAME IS NULL " - "AND DURATION IS NOT NULL" + query = query.where( + dictionary.all_tables.c.owner == schema, + dictionary.all_tables.c.iot_name.is_(null()), + dictionary.all_tables.c.duration.is_not(null()), ) - cursor = connection.execute(sql.text(sql_str), dict(owner=schema)) - return [self.normalize_name(row[0]) for row in cursor] + result = self._execute_reflection( + connection, query, dblink, returns_long=False + ).scalars() + return [self.normalize_name(row) for row in result] @reflection.cache - def get_view_names(self, connection, schema=None, **kw): - schema = self.denormalize_name(schema or self.default_schema_name) - s = sql.text("SELECT view_name FROM all_views WHERE owner = :owner") - cursor = connection.execute( - s, dict(owner=self.denormalize_name(schema)) + def get_materialized_view_names( + self, connection, schema=None, dblink=None, _normalize=True, **kw + ): + """Supported kw arguments are: ``dblink`` to reflect via a db link.""" + if not schema: + schema = self.default_schema_name + + query = select(dictionary.all_mviews.c.mview_name).where( + dictionary.all_mviews.c.owner == self.denormalize_name(schema) ) - return [self.normalize_name(row[0]) for row in cursor] + result = self._execute_reflection( + connection, query, dblink, returns_long=False + ).scalars() + if _normalize: + return [self.normalize_name(row) for row in result] + else: + return result.all() @reflection.cache - def get_sequence_names(self, connection, schema=None, **kw): + def get_view_names(self, connection, schema=None, dblink=None, **kw): + """Supported kw arguments are: ``dblink`` to reflect via a db link.""" if not schema: schema = self.default_schema_name - cursor = connection.execute( - sql.text( - "SELECT sequence_name FROM all_sequences " - "WHERE sequence_owner = :schema_name" - ), - dict(schema_name=self.denormalize_name(schema)), + + query = select(dictionary.all_views.c.view_name).where( + dictionary.all_views.c.owner == self.denormalize_name(schema) ) - return [self.normalize_name(row[0]) for row in cursor] + result = self._execute_reflection( + connection, query, dblink, returns_long=False + ).scalars() + return [self.normalize_name(row) for row in result] @reflection.cache - def get_table_options(self, connection, table_name, schema=None, **kw): - options = {} + def get_sequence_names(self, connection, schema=None, dblink=None, **kw): + """Supported kw arguments are: ``dblink`` to reflect via a db link.""" + if not schema: + schema = self.default_schema_name + query = select(dictionary.all_sequences.c.sequence_name).where( + dictionary.all_sequences.c.sequence_owner + == self.denormalize_name(schema) + ) - resolve_synonyms = kw.get("oracle_resolve_synonyms", False) - dblink = kw.get("dblink", "") - info_cache = kw.get("info_cache") + result = self._execute_reflection( + connection, query, dblink, returns_long=False + ).scalars() + return [self.normalize_name(row) for row in result] - (table_name, schema, dblink, synonym) = self._prepare_reflection_args( + def _value_or_raise(self, data, table, schema): + table = self.normalize_name(str(table)) + try: + return dict(data)[(schema, table)] + except KeyError: + raise exc.NoSuchTableError( + f"{schema}.{table}" if schema else table + ) from None + + def _prepare_filter_names(self, filter_names): + if filter_names: + fn = [self.denormalize_name(name) for name in filter_names] + return True, {"filter_names": fn} + else: + return False, {} + + @reflection.cache + def get_table_options(self, connection, table_name, schema=None, **kw): + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms + """ + data = self.get_multi_table_options( connection, - table_name, - schema, - resolve_synonyms, - dblink, - info_cache=info_cache, + schema=schema, + filter_names=[table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, ) + return self._value_or_raise(data, table_name, schema) - params = {"table_name": table_name} + @lru_cache() + def _table_options_query( + self, owner, scope, kind, has_filter_names, has_mat_views + ): + query = select( + dictionary.all_tables.c.table_name, + dictionary.all_tables.c.compression, + dictionary.all_tables.c.compress_for, + ).where(dictionary.all_tables.c.owner == owner) + if has_filter_names: + query = query.where( + dictionary.all_tables.c.table_name.in_( + bindparam("filter_names") + ) + ) + if scope is ObjectScope.DEFAULT: + query = query.where(dictionary.all_tables.c.duration.is_(null())) + elif scope is ObjectScope.TEMPORARY: + query = query.where( + dictionary.all_tables.c.duration.is_not(null()) + ) - columns = ["table_name"] - if self._supports_table_compression: - columns.append("compression") - if self._supports_table_compress_for: - columns.append("compress_for") + if ( + has_mat_views + and ObjectKind.TABLE in kind + and ObjectKind.MATERIALIZED_VIEW not in kind + ): + # cant use EXCEPT ALL / MINUS here because we don't have an + # excludable row vs. the query above + # outerjoin + where null works better on oracle 21 but 11 does + # not like it at all. this is the next best thing + + query = query.where( + dictionary.all_tables.c.table_name.not_in( + bindparam("mat_views") + ) + ) + elif ( + ObjectKind.TABLE not in kind + and ObjectKind.MATERIALIZED_VIEW in kind + ): + query = query.where( + dictionary.all_tables.c.table_name.in_(bindparam("mat_views")) + ) + return query - text = ( - "SELECT %(columns)s " - "FROM ALL_TABLES%(dblink)s " - "WHERE table_name = CAST(:table_name AS VARCHAR(128))" - ) + @_handle_synonyms_decorator + def get_multi_table_options( + self, + connection, + *, + schema, + filter_names, + scope, + kind, + dblink=None, + **kw, + ): + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms + """ + owner = self.denormalize_name(schema or self.default_schema_name) - if schema is not None: - params["owner"] = schema - text += " AND owner = CAST(:owner AS VARCHAR(128)) " - text = text % {"dblink": dblink, "columns": ", ".join(columns)} + has_filter_names, params = self._prepare_filter_names(filter_names) + has_mat_views = False - result = connection.execute(sql.text(text), params) + if ( + ObjectKind.TABLE in kind + and ObjectKind.MATERIALIZED_VIEW not in kind + ): + # see note in _table_options_query + mat_views = self.get_materialized_view_names( + connection, schema, dblink, _normalize=False, **kw + ) + if mat_views: + params["mat_views"] = mat_views + has_mat_views = True + elif ( + ObjectKind.TABLE not in kind + and ObjectKind.MATERIALIZED_VIEW in kind + ): + mat_views = self.get_materialized_view_names( + connection, schema, dblink, _normalize=False, **kw + ) + params["mat_views"] = mat_views - enabled = dict(DISABLED=False, ENABLED=True) + options = {} + default = ReflectionDefaults.table_options - row = result.first() - if row: - if "compression" in row._fields and enabled.get( - row.compression, False - ): - if "compress_for" in row._fields: - options["oracle_compress"] = row.compress_for + if ObjectKind.TABLE in kind or ObjectKind.MATERIALIZED_VIEW in kind: + query = self._table_options_query( + owner, scope, kind, has_filter_names, has_mat_views + ) + result = self._execute_reflection( + connection, query, dblink, returns_long=False, params=params + ) + + for table, compression, compress_for in result: + if compression == "ENABLED": + data = {"oracle_compress": compress_for} else: - options["oracle_compress"] = True + data = default() + options[(schema, self.normalize_name(table))] = data + if ObjectKind.VIEW in kind and ObjectScope.DEFAULT in scope: + # add the views (no temporary views) + for view in self.get_view_names(connection, schema, dblink, **kw): + if not filter_names or view in filter_names: + options[(schema, view)] = default() - return options + return options.items() @reflection.cache def get_columns(self, connection, table_name, schema=None, **kw): + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms """ - kw arguments can be: + data = self.get_multi_columns( + connection, + schema=schema, + filter_names=[table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, + ) + return self._value_or_raise(data, table_name, schema) + + def _run_batches( + self, connection, query, dblink, returns_long, mappings, all_objects + ): + each_batch = 500 + batches = list(all_objects) + while batches: + batch = batches[0:each_batch] + batches[0:each_batch] = [] + + result = self._execute_reflection( + connection, + query, + dblink, + returns_long=returns_long, + params={"all_objects": batch}, + ) + if mappings: + yield from result.mappings() + else: + yield from result + + @lru_cache() + def _column_query(self, owner): + all_cols = dictionary.all_tab_cols + all_comments = dictionary.all_col_comments + all_ids = dictionary.all_tab_identity_cols - oracle_resolve_synonyms + if self.server_version_info >= (12,): + add_cols = ( + all_cols.c.default_on_null, + sql.case( + (all_ids.c.table_name.is_(None), sql.null()), + else_=all_ids.c.generation_type + + "," + + all_ids.c.identity_options, + ).label("identity_options"), + ) + join_identity_cols = True + else: + add_cols = ( + sql.null().label("default_on_null"), + sql.null().label("identity_options"), + ) + join_identity_cols = False + + # NOTE: on oracle cannot create tables/views without columns and + # a table cannot have all column hidden: + # ORA-54039: table must have at least one column that is not invisible + # all_tab_cols returns data for tables/views/mat-views. + # all_tab_cols does not return recycled tables + + query = ( + select( + all_cols.c.table_name, + all_cols.c.column_name, + all_cols.c.data_type, + all_cols.c.char_length, + all_cols.c.data_precision, + all_cols.c.data_scale, + all_cols.c.nullable, + all_cols.c.data_default, + all_comments.c.comments, + all_cols.c.virtual_column, + *add_cols, + ).select_from(all_cols) + # NOTE: all_col_comments has a row for each column even if no + # comment is present, so a join could be performed, but there + # seems to be no difference compared to an outer join + .outerjoin( + all_comments, + and_( + all_cols.c.table_name == all_comments.c.table_name, + all_cols.c.column_name == all_comments.c.column_name, + all_cols.c.owner == all_comments.c.owner, + ), + ) + ) + if join_identity_cols: + query = query.outerjoin( + all_ids, + and_( + all_cols.c.table_name == all_ids.c.table_name, + all_cols.c.column_name == all_ids.c.column_name, + all_cols.c.owner == all_ids.c.owner, + ), + ) - dblink + query = query.where( + all_cols.c.table_name.in_(bindparam("all_objects")), + all_cols.c.hidden_column == "NO", + all_cols.c.owner == owner, + ).order_by(all_cols.c.table_name, all_cols.c.column_id) + return query + @_handle_synonyms_decorator + def get_multi_columns( + self, + connection, + *, + schema, + filter_names, + scope, + kind, + dblink=None, + **kw, + ): + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms """ + owner = self.denormalize_name(schema or self.default_schema_name) + query = self._column_query(owner) - resolve_synonyms = kw.get("oracle_resolve_synonyms", False) - dblink = kw.get("dblink", "") - info_cache = kw.get("info_cache") + if ( + filter_names + and kind is ObjectKind.ANY + and scope is ObjectScope.ANY + ): + all_objects = [self.denormalize_name(n) for n in filter_names] + else: + all_objects = self._get_all_objects( + connection, schema, scope, kind, filter_names, dblink, **kw + ) - (table_name, schema, dblink, synonym) = self._prepare_reflection_args( + columns = defaultdict(list) + + # all_tab_cols.data_default is LONG + result = self._run_batches( connection, - table_name, - schema, - resolve_synonyms, + query, dblink, - info_cache=info_cache, + returns_long=True, + mappings=True, + all_objects=all_objects, ) - columns = [] - if self._supports_char_length: - char_length_col = "char_length" - else: - char_length_col = "data_length" - if self.server_version_info >= (12,): - identity_cols = """\ - col.default_on_null, - ( - SELECT id.generation_type || ',' || id.IDENTITY_OPTIONS - FROM ALL_TAB_IDENTITY_COLS%(dblink)s id - WHERE col.table_name = id.table_name - AND col.column_name = id.column_name - AND col.owner = id.owner - ) AS identity_options""" % { - "dblink": dblink - } - else: - identity_cols = "NULL as default_on_null, NULL as identity_options" - - params = {"table_name": table_name} - - text = """ - SELECT - col.column_name, - col.data_type, - col.%(char_length_col)s, - col.data_precision, - col.data_scale, - col.nullable, - col.data_default, - com.comments, - col.virtual_column, - %(identity_cols)s - FROM all_tab_cols%(dblink)s col - LEFT JOIN all_col_comments%(dblink)s com - ON col.table_name = com.table_name - AND col.column_name = com.column_name - AND col.owner = com.owner - WHERE col.table_name = CAST(:table_name AS VARCHAR2(128)) - AND col.hidden_column = 'NO' - """ - if schema is not None: - params["owner"] = schema - text += " AND col.owner = :owner " - text += " ORDER BY col.column_id" - text = text % { - "dblink": dblink, - "char_length_col": char_length_col, - "identity_cols": identity_cols, - } - - c = connection.execute(sql.text(text), params) - - for row in c: - colname = self.normalize_name(row[0]) - orig_colname = row[0] - coltype = row[1] - length = row[2] - precision = row[3] - scale = row[4] - nullable = row[5] == "Y" - default = row[6] - comment = row[7] - generated = row[8] - default_on_nul = row[9] - identity_options = row[10] + for row_dict in result: + table_name = self.normalize_name(row_dict["table_name"]) + orig_colname = row_dict["column_name"] + colname = self.normalize_name(orig_colname) + coltype = row_dict["data_type"] + precision = row_dict["data_precision"] if coltype == "NUMBER": + scale = row_dict["data_scale"] if precision is None and scale == 0: coltype = INTEGER() else: @@ -2089,7 +2266,9 @@ class OracleDialect(default.DefaultDialect): coltype = FLOAT(binary_precision=precision) elif coltype in ("VARCHAR2", "NVARCHAR2", "CHAR", "NCHAR"): - coltype = self.ischema_names.get(coltype)(length) + coltype = self.ischema_names.get(coltype)( + row_dict["char_length"] + ) elif "WITH TIME ZONE" in coltype: coltype = TIMESTAMP(timezone=True) else: @@ -2103,15 +2282,17 @@ class OracleDialect(default.DefaultDialect): ) coltype = sqltypes.NULLTYPE - if generated == "YES": + default = row_dict["data_default"] + if row_dict["virtual_column"] == "YES": computed = dict(sqltext=default) default = None else: computed = None + identity_options = row_dict["identity_options"] if identity_options is not None: identity = self._parse_identity_options( - identity_options, default_on_nul + identity_options, row_dict["default_on_null"] ) default = None else: @@ -2120,10 +2301,9 @@ class OracleDialect(default.DefaultDialect): cdict = { "name": colname, "type": coltype, - "nullable": nullable, + "nullable": row_dict["nullable"] == "Y", "default": default, - "autoincrement": "auto", - "comment": comment, + "comment": row_dict["comments"], } if orig_colname.lower() == orig_colname: cdict["quote"] = True @@ -2132,10 +2312,17 @@ class OracleDialect(default.DefaultDialect): if identity is not None: cdict["identity"] = identity - columns.append(cdict) - return columns + columns[(schema, table_name)].append(cdict) - def _parse_identity_options(self, identity_options, default_on_nul): + # NOTE: default not needed since all tables have columns + # default = ReflectionDefaults.columns + # return ( + # (key, value if value else default()) + # for key, value in columns.items() + # ) + return columns.items() + + def _parse_identity_options(self, identity_options, default_on_null): # identity_options is a string that starts with 'ALWAYS,' or # 'BY DEFAULT,' and continues with # START WITH: 1, INCREMENT BY: 1, MAX_VALUE: 123, MIN_VALUE: 1, @@ -2144,7 +2331,7 @@ class OracleDialect(default.DefaultDialect): parts = [p.strip() for p in identity_options.split(",")] identity = { "always": parts[0] == "ALWAYS", - "on_null": default_on_nul == "YES", + "on_null": default_on_null == "YES", } for part in parts[1:]: @@ -2168,384 +2355,641 @@ class OracleDialect(default.DefaultDialect): return identity @reflection.cache - def get_table_comment( + def get_table_comment(self, connection, table_name, schema=None, **kw): + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms + """ + data = self.get_multi_table_comment( + connection, + schema=schema, + filter_names=[table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, + ) + return self._value_or_raise(data, table_name, schema) + + @lru_cache() + def _comment_query(self, owner, scope, kind, has_filter_names): + # NOTE: all_tab_comments / all_mview_comments have a row for all + # object even if they don't have comments + queries = [] + if ObjectKind.TABLE in kind or ObjectKind.VIEW in kind: + # all_tab_comments returns also plain views + tbl_view = select( + dictionary.all_tab_comments.c.table_name, + dictionary.all_tab_comments.c.comments, + ).where( + dictionary.all_tab_comments.c.owner == owner, + dictionary.all_tab_comments.c.table_name.not_like("BIN$%"), + ) + if ObjectKind.VIEW not in kind: + tbl_view = tbl_view.where( + dictionary.all_tab_comments.c.table_type == "TABLE" + ) + elif ObjectKind.TABLE not in kind: + tbl_view = tbl_view.where( + dictionary.all_tab_comments.c.table_type == "VIEW" + ) + queries.append(tbl_view) + if ObjectKind.MATERIALIZED_VIEW in kind: + mat_view = select( + dictionary.all_mview_comments.c.mview_name.label("table_name"), + dictionary.all_mview_comments.c.comments, + ).where( + dictionary.all_mview_comments.c.owner == owner, + dictionary.all_mview_comments.c.mview_name.not_like("BIN$%"), + ) + queries.append(mat_view) + if len(queries) == 1: + query = queries[0] + else: + union = sql.union_all(*queries).subquery("tables_and_views") + query = select(union.c.table_name, union.c.comments) + + name_col = query.selected_columns.table_name + + if scope in (ObjectScope.DEFAULT, ObjectScope.TEMPORARY): + temp = "Y" if scope is ObjectScope.TEMPORARY else "N" + # need distinct since materialized view are listed also + # as tables in all_objects + query = query.distinct().join( + dictionary.all_objects, + and_( + dictionary.all_objects.c.owner == owner, + dictionary.all_objects.c.object_name == name_col, + dictionary.all_objects.c.temporary == temp, + ), + ) + if has_filter_names: + query = query.where(name_col.in_(bindparam("filter_names"))) + return query + + @_handle_synonyms_decorator + def get_multi_table_comment( self, connection, - table_name, - schema=None, - resolve_synonyms=False, - dblink="", + *, + schema, + filter_names, + scope, + kind, + dblink=None, **kw, ): - - info_cache = kw.get("info_cache") - (table_name, schema, dblink, synonym) = self._prepare_reflection_args( - connection, - table_name, - schema, - resolve_synonyms, - dblink, - info_cache=info_cache, - ) - - if not schema: - schema = self.default_schema_name - - COMMENT_SQL = """ - SELECT comments - FROM all_tab_comments - WHERE table_name = CAST(:table_name AS VARCHAR(128)) - AND owner = CAST(:schema_name AS VARCHAR(128)) + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms """ + owner = self.denormalize_name(schema or self.default_schema_name) + has_filter_names, params = self._prepare_filter_names(filter_names) + query = self._comment_query(owner, scope, kind, has_filter_names) - c = connection.execute( - sql.text(COMMENT_SQL), - dict(table_name=table_name, schema_name=schema), + result = self._execute_reflection( + connection, query, dblink, returns_long=False, params=params + ) + default = ReflectionDefaults.table_comment + # materialized views by default seem to have a comment like + # "snapshot table for snapshot owner.mat_view_name" + ignore_mat_view = "snapshot table for snapshot " + return ( + ( + (schema, self.normalize_name(table)), + {"text": comment} + if comment is not None + and not comment.startswith(ignore_mat_view) + else default(), + ) + for table, comment in result ) - return {"text": c.scalar()} @reflection.cache - def get_indexes( - self, - connection, - table_name, - schema=None, - resolve_synonyms=False, - dblink="", - **kw, - ): - - info_cache = kw.get("info_cache") - (table_name, schema, dblink, synonym) = self._prepare_reflection_args( + def get_indexes(self, connection, table_name, schema=None, **kw): + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms + """ + data = self.get_multi_indexes( connection, - table_name, - schema, - resolve_synonyms, - dblink, - info_cache=info_cache, + schema=schema, + filter_names=[table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, ) - indexes = [] - - params = {"table_name": table_name} - text = ( - "SELECT a.index_name, a.column_name, " - "\nb.index_type, b.uniqueness, b.compression, b.prefix_length " - "\nFROM ALL_IND_COLUMNS%(dblink)s a, " - "\nALL_INDEXES%(dblink)s b " - "\nWHERE " - "\na.index_name = b.index_name " - "\nAND a.table_owner = b.table_owner " - "\nAND a.table_name = b.table_name " - "\nAND a.table_name = CAST(:table_name AS VARCHAR(128))" + return self._value_or_raise(data, table_name, schema) + + @lru_cache() + def _index_query(self, owner): + return ( + select( + dictionary.all_ind_columns.c.table_name, + dictionary.all_ind_columns.c.index_name, + dictionary.all_ind_columns.c.column_name, + dictionary.all_indexes.c.index_type, + dictionary.all_indexes.c.uniqueness, + dictionary.all_indexes.c.compression, + dictionary.all_indexes.c.prefix_length, + ) + .select_from(dictionary.all_ind_columns) + .join( + dictionary.all_indexes, + sql.and_( + dictionary.all_ind_columns.c.index_name + == dictionary.all_indexes.c.index_name, + dictionary.all_ind_columns.c.table_owner + == dictionary.all_indexes.c.table_owner, + # NOTE: this condition on table_name is not required + # but it improves the query performance noticeably + dictionary.all_ind_columns.c.table_name + == dictionary.all_indexes.c.table_name, + ), + ) + .where( + dictionary.all_ind_columns.c.table_owner == owner, + dictionary.all_ind_columns.c.table_name.in_( + bindparam("all_objects") + ), + ) + .order_by( + dictionary.all_ind_columns.c.index_name, + dictionary.all_ind_columns.c.column_position, + ) ) - if schema is not None: - params["schema"] = schema - text += "AND a.table_owner = :schema " + @reflection.flexi_cache( + ("schema", InternalTraversal.dp_string), + ("dblink", InternalTraversal.dp_string), + ("all_objects", InternalTraversal.dp_string_list), + ) + def _get_indexes_rows(self, connection, schema, dblink, all_objects, **kw): + owner = self.denormalize_name(schema or self.default_schema_name) - text += "ORDER BY a.index_name, a.column_position" + query = self._index_query(owner) - text = text % {"dblink": dblink} + pks = { + row_dict["constraint_name"] + for row_dict in self._get_all_constraint_rows( + connection, schema, dblink, all_objects, **kw + ) + if row_dict["constraint_type"] == "P" + } - q = sql.text(text) - rp = connection.execute(q, params) - indexes = [] - last_index_name = None - pk_constraint = self.get_pk_constraint( + result = self._run_batches( connection, - table_name, - schema, - resolve_synonyms=resolve_synonyms, - dblink=dblink, - info_cache=kw.get("info_cache"), + query, + dblink, + returns_long=False, + mappings=True, + all_objects=all_objects, ) - uniqueness = dict(NONUNIQUE=False, UNIQUE=True) - enabled = dict(DISABLED=False, ENABLED=True) + return [ + row_dict + for row_dict in result + if row_dict["index_name"] not in pks + ] - oracle_sys_col = re.compile(r"SYS_NC\d+\$", re.IGNORECASE) + @_handle_synonyms_decorator + def get_multi_indexes( + self, + connection, + *, + schema, + filter_names, + scope, + kind, + dblink=None, + **kw, + ): + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms + """ + all_objects = self._get_all_objects( + connection, schema, scope, kind, filter_names, dblink, **kw + ) - index = None - for rset in rp: - index_name_normalized = self.normalize_name(rset.index_name) + uniqueness = {"NONUNIQUE": False, "UNIQUE": True} + enabled = {"DISABLED": False, "ENABLED": True} + is_bitmap = {"BITMAP", "FUNCTION-BASED BITMAP"} - # skip primary key index. This is refined as of - # [ticket:5421]. Note that ALL_INDEXES.GENERATED will by "Y" - # if the name of this index was generated by Oracle, however - # if a named primary key constraint was created then this flag - # is false. - if ( - pk_constraint - and index_name_normalized == pk_constraint["name"] - ): - continue + oracle_sys_col = re.compile(r"SYS_NC\d+\$", re.IGNORECASE) - if rset.index_name != last_index_name: - index = dict( - name=index_name_normalized, - column_names=[], - dialect_options={}, - ) - indexes.append(index) - index["unique"] = uniqueness.get(rset.uniqueness, False) + indexes = defaultdict(dict) + + for row_dict in self._get_indexes_rows( + connection, schema, dblink, all_objects, **kw + ): + index_name = self.normalize_name(row_dict["index_name"]) + table_name = self.normalize_name(row_dict["table_name"]) + table_indexes = indexes[(schema, table_name)] + + if index_name not in table_indexes: + table_indexes[index_name] = index_dict = { + "name": index_name, + "column_names": [], + "dialect_options": {}, + "unique": uniqueness.get(row_dict["uniqueness"], False), + } + do = index_dict["dialect_options"] + if row_dict["index_type"] in is_bitmap: + do["oracle_bitmap"] = True + if enabled.get(row_dict["compression"], False): + do["oracle_compress"] = row_dict["prefix_length"] - if rset.index_type in ("BITMAP", "FUNCTION-BASED BITMAP"): - index["dialect_options"]["oracle_bitmap"] = True - if enabled.get(rset.compression, False): - index["dialect_options"][ - "oracle_compress" - ] = rset.prefix_length + else: + index_dict = table_indexes[index_name] # filter out Oracle SYS_NC names. could also do an outer join - # to the all_tab_columns table and check for real col names there. - if not oracle_sys_col.match(rset.column_name): - index["column_names"].append( - self.normalize_name(rset.column_name) + # to the all_tab_columns table and check for real col names + # there. + if not oracle_sys_col.match(row_dict["column_name"]): + index_dict["column_names"].append( + self.normalize_name(row_dict["column_name"]) ) - last_index_name = rset.index_name - return indexes + default = ReflectionDefaults.indexes - @reflection.cache - def _get_constraint_data( - self, connection, table_name, schema=None, dblink="", **kw - ): - - params = {"table_name": table_name} - - text = ( - "SELECT" - "\nac.constraint_name," # 0 - "\nac.constraint_type," # 1 - "\nloc.column_name AS local_column," # 2 - "\nrem.table_name AS remote_table," # 3 - "\nrem.column_name AS remote_column," # 4 - "\nrem.owner AS remote_owner," # 5 - "\nloc.position as loc_pos," # 6 - "\nrem.position as rem_pos," # 7 - "\nac.search_condition," # 8 - "\nac.delete_rule" # 9 - "\nFROM all_constraints%(dblink)s ac," - "\nall_cons_columns%(dblink)s loc," - "\nall_cons_columns%(dblink)s rem" - "\nWHERE ac.table_name = CAST(:table_name AS VARCHAR2(128))" - "\nAND ac.constraint_type IN ('R','P', 'U', 'C')" - ) - - if schema is not None: - params["owner"] = schema - text += "\nAND ac.owner = CAST(:owner AS VARCHAR2(128))" - - text += ( - "\nAND ac.owner = loc.owner" - "\nAND ac.constraint_name = loc.constraint_name" - "\nAND ac.r_owner = rem.owner(+)" - "\nAND ac.r_constraint_name = rem.constraint_name(+)" - "\nAND (rem.position IS NULL or loc.position=rem.position)" - "\nORDER BY ac.constraint_name, loc.position" + return ( + (key, list(indexes[key].values()) if key in indexes else default()) + for key in ( + (schema, self.normalize_name(obj_name)) + for obj_name in all_objects + ) ) - text = text % {"dblink": dblink} - rp = connection.execute(sql.text(text), params) - constraint_data = rp.fetchall() - return constraint_data - @reflection.cache def get_pk_constraint(self, connection, table_name, schema=None, **kw): - resolve_synonyms = kw.get("oracle_resolve_synonyms", False) - dblink = kw.get("dblink", "") - info_cache = kw.get("info_cache") - - (table_name, schema, dblink, synonym) = self._prepare_reflection_args( + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms + """ + data = self.get_multi_pk_constraint( connection, - table_name, - schema, - resolve_synonyms, - dblink, - info_cache=info_cache, + schema=schema, + filter_names=[table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, ) - pkeys = [] - constraint_name = None - constraint_data = self._get_constraint_data( - connection, - table_name, - schema, - dblink, - info_cache=kw.get("info_cache"), + return self._value_or_raise(data, table_name, schema) + + @lru_cache() + def _constraint_query(self, owner): + local = dictionary.all_cons_columns.alias("local") + remote = dictionary.all_cons_columns.alias("remote") + return ( + select( + dictionary.all_constraints.c.table_name, + dictionary.all_constraints.c.constraint_type, + dictionary.all_constraints.c.constraint_name, + local.c.column_name.label("local_column"), + remote.c.table_name.label("remote_table"), + remote.c.column_name.label("remote_column"), + remote.c.owner.label("remote_owner"), + dictionary.all_constraints.c.search_condition, + dictionary.all_constraints.c.delete_rule, + ) + .select_from(dictionary.all_constraints) + .join( + local, + and_( + local.c.owner == dictionary.all_constraints.c.owner, + dictionary.all_constraints.c.constraint_name + == local.c.constraint_name, + ), + ) + .outerjoin( + remote, + and_( + dictionary.all_constraints.c.r_owner == remote.c.owner, + dictionary.all_constraints.c.r_constraint_name + == remote.c.constraint_name, + or_( + remote.c.position.is_(sql.null()), + local.c.position == remote.c.position, + ), + ), + ) + .where( + dictionary.all_constraints.c.owner == owner, + dictionary.all_constraints.c.table_name.in_( + bindparam("all_objects") + ), + dictionary.all_constraints.c.constraint_type.in_( + ("R", "P", "U", "C") + ), + ) + .order_by( + dictionary.all_constraints.c.constraint_name, local.c.position + ) ) - for row in constraint_data: - ( - cons_name, - cons_type, - local_column, - remote_table, - remote_column, - remote_owner, - ) = row[0:2] + tuple([self.normalize_name(x) for x in row[2:6]]) - if cons_type == "P": - if constraint_name is None: - constraint_name = self.normalize_name(cons_name) - pkeys.append(local_column) - return {"constrained_columns": pkeys, "name": constraint_name} + @reflection.flexi_cache( + ("schema", InternalTraversal.dp_string), + ("dblink", InternalTraversal.dp_string), + ("all_objects", InternalTraversal.dp_string_list), + ) + def _get_all_constraint_rows( + self, connection, schema, dblink, all_objects, **kw + ): + owner = self.denormalize_name(schema or self.default_schema_name) + query = self._constraint_query(owner) - @reflection.cache - def get_foreign_keys(self, connection, table_name, schema=None, **kw): + # since the result is cached a list must be created + values = list( + self._run_batches( + connection, + query, + dblink, + returns_long=False, + mappings=True, + all_objects=all_objects, + ) + ) + return values + + @_handle_synonyms_decorator + def get_multi_pk_constraint( + self, + connection, + *, + scope, + schema, + filter_names, + kind, + dblink=None, + **kw, + ): + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms """ + all_objects = self._get_all_objects( + connection, schema, scope, kind, filter_names, dblink, **kw + ) - kw arguments can be: + primary_keys = defaultdict(dict) + default = ReflectionDefaults.pk_constraint - oracle_resolve_synonyms + for row_dict in self._get_all_constraint_rows( + connection, schema, dblink, all_objects, **kw + ): + if row_dict["constraint_type"] != "P": + continue + table_name = self.normalize_name(row_dict["table_name"]) + constraint_name = self.normalize_name(row_dict["constraint_name"]) + column_name = self.normalize_name(row_dict["local_column"]) + + table_pk = primary_keys[(schema, table_name)] + if not table_pk: + table_pk["name"] = constraint_name + table_pk["constrained_columns"] = [column_name] + else: + table_pk["constrained_columns"].append(column_name) - dblink + return ( + (key, primary_keys[key] if key in primary_keys else default()) + for key in ( + (schema, self.normalize_name(obj_name)) + for obj_name in all_objects + ) + ) + @reflection.cache + def get_foreign_keys( + self, + connection, + table_name, + schema=None, + **kw, + ): + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms """ - requested_schema = schema # to check later on - resolve_synonyms = kw.get("oracle_resolve_synonyms", False) - dblink = kw.get("dblink", "") - info_cache = kw.get("info_cache") - - (table_name, schema, dblink, synonym) = self._prepare_reflection_args( + data = self.get_multi_foreign_keys( connection, - table_name, - schema, - resolve_synonyms, - dblink, - info_cache=info_cache, + schema=schema, + filter_names=[table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, ) + return self._value_or_raise(data, table_name, schema) - constraint_data = self._get_constraint_data( - connection, - table_name, - schema, - dblink, - info_cache=kw.get("info_cache"), + @_handle_synonyms_decorator + def get_multi_foreign_keys( + self, + connection, + *, + scope, + schema, + filter_names, + kind, + dblink=None, + **kw, + ): + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms + """ + all_objects = self._get_all_objects( + connection, schema, scope, kind, filter_names, dblink, **kw ) - def fkey_rec(): - return { - "name": None, - "constrained_columns": [], - "referred_schema": None, - "referred_table": None, - "referred_columns": [], - "options": {}, - } + resolve_synonyms = kw.get("oracle_resolve_synonyms", False) - fkeys = util.defaultdict(fkey_rec) + owner = self.denormalize_name(schema or self.default_schema_name) - for row in constraint_data: - ( - cons_name, - cons_type, - local_column, - remote_table, - remote_column, - remote_owner, - ) = row[0:2] + tuple([self.normalize_name(x) for x in row[2:6]]) - - cons_name = self.normalize_name(cons_name) - - if cons_type == "R": - if remote_table is None: - # ticket 363 - util.warn( - ( - "Got 'None' querying 'table_name' from " - "all_cons_columns%(dblink)s - does the user have " - "proper rights to the table?" - ) - % {"dblink": dblink} - ) - continue + all_remote_owners = set() + fkeys = defaultdict(dict) + + for row_dict in self._get_all_constraint_rows( + connection, schema, dblink, all_objects, **kw + ): + if row_dict["constraint_type"] != "R": + continue + + table_name = self.normalize_name(row_dict["table_name"]) + constraint_name = self.normalize_name(row_dict["constraint_name"]) + table_fkey = fkeys[(schema, table_name)] + + assert constraint_name is not None - rec = fkeys[cons_name] - rec["name"] = cons_name - local_cols, remote_cols = ( - rec["constrained_columns"], - rec["referred_columns"], + local_column = self.normalize_name(row_dict["local_column"]) + remote_table = self.normalize_name(row_dict["remote_table"]) + remote_column = self.normalize_name(row_dict["remote_column"]) + remote_owner_orig = row_dict["remote_owner"] + remote_owner = self.normalize_name(remote_owner_orig) + if remote_owner_orig is not None: + all_remote_owners.add(remote_owner_orig) + + if remote_table is None: + # ticket 363 + if dblink and not dblink.startswith("@"): + dblink = f"@{dblink}" + util.warn( + "Got 'None' querying 'table_name' from " + f"all_cons_columns{dblink or ''} - does the user have " + "proper rights to the table?" ) + continue - if not rec["referred_table"]: - if resolve_synonyms: - ( - ref_remote_name, - ref_remote_owner, - ref_dblink, - ref_synonym, - ) = self._resolve_synonym( - connection, - desired_owner=self.denormalize_name(remote_owner), - desired_table=self.denormalize_name(remote_table), - ) - if ref_synonym: - remote_table = self.normalize_name(ref_synonym) - remote_owner = self.normalize_name( - ref_remote_owner - ) + if constraint_name not in table_fkey: + table_fkey[constraint_name] = fkey = { + "name": constraint_name, + "constrained_columns": [], + "referred_schema": None, + "referred_table": remote_table, + "referred_columns": [], + "options": {}, + } - rec["referred_table"] = remote_table + if resolve_synonyms: + # will be removed below + fkey["_ref_schema"] = remote_owner - if ( - requested_schema is not None - or self.denormalize_name(remote_owner) != schema - ): - rec["referred_schema"] = remote_owner + if schema is not None or remote_owner_orig != owner: + fkey["referred_schema"] = remote_owner + + delete_rule = row_dict["delete_rule"] + if delete_rule != "NO ACTION": + fkey["options"]["ondelete"] = delete_rule + + else: + fkey = table_fkey[constraint_name] + + fkey["constrained_columns"].append(local_column) + fkey["referred_columns"].append(remote_column) + + if resolve_synonyms and all_remote_owners: + query = select( + dictionary.all_synonyms.c.owner, + dictionary.all_synonyms.c.table_name, + dictionary.all_synonyms.c.table_owner, + dictionary.all_synonyms.c.synonym_name, + ).where(dictionary.all_synonyms.c.owner.in_(all_remote_owners)) + + result = self._execute_reflection( + connection, query, dblink, returns_long=False + ).mappings() - if row[9] != "NO ACTION": - rec["options"]["ondelete"] = row[9] + remote_owners_lut = {} + for row in result: + synonym_owner = self.normalize_name(row["owner"]) + table_name = self.normalize_name(row["table_name"]) - local_cols.append(local_column) - remote_cols.append(remote_column) + remote_owners_lut[(synonym_owner, table_name)] = ( + row["table_owner"], + row["synonym_name"], + ) + + empty = (None, None) + for table_fkeys in fkeys.values(): + for table_fkey in table_fkeys.values(): + key = ( + table_fkey.pop("_ref_schema"), + table_fkey["referred_table"], + ) + remote_owner, syn_name = remote_owners_lut.get(key, empty) + if syn_name: + sn = self.normalize_name(syn_name) + table_fkey["referred_table"] = sn + if schema is not None or remote_owner != owner: + ro = self.normalize_name(remote_owner) + table_fkey["referred_schema"] = ro + else: + table_fkey["referred_schema"] = None + default = ReflectionDefaults.foreign_keys - return list(fkeys.values()) + return ( + (key, list(fkeys[key].values()) if key in fkeys else default()) + for key in ( + (schema, self.normalize_name(obj_name)) + for obj_name in all_objects + ) + ) @reflection.cache def get_unique_constraints( self, connection, table_name, schema=None, **kw ): - resolve_synonyms = kw.get("oracle_resolve_synonyms", False) - dblink = kw.get("dblink", "") - info_cache = kw.get("info_cache") - - (table_name, schema, dblink, synonym) = self._prepare_reflection_args( + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms + """ + data = self.get_multi_unique_constraints( connection, - table_name, - schema, - resolve_synonyms, - dblink, - info_cache=info_cache, + schema=schema, + filter_names=[table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, ) + return self._value_or_raise(data, table_name, schema) - constraint_data = self._get_constraint_data( - connection, - table_name, - schema, - dblink, - info_cache=kw.get("info_cache"), + @_handle_synonyms_decorator + def get_multi_unique_constraints( + self, + connection, + *, + scope, + schema, + filter_names, + kind, + dblink=None, + **kw, + ): + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms + """ + all_objects = self._get_all_objects( + connection, schema, scope, kind, filter_names, dblink, **kw ) - unique_keys = filter(lambda x: x[1] == "U", constraint_data) - uniques_group = groupby(unique_keys, lambda x: x[0]) + unique_cons = defaultdict(dict) index_names = { - ix["name"] - for ix in self.get_indexes(connection, table_name, schema=schema) + row_dict["index_name"] + for row_dict in self._get_indexes_rows( + connection, schema, dblink, all_objects, **kw + ) } - return [ - { - "name": name, - "column_names": cols, - "duplicates_index": name if name in index_names else None, - } - for name, cols in [ - [ - self.normalize_name(i[0]), - [self.normalize_name(x[2]) for x in i[1]], - ] - for i in uniques_group - ] - ] + + for row_dict in self._get_all_constraint_rows( + connection, schema, dblink, all_objects, **kw + ): + if row_dict["constraint_type"] != "U": + continue + table_name = self.normalize_name(row_dict["table_name"]) + constraint_name_orig = row_dict["constraint_name"] + constraint_name = self.normalize_name(constraint_name_orig) + column_name = self.normalize_name(row_dict["local_column"]) + table_uc = unique_cons[(schema, table_name)] + + assert constraint_name is not None + + if constraint_name not in table_uc: + table_uc[constraint_name] = uc = { + "name": constraint_name, + "column_names": [], + "duplicates_index": constraint_name + if constraint_name_orig in index_names + else None, + } + else: + uc = table_uc[constraint_name] + + uc["column_names"].append(column_name) + + default = ReflectionDefaults.unique_constraints + + return ( + ( + key, + list(unique_cons[key].values()) + if key in unique_cons + else default(), + ) + for key in ( + (schema, self.normalize_name(obj_name)) + for obj_name in all_objects + ) + ) @reflection.cache def get_view_definition( @@ -2553,65 +2997,129 @@ class OracleDialect(default.DefaultDialect): connection, view_name, schema=None, - resolve_synonyms=False, - dblink="", + dblink=None, **kw, ): - info_cache = kw.get("info_cache") - (view_name, schema, dblink, synonym) = self._prepare_reflection_args( - connection, - view_name, - schema, - resolve_synonyms, - dblink, - info_cache=info_cache, + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms + """ + if kw.get("oracle_resolve_synonyms", False): + synonyms = self._get_synonyms( + connection, schema, filter_names=[view_name], dblink=dblink + ) + if synonyms: + assert len(synonyms) == 1 + row_dict = synonyms[0] + dblink = self.normalize_name(row_dict["db_link"]) + schema = row_dict["table_owner"] + view_name = row_dict["table_name"] + + name = self.denormalize_name(view_name) + owner = self.denormalize_name(schema or self.default_schema_name) + query = ( + select(dictionary.all_views.c.text) + .where( + dictionary.all_views.c.view_name == name, + dictionary.all_views.c.owner == owner, + ) + .union_all( + select(dictionary.all_mviews.c.query).where( + dictionary.all_mviews.c.mview_name == name, + dictionary.all_mviews.c.owner == owner, + ) + ) ) - params = {"view_name": view_name} - text = "SELECT text FROM all_views WHERE view_name=:view_name" - - if schema is not None: - text += " AND owner = :schema" - params["schema"] = schema - - rp = connection.execute(sql.text(text), params).scalar() - if rp: - return rp + rp = self._execute_reflection( + connection, query, dblink, returns_long=False + ).scalar() + if rp is None: + raise exc.NoSuchTableError( + f"{schema}.{view_name}" if schema else view_name + ) else: - return None + return rp @reflection.cache def get_check_constraints( self, connection, table_name, schema=None, include_all=False, **kw ): - resolve_synonyms = kw.get("oracle_resolve_synonyms", False) - dblink = kw.get("dblink", "") - info_cache = kw.get("info_cache") - - (table_name, schema, dblink, synonym) = self._prepare_reflection_args( + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms + """ + data = self.get_multi_check_constraints( connection, - table_name, - schema, - resolve_synonyms, - dblink, - info_cache=info_cache, + schema=schema, + filter_names=[table_name], + scope=ObjectScope.ANY, + include_all=include_all, + kind=ObjectKind.ANY, + **kw, ) + return self._value_or_raise(data, table_name, schema) - constraint_data = self._get_constraint_data( - connection, - table_name, - schema, - dblink, - info_cache=kw.get("info_cache"), + @_handle_synonyms_decorator + def get_multi_check_constraints( + self, + connection, + *, + schema, + filter_names, + dblink=None, + scope, + kind, + include_all=False, + **kw, + ): + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms + """ + all_objects = self._get_all_objects( + connection, schema, scope, kind, filter_names, dblink, **kw ) - check_constraints = filter(lambda x: x[1] == "C", constraint_data) + not_null = re.compile(r"..+?. IS NOT NULL$") - return [ - {"name": self.normalize_name(cons[0]), "sqltext": cons[8]} - for cons in check_constraints - if include_all or not re.match(r"..+?. IS NOT NULL$", cons[8]) - ] + check_constraints = defaultdict(list) + + for row_dict in self._get_all_constraint_rows( + connection, schema, dblink, all_objects, **kw + ): + if row_dict["constraint_type"] != "C": + continue + table_name = self.normalize_name(row_dict["table_name"]) + constraint_name = self.normalize_name(row_dict["constraint_name"]) + search_condition = row_dict["search_condition"] + + table_checks = check_constraints[(schema, table_name)] + if constraint_name is not None and ( + include_all or not not_null.match(search_condition) + ): + table_checks.append( + {"name": constraint_name, "sqltext": search_condition} + ) + + default = ReflectionDefaults.check_constraints + + return ( + ( + key, + check_constraints[key] + if key in check_constraints + else default(), + ) + for key in ( + (schema, self.normalize_name(obj_name)) + for obj_name in all_objects + ) + ) + + def _list_dblinks(self, connection, dblink=None): + query = select(dictionary.all_db_links.c.db_link) + links = self._execute_reflection( + connection, query, dblink, returns_long=False + ).scalars() + return [self.normalize_name(link) for link in links] class _OuterJoinColumn(sql.ClauseElement): diff --git a/lib/sqlalchemy/dialects/oracle/cx_oracle.py b/lib/sqlalchemy/dialects/oracle/cx_oracle.py index 25e93632cb..d2ee0a96ed 100644 --- a/lib/sqlalchemy/dialects/oracle/cx_oracle.py +++ b/lib/sqlalchemy/dialects/oracle/cx_oracle.py @@ -431,6 +431,7 @@ from . import base as oracle from .base import OracleCompiler from .base import OracleDialect from .base import OracleExecutionContext +from .types import _OracleDateLiteralRender from ... import exc from ... import util from ...engine import cursor as _cursor @@ -573,7 +574,7 @@ class _CXOracleDate(oracle._OracleDate): return process -class _CXOracleTIMESTAMP(oracle._OracleDateLiteralRender, sqltypes.TIMESTAMP): +class _CXOracleTIMESTAMP(_OracleDateLiteralRender, sqltypes.TIMESTAMP): def literal_processor(self, dialect): return self._literal_processor_datetime(dialect) @@ -812,6 +813,7 @@ class OracleExecutionContext_cx_oracle(OracleExecutionContext): return None def pre_exec(self): + super().pre_exec() if not getattr(self.compiled, "_oracle_cx_sql_compiler", False): return diff --git a/lib/sqlalchemy/dialects/oracle/dictionary.py b/lib/sqlalchemy/dialects/oracle/dictionary.py new file mode 100644 index 0000000000..ac7a350da3 --- /dev/null +++ b/lib/sqlalchemy/dialects/oracle/dictionary.py @@ -0,0 +1,495 @@ +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +from .types import DATE +from .types import LONG +from .types import NUMBER +from .types import RAW +from .types import VARCHAR2 +from ... import Column +from ... import MetaData +from ... import Table +from ... import table +from ...sql.sqltypes import CHAR + +# constants +DB_LINK_PLACEHOLDER = "__$sa_dblink$__" +# tables +dual = table("dual") +dictionary_meta = MetaData() + +# NOTE: all the dictionary_meta are aliases because oracle does not like +# using the full table@dblink for every column in query, and complains with +# ORA-00960: ambiguous column naming in select list +all_tables = Table( + "all_tables" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("owner", VARCHAR2(128), nullable=False), + Column("table_name", VARCHAR2(128), nullable=False), + Column("tablespace_name", VARCHAR2(30)), + Column("cluster_name", VARCHAR2(128)), + Column("iot_name", VARCHAR2(128)), + Column("status", VARCHAR2(8)), + Column("pct_free", NUMBER), + Column("pct_used", NUMBER), + Column("ini_trans", NUMBER), + Column("max_trans", NUMBER), + Column("initial_extent", NUMBER), + Column("next_extent", NUMBER), + Column("min_extents", NUMBER), + Column("max_extents", NUMBER), + Column("pct_increase", NUMBER), + Column("freelists", NUMBER), + Column("freelist_groups", NUMBER), + Column("logging", VARCHAR2(3)), + Column("backed_up", VARCHAR2(1)), + Column("num_rows", NUMBER), + Column("blocks", NUMBER), + Column("empty_blocks", NUMBER), + Column("avg_space", NUMBER), + Column("chain_cnt", NUMBER), + Column("avg_row_len", NUMBER), + Column("avg_space_freelist_blocks", NUMBER), + Column("num_freelist_blocks", NUMBER), + Column("degree", VARCHAR2(10)), + Column("instances", VARCHAR2(10)), + Column("cache", VARCHAR2(5)), + Column("table_lock", VARCHAR2(8)), + Column("sample_size", NUMBER), + Column("last_analyzed", DATE), + Column("partitioned", VARCHAR2(3)), + Column("iot_type", VARCHAR2(12)), + Column("temporary", VARCHAR2(1)), + Column("secondary", VARCHAR2(1)), + Column("nested", VARCHAR2(3)), + Column("buffer_pool", VARCHAR2(7)), + Column("flash_cache", VARCHAR2(7)), + Column("cell_flash_cache", VARCHAR2(7)), + Column("row_movement", VARCHAR2(8)), + Column("global_stats", VARCHAR2(3)), + Column("user_stats", VARCHAR2(3)), + Column("duration", VARCHAR2(15)), + Column("skip_corrupt", VARCHAR2(8)), + Column("monitoring", VARCHAR2(3)), + Column("cluster_owner", VARCHAR2(128)), + Column("dependencies", VARCHAR2(8)), + Column("compression", VARCHAR2(8)), + Column("compress_for", VARCHAR2(30)), + Column("dropped", VARCHAR2(3)), + Column("read_only", VARCHAR2(3)), + Column("segment_created", VARCHAR2(3)), + Column("result_cache", VARCHAR2(7)), + Column("clustering", VARCHAR2(3)), + Column("activity_tracking", VARCHAR2(23)), + Column("dml_timestamp", VARCHAR2(25)), + Column("has_identity", VARCHAR2(3)), + Column("container_data", VARCHAR2(3)), + Column("inmemory", VARCHAR2(8)), + Column("inmemory_priority", VARCHAR2(8)), + Column("inmemory_distribute", VARCHAR2(15)), + Column("inmemory_compression", VARCHAR2(17)), + Column("inmemory_duplicate", VARCHAR2(13)), + Column("default_collation", VARCHAR2(100)), + Column("duplicated", VARCHAR2(1)), + Column("sharded", VARCHAR2(1)), + Column("externally_sharded", VARCHAR2(1)), + Column("externally_duplicated", VARCHAR2(1)), + Column("external", VARCHAR2(3)), + Column("hybrid", VARCHAR2(3)), + Column("cellmemory", VARCHAR2(24)), + Column("containers_default", VARCHAR2(3)), + Column("container_map", VARCHAR2(3)), + Column("extended_data_link", VARCHAR2(3)), + Column("extended_data_link_map", VARCHAR2(3)), + Column("inmemory_service", VARCHAR2(12)), + Column("inmemory_service_name", VARCHAR2(1000)), + Column("container_map_object", VARCHAR2(3)), + Column("memoptimize_read", VARCHAR2(8)), + Column("memoptimize_write", VARCHAR2(8)), + Column("has_sensitive_column", VARCHAR2(3)), + Column("admit_null", VARCHAR2(3)), + Column("data_link_dml_enabled", VARCHAR2(3)), + Column("logical_replication", VARCHAR2(8)), +).alias("a_tables") + +all_views = Table( + "all_views" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("owner", VARCHAR2(128), nullable=False), + Column("view_name", VARCHAR2(128), nullable=False), + Column("text_length", NUMBER), + Column("text", LONG), + Column("text_vc", VARCHAR2(4000)), + Column("type_text_length", NUMBER), + Column("type_text", VARCHAR2(4000)), + Column("oid_text_length", NUMBER), + Column("oid_text", VARCHAR2(4000)), + Column("view_type_owner", VARCHAR2(128)), + Column("view_type", VARCHAR2(128)), + Column("superview_name", VARCHAR2(128)), + Column("editioning_view", VARCHAR2(1)), + Column("read_only", VARCHAR2(1)), + Column("container_data", VARCHAR2(1)), + Column("bequeath", VARCHAR2(12)), + Column("origin_con_id", VARCHAR2(256)), + Column("default_collation", VARCHAR2(100)), + Column("containers_default", VARCHAR2(3)), + Column("container_map", VARCHAR2(3)), + Column("extended_data_link", VARCHAR2(3)), + Column("extended_data_link_map", VARCHAR2(3)), + Column("has_sensitive_column", VARCHAR2(3)), + Column("admit_null", VARCHAR2(3)), + Column("pdb_local_only", VARCHAR2(3)), +).alias("a_views") + +all_sequences = Table( + "all_sequences" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("sequence_owner", VARCHAR2(128), nullable=False), + Column("sequence_name", VARCHAR2(128), nullable=False), + Column("min_value", NUMBER), + Column("max_value", NUMBER), + Column("increment_by", NUMBER, nullable=False), + Column("cycle_flag", VARCHAR2(1)), + Column("order_flag", VARCHAR2(1)), + Column("cache_size", NUMBER, nullable=False), + Column("last_number", NUMBER, nullable=False), + Column("scale_flag", VARCHAR2(1)), + Column("extend_flag", VARCHAR2(1)), + Column("sharded_flag", VARCHAR2(1)), + Column("session_flag", VARCHAR2(1)), + Column("keep_value", VARCHAR2(1)), +).alias("a_sequences") + +all_users = Table( + "all_users" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("username", VARCHAR2(128), nullable=False), + Column("user_id", NUMBER, nullable=False), + Column("created", DATE, nullable=False), + Column("common", VARCHAR2(3)), + Column("oracle_maintained", VARCHAR2(1)), + Column("inherited", VARCHAR2(3)), + Column("default_collation", VARCHAR2(100)), + Column("implicit", VARCHAR2(3)), + Column("all_shard", VARCHAR2(3)), + Column("external_shard", VARCHAR2(3)), +).alias("a_users") + +all_mviews = Table( + "all_mviews" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("owner", VARCHAR2(128), nullable=False), + Column("mview_name", VARCHAR2(128), nullable=False), + Column("container_name", VARCHAR2(128), nullable=False), + Column("query", LONG), + Column("query_len", NUMBER(38)), + Column("updatable", VARCHAR2(1)), + Column("update_log", VARCHAR2(128)), + Column("master_rollback_seg", VARCHAR2(128)), + Column("master_link", VARCHAR2(128)), + Column("rewrite_enabled", VARCHAR2(1)), + Column("rewrite_capability", VARCHAR2(9)), + Column("refresh_mode", VARCHAR2(6)), + Column("refresh_method", VARCHAR2(8)), + Column("build_mode", VARCHAR2(9)), + Column("fast_refreshable", VARCHAR2(18)), + Column("last_refresh_type", VARCHAR2(8)), + Column("last_refresh_date", DATE), + Column("last_refresh_end_time", DATE), + Column("staleness", VARCHAR2(19)), + Column("after_fast_refresh", VARCHAR2(19)), + Column("unknown_prebuilt", VARCHAR2(1)), + Column("unknown_plsql_func", VARCHAR2(1)), + Column("unknown_external_table", VARCHAR2(1)), + Column("unknown_consider_fresh", VARCHAR2(1)), + Column("unknown_import", VARCHAR2(1)), + Column("unknown_trusted_fd", VARCHAR2(1)), + Column("compile_state", VARCHAR2(19)), + Column("use_no_index", VARCHAR2(1)), + Column("stale_since", DATE), + Column("num_pct_tables", NUMBER), + Column("num_fresh_pct_regions", NUMBER), + Column("num_stale_pct_regions", NUMBER), + Column("segment_created", VARCHAR2(3)), + Column("evaluation_edition", VARCHAR2(128)), + Column("unusable_before", VARCHAR2(128)), + Column("unusable_beginning", VARCHAR2(128)), + Column("default_collation", VARCHAR2(100)), + Column("on_query_computation", VARCHAR2(1)), + Column("auto", VARCHAR2(3)), +).alias("a_mviews") + +all_tab_identity_cols = Table( + "all_tab_identity_cols" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("owner", VARCHAR2(128), nullable=False), + Column("table_name", VARCHAR2(128), nullable=False), + Column("column_name", VARCHAR2(128), nullable=False), + Column("generation_type", VARCHAR2(10)), + Column("sequence_name", VARCHAR2(128), nullable=False), + Column("identity_options", VARCHAR2(298)), +).alias("a_tab_identity_cols") + +all_tab_cols = Table( + "all_tab_cols" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("owner", VARCHAR2(128), nullable=False), + Column("table_name", VARCHAR2(128), nullable=False), + Column("column_name", VARCHAR2(128), nullable=False), + Column("data_type", VARCHAR2(128)), + Column("data_type_mod", VARCHAR2(3)), + Column("data_type_owner", VARCHAR2(128)), + Column("data_length", NUMBER, nullable=False), + Column("data_precision", NUMBER), + Column("data_scale", NUMBER), + Column("nullable", VARCHAR2(1)), + Column("column_id", NUMBER), + Column("default_length", NUMBER), + Column("data_default", LONG), + Column("num_distinct", NUMBER), + Column("low_value", RAW(1000)), + Column("high_value", RAW(1000)), + Column("density", NUMBER), + Column("num_nulls", NUMBER), + Column("num_buckets", NUMBER), + Column("last_analyzed", DATE), + Column("sample_size", NUMBER), + Column("character_set_name", VARCHAR2(44)), + Column("char_col_decl_length", NUMBER), + Column("global_stats", VARCHAR2(3)), + Column("user_stats", VARCHAR2(3)), + Column("avg_col_len", NUMBER), + Column("char_length", NUMBER), + Column("char_used", VARCHAR2(1)), + Column("v80_fmt_image", VARCHAR2(3)), + Column("data_upgraded", VARCHAR2(3)), + Column("hidden_column", VARCHAR2(3)), + Column("virtual_column", VARCHAR2(3)), + Column("segment_column_id", NUMBER), + Column("internal_column_id", NUMBER, nullable=False), + Column("histogram", VARCHAR2(15)), + Column("qualified_col_name", VARCHAR2(4000)), + Column("user_generated", VARCHAR2(3)), + Column("default_on_null", VARCHAR2(3)), + Column("identity_column", VARCHAR2(3)), + Column("evaluation_edition", VARCHAR2(128)), + Column("unusable_before", VARCHAR2(128)), + Column("unusable_beginning", VARCHAR2(128)), + Column("collation", VARCHAR2(100)), + Column("collated_column_id", NUMBER), +).alias("a_tab_cols") + +all_tab_comments = Table( + "all_tab_comments" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("owner", VARCHAR2(128), nullable=False), + Column("table_name", VARCHAR2(128), nullable=False), + Column("table_type", VARCHAR2(11)), + Column("comments", VARCHAR2(4000)), + Column("origin_con_id", NUMBER), +).alias("a_tab_comments") + +all_col_comments = Table( + "all_col_comments" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("owner", VARCHAR2(128), nullable=False), + Column("table_name", VARCHAR2(128), nullable=False), + Column("column_name", VARCHAR2(128), nullable=False), + Column("comments", VARCHAR2(4000)), + Column("origin_con_id", NUMBER), +).alias("a_col_comments") + +all_mview_comments = Table( + "all_mview_comments" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("owner", VARCHAR2(128), nullable=False), + Column("mview_name", VARCHAR2(128), nullable=False), + Column("comments", VARCHAR2(4000)), +).alias("a_mview_comments") + +all_ind_columns = Table( + "all_ind_columns" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("index_owner", VARCHAR2(128), nullable=False), + Column("index_name", VARCHAR2(128), nullable=False), + Column("table_owner", VARCHAR2(128), nullable=False), + Column("table_name", VARCHAR2(128), nullable=False), + Column("column_name", VARCHAR2(4000)), + Column("column_position", NUMBER, nullable=False), + Column("column_length", NUMBER, nullable=False), + Column("char_length", NUMBER), + Column("descend", VARCHAR2(4)), + Column("collated_column_id", NUMBER), +).alias("a_ind_columns") + +all_indexes = Table( + "all_indexes" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("owner", VARCHAR2(128), nullable=False), + Column("index_name", VARCHAR2(128), nullable=False), + Column("index_type", VARCHAR2(27)), + Column("table_owner", VARCHAR2(128), nullable=False), + Column("table_name", VARCHAR2(128), nullable=False), + Column("table_type", CHAR(11)), + Column("uniqueness", VARCHAR2(9)), + Column("compression", VARCHAR2(13)), + Column("prefix_length", NUMBER), + Column("tablespace_name", VARCHAR2(30)), + Column("ini_trans", NUMBER), + Column("max_trans", NUMBER), + Column("initial_extent", NUMBER), + Column("next_extent", NUMBER), + Column("min_extents", NUMBER), + Column("max_extents", NUMBER), + Column("pct_increase", NUMBER), + Column("pct_threshold", NUMBER), + Column("include_column", NUMBER), + Column("freelists", NUMBER), + Column("freelist_groups", NUMBER), + Column("pct_free", NUMBER), + Column("logging", VARCHAR2(3)), + Column("blevel", NUMBER), + Column("leaf_blocks", NUMBER), + Column("distinct_keys", NUMBER), + Column("avg_leaf_blocks_per_key", NUMBER), + Column("avg_data_blocks_per_key", NUMBER), + Column("clustering_factor", NUMBER), + Column("status", VARCHAR2(8)), + Column("num_rows", NUMBER), + Column("sample_size", NUMBER), + Column("last_analyzed", DATE), + Column("degree", VARCHAR2(40)), + Column("instances", VARCHAR2(40)), + Column("partitioned", VARCHAR2(3)), + Column("temporary", VARCHAR2(1)), + Column("generated", VARCHAR2(1)), + Column("secondary", VARCHAR2(1)), + Column("buffer_pool", VARCHAR2(7)), + Column("flash_cache", VARCHAR2(7)), + Column("cell_flash_cache", VARCHAR2(7)), + Column("user_stats", VARCHAR2(3)), + Column("duration", VARCHAR2(15)), + Column("pct_direct_access", NUMBER), + Column("ityp_owner", VARCHAR2(128)), + Column("ityp_name", VARCHAR2(128)), + Column("parameters", VARCHAR2(1000)), + Column("global_stats", VARCHAR2(3)), + Column("domidx_status", VARCHAR2(12)), + Column("domidx_opstatus", VARCHAR2(6)), + Column("funcidx_status", VARCHAR2(8)), + Column("join_index", VARCHAR2(3)), + Column("iot_redundant_pkey_elim", VARCHAR2(3)), + Column("dropped", VARCHAR2(3)), + Column("visibility", VARCHAR2(9)), + Column("domidx_management", VARCHAR2(14)), + Column("segment_created", VARCHAR2(3)), + Column("orphaned_entries", VARCHAR2(3)), + Column("indexing", VARCHAR2(7)), + Column("auto", VARCHAR2(3)), +).alias("a_indexes") + +all_constraints = Table( + "all_constraints" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("owner", VARCHAR2(128)), + Column("constraint_name", VARCHAR2(128)), + Column("constraint_type", VARCHAR2(1)), + Column("table_name", VARCHAR2(128)), + Column("search_condition", LONG), + Column("search_condition_vc", VARCHAR2(4000)), + Column("r_owner", VARCHAR2(128)), + Column("r_constraint_name", VARCHAR2(128)), + Column("delete_rule", VARCHAR2(9)), + Column("status", VARCHAR2(8)), + Column("deferrable", VARCHAR2(14)), + Column("deferred", VARCHAR2(9)), + Column("validated", VARCHAR2(13)), + Column("generated", VARCHAR2(14)), + Column("bad", VARCHAR2(3)), + Column("rely", VARCHAR2(4)), + Column("last_change", DATE), + Column("index_owner", VARCHAR2(128)), + Column("index_name", VARCHAR2(128)), + Column("invalid", VARCHAR2(7)), + Column("view_related", VARCHAR2(14)), + Column("origin_con_id", VARCHAR2(256)), +).alias("a_constraints") + +all_cons_columns = Table( + "all_cons_columns" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("owner", VARCHAR2(128), nullable=False), + Column("constraint_name", VARCHAR2(128), nullable=False), + Column("table_name", VARCHAR2(128), nullable=False), + Column("column_name", VARCHAR2(4000)), + Column("position", NUMBER), +).alias("a_cons_columns") + +# TODO figure out if it's still relevant, since there is no mention from here +# https://docs.oracle.com/en/database/oracle/oracle-database/21/refrn/ALL_DB_LINKS.html +# original note: +# using user_db_links here since all_db_links appears +# to have more restricted permissions. +# https://docs.oracle.com/cd/B28359_01/server.111/b28310/ds_admin005.htm +# will need to hear from more users if we are doing +# the right thing here. See [ticket:2619] +all_db_links = Table( + "all_db_links" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("owner", VARCHAR2(128), nullable=False), + Column("db_link", VARCHAR2(128), nullable=False), + Column("username", VARCHAR2(128)), + Column("host", VARCHAR2(2000)), + Column("created", DATE, nullable=False), + Column("hidden", VARCHAR2(3)), + Column("shard_internal", VARCHAR2(3)), + Column("valid", VARCHAR2(3)), + Column("intra_cdb", VARCHAR2(3)), +).alias("a_db_links") + +all_synonyms = Table( + "all_synonyms" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("owner", VARCHAR2(128)), + Column("synonym_name", VARCHAR2(128)), + Column("table_owner", VARCHAR2(128)), + Column("table_name", VARCHAR2(128)), + Column("db_link", VARCHAR2(128)), + Column("origin_con_id", VARCHAR2(256)), +).alias("a_synonyms") + +all_objects = Table( + "all_objects" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("owner", VARCHAR2(128), nullable=False), + Column("object_name", VARCHAR2(128), nullable=False), + Column("subobject_name", VARCHAR2(128)), + Column("object_id", NUMBER, nullable=False), + Column("data_object_id", NUMBER), + Column("object_type", VARCHAR2(23)), + Column("created", DATE, nullable=False), + Column("last_ddl_time", DATE, nullable=False), + Column("timestamp", VARCHAR2(19)), + Column("status", VARCHAR2(7)), + Column("temporary", VARCHAR2(1)), + Column("generated", VARCHAR2(1)), + Column("secondary", VARCHAR2(1)), + Column("namespace", NUMBER, nullable=False), + Column("edition_name", VARCHAR2(128)), + Column("sharing", VARCHAR2(13)), + Column("editionable", VARCHAR2(1)), + Column("oracle_maintained", VARCHAR2(1)), + Column("application", VARCHAR2(1)), + Column("default_collation", VARCHAR2(100)), + Column("duplicated", VARCHAR2(1)), + Column("sharded", VARCHAR2(1)), + Column("created_appid", NUMBER), + Column("created_vsnid", NUMBER), + Column("modified_appid", NUMBER), + Column("modified_vsnid", NUMBER), +).alias("a_objects") diff --git a/lib/sqlalchemy/dialects/oracle/provision.py b/lib/sqlalchemy/dialects/oracle/provision.py index cba3b5be47..75b7a7aa99 100644 --- a/lib/sqlalchemy/dialects/oracle/provision.py +++ b/lib/sqlalchemy/dialects/oracle/provision.py @@ -2,9 +2,12 @@ from ... import create_engine from ... import exc +from ... import inspect from ...engine import url as sa_url from ...testing.provision import configure_follower from ...testing.provision import create_db +from ...testing.provision import drop_all_schema_objects_post_tables +from ...testing.provision import drop_all_schema_objects_pre_tables from ...testing.provision import drop_db from ...testing.provision import follower_url_from_main from ...testing.provision import log @@ -28,6 +31,10 @@ def _oracle_create_db(cfg, eng, ident): conn.exec_driver_sql("grant unlimited tablespace to %s" % ident) conn.exec_driver_sql("grant unlimited tablespace to %s_ts1" % ident) conn.exec_driver_sql("grant unlimited tablespace to %s_ts2" % ident) + # these are needed to create materialized views + conn.exec_driver_sql("grant create table to %s" % ident) + conn.exec_driver_sql("grant create table to %s_ts1" % ident) + conn.exec_driver_sql("grant create table to %s_ts2" % ident) @configure_follower.for_db("oracle") @@ -46,6 +53,30 @@ def _ora_drop_ignore(conn, dbname): return False +@drop_all_schema_objects_pre_tables.for_db("oracle") +def _ora_drop_all_schema_objects_pre_tables(cfg, eng): + _purge_recyclebin(eng) + _purge_recyclebin(eng, cfg.test_schema) + + +@drop_all_schema_objects_post_tables.for_db("oracle") +def _ora_drop_all_schema_objects_post_tables(cfg, eng): + + with eng.begin() as conn: + for syn in conn.dialect._get_synonyms(conn, None, None, None): + conn.exec_driver_sql(f"drop synonym {syn['synonym_name']}") + + for syn in conn.dialect._get_synonyms( + conn, cfg.test_schema, None, None + ): + conn.exec_driver_sql( + f"drop synonym {cfg.test_schema}.{syn['synonym_name']}" + ) + + for tmp_table in inspect(conn).get_temp_table_names(): + conn.exec_driver_sql(f"drop table {tmp_table}") + + @drop_db.for_db("oracle") def _oracle_drop_db(cfg, eng, ident): with eng.begin() as conn: @@ -60,13 +91,10 @@ def _oracle_drop_db(cfg, eng, ident): @stop_test_class_outside_fixtures.for_db("oracle") -def stop_test_class_outside_fixtures(config, db, cls): +def _ora_stop_test_class_outside_fixtures(config, db, cls): try: - with db.begin() as conn: - # run magic command to get rid of identity sequences - # https://floo.bar/2019/11/29/drop-the-underlying-sequence-of-an-identity-column/ # noqa: E501 - conn.exec_driver_sql("purge recyclebin") + _purge_recyclebin(db) except exc.DatabaseError as err: log.warning("purge recyclebin command failed: %s", err) @@ -85,6 +113,22 @@ def stop_test_class_outside_fixtures(config, db, cls): _all_conns.clear() +def _purge_recyclebin(eng, schema=None): + with eng.begin() as conn: + if schema is None: + # run magic command to get rid of identity sequences + # https://floo.bar/2019/11/29/drop-the-underlying-sequence-of-an-identity-column/ # noqa: E501 + conn.exec_driver_sql("purge recyclebin") + else: + # per user: https://community.oracle.com/tech/developers/discussion/2255402/how-to-clear-dba-recyclebin-for-a-particular-user # noqa: E501 + for owner, object_name, type_ in conn.exec_driver_sql( + "select owner, object_name,type from " + "dba_recyclebin where owner=:schema and type='TABLE'", + {"schema": conn.dialect.denormalize_name(schema)}, + ).all(): + conn.exec_driver_sql(f'purge {type_} {owner}."{object_name}"') + + _all_conns = set() diff --git a/lib/sqlalchemy/dialects/oracle/types.py b/lib/sqlalchemy/dialects/oracle/types.py new file mode 100644 index 0000000000..60a8ebcb50 --- /dev/null +++ b/lib/sqlalchemy/dialects/oracle/types.py @@ -0,0 +1,233 @@ +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +from ...sql import sqltypes +from ...types import NVARCHAR +from ...types import VARCHAR + + +class RAW(sqltypes._Binary): + __visit_name__ = "RAW" + + +OracleRaw = RAW + + +class NCLOB(sqltypes.Text): + __visit_name__ = "NCLOB" + + +class VARCHAR2(VARCHAR): + __visit_name__ = "VARCHAR2" + + +NVARCHAR2 = NVARCHAR + + +class NUMBER(sqltypes.Numeric, sqltypes.Integer): + __visit_name__ = "NUMBER" + + def __init__(self, precision=None, scale=None, asdecimal=None): + if asdecimal is None: + asdecimal = bool(scale and scale > 0) + + super(NUMBER, self).__init__( + precision=precision, scale=scale, asdecimal=asdecimal + ) + + def adapt(self, impltype): + ret = super(NUMBER, self).adapt(impltype) + # leave a hint for the DBAPI handler + ret._is_oracle_number = True + return ret + + @property + def _type_affinity(self): + if bool(self.scale and self.scale > 0): + return sqltypes.Numeric + else: + return sqltypes.Integer + + +class FLOAT(sqltypes.FLOAT): + """Oracle FLOAT. + + This is the same as :class:`_sqltypes.FLOAT` except that + an Oracle-specific :paramref:`_oracle.FLOAT.binary_precision` + parameter is accepted, and + the :paramref:`_sqltypes.Float.precision` parameter is not accepted. + + Oracle FLOAT types indicate precision in terms of "binary precision", which + defaults to 126. For a REAL type, the value is 63. This parameter does not + cleanly map to a specific number of decimal places but is roughly + equivalent to the desired number of decimal places divided by 0.3103. + + .. versionadded:: 2.0 + + """ + + __visit_name__ = "FLOAT" + + def __init__( + self, + binary_precision=None, + asdecimal=False, + decimal_return_scale=None, + ): + r""" + Construct a FLOAT + + :param binary_precision: Oracle binary precision value to be rendered + in DDL. This may be approximated to the number of decimal characters + using the formula "decimal precision = 0.30103 * binary precision". + The default value used by Oracle for FLOAT / DOUBLE PRECISION is 126. + + :param asdecimal: See :paramref:`_sqltypes.Float.asdecimal` + + :param decimal_return_scale: See + :paramref:`_sqltypes.Float.decimal_return_scale` + + """ + super().__init__( + asdecimal=asdecimal, decimal_return_scale=decimal_return_scale + ) + self.binary_precision = binary_precision + + +class BINARY_DOUBLE(sqltypes.Float): + __visit_name__ = "BINARY_DOUBLE" + + +class BINARY_FLOAT(sqltypes.Float): + __visit_name__ = "BINARY_FLOAT" + + +class BFILE(sqltypes.LargeBinary): + __visit_name__ = "BFILE" + + +class LONG(sqltypes.Text): + __visit_name__ = "LONG" + + +class _OracleDateLiteralRender: + def _literal_processor_datetime(self, dialect): + def process(value): + if value is not None: + if getattr(value, "microsecond", None): + value = ( + f"""TO_TIMESTAMP""" + f"""('{value.isoformat().replace("T", " ")}', """ + """'YYYY-MM-DD HH24:MI:SS.FF')""" + ) + else: + value = ( + f"""TO_DATE""" + f"""('{value.isoformat().replace("T", " ")}', """ + """'YYYY-MM-DD HH24:MI:SS')""" + ) + return value + + return process + + def _literal_processor_date(self, dialect): + def process(value): + if value is not None: + if getattr(value, "microsecond", None): + value = ( + f"""TO_TIMESTAMP""" + f"""('{value.isoformat().split("T")[0]}', """ + """'YYYY-MM-DD')""" + ) + else: + value = ( + f"""TO_DATE""" + f"""('{value.isoformat().split("T")[0]}', """ + """'YYYY-MM-DD')""" + ) + return value + + return process + + +class DATE(_OracleDateLiteralRender, sqltypes.DateTime): + """Provide the oracle DATE type. + + This type has no special Python behavior, except that it subclasses + :class:`_types.DateTime`; this is to suit the fact that the Oracle + ``DATE`` type supports a time value. + + .. versionadded:: 0.9.4 + + """ + + __visit_name__ = "DATE" + + def literal_processor(self, dialect): + return self._literal_processor_datetime(dialect) + + def _compare_type_affinity(self, other): + return other._type_affinity in (sqltypes.DateTime, sqltypes.Date) + + +class _OracleDate(_OracleDateLiteralRender, sqltypes.Date): + def literal_processor(self, dialect): + return self._literal_processor_date(dialect) + + +class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval): + __visit_name__ = "INTERVAL" + + def __init__(self, day_precision=None, second_precision=None): + """Construct an INTERVAL. + + Note that only DAY TO SECOND intervals are currently supported. + This is due to a lack of support for YEAR TO MONTH intervals + within available DBAPIs. + + :param day_precision: the day precision value. this is the number of + digits to store for the day field. Defaults to "2" + :param second_precision: the second precision value. this is the + number of digits to store for the fractional seconds field. + Defaults to "6". + + """ + self.day_precision = day_precision + self.second_precision = second_precision + + @classmethod + def _adapt_from_generic_interval(cls, interval): + return INTERVAL( + day_precision=interval.day_precision, + second_precision=interval.second_precision, + ) + + @property + def _type_affinity(self): + return sqltypes.Interval + + def as_generic(self, allow_nulltype=False): + return sqltypes.Interval( + native=True, + second_precision=self.second_precision, + day_precision=self.day_precision, + ) + + +class ROWID(sqltypes.TypeEngine): + """Oracle ROWID type. + + When used in a cast() or similar, generates ROWID. + + """ + + __visit_name__ = "ROWID" + + +class _OracleBoolean(sqltypes.Boolean): + def get_dbapi_type(self, dbapi): + return dbapi.NUMBER diff --git a/lib/sqlalchemy/dialects/postgresql/__init__.py b/lib/sqlalchemy/dialects/postgresql/__init__.py index c2472fb559..85bbf8c5bb 100644 --- a/lib/sqlalchemy/dialects/postgresql/__init__.py +++ b/lib/sqlalchemy/dialects/postgresql/__init__.py @@ -19,31 +19,16 @@ from .array import Any from .array import ARRAY from .array import array from .base import BIGINT -from .base import BIT from .base import BOOLEAN -from .base import BYTEA from .base import CHAR -from .base import CIDR -from .base import CreateEnumType from .base import DATE from .base import DOUBLE_PRECISION -from .base import DropEnumType -from .base import ENUM from .base import FLOAT -from .base import INET from .base import INTEGER -from .base import INTERVAL -from .base import MACADDR -from .base import MONEY from .base import NUMERIC -from .base import OID from .base import REAL -from .base import REGCLASS from .base import SMALLINT from .base import TEXT -from .base import TIME -from .base import TIMESTAMP -from .base import TSVECTOR from .base import UUID from .base import VARCHAR from .dml import Insert @@ -61,7 +46,21 @@ from .ranges import INT8RANGE from .ranges import NUMRANGE from .ranges import TSRANGE from .ranges import TSTZRANGE -from ...util import compat +from .types import BIT +from .types import BYTEA +from .types import CIDR +from .types import CreateEnumType +from .types import DropEnumType +from .types import ENUM +from .types import INET +from .types import INTERVAL +from .types import MACADDR +from .types import MONEY +from .types import OID +from .types import REGCLASS +from .types import TIME +from .types import TIMESTAMP +from .types import TSVECTOR # Alias psycopg also as psycopg_async psycopg_async = type( diff --git a/lib/sqlalchemy/dialects/postgresql/_psycopg_common.py b/lib/sqlalchemy/dialects/postgresql/_psycopg_common.py index e831f2ed9d..8dcd36c6d1 100644 --- a/lib/sqlalchemy/dialects/postgresql/_psycopg_common.py +++ b/lib/sqlalchemy/dialects/postgresql/_psycopg_common.py @@ -1,3 +1,8 @@ +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php # mypy: ignore-errors import decimal @@ -9,6 +14,9 @@ from .base import _INT_TYPES from .base import PGDialect from .base import PGExecutionContext from .hstore import HSTORE +from .pg_catalog import _SpaceVector +from .pg_catalog import INT2VECTOR +from .pg_catalog import OIDVECTOR from ... import exc from ... import types as sqltypes from ... import util @@ -66,6 +74,14 @@ class _PsycopgARRAY(PGARRAY): render_bind_cast = True +class _PsycopgINT2VECTOR(_SpaceVector, INT2VECTOR): + pass + + +class _PsycopgOIDVECTOR(_SpaceVector, OIDVECTOR): + pass + + class _PGExecutionContext_common_psycopg(PGExecutionContext): def create_server_side_cursor(self): # use server-side cursors: @@ -91,6 +107,8 @@ class _PGDialect_common_psycopg(PGDialect): sqltypes.Numeric: _PsycopgNumeric, HSTORE: _PsycopgHStore, sqltypes.ARRAY: _PsycopgARRAY, + INT2VECTOR: _PsycopgINT2VECTOR, + OIDVECTOR: _PsycopgOIDVECTOR, }, ) diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py index 1ec787e1fc..d6385a5d61 100644 --- a/lib/sqlalchemy/dialects/postgresql/asyncpg.py +++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py @@ -274,6 +274,10 @@ class AsyncpgOID(OID): render_bind_cast = True +class AsyncpgCHAR(sqltypes.CHAR): + render_bind_cast = True + + class PGExecutionContext_asyncpg(PGExecutionContext): def handle_dbapi_exception(self, e): if isinstance( @@ -823,6 +827,7 @@ class PGDialect_asyncpg(PGDialect): sqltypes.Enum: AsyncPgEnum, OID: AsyncpgOID, REGCLASS: AsyncpgREGCLASS, + sqltypes.CHAR: AsyncpgCHAR, }, ) is_async = True diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 83e46151f0..36de76e0db 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -11,7 +11,7 @@ r""" :name: PostgreSQL :full_support: 9.6, 10, 11, 12, 13, 14 :normal_support: 9.6+ - :best_effort: 8+ + :best_effort: 9+ .. _postgresql_sequences: @@ -1448,23 +1448,52 @@ E.g.:: from __future__ import annotations from collections import defaultdict -import datetime as dt +from functools import lru_cache import re -from typing import Any from . import array as _array from . import dml from . import hstore as _hstore from . import json as _json +from . import pg_catalog from . import ranges as _ranges +from .types import _DECIMAL_TYPES # noqa +from .types import _FLOAT_TYPES # noqa +from .types import _INT_TYPES # noqa +from .types import BIT +from .types import BYTEA +from .types import CIDR +from .types import CreateEnumType # noqa +from .types import DropEnumType # noqa +from .types import ENUM +from .types import INET +from .types import INTERVAL +from .types import MACADDR +from .types import MONEY +from .types import OID +from .types import PGBit # noqa +from .types import PGCidr # noqa +from .types import PGInet # noqa +from .types import PGInterval # noqa +from .types import PGMacAddr # noqa +from .types import PGUuid +from .types import REGCLASS +from .types import TIME +from .types import TIMESTAMP +from .types import TSVECTOR from ... import exc from ... import schema +from ... import select from ... import sql from ... import util from ...engine import characteristics from ...engine import default from ...engine import interfaces +from ...engine import ObjectKind +from ...engine import ObjectScope from ...engine import reflection +from ...engine.reflection import ReflectionDefaults +from ...sql import bindparam from ...sql import coercions from ...sql import compiler from ...sql import elements @@ -1472,7 +1501,7 @@ from ...sql import expression from ...sql import roles from ...sql import sqltypes from ...sql import util as sql_util -from ...sql.ddl import InvokeDDLBase +from ...sql.visitors import InternalTraversal from ...types import BIGINT from ...types import BOOLEAN from ...types import CHAR @@ -1596,469 +1625,6 @@ RESERVED_WORDS = set( ] ) -_DECIMAL_TYPES = (1231, 1700) -_FLOAT_TYPES = (700, 701, 1021, 1022) -_INT_TYPES = (20, 21, 23, 26, 1005, 1007, 1016) - - -class PGUuid(UUID): - render_bind_cast = True - render_literal_cast = True - - -class BYTEA(sqltypes.LargeBinary[bytes]): - __visit_name__ = "BYTEA" - - -class INET(sqltypes.TypeEngine[str]): - __visit_name__ = "INET" - - -PGInet = INET - - -class CIDR(sqltypes.TypeEngine[str]): - __visit_name__ = "CIDR" - - -PGCidr = CIDR - - -class MACADDR(sqltypes.TypeEngine[str]): - __visit_name__ = "MACADDR" - - -PGMacAddr = MACADDR - - -class MONEY(sqltypes.TypeEngine[str]): - - r"""Provide the PostgreSQL MONEY type. - - Depending on driver, result rows using this type may return a - string value which includes currency symbols. - - For this reason, it may be preferable to provide conversion to a - numerically-based currency datatype using :class:`_types.TypeDecorator`:: - - import re - import decimal - from sqlalchemy import TypeDecorator - - class NumericMoney(TypeDecorator): - impl = MONEY - - def process_result_value(self, value: Any, dialect: Any) -> None: - if value is not None: - # adjust this for the currency and numeric - m = re.match(r"\$([\d.]+)", value) - if m: - value = decimal.Decimal(m.group(1)) - return value - - Alternatively, the conversion may be applied as a CAST using - the :meth:`_types.TypeDecorator.column_expression` method as follows:: - - import decimal - from sqlalchemy import cast - from sqlalchemy import TypeDecorator - - class NumericMoney(TypeDecorator): - impl = MONEY - - def column_expression(self, column: Any): - return cast(column, Numeric()) - - .. versionadded:: 1.2 - - """ - - __visit_name__ = "MONEY" - - -class OID(sqltypes.TypeEngine[int]): - - """Provide the PostgreSQL OID type. - - .. versionadded:: 0.9.5 - - """ - - __visit_name__ = "OID" - - -class REGCLASS(sqltypes.TypeEngine[str]): - - """Provide the PostgreSQL REGCLASS type. - - .. versionadded:: 1.2.7 - - """ - - __visit_name__ = "REGCLASS" - - -class TIMESTAMP(sqltypes.TIMESTAMP): - def __init__(self, timezone=False, precision=None): - super(TIMESTAMP, self).__init__(timezone=timezone) - self.precision = precision - - -class TIME(sqltypes.TIME): - def __init__(self, timezone=False, precision=None): - super(TIME, self).__init__(timezone=timezone) - self.precision = precision - - -class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval): - - """PostgreSQL INTERVAL type.""" - - __visit_name__ = "INTERVAL" - native = True - - def __init__(self, precision=None, fields=None): - """Construct an INTERVAL. - - :param precision: optional integer precision value - :param fields: string fields specifier. allows storage of fields - to be limited, such as ``"YEAR"``, ``"MONTH"``, ``"DAY TO HOUR"``, - etc. - - .. versionadded:: 1.2 - - """ - self.precision = precision - self.fields = fields - - @classmethod - def adapt_emulated_to_native(cls, interval, **kw): - return INTERVAL(precision=interval.second_precision) - - @property - def _type_affinity(self): - return sqltypes.Interval - - def as_generic(self, allow_nulltype=False): - return sqltypes.Interval(native=True, second_precision=self.precision) - - @property - def python_type(self): - return dt.timedelta - - -PGInterval = INTERVAL - - -class BIT(sqltypes.TypeEngine[int]): - __visit_name__ = "BIT" - - def __init__(self, length=None, varying=False): - if not varying: - # BIT without VARYING defaults to length 1 - self.length = length or 1 - else: - # but BIT VARYING can be unlimited-length, so no default - self.length = length - self.varying = varying - - -PGBit = BIT - - -class TSVECTOR(sqltypes.TypeEngine[Any]): - - """The :class:`_postgresql.TSVECTOR` type implements the PostgreSQL - text search type TSVECTOR. - - It can be used to do full text queries on natural language - documents. - - .. versionadded:: 0.9.0 - - .. seealso:: - - :ref:`postgresql_match` - - """ - - __visit_name__ = "TSVECTOR" - - -class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum): - - """PostgreSQL ENUM type. - - This is a subclass of :class:`_types.Enum` which includes - support for PG's ``CREATE TYPE`` and ``DROP TYPE``. - - When the builtin type :class:`_types.Enum` is used and the - :paramref:`.Enum.native_enum` flag is left at its default of - True, the PostgreSQL backend will use a :class:`_postgresql.ENUM` - type as the implementation, so the special create/drop rules - will be used. - - The create/drop behavior of ENUM is necessarily intricate, due to the - awkward relationship the ENUM type has in relationship to the - parent table, in that it may be "owned" by just a single table, or - may be shared among many tables. - - When using :class:`_types.Enum` or :class:`_postgresql.ENUM` - in an "inline" fashion, the ``CREATE TYPE`` and ``DROP TYPE`` is emitted - corresponding to when the :meth:`_schema.Table.create` and - :meth:`_schema.Table.drop` - methods are called:: - - table = Table('sometable', metadata, - Column('some_enum', ENUM('a', 'b', 'c', name='myenum')) - ) - - table.create(engine) # will emit CREATE ENUM and CREATE TABLE - table.drop(engine) # will emit DROP TABLE and DROP ENUM - - To use a common enumerated type between multiple tables, the best - practice is to declare the :class:`_types.Enum` or - :class:`_postgresql.ENUM` independently, and associate it with the - :class:`_schema.MetaData` object itself:: - - my_enum = ENUM('a', 'b', 'c', name='myenum', metadata=metadata) - - t1 = Table('sometable_one', metadata, - Column('some_enum', myenum) - ) - - t2 = Table('sometable_two', metadata, - Column('some_enum', myenum) - ) - - When this pattern is used, care must still be taken at the level - of individual table creates. Emitting CREATE TABLE without also - specifying ``checkfirst=True`` will still cause issues:: - - t1.create(engine) # will fail: no such type 'myenum' - - If we specify ``checkfirst=True``, the individual table-level create - operation will check for the ``ENUM`` and create if not exists:: - - # will check if enum exists, and emit CREATE TYPE if not - t1.create(engine, checkfirst=True) - - When using a metadata-level ENUM type, the type will always be created - and dropped if either the metadata-wide create/drop is called:: - - metadata.create_all(engine) # will emit CREATE TYPE - metadata.drop_all(engine) # will emit DROP TYPE - - The type can also be created and dropped directly:: - - my_enum.create(engine) - my_enum.drop(engine) - - .. versionchanged:: 1.0.0 The PostgreSQL :class:`_postgresql.ENUM` type - now behaves more strictly with regards to CREATE/DROP. A metadata-level - ENUM type will only be created and dropped at the metadata level, - not the table level, with the exception of - ``table.create(checkfirst=True)``. - The ``table.drop()`` call will now emit a DROP TYPE for a table-level - enumerated type. - - """ - - native_enum = True - - def __init__(self, *enums, **kw): - """Construct an :class:`_postgresql.ENUM`. - - Arguments are the same as that of - :class:`_types.Enum`, but also including - the following parameters. - - :param create_type: Defaults to True. - Indicates that ``CREATE TYPE`` should be - emitted, after optionally checking for the - presence of the type, when the parent - table is being created; and additionally - that ``DROP TYPE`` is called when the table - is dropped. When ``False``, no check - will be performed and no ``CREATE TYPE`` - or ``DROP TYPE`` is emitted, unless - :meth:`~.postgresql.ENUM.create` - or :meth:`~.postgresql.ENUM.drop` - are called directly. - Setting to ``False`` is helpful - when invoking a creation scheme to a SQL file - without access to the actual database - - the :meth:`~.postgresql.ENUM.create` and - :meth:`~.postgresql.ENUM.drop` methods can - be used to emit SQL to a target bind. - - """ - native_enum = kw.pop("native_enum", None) - if native_enum is False: - util.warn( - "the native_enum flag does not apply to the " - "sqlalchemy.dialects.postgresql.ENUM datatype; this type " - "always refers to ENUM. Use sqlalchemy.types.Enum for " - "non-native enum." - ) - self.create_type = kw.pop("create_type", True) - super(ENUM, self).__init__(*enums, **kw) - - @classmethod - def adapt_emulated_to_native(cls, impl, **kw): - """Produce a PostgreSQL native :class:`_postgresql.ENUM` from plain - :class:`.Enum`. - - """ - kw.setdefault("validate_strings", impl.validate_strings) - kw.setdefault("name", impl.name) - kw.setdefault("schema", impl.schema) - kw.setdefault("inherit_schema", impl.inherit_schema) - kw.setdefault("metadata", impl.metadata) - kw.setdefault("_create_events", False) - kw.setdefault("values_callable", impl.values_callable) - kw.setdefault("omit_aliases", impl._omit_aliases) - return cls(**kw) - - def create(self, bind=None, checkfirst=True): - """Emit ``CREATE TYPE`` for this - :class:`_postgresql.ENUM`. - - If the underlying dialect does not support - PostgreSQL CREATE TYPE, no action is taken. - - :param bind: a connectable :class:`_engine.Engine`, - :class:`_engine.Connection`, or similar object to emit - SQL. - :param checkfirst: if ``True``, a query against - the PG catalog will be first performed to see - if the type does not exist already before - creating. - - """ - if not bind.dialect.supports_native_enum: - return - - bind._run_ddl_visitor(self.EnumGenerator, self, checkfirst=checkfirst) - - def drop(self, bind=None, checkfirst=True): - """Emit ``DROP TYPE`` for this - :class:`_postgresql.ENUM`. - - If the underlying dialect does not support - PostgreSQL DROP TYPE, no action is taken. - - :param bind: a connectable :class:`_engine.Engine`, - :class:`_engine.Connection`, or similar object to emit - SQL. - :param checkfirst: if ``True``, a query against - the PG catalog will be first performed to see - if the type actually exists before dropping. - - """ - if not bind.dialect.supports_native_enum: - return - - bind._run_ddl_visitor(self.EnumDropper, self, checkfirst=checkfirst) - - class EnumGenerator(InvokeDDLBase): - def __init__(self, dialect, connection, checkfirst=False, **kwargs): - super(ENUM.EnumGenerator, self).__init__(connection, **kwargs) - self.checkfirst = checkfirst - - def _can_create_enum(self, enum): - if not self.checkfirst: - return True - - effective_schema = self.connection.schema_for_object(enum) - - return not self.connection.dialect.has_type( - self.connection, enum.name, schema=effective_schema - ) - - def visit_enum(self, enum): - if not self._can_create_enum(enum): - return - - self.connection.execute(CreateEnumType(enum)) - - class EnumDropper(InvokeDDLBase): - def __init__(self, dialect, connection, checkfirst=False, **kwargs): - super(ENUM.EnumDropper, self).__init__(connection, **kwargs) - self.checkfirst = checkfirst - - def _can_drop_enum(self, enum): - if not self.checkfirst: - return True - - effective_schema = self.connection.schema_for_object(enum) - - return self.connection.dialect.has_type( - self.connection, enum.name, schema=effective_schema - ) - - def visit_enum(self, enum): - if not self._can_drop_enum(enum): - return - - self.connection.execute(DropEnumType(enum)) - - def get_dbapi_type(self, dbapi): - """dont return dbapi.STRING for ENUM in PostgreSQL, since that's - a different type""" - - return None - - def _check_for_name_in_memos(self, checkfirst, kw): - """Look in the 'ddl runner' for 'memos', then - note our name in that collection. - - This to ensure a particular named enum is operated - upon only once within any kind of create/drop - sequence without relying upon "checkfirst". - - """ - if not self.create_type: - return True - if "_ddl_runner" in kw: - ddl_runner = kw["_ddl_runner"] - if "_pg_enums" in ddl_runner.memo: - pg_enums = ddl_runner.memo["_pg_enums"] - else: - pg_enums = ddl_runner.memo["_pg_enums"] = set() - present = (self.schema, self.name) in pg_enums - pg_enums.add((self.schema, self.name)) - return present - else: - return False - - def _on_table_create(self, target, bind, checkfirst=False, **kw): - if ( - checkfirst - or ( - not self.metadata - and not kw.get("_is_metadata_operation", False) - ) - ) and not self._check_for_name_in_memos(checkfirst, kw): - self.create(bind=bind, checkfirst=checkfirst) - - def _on_table_drop(self, target, bind, checkfirst=False, **kw): - if ( - not self.metadata - and not kw.get("_is_metadata_operation", False) - and not self._check_for_name_in_memos(checkfirst, kw) - ): - self.drop(bind=bind, checkfirst=checkfirst) - - def _on_metadata_create(self, target, bind, checkfirst=False, **kw): - if not self._check_for_name_in_memos(checkfirst, kw): - self.create(bind=bind, checkfirst=checkfirst) - - def _on_metadata_drop(self, target, bind, checkfirst=False, **kw): - if not self._check_for_name_in_memos(checkfirst, kw): - self.drop(bind=bind, checkfirst=checkfirst) - - colspecs = { sqltypes.ARRAY: _array.ARRAY, sqltypes.Interval: INTERVAL, @@ -2997,8 +2563,19 @@ class PGIdentifierPreparer(compiler.IdentifierPreparer): class PGInspector(reflection.Inspector): + dialect: PGDialect + def get_table_oid(self, table_name, schema=None): - """Return the OID for the given table name.""" + """Return the OID for the given table name. + + :param table_name: string name of the table. For special quoting, + use :class:`.quoted_name`. + + :param schema: string schema name; if omitted, uses the default schema + of the database connection. For special quoting, + use :class:`.quoted_name`. + + """ with self._operation_context() as conn: return self.dialect.get_table_oid( @@ -3023,9 +2600,10 @@ class PGInspector(reflection.Inspector): .. versionadded:: 1.0.0 """ - schema = schema or self.default_schema_name with self._operation_context() as conn: - return self.dialect._load_enums(conn, schema) + return self.dialect._load_enums( + conn, schema, info_cache=self.info_cache + ) def get_foreign_table_names(self, schema=None): """Return a list of FOREIGN TABLE names. @@ -3038,38 +2616,29 @@ class PGInspector(reflection.Inspector): .. versionadded:: 1.0.0 """ - schema = schema or self.default_schema_name with self._operation_context() as conn: - return self.dialect._get_foreign_table_names(conn, schema) - - def get_view_names(self, schema=None, include=("plain", "materialized")): - """Return all view names in `schema`. + return self.dialect._get_foreign_table_names( + conn, schema, info_cache=self.info_cache + ) - :param schema: Optional, retrieve names from a non-default schema. - For special quoting, use :class:`.quoted_name`. + def has_type(self, type_name, schema=None, **kw): + """Return if the database has the specified type in the provided + schema. - :param include: specify which types of views to return. Passed - as a string value (for a single type) or a tuple (for any number - of types). Defaults to ``('plain', 'materialized')``. + :param type_name: the type to check. + :param schema: schema name. If None, the default schema + (typically 'public') is used. May also be set to '*' to + check in all schemas. - .. versionadded:: 1.1 + .. versionadded:: 2.0 """ - with self._operation_context() as conn: - return self.dialect.get_view_names( - conn, schema, info_cache=self.info_cache, include=include + return self.dialect.has_type( + conn, type_name, schema, info_cache=self.info_cache ) -class CreateEnumType(schema._CreateDropBase): - __visit_name__ = "create_enum_type" - - -class DropEnumType(schema._CreateDropBase): - __visit_name__ = "drop_enum_type" - - class PGExecutionContext(default.DefaultExecutionContext): def fire_sequence(self, seq, type_): return self._execute_scalar( @@ -3262,35 +2831,14 @@ class PGDialect(default.DefaultDialect): def initialize(self, connection): super(PGDialect, self).initialize(connection) - if self.server_version_info <= (8, 2): - self.delete_returning = ( - self.update_returning - ) = self.insert_returning = False - - self.supports_native_enum = self.server_version_info >= (8, 3) - if not self.supports_native_enum: - self.colspecs = self.colspecs.copy() - # pop base Enum type - self.colspecs.pop(sqltypes.Enum, None) - # psycopg2, others may have placed ENUM here as well - self.colspecs.pop(ENUM, None) - # https://www.postgresql.org/docs/9.3/static/release-9-2.html#AEN116689 self.supports_smallserial = self.server_version_info >= (9, 2) - if self.server_version_info < (8, 2): - self._backslash_escapes = False - else: - # ensure this query is not emitted on server version < 8.2 - # as it will fail - std_string = connection.exec_driver_sql( - "show standard_conforming_strings" - ).scalar() - self._backslash_escapes = std_string == "off" - - self._supports_create_index_concurrently = ( - self.server_version_info >= (8, 2) - ) + std_string = connection.exec_driver_sql( + "show standard_conforming_strings" + ).scalar() + self._backslash_escapes = std_string == "off" + self._supports_drop_index_concurrently = self.server_version_info >= ( 9, 2, @@ -3370,122 +2918,100 @@ class PGDialect(default.DefaultDialect): self.do_commit(connection.connection) def do_recover_twophase(self, connection): - resultset = connection.execute( + return connection.scalars( sql.text("SELECT gid FROM pg_prepared_xacts") - ) - return [row[0] for row in resultset] + ).all() def _get_default_schema_name(self, connection): return connection.exec_driver_sql("select current_schema()").scalar() - def has_schema(self, connection, schema): - query = ( - "select nspname from pg_namespace " "where lower(nspname)=:schema" - ) - cursor = connection.execute( - sql.text(query).bindparams( - sql.bindparam( - "schema", - str(schema.lower()), - type_=sqltypes.Unicode, - ) - ) + @reflection.cache + def has_schema(self, connection, schema, **kw): + query = select(pg_catalog.pg_namespace.c.nspname).where( + pg_catalog.pg_namespace.c.nspname == schema ) + return bool(connection.scalar(query)) - return bool(cursor.first()) - - def has_table(self, connection, table_name, schema=None): - self._ensure_has_table_connection(connection) - # seems like case gets folded in pg_class... + def _pg_class_filter_scope_schema( + self, query, schema, scope, pg_class_table=None + ): + if pg_class_table is None: + pg_class_table = pg_catalog.pg_class + query = query.join( + pg_catalog.pg_namespace, + pg_catalog.pg_namespace.c.oid == pg_class_table.c.relnamespace, + ) + if scope is ObjectScope.DEFAULT: + query = query.where(pg_class_table.c.relpersistence != "t") + elif scope is ObjectScope.TEMPORARY: + query = query.where(pg_class_table.c.relpersistence == "t") if schema is None: - cursor = connection.execute( - sql.text( - "select relname from pg_class c join pg_namespace n on " - "n.oid=c.relnamespace where " - "pg_catalog.pg_table_is_visible(c.oid) " - "and relname=:name" - ).bindparams( - sql.bindparam( - "name", - str(table_name), - type_=sqltypes.Unicode, - ) - ) + query = query.where( + pg_catalog.pg_table_is_visible(pg_class_table.c.oid), + # ignore pg_catalog schema + pg_catalog.pg_namespace.c.nspname != "pg_catalog", ) else: - cursor = connection.execute( - sql.text( - "select relname from pg_class c join pg_namespace n on " - "n.oid=c.relnamespace where n.nspname=:schema and " - "relname=:name" - ).bindparams( - sql.bindparam( - "name", - str(table_name), - type_=sqltypes.Unicode, - ), - sql.bindparam( - "schema", - str(schema), - type_=sqltypes.Unicode, - ), - ) - ) - return bool(cursor.first()) - - def has_sequence(self, connection, sequence_name, schema=None): - if schema is None: - schema = self.default_schema_name - cursor = connection.execute( - sql.text( - "SELECT relname FROM pg_class c join pg_namespace n on " - "n.oid=c.relnamespace where relkind='S' and " - "n.nspname=:schema and relname=:name" - ).bindparams( - sql.bindparam( - "name", - str(sequence_name), - type_=sqltypes.Unicode, - ), - sql.bindparam( - "schema", - str(schema), - type_=sqltypes.Unicode, - ), - ) + query = query.where(pg_catalog.pg_namespace.c.nspname == schema) + return query + + def _pg_class_relkind_condition(self, relkinds, pg_class_table=None): + if pg_class_table is None: + pg_class_table = pg_catalog.pg_class + # uses the any form instead of in otherwise postgresql complaings + # that 'IN could not convert type character to "char"' + return pg_class_table.c.relkind == sql.any_(_array.array(relkinds)) + + @lru_cache() + def _has_table_query(self, schema): + query = select(pg_catalog.pg_class.c.relname).where( + pg_catalog.pg_class.c.relname == bindparam("table_name"), + self._pg_class_relkind_condition( + pg_catalog.RELKINDS_ALL_TABLE_LIKE + ), + ) + return self._pg_class_filter_scope_schema( + query, schema, scope=ObjectScope.ANY ) - return bool(cursor.first()) + @reflection.cache + def has_table(self, connection, table_name, schema=None, **kw): + self._ensure_has_table_connection(connection) + query = self._has_table_query(schema) + return bool(connection.scalar(query, {"table_name": table_name})) - def has_type(self, connection, type_name, schema=None): - if schema is not None: - query = """ - SELECT EXISTS ( - SELECT * FROM pg_catalog.pg_type t, pg_catalog.pg_namespace n - WHERE t.typnamespace = n.oid - AND t.typname = :typname - AND n.nspname = :nspname - ) - """ - query = sql.text(query) - else: - query = """ - SELECT EXISTS ( - SELECT * FROM pg_catalog.pg_type t - WHERE t.typname = :typname - AND pg_type_is_visible(t.oid) - ) - """ - query = sql.text(query) - query = query.bindparams( - sql.bindparam("typname", str(type_name), type_=sqltypes.Unicode) + @reflection.cache + def has_sequence(self, connection, sequence_name, schema=None, **kw): + query = select(pg_catalog.pg_class.c.relname).where( + pg_catalog.pg_class.c.relkind == "S", + pg_catalog.pg_class.c.relname == sequence_name, ) - if schema is not None: - query = query.bindparams( - sql.bindparam("nspname", str(schema), type_=sqltypes.Unicode) + query = self._pg_class_filter_scope_schema( + query, schema, scope=ObjectScope.ANY + ) + return bool(connection.scalar(query)) + + @reflection.cache + def has_type(self, connection, type_name, schema=None, **kw): + query = ( + select(pg_catalog.pg_type.c.typname) + .join( + pg_catalog.pg_namespace, + pg_catalog.pg_namespace.c.oid + == pg_catalog.pg_type.c.typnamespace, ) - cursor = connection.execute(query) - return bool(cursor.scalar()) + .where(pg_catalog.pg_type.c.typname == type_name) + ) + if schema is None: + query = query.where( + pg_catalog.pg_type_is_visible(pg_catalog.pg_type.c.oid), + # ignore pg_catalog schema + pg_catalog.pg_namespace.c.nspname != "pg_catalog", + ) + elif schema != "*": + query = query.where(pg_catalog.pg_namespace.c.nspname == schema) + + return bool(connection.scalar(query)) def _get_server_version_info(self, connection): v = connection.exec_driver_sql("select pg_catalog.version()").scalar() @@ -3502,229 +3028,300 @@ class PGDialect(default.DefaultDialect): @reflection.cache def get_table_oid(self, connection, table_name, schema=None, **kw): - """Fetch the oid for schema.table_name. - - Several reflection methods require the table oid. The idea for using - this method is that it can be fetched one time and cached for - subsequent calls. - - """ - table_oid = None - if schema is not None: - schema_where_clause = "n.nspname = :schema" - else: - schema_where_clause = "pg_catalog.pg_table_is_visible(c.oid)" - query = ( - """ - SELECT c.oid - FROM pg_catalog.pg_class c - LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace - WHERE (%s) - AND c.relname = :table_name AND c.relkind in - ('r', 'v', 'm', 'f', 'p') - """ - % schema_where_clause + """Fetch the oid for schema.table_name.""" + query = select(pg_catalog.pg_class.c.oid).where( + pg_catalog.pg_class.c.relname == table_name, + self._pg_class_relkind_condition( + pg_catalog.RELKINDS_ALL_TABLE_LIKE + ), ) - # Since we're binding to unicode, table_name and schema_name must be - # unicode. - table_name = str(table_name) - if schema is not None: - schema = str(schema) - s = sql.text(query).bindparams(table_name=sqltypes.Unicode) - s = s.columns(oid=sqltypes.Integer) - if schema: - s = s.bindparams(sql.bindparam("schema", type_=sqltypes.Unicode)) - c = connection.execute(s, dict(table_name=table_name, schema=schema)) - table_oid = c.scalar() + query = self._pg_class_filter_scope_schema( + query, schema, scope=ObjectScope.ANY + ) + table_oid = connection.scalar(query) if table_oid is None: - raise exc.NoSuchTableError(table_name) + raise exc.NoSuchTableError( + f"{schema}.{table_name}" if schema else table_name + ) return table_oid @reflection.cache def get_schema_names(self, connection, **kw): - result = connection.execute( - sql.text( - "SELECT nspname FROM pg_namespace " - "WHERE nspname NOT LIKE 'pg_%' " - "ORDER BY nspname" - ).columns(nspname=sqltypes.Unicode) + query = ( + select(pg_catalog.pg_namespace.c.nspname) + .where(pg_catalog.pg_namespace.c.nspname.not_like("pg_%")) + .order_by(pg_catalog.pg_namespace.c.nspname) + ) + return connection.scalars(query).all() + + def _get_relnames_for_relkinds(self, connection, schema, relkinds, scope): + query = select(pg_catalog.pg_class.c.relname).where( + self._pg_class_relkind_condition(relkinds) ) - return [name for name, in result] + query = self._pg_class_filter_scope_schema(query, schema, scope=scope) + return connection.scalars(query).all() @reflection.cache def get_table_names(self, connection, schema=None, **kw): - result = connection.execute( - sql.text( - "SELECT c.relname FROM pg_class c " - "JOIN pg_namespace n ON n.oid = c.relnamespace " - "WHERE n.nspname = :schema AND c.relkind in ('r', 'p')" - ).columns(relname=sqltypes.Unicode), - dict( - schema=schema - if schema is not None - else self.default_schema_name - ), + return self._get_relnames_for_relkinds( + connection, + schema, + pg_catalog.RELKINDS_TABLE_NO_FOREIGN, + scope=ObjectScope.DEFAULT, + ) + + @reflection.cache + def get_temp_table_names(self, connection, **kw): + return self._get_relnames_for_relkinds( + connection, + schema=None, + relkinds=pg_catalog.RELKINDS_TABLE_NO_FOREIGN, + scope=ObjectScope.TEMPORARY, ) - return [name for name, in result] @reflection.cache def _get_foreign_table_names(self, connection, schema=None, **kw): - result = connection.execute( - sql.text( - "SELECT c.relname FROM pg_class c " - "JOIN pg_namespace n ON n.oid = c.relnamespace " - "WHERE n.nspname = :schema AND c.relkind = 'f'" - ).columns(relname=sqltypes.Unicode), - dict( - schema=schema - if schema is not None - else self.default_schema_name - ), + return self._get_relnames_for_relkinds( + connection, schema, relkinds=("f",), scope=ObjectScope.ANY ) - return [name for name, in result] @reflection.cache - def get_view_names( - self, connection, schema=None, include=("plain", "materialized"), **kw - ): + def get_view_names(self, connection, schema=None, **kw): + return self._get_relnames_for_relkinds( + connection, + schema, + pg_catalog.RELKINDS_VIEW, + scope=ObjectScope.DEFAULT, + ) - include_kind = {"plain": "v", "materialized": "m"} - try: - kinds = [include_kind[i] for i in util.to_list(include)] - except KeyError: - raise ValueError( - "include %r unknown, needs to be a sequence containing " - "one or both of 'plain' and 'materialized'" % (include,) - ) - if not kinds: - raise ValueError( - "empty include, needs to be a sequence containing " - "one or both of 'plain' and 'materialized'" - ) + @reflection.cache + def get_materialized_view_names(self, connection, schema=None, **kw): + return self._get_relnames_for_relkinds( + connection, + schema, + pg_catalog.RELKINDS_MAT_VIEW, + scope=ObjectScope.DEFAULT, + ) - result = connection.execute( - sql.text( - "SELECT c.relname FROM pg_class c " - "JOIN pg_namespace n ON n.oid = c.relnamespace " - "WHERE n.nspname = :schema AND c.relkind IN (%s)" - % (", ".join("'%s'" % elem for elem in kinds)) - ).columns(relname=sqltypes.Unicode), - dict( - schema=schema - if schema is not None - else self.default_schema_name - ), + @reflection.cache + def get_temp_view_names(self, connection, schema=None, **kw): + return self._get_relnames_for_relkinds( + connection, + schema, + # NOTE: do not include temp materialzied views (that do not + # seem to be a thing at least up to version 14) + pg_catalog.RELKINDS_VIEW, + scope=ObjectScope.TEMPORARY, ) - return [name for name, in result] @reflection.cache def get_sequence_names(self, connection, schema=None, **kw): - if not schema: - schema = self.default_schema_name - cursor = connection.execute( - sql.text( - "SELECT relname FROM pg_class c join pg_namespace n on " - "n.oid=c.relnamespace where relkind='S' and " - "n.nspname=:schema" - ).bindparams( - sql.bindparam( - "schema", - str(schema), - type_=sqltypes.Unicode, - ), - ) + return self._get_relnames_for_relkinds( + connection, schema, relkinds=("S",), scope=ObjectScope.ANY ) - return [row[0] for row in cursor] @reflection.cache def get_view_definition(self, connection, view_name, schema=None, **kw): - view_def = connection.scalar( - sql.text( - "SELECT pg_get_viewdef(c.oid) view_def FROM pg_class c " - "JOIN pg_namespace n ON n.oid = c.relnamespace " - "WHERE n.nspname = :schema AND c.relname = :view_name " - "AND c.relkind IN ('v', 'm')" - ).columns(view_def=sqltypes.Unicode), - dict( - schema=schema - if schema is not None - else self.default_schema_name, - view_name=view_name, - ), + query = ( + select(pg_catalog.pg_get_viewdef(pg_catalog.pg_class.c.oid)) + .select_from(pg_catalog.pg_class) + .where( + pg_catalog.pg_class.c.relname == view_name, + self._pg_class_relkind_condition( + pg_catalog.RELKINDS_VIEW + pg_catalog.RELKINDS_MAT_VIEW + ), + ) ) - return view_def + query = self._pg_class_filter_scope_schema( + query, schema, scope=ObjectScope.ANY + ) + res = connection.scalar(query) + if res is None: + raise exc.NoSuchTableError( + f"{schema}.{view_name}" if schema else view_name + ) + else: + return res + + def _value_or_raise(self, data, table, schema): + try: + return dict(data)[(schema, table)] + except KeyError: + raise exc.NoSuchTableError( + f"{schema}.{table}" if schema else table + ) from None + + def _prepare_filter_names(self, filter_names): + if filter_names: + return True, {"filter_names": filter_names} + else: + return False, {} + + def _kind_to_relkinds(self, kind: ObjectKind) -> tuple[str, ...]: + if kind is ObjectKind.ANY: + return pg_catalog.RELKINDS_ALL_TABLE_LIKE + relkinds = () + if ObjectKind.TABLE in kind: + relkinds += pg_catalog.RELKINDS_TABLE + if ObjectKind.VIEW in kind: + relkinds += pg_catalog.RELKINDS_VIEW + if ObjectKind.MATERIALIZED_VIEW in kind: + relkinds += pg_catalog.RELKINDS_MAT_VIEW + return relkinds @reflection.cache def get_columns(self, connection, table_name, schema=None, **kw): - - table_oid = self.get_table_oid( - connection, table_name, schema, info_cache=kw.get("info_cache") + data = self.get_multi_columns( + connection, + schema=schema, + filter_names=[table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, ) + return self._value_or_raise(data, table_name, schema) + @lru_cache() + def _columns_query(self, schema, has_filter_names, scope, kind): + # NOTE: the query with the default and identity options scalar + # subquery is faster than trying to use outer joins for them generated = ( - "a.attgenerated as generated" + pg_catalog.pg_attribute.c.attgenerated.label("generated") if self.server_version_info >= (12,) - else "NULL as generated" + else sql.null().label("generated") ) if self.server_version_info >= (10,): - # a.attidentity != '' is required or it will reflect also - # serial columns as identity. - identity = """\ - (SELECT json_build_object( - 'always', a.attidentity = 'a', - 'start', s.seqstart, - 'increment', s.seqincrement, - 'minvalue', s.seqmin, - 'maxvalue', s.seqmax, - 'cache', s.seqcache, - 'cycle', s.seqcycle) - FROM pg_catalog.pg_sequence s - JOIN pg_catalog.pg_class c on s.seqrelid = c."oid" - WHERE c.relkind = 'S' - AND a.attidentity != '' - AND s.seqrelid = pg_catalog.pg_get_serial_sequence( - a.attrelid::regclass::text, a.attname - )::regclass::oid - ) as identity_options\ - """ + # join lateral performs worse (~2x slower) than a scalar_subquery + identity = ( + select( + sql.func.json_build_object( + "always", + pg_catalog.pg_attribute.c.attidentity == "a", + "start", + pg_catalog.pg_sequence.c.seqstart, + "increment", + pg_catalog.pg_sequence.c.seqincrement, + "minvalue", + pg_catalog.pg_sequence.c.seqmin, + "maxvalue", + pg_catalog.pg_sequence.c.seqmax, + "cache", + pg_catalog.pg_sequence.c.seqcache, + "cycle", + pg_catalog.pg_sequence.c.seqcycle, + ) + ) + .select_from(pg_catalog.pg_sequence) + .where( + # attidentity != '' is required or it will reflect also + # serial columns as identity. + pg_catalog.pg_attribute.c.attidentity != "", + pg_catalog.pg_sequence.c.seqrelid + == sql.cast( + sql.cast( + pg_catalog.pg_get_serial_sequence( + sql.cast( + sql.cast( + pg_catalog.pg_attribute.c.attrelid, + REGCLASS, + ), + TEXT, + ), + pg_catalog.pg_attribute.c.attname, + ), + REGCLASS, + ), + OID, + ), + ) + .correlate(pg_catalog.pg_attribute) + .scalar_subquery() + .label("identity_options") + ) else: - identity = "NULL as identity_options" - - SQL_COLS = """ - SELECT a.attname, - pg_catalog.format_type(a.atttypid, a.atttypmod), - ( - SELECT pg_catalog.pg_get_expr(d.adbin, d.adrelid) - FROM pg_catalog.pg_attrdef d - WHERE d.adrelid = a.attrelid AND d.adnum = a.attnum - AND a.atthasdef - ) AS DEFAULT, - a.attnotnull, - a.attrelid as table_oid, - pgd.description as comment, - %s, - %s - FROM pg_catalog.pg_attribute a - LEFT JOIN pg_catalog.pg_description pgd ON ( - pgd.objoid = a.attrelid AND pgd.objsubid = a.attnum) - WHERE a.attrelid = :table_oid - AND a.attnum > 0 AND NOT a.attisdropped - ORDER BY a.attnum - """ % ( - generated, - identity, + identity = sql.null().label("identity_options") + + # join lateral performs the same as scalar_subquery here + default = ( + select( + pg_catalog.pg_get_expr( + pg_catalog.pg_attrdef.c.adbin, + pg_catalog.pg_attrdef.c.adrelid, + ) + ) + .select_from(pg_catalog.pg_attrdef) + .where( + pg_catalog.pg_attrdef.c.adrelid + == pg_catalog.pg_attribute.c.attrelid, + pg_catalog.pg_attrdef.c.adnum + == pg_catalog.pg_attribute.c.attnum, + pg_catalog.pg_attribute.c.atthasdef, + ) + .correlate(pg_catalog.pg_attribute) + .scalar_subquery() + .label("default") ) - s = ( - sql.text(SQL_COLS) - .bindparams(sql.bindparam("table_oid", type_=sqltypes.Integer)) - .columns(attname=sqltypes.Unicode, default=sqltypes.Unicode) + relkinds = self._kind_to_relkinds(kind) + query = ( + select( + pg_catalog.pg_attribute.c.attname.label("name"), + pg_catalog.format_type( + pg_catalog.pg_attribute.c.atttypid, + pg_catalog.pg_attribute.c.atttypmod, + ).label("format_type"), + default, + pg_catalog.pg_attribute.c.attnotnull.label("not_null"), + pg_catalog.pg_class.c.relname.label("table_name"), + pg_catalog.pg_description.c.description.label("comment"), + generated, + identity, + ) + .select_from(pg_catalog.pg_class) + # NOTE: postgresql support table with no user column, meaning + # there is no row with pg_attribute.attnum > 0. use a left outer + # join to avoid filtering these tables. + .outerjoin( + pg_catalog.pg_attribute, + sql.and_( + pg_catalog.pg_class.c.oid + == pg_catalog.pg_attribute.c.attrelid, + pg_catalog.pg_attribute.c.attnum > 0, + ~pg_catalog.pg_attribute.c.attisdropped, + ), + ) + .outerjoin( + pg_catalog.pg_description, + sql.and_( + pg_catalog.pg_description.c.objoid + == pg_catalog.pg_attribute.c.attrelid, + pg_catalog.pg_description.c.objsubid + == pg_catalog.pg_attribute.c.attnum, + ), + ) + .where(self._pg_class_relkind_condition(relkinds)) + .order_by( + pg_catalog.pg_class.c.relname, pg_catalog.pg_attribute.c.attnum + ) ) - c = connection.execute(s, dict(table_oid=table_oid)) - rows = c.fetchall() + query = self._pg_class_filter_scope_schema(query, schema, scope=scope) + if has_filter_names: + query = query.where( + pg_catalog.pg_class.c.relname.in_(bindparam("filter_names")) + ) + return query + + def get_multi_columns( + self, connection, schema, filter_names, scope, kind, **kw + ): + has_filter_names, params = self._prepare_filter_names(filter_names) + query = self._columns_query(schema, has_filter_names, scope, kind) + rows = connection.execute(query, params).mappings() # dictionary with (name, ) if default search path or (schema, name) # as keys - domains = self._load_domains(connection) + domains = self._load_domains( + connection, info_cache=kw.get("info_cache") + ) # dictionary with (name, ) if default search path or (schema, name) # as keys @@ -3732,257 +3329,340 @@ class PGDialect(default.DefaultDialect): ((rec["name"],), rec) if rec["visible"] else ((rec["schema"], rec["name"]), rec) - for rec in self._load_enums(connection, schema="*") + for rec in self._load_enums( + connection, schema="*", info_cache=kw.get("info_cache") + ) ) - # format columns - columns = [] - - for ( - name, - format_type, - default_, - notnull, - table_oid, - comment, - generated, - identity, - ) in rows: - column_info = self._get_column_info( - name, - format_type, - default_, - notnull, - domains, - enums, - schema, - comment, - generated, - identity, - ) - columns.append(column_info) - return columns + columns = self._get_columns_info(rows, domains, enums, schema) + + return columns.items() + + def _get_columns_info(self, rows, domains, enums, schema): + array_type_pattern = re.compile(r"\[\]$") + attype_pattern = re.compile(r"\(.*\)") + charlen_pattern = re.compile(r"\(([\d,]+)\)") + args_pattern = re.compile(r"\((.*)\)") + args_split_pattern = re.compile(r"\s*,\s*") - def _get_column_info( - self, - name, - format_type, - default, - notnull, - domains, - enums, - schema, - comment, - generated, - identity, - ): def _handle_array_type(attype): return ( # strip '[]' from integer[], etc. - re.sub(r"\[\]$", "", attype), + array_type_pattern.sub("", attype), attype.endswith("[]"), ) - # strip (*) from character varying(5), timestamp(5) - # with time zone, geometry(POLYGON), etc. - attype = re.sub(r"\(.*\)", "", format_type) + columns = defaultdict(list) + for row_dict in rows: + # ensure that each table has an entry, even if it has no columns + if row_dict["name"] is None: + columns[ + (schema, row_dict["table_name"]) + ] = ReflectionDefaults.columns() + continue + table_cols = columns[(schema, row_dict["table_name"])] - # strip '[]' from integer[], etc. and check if an array - attype, is_array = _handle_array_type(attype) + format_type = row_dict["format_type"] + default = row_dict["default"] + name = row_dict["name"] + generated = row_dict["generated"] + identity = row_dict["identity_options"] - # strip quotes from case sensitive enum or domain names - enum_or_domain_key = tuple(util.quoted_token_parser(attype)) + # strip (*) from character varying(5), timestamp(5) + # with time zone, geometry(POLYGON), etc. + attype = attype_pattern.sub("", format_type) - nullable = not notnull + # strip '[]' from integer[], etc. and check if an array + attype, is_array = _handle_array_type(attype) - charlen = re.search(r"\(([\d,]+)\)", format_type) - if charlen: - charlen = charlen.group(1) - args = re.search(r"\((.*)\)", format_type) - if args and args.group(1): - args = tuple(re.split(r"\s*,\s*", args.group(1))) - else: - args = () - kwargs = {} + # strip quotes from case sensitive enum or domain names + enum_or_domain_key = tuple(util.quoted_token_parser(attype)) + + nullable = not row_dict["not_null"] - if attype == "numeric": + charlen = charlen_pattern.search(format_type) if charlen: - prec, scale = charlen.split(",") - args = (int(prec), int(scale)) + charlen = charlen.group(1) + args = args_pattern.search(format_type) + if args and args.group(1): + args = tuple(args_split_pattern.split(args.group(1))) else: args = () - elif attype == "double precision": - args = (53,) - elif attype == "integer": - args = () - elif attype in ("timestamp with time zone", "time with time zone"): - kwargs["timezone"] = True - if charlen: - kwargs["precision"] = int(charlen) - args = () - elif attype in ( - "timestamp without time zone", - "time without time zone", - "time", - ): - kwargs["timezone"] = False - if charlen: - kwargs["precision"] = int(charlen) - args = () - elif attype == "bit varying": - kwargs["varying"] = True - if charlen: + kwargs = {} + + if attype == "numeric": + if charlen: + prec, scale = charlen.split(",") + args = (int(prec), int(scale)) + else: + args = () + elif attype == "double precision": + args = (53,) + elif attype == "integer": + args = () + elif attype in ("timestamp with time zone", "time with time zone"): + kwargs["timezone"] = True + if charlen: + kwargs["precision"] = int(charlen) + args = () + elif attype in ( + "timestamp without time zone", + "time without time zone", + "time", + ): + kwargs["timezone"] = False + if charlen: + kwargs["precision"] = int(charlen) + args = () + elif attype == "bit varying": + kwargs["varying"] = True + if charlen: + args = (int(charlen),) + else: + args = () + elif attype.startswith("interval"): + field_match = re.match(r"interval (.+)", attype, re.I) + if charlen: + kwargs["precision"] = int(charlen) + if field_match: + kwargs["fields"] = field_match.group(1) + attype = "interval" + args = () + elif charlen: args = (int(charlen),) + + while True: + # looping here to suit nested domains + if attype in self.ischema_names: + coltype = self.ischema_names[attype] + break + elif enum_or_domain_key in enums: + enum = enums[enum_or_domain_key] + coltype = ENUM + kwargs["name"] = enum["name"] + if not enum["visible"]: + kwargs["schema"] = enum["schema"] + args = tuple(enum["labels"]) + break + elif enum_or_domain_key in domains: + domain = domains[enum_or_domain_key] + attype = domain["attype"] + attype, is_array = _handle_array_type(attype) + # strip quotes from case sensitive enum or domain names + enum_or_domain_key = tuple( + util.quoted_token_parser(attype) + ) + # A table can't override a not null on the domain, + # but can override nullable + nullable = nullable and domain["nullable"] + if domain["default"] and not default: + # It can, however, override the default + # value, but can't set it to null. + default = domain["default"] + continue + else: + coltype = None + break + + if coltype: + coltype = coltype(*args, **kwargs) + if is_array: + coltype = self.ischema_names["_array"](coltype) else: - args = () - elif attype.startswith("interval"): - field_match = re.match(r"interval (.+)", attype, re.I) - if charlen: - kwargs["precision"] = int(charlen) - if field_match: - kwargs["fields"] = field_match.group(1) - attype = "interval" - args = () - elif charlen: - args = (int(charlen),) - - while True: - # looping here to suit nested domains - if attype in self.ischema_names: - coltype = self.ischema_names[attype] - break - elif enum_or_domain_key in enums: - enum = enums[enum_or_domain_key] - coltype = ENUM - kwargs["name"] = enum["name"] - if not enum["visible"]: - kwargs["schema"] = enum["schema"] - args = tuple(enum["labels"]) - break - elif enum_or_domain_key in domains: - domain = domains[enum_or_domain_key] - attype = domain["attype"] - attype, is_array = _handle_array_type(attype) - # strip quotes from case sensitive enum or domain names - enum_or_domain_key = tuple(util.quoted_token_parser(attype)) - # A table can't override a not null on the domain, - # but can override nullable - nullable = nullable and domain["nullable"] - if domain["default"] and not default: - # It can, however, override the default - # value, but can't set it to null. - default = domain["default"] - continue + util.warn( + "Did not recognize type '%s' of column '%s'" + % (attype, name) + ) + coltype = sqltypes.NULLTYPE + + # If a zero byte or blank string depending on driver (is also + # absent for older PG versions), then not a generated column. + # Otherwise, s = stored. (Other values might be added in the + # future.) + if generated not in (None, "", b"\x00"): + computed = dict( + sqltext=default, persisted=generated in ("s", b"s") + ) + default = None else: - coltype = None - break + computed = None - if coltype: - coltype = coltype(*args, **kwargs) - if is_array: - coltype = self.ischema_names["_array"](coltype) - else: - util.warn( - "Did not recognize type '%s' of column '%s'" % (attype, name) + # adjust the default value + autoincrement = False + if default is not None: + match = re.search(r"""(nextval\(')([^']+)('.*$)""", default) + if match is not None: + if issubclass(coltype._type_affinity, sqltypes.Integer): + autoincrement = True + # the default is related to a Sequence + if "." not in match.group(2) and schema is not None: + # unconditionally quote the schema name. this could + # later be enhanced to obey quoting rules / + # "quote schema" + default = ( + match.group(1) + + ('"%s"' % schema) + + "." + + match.group(2) + + match.group(3) + ) + + column_info = { + "name": name, + "type": coltype, + "nullable": nullable, + "default": default, + "autoincrement": autoincrement or identity is not None, + "comment": row_dict["comment"], + } + if computed is not None: + column_info["computed"] = computed + if identity is not None: + column_info["identity"] = identity + + table_cols.append(column_info) + + return columns + + @lru_cache() + def _table_oids_query(self, schema, has_filter_names, scope, kind): + relkinds = self._kind_to_relkinds(kind) + oid_q = select( + pg_catalog.pg_class.c.oid, pg_catalog.pg_class.c.relname + ).where(self._pg_class_relkind_condition(relkinds)) + oid_q = self._pg_class_filter_scope_schema(oid_q, schema, scope=scope) + + if has_filter_names: + oid_q = oid_q.where( + pg_catalog.pg_class.c.relname.in_(bindparam("filter_names")) ) - coltype = sqltypes.NULLTYPE - - # If a zero byte or blank string depending on driver (is also absent - # for older PG versions), then not a generated column. Otherwise, s = - # stored. (Other values might be added in the future.) - if generated not in (None, "", b"\x00"): - computed = dict( - sqltext=default, persisted=generated in ("s", b"s") + return oid_q + + @reflection.flexi_cache( + ("schema", InternalTraversal.dp_string), + ("filter_names", InternalTraversal.dp_string_list), + ("kind", InternalTraversal.dp_plain_obj), + ("scope", InternalTraversal.dp_plain_obj), + ) + def _get_table_oids( + self, connection, schema, filter_names, scope, kind, **kw + ): + has_filter_names, params = self._prepare_filter_names(filter_names) + oid_q = self._table_oids_query(schema, has_filter_names, scope, kind) + result = connection.execute(oid_q, params) + return result.all() + + @util.memoized_property + def _constraint_query(self): + con_sq = ( + select( + pg_catalog.pg_constraint.c.conrelid, + pg_catalog.pg_constraint.c.conname, + sql.func.unnest(pg_catalog.pg_constraint.c.conkey).label( + "attnum" + ), + sql.func.generate_subscripts( + pg_catalog.pg_constraint.c.conkey, 1 + ).label("ord"), ) - default = None - else: - computed = None - - # adjust the default value - autoincrement = False - if default is not None: - match = re.search(r"""(nextval\(')([^']+)('.*$)""", default) - if match is not None: - if issubclass(coltype._type_affinity, sqltypes.Integer): - autoincrement = True - # the default is related to a Sequence - sch = schema - if "." not in match.group(2) and sch is not None: - # unconditionally quote the schema name. this could - # later be enhanced to obey quoting rules / - # "quote schema" - default = ( - match.group(1) - + ('"%s"' % sch) - + "." - + match.group(2) - + match.group(3) - ) + .where( + pg_catalog.pg_constraint.c.contype == bindparam("contype"), + pg_catalog.pg_constraint.c.conrelid.in_(bindparam("oids")), + ) + .subquery("con") + ) - column_info = dict( - name=name, - type=coltype, - nullable=nullable, - default=default, - autoincrement=autoincrement or identity is not None, - comment=comment, + attr_sq = ( + select( + con_sq.c.conrelid, + con_sq.c.conname, + pg_catalog.pg_attribute.c.attname, + ) + .select_from(pg_catalog.pg_attribute) + .join( + con_sq, + sql.and_( + pg_catalog.pg_attribute.c.attnum == con_sq.c.attnum, + pg_catalog.pg_attribute.c.attrelid == con_sq.c.conrelid, + ), + ) + .order_by(con_sq.c.conname, con_sq.c.ord) + .subquery("attr") ) - if computed is not None: - column_info["computed"] = computed - if identity is not None: - column_info["identity"] = identity - return column_info - @reflection.cache - def get_pk_constraint(self, connection, table_name, schema=None, **kw): - table_oid = self.get_table_oid( - connection, table_name, schema, info_cache=kw.get("info_cache") + return ( + select( + attr_sq.c.conrelid, + sql.func.array_agg(attr_sq.c.attname).label("cols"), + attr_sq.c.conname, + ) + .group_by(attr_sq.c.conrelid, attr_sq.c.conname) + .order_by(attr_sq.c.conrelid, attr_sq.c.conname) ) - if self.server_version_info < (8, 4): - PK_SQL = """ - SELECT a.attname - FROM - pg_class t - join pg_index ix on t.oid = ix.indrelid - join pg_attribute a - on t.oid=a.attrelid AND %s - WHERE - t.oid = :table_oid and ix.indisprimary = 't' - ORDER BY a.attnum - """ % self._pg_index_any( - "a.attnum", "ix.indkey" + def _reflect_constraint( + self, connection, contype, schema, filter_names, scope, kind, **kw + ): + table_oids = self._get_table_oids( + connection, schema, filter_names, scope, kind, **kw + ) + batches = list(table_oids) + + while batches: + batch = batches[0:3000] + batches[0:3000] = [] + + result = connection.execute( + self._constraint_query, + {"oids": [r[0] for r in batch], "contype": contype}, ) - else: - # unnest() and generate_subscripts() both introduced in - # version 8.4 - PK_SQL = """ - SELECT a.attname - FROM pg_attribute a JOIN ( - SELECT unnest(ix.indkey) attnum, - generate_subscripts(ix.indkey, 1) ord - FROM pg_index ix - WHERE ix.indrelid = :table_oid AND ix.indisprimary - ) k ON a.attnum=k.attnum - WHERE a.attrelid = :table_oid - ORDER BY k.ord - """ - t = sql.text(PK_SQL).columns(attname=sqltypes.Unicode) - c = connection.execute(t, dict(table_oid=table_oid)) - cols = [r[0] for r in c.fetchall()] - - PK_CONS_SQL = """ - SELECT conname - FROM pg_catalog.pg_constraint r - WHERE r.conrelid = :table_oid AND r.contype = 'p' - ORDER BY 1 - """ - t = sql.text(PK_CONS_SQL).columns(conname=sqltypes.Unicode) - c = connection.execute(t, dict(table_oid=table_oid)) - name = c.scalar() + result_by_oid = defaultdict(list) + for oid, cols, constraint_name in result: + result_by_oid[oid].append((cols, constraint_name)) + + for oid, tablename in batch: + for_oid = result_by_oid.get(oid, ()) + if for_oid: + for cols, constraint in for_oid: + yield tablename, cols, constraint + else: + yield tablename, None, None + + @reflection.cache + def get_pk_constraint(self, connection, table_name, schema=None, **kw): + data = self.get_multi_pk_constraint( + connection, + schema=schema, + filter_names=[table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, + ) + return self._value_or_raise(data, table_name, schema) + + def get_multi_pk_constraint( + self, connection, schema, filter_names, scope, kind, **kw + ): + result = self._reflect_constraint( + connection, "p", schema, filter_names, scope, kind, **kw + ) - return {"constrained_columns": cols, "name": name} + # only a single pk can be present for each table. Return an entry + # even if a table has no primary key + default = ReflectionDefaults.pk_constraint + return ( + ( + (schema, table_name), + { + "constrained_columns": [] if cols is None else cols, + "name": pk_name, + } + if pk_name is not None + else default(), + ) + for (table_name, cols, pk_name) in result + ) @reflection.cache def get_foreign_keys( @@ -3993,27 +3673,71 @@ class PGDialect(default.DefaultDialect): postgresql_ignore_search_path=False, **kw, ): - preparer = self.identifier_preparer - table_oid = self.get_table_oid( - connection, table_name, schema, info_cache=kw.get("info_cache") + data = self.get_multi_foreign_keys( + connection, + schema=schema, + filter_names=[table_name], + postgresql_ignore_search_path=postgresql_ignore_search_path, + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, ) + return self._value_or_raise(data, table_name, schema) - FK_SQL = """ - SELECT r.conname, - pg_catalog.pg_get_constraintdef(r.oid, true) as condef, - n.nspname as conschema - FROM pg_catalog.pg_constraint r, - pg_namespace n, - pg_class c - - WHERE r.conrelid = :table AND - r.contype = 'f' AND - c.oid = confrelid AND - n.oid = c.relnamespace - ORDER BY 1 - """ - # https://www.postgresql.org/docs/9.0/static/sql-createtable.html - FK_REGEX = re.compile( + @lru_cache() + def _foreing_key_query(self, schema, has_filter_names, scope, kind): + pg_class_ref = pg_catalog.pg_class.alias("cls_ref") + pg_namespace_ref = pg_catalog.pg_namespace.alias("nsp_ref") + relkinds = self._kind_to_relkinds(kind) + query = ( + select( + pg_catalog.pg_class.c.relname, + pg_catalog.pg_constraint.c.conname, + sql.case( + ( + pg_catalog.pg_constraint.c.oid.is_not(None), + pg_catalog.pg_get_constraintdef( + pg_catalog.pg_constraint.c.oid, True + ), + ), + else_=None, + ), + pg_namespace_ref.c.nspname, + ) + .select_from(pg_catalog.pg_class) + .outerjoin( + pg_catalog.pg_constraint, + sql.and_( + pg_catalog.pg_class.c.oid + == pg_catalog.pg_constraint.c.conrelid, + pg_catalog.pg_constraint.c.contype == "f", + ), + ) + .outerjoin( + pg_class_ref, + pg_class_ref.c.oid == pg_catalog.pg_constraint.c.confrelid, + ) + .outerjoin( + pg_namespace_ref, + pg_class_ref.c.relnamespace == pg_namespace_ref.c.oid, + ) + .order_by( + pg_catalog.pg_class.c.relname, + pg_catalog.pg_constraint.c.conname, + ) + .where(self._pg_class_relkind_condition(relkinds)) + ) + query = self._pg_class_filter_scope_schema(query, schema, scope) + if has_filter_names: + query = query.where( + pg_catalog.pg_class.c.relname.in_(bindparam("filter_names")) + ) + return query + + @util.memoized_property + def _fk_regex_pattern(self): + # https://www.postgresql.org/docs/14.0/static/sql-createtable.html + return re.compile( r"FOREIGN KEY \((.*?)\) REFERENCES (?:(.*?)\.)?(.*?)\((.*?)\)" r"[\s]?(MATCH (FULL|PARTIAL|SIMPLE)+)?" r"[\s]?(ON UPDATE " @@ -4024,12 +3748,33 @@ class PGDialect(default.DefaultDialect): r"[\s]?(INITIALLY (DEFERRED|IMMEDIATE)+)?" ) - t = sql.text(FK_SQL).columns( - conname=sqltypes.Unicode, condef=sqltypes.Unicode - ) - c = connection.execute(t, dict(table=table_oid)) - fkeys = [] - for conname, condef, conschema in c.fetchall(): + def get_multi_foreign_keys( + self, + connection, + schema, + filter_names, + scope, + kind, + postgresql_ignore_search_path=False, + **kw, + ): + preparer = self.identifier_preparer + + has_filter_names, params = self._prepare_filter_names(filter_names) + query = self._foreing_key_query(schema, has_filter_names, scope, kind) + result = connection.execute(query, params) + + FK_REGEX = self._fk_regex_pattern + + fkeys = defaultdict(list) + default = ReflectionDefaults.foreign_keys + for table_name, conname, condef, conschema in result: + # ensure that each table has an entry, even if it has + # no foreign keys + if conname is None: + fkeys[(schema, table_name)] = default() + continue + table_fks = fkeys[(schema, table_name)] m = re.search(FK_REGEX, condef).groups() ( @@ -4096,317 +3841,406 @@ class PGDialect(default.DefaultDialect): "referred_columns": referred_columns, "options": options, } - fkeys.append(fkey_d) - return fkeys - - def _pg_index_any(self, col, compare_to): - if self.server_version_info < (8, 1): - # https://www.postgresql.org/message-id/10279.1124395722@sss.pgh.pa.us - # "In CVS tip you could replace this with "attnum = ANY (indkey)". - # Unfortunately, most array support doesn't work on int2vector in - # pre-8.1 releases, so I think you're kinda stuck with the above - # for now. - # regards, tom lane" - return "(%s)" % " OR ".join( - "%s[%d] = %s" % (compare_to, ind, col) for ind in range(0, 10) - ) - else: - return "%s = ANY(%s)" % (col, compare_to) + table_fks.append(fkey_d) + return fkeys.items() @reflection.cache - def get_indexes(self, connection, table_name, schema, **kw): - table_oid = self.get_table_oid( - connection, table_name, schema, info_cache=kw.get("info_cache") + def get_indexes(self, connection, table_name, schema=None, **kw): + data = self.get_multi_indexes( + connection, + schema=schema, + filter_names=[table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, ) + return self._value_or_raise(data, table_name, schema) - # cast indkey as varchar since it's an int2vector, - # returned as a list by some drivers such as pypostgresql - - if self.server_version_info < (8, 5): - IDX_SQL = """ - SELECT - i.relname as relname, - ix.indisunique, ix.indexprs, ix.indpred, - a.attname, a.attnum, NULL, ix.indkey%s, - %s, %s, am.amname, - NULL as indnkeyatts - FROM - pg_class t - join pg_index ix on t.oid = ix.indrelid - join pg_class i on i.oid = ix.indexrelid - left outer join - pg_attribute a - on t.oid = a.attrelid and %s - left outer join - pg_am am - on i.relam = am.oid - WHERE - t.relkind IN ('r', 'v', 'f', 'm') - and t.oid = :table_oid - and ix.indisprimary = 'f' - ORDER BY - t.relname, - i.relname - """ % ( - # version 8.3 here was based on observing the - # cast does not work in PG 8.2.4, does work in 8.3.0. - # nothing in PG changelogs regarding this. - "::varchar" if self.server_version_info >= (8, 3) else "", - "ix.indoption::varchar" - if self.server_version_info >= (8, 3) - else "NULL", - "i.reloptions" - if self.server_version_info >= (8, 2) - else "NULL", - self._pg_index_any("a.attnum", "ix.indkey"), + @util.memoized_property + def _index_query(self): + pg_class_index = pg_catalog.pg_class.alias("cls_idx") + # NOTE: repeating oids clause improve query performance + + # subquery to get the columns + idx_sq = ( + select( + pg_catalog.pg_index.c.indexrelid, + pg_catalog.pg_index.c.indrelid, + sql.func.unnest(pg_catalog.pg_index.c.indkey).label("attnum"), + sql.func.generate_subscripts( + pg_catalog.pg_index.c.indkey, 1 + ).label("ord"), ) - else: - IDX_SQL = """ - SELECT - i.relname as relname, - ix.indisunique, ix.indexprs, - a.attname, a.attnum, c.conrelid, ix.indkey::varchar, - ix.indoption::varchar, i.reloptions, am.amname, - pg_get_expr(ix.indpred, ix.indrelid), - %s as indnkeyatts - FROM - pg_class t - join pg_index ix on t.oid = ix.indrelid - join pg_class i on i.oid = ix.indexrelid - left outer join - pg_attribute a - on t.oid = a.attrelid and a.attnum = ANY(ix.indkey) - left outer join - pg_constraint c - on (ix.indrelid = c.conrelid and - ix.indexrelid = c.conindid and - c.contype in ('p', 'u', 'x')) - left outer join - pg_am am - on i.relam = am.oid - WHERE - t.relkind IN ('r', 'v', 'f', 'm', 'p') - and t.oid = :table_oid - and ix.indisprimary = 'f' - ORDER BY - t.relname, - i.relname - """ % ( - "ix.indnkeyatts" - if self.server_version_info >= (11, 0) - else "NULL", + .where( + ~pg_catalog.pg_index.c.indisprimary, + pg_catalog.pg_index.c.indrelid.in_(bindparam("oids")), ) + .subquery("idx") + ) - t = sql.text(IDX_SQL).columns( - relname=sqltypes.Unicode, attname=sqltypes.Unicode + attr_sq = ( + select( + idx_sq.c.indexrelid, + idx_sq.c.indrelid, + pg_catalog.pg_attribute.c.attname, + ) + .select_from(pg_catalog.pg_attribute) + .join( + idx_sq, + sql.and_( + pg_catalog.pg_attribute.c.attnum == idx_sq.c.attnum, + pg_catalog.pg_attribute.c.attrelid == idx_sq.c.indrelid, + ), + ) + .where(idx_sq.c.indrelid.in_(bindparam("oids"))) + .order_by(idx_sq.c.indexrelid, idx_sq.c.ord) + .subquery("idx_attr") ) - c = connection.execute(t, dict(table_oid=table_oid)) - indexes = defaultdict(lambda: defaultdict(dict)) + cols_sq = ( + select( + attr_sq.c.indexrelid, + attr_sq.c.indrelid, + sql.func.array_agg(attr_sq.c.attname).label("cols"), + ) + .group_by(attr_sq.c.indexrelid, attr_sq.c.indrelid) + .subquery("idx_cols") + ) - sv_idx_name = None - for row in c.fetchall(): - ( - idx_name, - unique, - expr, - col, - col_num, - conrelid, - idx_key, - idx_option, - options, - amname, - filter_definition, - indnkeyatts, - ) = row + if self.server_version_info >= (11, 0): + indnkeyatts = pg_catalog.pg_index.c.indnkeyatts + else: + indnkeyatts = sql.null().label("indnkeyatts") - if expr: - if idx_name != sv_idx_name: - util.warn( - "Skipped unsupported reflection of " - "expression-based index %s" % idx_name - ) - sv_idx_name = idx_name - continue + query = ( + select( + pg_catalog.pg_index.c.indrelid, + pg_class_index.c.relname.label("relname_index"), + pg_catalog.pg_index.c.indisunique, + pg_catalog.pg_index.c.indexprs, + pg_catalog.pg_constraint.c.conrelid.is_not(None).label( + "has_constraint" + ), + pg_catalog.pg_index.c.indoption, + pg_class_index.c.reloptions, + pg_catalog.pg_am.c.amname, + pg_catalog.pg_get_expr( + pg_catalog.pg_index.c.indpred, + pg_catalog.pg_index.c.indrelid, + ).label("filter_definition"), + indnkeyatts, + cols_sq.c.cols.label("index_cols"), + ) + .select_from(pg_catalog.pg_index) + .where( + pg_catalog.pg_index.c.indrelid.in_(bindparam("oids")), + ~pg_catalog.pg_index.c.indisprimary, + ) + .join( + pg_class_index, + pg_catalog.pg_index.c.indexrelid == pg_class_index.c.oid, + ) + .join( + pg_catalog.pg_am, + pg_class_index.c.relam == pg_catalog.pg_am.c.oid, + ) + .outerjoin( + cols_sq, + pg_catalog.pg_index.c.indexrelid == cols_sq.c.indexrelid, + ) + .outerjoin( + pg_catalog.pg_constraint, + sql.and_( + pg_catalog.pg_index.c.indrelid + == pg_catalog.pg_constraint.c.conrelid, + pg_catalog.pg_index.c.indexrelid + == pg_catalog.pg_constraint.c.conindid, + pg_catalog.pg_constraint.c.contype + == sql.any_(_array.array(("p", "u", "x"))), + ), + ) + .order_by(pg_catalog.pg_index.c.indrelid, pg_class_index.c.relname) + ) + return query - has_idx = idx_name in indexes - index = indexes[idx_name] - if col is not None: - index["cols"][col_num] = col - if not has_idx: - idx_keys = idx_key.split() - # "The number of key columns in the index, not counting any - # included columns, which are merely stored and do not - # participate in the index semantics" - if indnkeyatts and idx_keys[indnkeyatts:]: - # this is a "covering index" which has INCLUDE columns - # as well as regular index columns - inc_keys = idx_keys[indnkeyatts:] - idx_keys = idx_keys[:indnkeyatts] - else: - inc_keys = [] + def get_multi_indexes( + self, connection, schema, filter_names, scope, kind, **kw + ): - index["key"] = [int(k.strip()) for k in idx_keys] - index["inc"] = [int(k.strip()) for k in inc_keys] + table_oids = self._get_table_oids( + connection, schema, filter_names, scope, kind, **kw + ) - # (new in pg 8.3) - # "pg_index.indoption" is list of ints, one per column/expr. - # int acts as bitmask: 0x01=DESC, 0x02=NULLSFIRST - sorting = {} - for col_idx, col_flags in enumerate( - (idx_option or "").split() - ): - col_flags = int(col_flags.strip()) - col_sorting = () - # try to set flags only if they differ from PG defaults... - if col_flags & 0x01: - col_sorting += ("desc",) - if not (col_flags & 0x02): - col_sorting += ("nulls_last",) + indexes = defaultdict(list) + default = ReflectionDefaults.indexes + + batches = list(table_oids) + + while batches: + batch = batches[0:3000] + batches[0:3000] = [] + + result = connection.execute( + self._index_query, {"oids": [r[0] for r in batch]} + ).mappings() + + result_by_oid = defaultdict(list) + for row_dict in result: + result_by_oid[row_dict["indrelid"]].append(row_dict) + + for oid, table_name in batch: + if oid not in result_by_oid: + # ensure that each table has an entry, even if reflection + # is skipped because not supported + indexes[(schema, table_name)] = default() + continue + + for row in result_by_oid[oid]: + index_name = row["relname_index"] + + table_indexes = indexes[(schema, table_name)] + + if row["indexprs"]: + tn = ( + table_name + if schema is None + else f"{schema}.{table_name}" + ) + util.warn( + "Skipped unsupported reflection of " + f"expression-based index {index_name} of " + f"table {tn}" + ) + continue + + all_cols = row["index_cols"] + indnkeyatts = row["indnkeyatts"] + # "The number of key columns in the index, not counting any + # included columns, which are merely stored and do not + # participate in the index semantics" + if indnkeyatts and all_cols[indnkeyatts:]: + # this is a "covering index" which has INCLUDE columns + # as well as regular index columns + inc_cols = all_cols[indnkeyatts:] + idx_cols = all_cols[:indnkeyatts] else: - if col_flags & 0x02: - col_sorting += ("nulls_first",) - if col_sorting: - sorting[col_idx] = col_sorting - if sorting: - index["sorting"] = sorting - - index["unique"] = unique - if conrelid is not None: - index["duplicates_constraint"] = idx_name - if options: - index["options"] = dict( - [option.split("=") for option in options] - ) - - # it *might* be nice to include that this is 'btree' in the - # reflection info. But we don't want an Index object - # to have a ``postgresql_using`` in it that is just the - # default, so for the moment leaving this out. - if amname and amname != "btree": - index["amname"] = amname - - if filter_definition: - index["postgresql_where"] = filter_definition + idx_cols = all_cols + inc_cols = [] + + index = { + "name": index_name, + "unique": row["indisunique"], + "column_names": idx_cols, + } + + sorting = {} + for col_index, col_flags in enumerate(row["indoption"]): + col_sorting = () + # try to set flags only if they differ from PG + # defaults... + if col_flags & 0x01: + col_sorting += ("desc",) + if not (col_flags & 0x02): + col_sorting += ("nulls_last",) + else: + if col_flags & 0x02: + col_sorting += ("nulls_first",) + if col_sorting: + sorting[idx_cols[col_index]] = col_sorting + if sorting: + index["column_sorting"] = sorting + if row["has_constraint"]: + index["duplicates_constraint"] = index_name + + dialect_options = {} + if row["reloptions"]: + dialect_options["postgresql_with"] = dict( + [option.split("=") for option in row["reloptions"]] + ) + # it *might* be nice to include that this is 'btree' in the + # reflection info. But we don't want an Index object + # to have a ``postgresql_using`` in it that is just the + # default, so for the moment leaving this out. + amname = row["amname"] + if amname != "btree": + dialect_options["postgresql_using"] = row["amname"] + if row["filter_definition"]: + dialect_options["postgresql_where"] = row[ + "filter_definition" + ] + if self.server_version_info >= (11, 0): + # NOTE: this is legacy, this is part of + # dialect_options now as of #7382 + index["include_columns"] = inc_cols + dialect_options["postgresql_include"] = inc_cols + if dialect_options: + index["dialect_options"] = dialect_options - result = [] - for name, idx in indexes.items(): - entry = { - "name": name, - "unique": idx["unique"], - "column_names": [idx["cols"][i] for i in idx["key"]], - } - if self.server_version_info >= (11, 0): - # NOTE: this is legacy, this is part of dialect_options now - # as of #7382 - entry["include_columns"] = [idx["cols"][i] for i in idx["inc"]] - if "duplicates_constraint" in idx: - entry["duplicates_constraint"] = idx["duplicates_constraint"] - if "sorting" in idx: - entry["column_sorting"] = dict( - (idx["cols"][idx["key"][i]], value) - for i, value in idx["sorting"].items() - ) - if "include_columns" in entry: - entry.setdefault("dialect_options", {})[ - "postgresql_include" - ] = entry["include_columns"] - if "options" in idx: - entry.setdefault("dialect_options", {})[ - "postgresql_with" - ] = idx["options"] - if "amname" in idx: - entry.setdefault("dialect_options", {})[ - "postgresql_using" - ] = idx["amname"] - if "postgresql_where" in idx: - entry.setdefault("dialect_options", {})[ - "postgresql_where" - ] = idx["postgresql_where"] - result.append(entry) - return result + table_indexes.append(index) + return indexes.items() @reflection.cache def get_unique_constraints( self, connection, table_name, schema=None, **kw ): - table_oid = self.get_table_oid( - connection, table_name, schema, info_cache=kw.get("info_cache") + data = self.get_multi_unique_constraints( + connection, + schema=schema, + filter_names=[table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, ) + return self._value_or_raise(data, table_name, schema) - UNIQUE_SQL = """ - SELECT - cons.conname as name, - cons.conkey as key, - a.attnum as col_num, - a.attname as col_name - FROM - pg_catalog.pg_constraint cons - join pg_attribute a - on cons.conrelid = a.attrelid AND - a.attnum = ANY(cons.conkey) - WHERE - cons.conrelid = :table_oid AND - cons.contype = 'u' - """ - - t = sql.text(UNIQUE_SQL).columns(col_name=sqltypes.Unicode) - c = connection.execute(t, dict(table_oid=table_oid)) + def get_multi_unique_constraints( + self, + connection, + schema, + filter_names, + scope, + kind, + **kw, + ): + result = self._reflect_constraint( + connection, "u", schema, filter_names, scope, kind, **kw + ) - uniques = defaultdict(lambda: defaultdict(dict)) - for row in c.fetchall(): - uc = uniques[row.name] - uc["key"] = row.key - uc["cols"][row.col_num] = row.col_name + # each table can have multiple unique constraints + uniques = defaultdict(list) + default = ReflectionDefaults.unique_constraints + for (table_name, cols, con_name) in result: + # ensure a list is created for each table. leave it empty if + # the table has no unique cosntraint + if con_name is None: + uniques[(schema, table_name)] = default() + continue - return [ - {"name": name, "column_names": [uc["cols"][i] for i in uc["key"]]} - for name, uc in uniques.items() - ] + uniques[(schema, table_name)].append( + { + "column_names": cols, + "name": con_name, + } + ) + return uniques.items() @reflection.cache def get_table_comment(self, connection, table_name, schema=None, **kw): - table_oid = self.get_table_oid( - connection, table_name, schema, info_cache=kw.get("info_cache") + data = self.get_multi_table_comment( + connection, + schema, + [table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, ) + return self._value_or_raise(data, table_name, schema) - COMMENT_SQL = """ - SELECT - pgd.description as table_comment - FROM - pg_catalog.pg_description pgd - WHERE - pgd.objsubid = 0 AND - pgd.objoid = :table_oid - """ + @lru_cache() + def _comment_query(self, schema, has_filter_names, scope, kind): + relkinds = self._kind_to_relkinds(kind) + query = ( + select( + pg_catalog.pg_class.c.relname, + pg_catalog.pg_description.c.description, + ) + .select_from(pg_catalog.pg_class) + .outerjoin( + pg_catalog.pg_description, + sql.and_( + pg_catalog.pg_class.c.oid + == pg_catalog.pg_description.c.objoid, + pg_catalog.pg_description.c.objsubid == 0, + ), + ) + .where(self._pg_class_relkind_condition(relkinds)) + ) + query = self._pg_class_filter_scope_schema(query, schema, scope) + if has_filter_names: + query = query.where( + pg_catalog.pg_class.c.relname.in_(bindparam("filter_names")) + ) + return query - c = connection.execute( - sql.text(COMMENT_SQL), dict(table_oid=table_oid) + def get_multi_table_comment( + self, connection, schema, filter_names, scope, kind, **kw + ): + has_filter_names, params = self._prepare_filter_names(filter_names) + query = self._comment_query(schema, has_filter_names, scope, kind) + result = connection.execute(query, params) + + default = ReflectionDefaults.table_comment + return ( + ( + (schema, table), + {"text": comment} if comment is not None else default(), + ) + for table, comment in result ) - return {"text": c.scalar()} @reflection.cache def get_check_constraints(self, connection, table_name, schema=None, **kw): - table_oid = self.get_table_oid( - connection, table_name, schema, info_cache=kw.get("info_cache") + data = self.get_multi_check_constraints( + connection, + schema, + [table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, ) + return self._value_or_raise(data, table_name, schema) - CHECK_SQL = """ - SELECT - cons.conname as name, - pg_get_constraintdef(cons.oid) as src - FROM - pg_catalog.pg_constraint cons - WHERE - cons.conrelid = :table_oid AND - cons.contype = 'c' - """ - - c = connection.execute(sql.text(CHECK_SQL), dict(table_oid=table_oid)) + @lru_cache() + def _check_constraint_query(self, schema, has_filter_names, scope, kind): + relkinds = self._kind_to_relkinds(kind) + query = ( + select( + pg_catalog.pg_class.c.relname, + pg_catalog.pg_constraint.c.conname, + sql.case( + ( + pg_catalog.pg_constraint.c.oid.is_not(None), + pg_catalog.pg_get_constraintdef( + pg_catalog.pg_constraint.c.oid + ), + ), + else_=None, + ), + ) + .select_from(pg_catalog.pg_class) + .outerjoin( + pg_catalog.pg_constraint, + sql.and_( + pg_catalog.pg_class.c.oid + == pg_catalog.pg_constraint.c.conrelid, + pg_catalog.pg_constraint.c.contype == "c", + ), + ) + .where(self._pg_class_relkind_condition(relkinds)) + ) + query = self._pg_class_filter_scope_schema(query, schema, scope) + if has_filter_names: + query = query.where( + pg_catalog.pg_class.c.relname.in_(bindparam("filter_names")) + ) + return query - ret = [] - for name, src in c: + def get_multi_check_constraints( + self, connection, schema, filter_names, scope, kind, **kw + ): + has_filter_names, params = self._prepare_filter_names(filter_names) + query = self._check_constraint_query( + schema, has_filter_names, scope, kind + ) + result = connection.execute(query, params) + + check_constraints = defaultdict(list) + default = ReflectionDefaults.check_constraints + for table_name, check_name, src in result: + # only two cases for check_name and src: both null or both defined + if check_name is None and src is None: + check_constraints[(schema, table_name)] = default() + continue # samples: # "CHECK (((a > 1) AND (a < 5)))" # "CHECK (((a = 1) OR ((a > 2) AND (a < 5))))" @@ -4424,84 +4258,118 @@ class PGDialect(default.DefaultDialect): sqltext = re.compile( r"^[\s\n]*\((.+)\)[\s\n]*$", flags=re.DOTALL ).sub(r"\1", m.group(1)) - entry = {"name": name, "sqltext": sqltext} + entry = {"name": check_name, "sqltext": sqltext} if m and m.group(2): entry["dialect_options"] = {"not_valid": True} - ret.append(entry) - return ret - - def _load_enums(self, connection, schema=None): - schema = schema or self.default_schema_name - if not self.supports_native_enum: - return {} - - # Load data types for enums: - SQL_ENUMS = """ - SELECT t.typname as "name", - -- no enum defaults in 8.4 at least - -- t.typdefault as "default", - pg_catalog.pg_type_is_visible(t.oid) as "visible", - n.nspname as "schema", - e.enumlabel as "label" - FROM pg_catalog.pg_type t - LEFT JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace - LEFT JOIN pg_catalog.pg_enum e ON t.oid = e.enumtypid - WHERE t.typtype = 'e' - """ + check_constraints[(schema, table_name)].append(entry) + return check_constraints.items() - if schema != "*": - SQL_ENUMS += "AND n.nspname = :schema " + @lru_cache() + def _enum_query(self, schema): + lbl_sq = ( + select( + pg_catalog.pg_enum.c.enumtypid, pg_catalog.pg_enum.c.enumlabel + ) + .order_by( + pg_catalog.pg_enum.c.enumtypid, + pg_catalog.pg_enum.c.enumsortorder, + ) + .subquery("lbl") + ) - # e.oid gives us label order within an enum - SQL_ENUMS += 'ORDER BY "schema", "name", e.oid' + lbl_agg_sq = ( + select( + lbl_sq.c.enumtypid, + sql.func.array_agg(lbl_sq.c.enumlabel).label("labels"), + ) + .group_by(lbl_sq.c.enumtypid) + .subquery("lbl_agg") + ) - s = sql.text(SQL_ENUMS).columns( - attname=sqltypes.Unicode, label=sqltypes.Unicode + query = ( + select( + pg_catalog.pg_type.c.typname.label("name"), + pg_catalog.pg_type_is_visible(pg_catalog.pg_type.c.oid).label( + "visible" + ), + pg_catalog.pg_namespace.c.nspname.label("schema"), + lbl_agg_sq.c.labels.label("labels"), + ) + .join( + pg_catalog.pg_namespace, + pg_catalog.pg_namespace.c.oid + == pg_catalog.pg_type.c.typnamespace, + ) + .outerjoin( + lbl_agg_sq, pg_catalog.pg_type.c.oid == lbl_agg_sq.c.enumtypid + ) + .where(pg_catalog.pg_type.c.typtype == "e") + .order_by( + pg_catalog.pg_namespace.c.nspname, pg_catalog.pg_type.c.typname + ) ) - if schema != "*": - s = s.bindparams(schema=schema) + if schema is None: + query = query.where( + pg_catalog.pg_type_is_visible(pg_catalog.pg_type.c.oid), + # ignore pg_catalog schema + pg_catalog.pg_namespace.c.nspname != "pg_catalog", + ) + elif schema != "*": + query = query.where(pg_catalog.pg_namespace.c.nspname == schema) + return query + + @reflection.cache + def _load_enums(self, connection, schema=None, **kw): + if not self.supports_native_enum: + return [] - c = connection.execute(s) + result = connection.execute(self._enum_query(schema)) enums = [] - enum_by_name = {} - for enum in c.fetchall(): - key = (enum.schema, enum.name) - if key in enum_by_name: - enum_by_name[key]["labels"].append(enum.label) - else: - enum_by_name[key] = enum_rec = { - "name": enum.name, - "schema": enum.schema, - "visible": enum.visible, - "labels": [], + for name, visible, schema, labels in result: + enums.append( + { + "name": name, + "schema": schema, + "visible": visible, + "labels": [] if labels is None else labels, } - if enum.label is not None: - enum_rec["labels"].append(enum.label) - enums.append(enum_rec) + ) return enums - def _load_domains(self, connection): - # Load data types for domains: - SQL_DOMAINS = """ - SELECT t.typname as "name", - pg_catalog.format_type(t.typbasetype, t.typtypmod) as "attype", - not t.typnotnull as "nullable", - t.typdefault as "default", - pg_catalog.pg_type_is_visible(t.oid) as "visible", - n.nspname as "schema" - FROM pg_catalog.pg_type t - LEFT JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace - WHERE t.typtype = 'd' - """ + @util.memoized_property + def _domain_query(self): + return ( + select( + pg_catalog.pg_type.c.typname.label("name"), + pg_catalog.format_type( + pg_catalog.pg_type.c.typbasetype, + pg_catalog.pg_type.c.typtypmod, + ).label("attype"), + (~pg_catalog.pg_type.c.typnotnull).label("nullable"), + pg_catalog.pg_type.c.typdefault.label("default"), + pg_catalog.pg_type_is_visible(pg_catalog.pg_type.c.oid).label( + "visible" + ), + pg_catalog.pg_namespace.c.nspname.label("schema"), + ) + .join( + pg_catalog.pg_namespace, + pg_catalog.pg_namespace.c.oid + == pg_catalog.pg_type.c.typnamespace, + ) + .where(pg_catalog.pg_type.c.typtype == "d") + ) - s = sql.text(SQL_DOMAINS) - c = connection.execution_options(future_result=True).execute(s) + @reflection.cache + def _load_domains(self, connection, **kw): + # Load data types for domains: + result = connection.execute(self._domain_query) domains = {} - for domain in c.mappings(): + for domain in result.mappings(): domain = domain # strip (30) from character varying(30) attype = re.search(r"([^\(]+)", domain["attype"]).group(1) diff --git a/lib/sqlalchemy/dialects/postgresql/pg8000.py b/lib/sqlalchemy/dialects/postgresql/pg8000.py index 6cb97ece4a..ce9a3bb6c1 100644 --- a/lib/sqlalchemy/dialects/postgresql/pg8000.py +++ b/lib/sqlalchemy/dialects/postgresql/pg8000.py @@ -107,6 +107,8 @@ from .base import PGIdentifierPreparer from .json import JSON from .json import JSONB from .json import JSONPathType +from .pg_catalog import _SpaceVector +from .pg_catalog import OIDVECTOR from ... import exc from ... import util from ...engine import processors @@ -245,6 +247,10 @@ class _PGARRAY(PGARRAY): render_bind_cast = True +class _PGOIDVECTOR(_SpaceVector, OIDVECTOR): + pass + + _server_side_id = util.counter() @@ -376,6 +382,7 @@ class PGDialect_pg8000(PGDialect): sqltypes.BigInteger: _PGBigInteger, sqltypes.Enum: _PGEnum, sqltypes.ARRAY: _PGARRAY, + OIDVECTOR: _PGOIDVECTOR, }, ) diff --git a/lib/sqlalchemy/dialects/postgresql/pg_catalog.py b/lib/sqlalchemy/dialects/postgresql/pg_catalog.py new file mode 100644 index 0000000000..a77e7ccf6b --- /dev/null +++ b/lib/sqlalchemy/dialects/postgresql/pg_catalog.py @@ -0,0 +1,292 @@ +# postgresql/pg_catalog.py +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +from .array import ARRAY +from .types import OID +from .types import REGCLASS +from ... import Column +from ... import func +from ... import MetaData +from ... import Table +from ...types import BigInteger +from ...types import Boolean +from ...types import CHAR +from ...types import Float +from ...types import Integer +from ...types import SmallInteger +from ...types import String +from ...types import Text +from ...types import TypeDecorator + + +# types +class NAME(TypeDecorator): + impl = String(64, collation="C") + cache_ok = True + + +class PG_NODE_TREE(TypeDecorator): + impl = Text(collation="C") + cache_ok = True + + +class INT2VECTOR(TypeDecorator): + impl = ARRAY(SmallInteger) + cache_ok = True + + +class OIDVECTOR(TypeDecorator): + impl = ARRAY(OID) + cache_ok = True + + +class _SpaceVector: + def result_processor(self, dialect, coltype): + def process(value): + if value is None: + return value + return [int(p) for p in value.split(" ")] + + return process + + +REGPROC = REGCLASS # seems an alias + +# functions +_pg_cat = func.pg_catalog +quote_ident = _pg_cat.quote_ident +pg_table_is_visible = _pg_cat.pg_table_is_visible +pg_type_is_visible = _pg_cat.pg_type_is_visible +pg_get_viewdef = _pg_cat.pg_get_viewdef +pg_get_serial_sequence = _pg_cat.pg_get_serial_sequence +format_type = _pg_cat.format_type +pg_get_expr = _pg_cat.pg_get_expr +pg_get_constraintdef = _pg_cat.pg_get_constraintdef + +# constants +RELKINDS_TABLE_NO_FOREIGN = ("r", "p") +RELKINDS_TABLE = RELKINDS_TABLE_NO_FOREIGN + ("f",) +RELKINDS_VIEW = ("v",) +RELKINDS_MAT_VIEW = ("m",) +RELKINDS_ALL_TABLE_LIKE = RELKINDS_TABLE + RELKINDS_VIEW + RELKINDS_MAT_VIEW + +# tables +pg_catalog_meta = MetaData() + +pg_namespace = Table( + "pg_namespace", + pg_catalog_meta, + Column("oid", OID), + Column("nspname", NAME), + Column("nspowner", OID), + schema="pg_catalog", +) + +pg_class = Table( + "pg_class", + pg_catalog_meta, + Column("oid", OID, info={"server_version": (9, 3)}), + Column("relname", NAME), + Column("relnamespace", OID), + Column("reltype", OID), + Column("reloftype", OID), + Column("relowner", OID), + Column("relam", OID), + Column("relfilenode", OID), + Column("reltablespace", OID), + Column("relpages", Integer), + Column("reltuples", Float), + Column("relallvisible", Integer, info={"server_version": (9, 2)}), + Column("reltoastrelid", OID), + Column("relhasindex", Boolean), + Column("relisshared", Boolean), + Column("relpersistence", CHAR, info={"server_version": (9, 1)}), + Column("relkind", CHAR), + Column("relnatts", SmallInteger), + Column("relchecks", SmallInteger), + Column("relhasrules", Boolean), + Column("relhastriggers", Boolean), + Column("relhassubclass", Boolean), + Column("relrowsecurity", Boolean), + Column("relforcerowsecurity", Boolean, info={"server_version": (9, 5)}), + Column("relispopulated", Boolean, info={"server_version": (9, 3)}), + Column("relreplident", CHAR, info={"server_version": (9, 4)}), + Column("relispartition", Boolean, info={"server_version": (10,)}), + Column("relrewrite", OID, info={"server_version": (11,)}), + Column("reloptions", ARRAY(Text)), + schema="pg_catalog", +) + +pg_type = Table( + "pg_type", + pg_catalog_meta, + Column("oid", OID, info={"server_version": (9, 3)}), + Column("typname", NAME), + Column("typnamespace", OID), + Column("typowner", OID), + Column("typlen", SmallInteger), + Column("typbyval", Boolean), + Column("typtype", CHAR), + Column("typcategory", CHAR), + Column("typispreferred", Boolean), + Column("typisdefined", Boolean), + Column("typdelim", CHAR), + Column("typrelid", OID), + Column("typelem", OID), + Column("typarray", OID), + Column("typinput", REGPROC), + Column("typoutput", REGPROC), + Column("typreceive", REGPROC), + Column("typsend", REGPROC), + Column("typmodin", REGPROC), + Column("typmodout", REGPROC), + Column("typanalyze", REGPROC), + Column("typalign", CHAR), + Column("typstorage", CHAR), + Column("typnotnull", Boolean), + Column("typbasetype", OID), + Column("typtypmod", Integer), + Column("typndims", Integer), + Column("typcollation", OID, info={"server_version": (9, 1)}), + Column("typdefault", Text), + schema="pg_catalog", +) + +pg_index = Table( + "pg_index", + pg_catalog_meta, + Column("indexrelid", OID), + Column("indrelid", OID), + Column("indnatts", SmallInteger), + Column("indnkeyatts", SmallInteger, info={"server_version": (11,)}), + Column("indisunique", Boolean), + Column("indisprimary", Boolean), + Column("indisexclusion", Boolean, info={"server_version": (9, 1)}), + Column("indimmediate", Boolean), + Column("indisclustered", Boolean), + Column("indisvalid", Boolean), + Column("indcheckxmin", Boolean), + Column("indisready", Boolean), + Column("indislive", Boolean, info={"server_version": (9, 3)}), # 9.3 + Column("indisreplident", Boolean), + Column("indkey", INT2VECTOR), + Column("indcollation", OIDVECTOR, info={"server_version": (9, 1)}), # 9.1 + Column("indclass", OIDVECTOR), + Column("indoption", INT2VECTOR), + Column("indexprs", PG_NODE_TREE), + Column("indpred", PG_NODE_TREE), + schema="pg_catalog", +) + +pg_attribute = Table( + "pg_attribute", + pg_catalog_meta, + Column("attrelid", OID), + Column("attname", NAME), + Column("atttypid", OID), + Column("attstattarget", Integer), + Column("attlen", SmallInteger), + Column("attnum", SmallInteger), + Column("attndims", Integer), + Column("attcacheoff", Integer), + Column("atttypmod", Integer), + Column("attbyval", Boolean), + Column("attstorage", CHAR), + Column("attalign", CHAR), + Column("attnotnull", Boolean), + Column("atthasdef", Boolean), + Column("atthasmissing", Boolean, info={"server_version": (11,)}), + Column("attidentity", CHAR, info={"server_version": (10,)}), + Column("attgenerated", CHAR, info={"server_version": (12,)}), + Column("attisdropped", Boolean), + Column("attislocal", Boolean), + Column("attinhcount", Integer), + Column("attcollation", OID, info={"server_version": (9, 1)}), + schema="pg_catalog", +) + +pg_constraint = Table( + "pg_constraint", + pg_catalog_meta, + Column("oid", OID), # 9.3 + Column("conname", NAME), + Column("connamespace", OID), + Column("contype", CHAR), + Column("condeferrable", Boolean), + Column("condeferred", Boolean), + Column("convalidated", Boolean, info={"server_version": (9, 1)}), + Column("conrelid", OID), + Column("contypid", OID), + Column("conindid", OID), + Column("conparentid", OID, info={"server_version": (11,)}), + Column("confrelid", OID), + Column("confupdtype", CHAR), + Column("confdeltype", CHAR), + Column("confmatchtype", CHAR), + Column("conislocal", Boolean), + Column("coninhcount", Integer), + Column("connoinherit", Boolean, info={"server_version": (9, 2)}), + Column("conkey", ARRAY(SmallInteger)), + Column("confkey", ARRAY(SmallInteger)), + schema="pg_catalog", +) + +pg_sequence = Table( + "pg_sequence", + pg_catalog_meta, + Column("seqrelid", OID), + Column("seqtypid", OID), + Column("seqstart", BigInteger), + Column("seqincrement", BigInteger), + Column("seqmax", BigInteger), + Column("seqmin", BigInteger), + Column("seqcache", BigInteger), + Column("seqcycle", Boolean), + schema="pg_catalog", + info={"server_version": (10,)}, +) + +pg_attrdef = Table( + "pg_attrdef", + pg_catalog_meta, + Column("oid", OID, info={"server_version": (9, 3)}), + Column("adrelid", OID), + Column("adnum", SmallInteger), + Column("adbin", PG_NODE_TREE), + schema="pg_catalog", +) + +pg_description = Table( + "pg_description", + pg_catalog_meta, + Column("objoid", OID), + Column("classoid", OID), + Column("objsubid", Integer), + Column("description", Text(collation="C")), + schema="pg_catalog", +) + +pg_enum = Table( + "pg_enum", + pg_catalog_meta, + Column("oid", OID, info={"server_version": (9, 3)}), + Column("enumtypid", OID), + Column("enumsortorder", Float(), info={"server_version": (9, 1)}), + Column("enumlabel", NAME), + schema="pg_catalog", +) + +pg_am = Table( + "pg_am", + pg_catalog_meta, + Column("oid", OID, info={"server_version": (9, 3)}), + Column("amname", NAME), + Column("amhandler", REGPROC, info={"server_version": (9, 6)}), + Column("amtype", CHAR, info={"server_version": (9, 6)}), + schema="pg_catalog", +) diff --git a/lib/sqlalchemy/dialects/postgresql/types.py b/lib/sqlalchemy/dialects/postgresql/types.py new file mode 100644 index 0000000000..55735953b5 --- /dev/null +++ b/lib/sqlalchemy/dialects/postgresql/types.py @@ -0,0 +1,485 @@ +# Copyright (C) 2013-2022 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +import datetime as dt +from typing import Any + +from ... import schema +from ... import util +from ...sql import sqltypes +from ...sql.ddl import InvokeDDLBase + + +_DECIMAL_TYPES = (1231, 1700) +_FLOAT_TYPES = (700, 701, 1021, 1022) +_INT_TYPES = (20, 21, 23, 26, 1005, 1007, 1016) + + +class PGUuid(sqltypes.UUID): + render_bind_cast = True + render_literal_cast = True + + +class BYTEA(sqltypes.LargeBinary[bytes]): + __visit_name__ = "BYTEA" + + +class INET(sqltypes.TypeEngine[str]): + __visit_name__ = "INET" + + +PGInet = INET + + +class CIDR(sqltypes.TypeEngine[str]): + __visit_name__ = "CIDR" + + +PGCidr = CIDR + + +class MACADDR(sqltypes.TypeEngine[str]): + __visit_name__ = "MACADDR" + + +PGMacAddr = MACADDR + + +class MONEY(sqltypes.TypeEngine[str]): + + r"""Provide the PostgreSQL MONEY type. + + Depending on driver, result rows using this type may return a + string value which includes currency symbols. + + For this reason, it may be preferable to provide conversion to a + numerically-based currency datatype using :class:`_types.TypeDecorator`:: + + import re + import decimal + from sqlalchemy import TypeDecorator + + class NumericMoney(TypeDecorator): + impl = MONEY + + def process_result_value(self, value: Any, dialect: Any) -> None: + if value is not None: + # adjust this for the currency and numeric + m = re.match(r"\$([\d.]+)", value) + if m: + value = decimal.Decimal(m.group(1)) + return value + + Alternatively, the conversion may be applied as a CAST using + the :meth:`_types.TypeDecorator.column_expression` method as follows:: + + import decimal + from sqlalchemy import cast + from sqlalchemy import TypeDecorator + + class NumericMoney(TypeDecorator): + impl = MONEY + + def column_expression(self, column: Any): + return cast(column, Numeric()) + + .. versionadded:: 1.2 + + """ + + __visit_name__ = "MONEY" + + +class OID(sqltypes.TypeEngine[int]): + + """Provide the PostgreSQL OID type. + + .. versionadded:: 0.9.5 + + """ + + __visit_name__ = "OID" + + +class REGCLASS(sqltypes.TypeEngine[str]): + + """Provide the PostgreSQL REGCLASS type. + + .. versionadded:: 1.2.7 + + """ + + __visit_name__ = "REGCLASS" + + +class TIMESTAMP(sqltypes.TIMESTAMP): + def __init__(self, timezone=False, precision=None): + super(TIMESTAMP, self).__init__(timezone=timezone) + self.precision = precision + + +class TIME(sqltypes.TIME): + def __init__(self, timezone=False, precision=None): + super(TIME, self).__init__(timezone=timezone) + self.precision = precision + + +class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval): + + """PostgreSQL INTERVAL type.""" + + __visit_name__ = "INTERVAL" + native = True + + def __init__(self, precision=None, fields=None): + """Construct an INTERVAL. + + :param precision: optional integer precision value + :param fields: string fields specifier. allows storage of fields + to be limited, such as ``"YEAR"``, ``"MONTH"``, ``"DAY TO HOUR"``, + etc. + + .. versionadded:: 1.2 + + """ + self.precision = precision + self.fields = fields + + @classmethod + def adapt_emulated_to_native(cls, interval, **kw): + return INTERVAL(precision=interval.second_precision) + + @property + def _type_affinity(self): + return sqltypes.Interval + + def as_generic(self, allow_nulltype=False): + return sqltypes.Interval(native=True, second_precision=self.precision) + + @property + def python_type(self): + return dt.timedelta + + +PGInterval = INTERVAL + + +class BIT(sqltypes.TypeEngine[int]): + __visit_name__ = "BIT" + + def __init__(self, length=None, varying=False): + if not varying: + # BIT without VARYING defaults to length 1 + self.length = length or 1 + else: + # but BIT VARYING can be unlimited-length, so no default + self.length = length + self.varying = varying + + +PGBit = BIT + + +class TSVECTOR(sqltypes.TypeEngine[Any]): + + """The :class:`_postgresql.TSVECTOR` type implements the PostgreSQL + text search type TSVECTOR. + + It can be used to do full text queries on natural language + documents. + + .. versionadded:: 0.9.0 + + .. seealso:: + + :ref:`postgresql_match` + + """ + + __visit_name__ = "TSVECTOR" + + +class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum): + + """PostgreSQL ENUM type. + + This is a subclass of :class:`_types.Enum` which includes + support for PG's ``CREATE TYPE`` and ``DROP TYPE``. + + When the builtin type :class:`_types.Enum` is used and the + :paramref:`.Enum.native_enum` flag is left at its default of + True, the PostgreSQL backend will use a :class:`_postgresql.ENUM` + type as the implementation, so the special create/drop rules + will be used. + + The create/drop behavior of ENUM is necessarily intricate, due to the + awkward relationship the ENUM type has in relationship to the + parent table, in that it may be "owned" by just a single table, or + may be shared among many tables. + + When using :class:`_types.Enum` or :class:`_postgresql.ENUM` + in an "inline" fashion, the ``CREATE TYPE`` and ``DROP TYPE`` is emitted + corresponding to when the :meth:`_schema.Table.create` and + :meth:`_schema.Table.drop` + methods are called:: + + table = Table('sometable', metadata, + Column('some_enum', ENUM('a', 'b', 'c', name='myenum')) + ) + + table.create(engine) # will emit CREATE ENUM and CREATE TABLE + table.drop(engine) # will emit DROP TABLE and DROP ENUM + + To use a common enumerated type between multiple tables, the best + practice is to declare the :class:`_types.Enum` or + :class:`_postgresql.ENUM` independently, and associate it with the + :class:`_schema.MetaData` object itself:: + + my_enum = ENUM('a', 'b', 'c', name='myenum', metadata=metadata) + + t1 = Table('sometable_one', metadata, + Column('some_enum', myenum) + ) + + t2 = Table('sometable_two', metadata, + Column('some_enum', myenum) + ) + + When this pattern is used, care must still be taken at the level + of individual table creates. Emitting CREATE TABLE without also + specifying ``checkfirst=True`` will still cause issues:: + + t1.create(engine) # will fail: no such type 'myenum' + + If we specify ``checkfirst=True``, the individual table-level create + operation will check for the ``ENUM`` and create if not exists:: + + # will check if enum exists, and emit CREATE TYPE if not + t1.create(engine, checkfirst=True) + + When using a metadata-level ENUM type, the type will always be created + and dropped if either the metadata-wide create/drop is called:: + + metadata.create_all(engine) # will emit CREATE TYPE + metadata.drop_all(engine) # will emit DROP TYPE + + The type can also be created and dropped directly:: + + my_enum.create(engine) + my_enum.drop(engine) + + .. versionchanged:: 1.0.0 The PostgreSQL :class:`_postgresql.ENUM` type + now behaves more strictly with regards to CREATE/DROP. A metadata-level + ENUM type will only be created and dropped at the metadata level, + not the table level, with the exception of + ``table.create(checkfirst=True)``. + The ``table.drop()`` call will now emit a DROP TYPE for a table-level + enumerated type. + + """ + + native_enum = True + + def __init__(self, *enums, **kw): + """Construct an :class:`_postgresql.ENUM`. + + Arguments are the same as that of + :class:`_types.Enum`, but also including + the following parameters. + + :param create_type: Defaults to True. + Indicates that ``CREATE TYPE`` should be + emitted, after optionally checking for the + presence of the type, when the parent + table is being created; and additionally + that ``DROP TYPE`` is called when the table + is dropped. When ``False``, no check + will be performed and no ``CREATE TYPE`` + or ``DROP TYPE`` is emitted, unless + :meth:`~.postgresql.ENUM.create` + or :meth:`~.postgresql.ENUM.drop` + are called directly. + Setting to ``False`` is helpful + when invoking a creation scheme to a SQL file + without access to the actual database - + the :meth:`~.postgresql.ENUM.create` and + :meth:`~.postgresql.ENUM.drop` methods can + be used to emit SQL to a target bind. + + """ + native_enum = kw.pop("native_enum", None) + if native_enum is False: + util.warn( + "the native_enum flag does not apply to the " + "sqlalchemy.dialects.postgresql.ENUM datatype; this type " + "always refers to ENUM. Use sqlalchemy.types.Enum for " + "non-native enum." + ) + self.create_type = kw.pop("create_type", True) + super(ENUM, self).__init__(*enums, **kw) + + @classmethod + def adapt_emulated_to_native(cls, impl, **kw): + """Produce a PostgreSQL native :class:`_postgresql.ENUM` from plain + :class:`.Enum`. + + """ + kw.setdefault("validate_strings", impl.validate_strings) + kw.setdefault("name", impl.name) + kw.setdefault("schema", impl.schema) + kw.setdefault("inherit_schema", impl.inherit_schema) + kw.setdefault("metadata", impl.metadata) + kw.setdefault("_create_events", False) + kw.setdefault("values_callable", impl.values_callable) + kw.setdefault("omit_aliases", impl._omit_aliases) + return cls(**kw) + + def create(self, bind=None, checkfirst=True): + """Emit ``CREATE TYPE`` for this + :class:`_postgresql.ENUM`. + + If the underlying dialect does not support + PostgreSQL CREATE TYPE, no action is taken. + + :param bind: a connectable :class:`_engine.Engine`, + :class:`_engine.Connection`, or similar object to emit + SQL. + :param checkfirst: if ``True``, a query against + the PG catalog will be first performed to see + if the type does not exist already before + creating. + + """ + if not bind.dialect.supports_native_enum: + return + + bind._run_ddl_visitor(self.EnumGenerator, self, checkfirst=checkfirst) + + def drop(self, bind=None, checkfirst=True): + """Emit ``DROP TYPE`` for this + :class:`_postgresql.ENUM`. + + If the underlying dialect does not support + PostgreSQL DROP TYPE, no action is taken. + + :param bind: a connectable :class:`_engine.Engine`, + :class:`_engine.Connection`, or similar object to emit + SQL. + :param checkfirst: if ``True``, a query against + the PG catalog will be first performed to see + if the type actually exists before dropping. + + """ + if not bind.dialect.supports_native_enum: + return + + bind._run_ddl_visitor(self.EnumDropper, self, checkfirst=checkfirst) + + class EnumGenerator(InvokeDDLBase): + def __init__(self, dialect, connection, checkfirst=False, **kwargs): + super(ENUM.EnumGenerator, self).__init__(connection, **kwargs) + self.checkfirst = checkfirst + + def _can_create_enum(self, enum): + if not self.checkfirst: + return True + + effective_schema = self.connection.schema_for_object(enum) + + return not self.connection.dialect.has_type( + self.connection, enum.name, schema=effective_schema + ) + + def visit_enum(self, enum): + if not self._can_create_enum(enum): + return + + self.connection.execute(CreateEnumType(enum)) + + class EnumDropper(InvokeDDLBase): + def __init__(self, dialect, connection, checkfirst=False, **kwargs): + super(ENUM.EnumDropper, self).__init__(connection, **kwargs) + self.checkfirst = checkfirst + + def _can_drop_enum(self, enum): + if not self.checkfirst: + return True + + effective_schema = self.connection.schema_for_object(enum) + + return self.connection.dialect.has_type( + self.connection, enum.name, schema=effective_schema + ) + + def visit_enum(self, enum): + if not self._can_drop_enum(enum): + return + + self.connection.execute(DropEnumType(enum)) + + def get_dbapi_type(self, dbapi): + """dont return dbapi.STRING for ENUM in PostgreSQL, since that's + a different type""" + + return None + + def _check_for_name_in_memos(self, checkfirst, kw): + """Look in the 'ddl runner' for 'memos', then + note our name in that collection. + + This to ensure a particular named enum is operated + upon only once within any kind of create/drop + sequence without relying upon "checkfirst". + + """ + if not self.create_type: + return True + if "_ddl_runner" in kw: + ddl_runner = kw["_ddl_runner"] + if "_pg_enums" in ddl_runner.memo: + pg_enums = ddl_runner.memo["_pg_enums"] + else: + pg_enums = ddl_runner.memo["_pg_enums"] = set() + present = (self.schema, self.name) in pg_enums + pg_enums.add((self.schema, self.name)) + return present + else: + return False + + def _on_table_create(self, target, bind, checkfirst=False, **kw): + if ( + checkfirst + or ( + not self.metadata + and not kw.get("_is_metadata_operation", False) + ) + ) and not self._check_for_name_in_memos(checkfirst, kw): + self.create(bind=bind, checkfirst=checkfirst) + + def _on_table_drop(self, target, bind, checkfirst=False, **kw): + if ( + not self.metadata + and not kw.get("_is_metadata_operation", False) + and not self._check_for_name_in_memos(checkfirst, kw) + ): + self.drop(bind=bind, checkfirst=checkfirst) + + def _on_metadata_create(self, target, bind, checkfirst=False, **kw): + if not self._check_for_name_in_memos(checkfirst, kw): + self.create(bind=bind, checkfirst=checkfirst) + + def _on_metadata_drop(self, target, bind, checkfirst=False, **kw): + if not self._check_for_name_in_memos(checkfirst, kw): + self.drop(bind=bind, checkfirst=checkfirst) + + +class CreateEnumType(schema._CreateDropBase): + __visit_name__ = "create_enum_type" + + +class DropEnumType(schema._CreateDropBase): + __visit_name__ = "drop_enum_type" diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index fdcd1340bb..22f003e38f 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -867,6 +867,7 @@ from ... import util from ...engine import default from ...engine import processors from ...engine import reflection +from ...engine.reflection import ReflectionDefaults from ...sql import coercions from ...sql import ColumnElement from ...sql import compiler @@ -2053,28 +2054,27 @@ class SQLiteDialect(default.DefaultDialect): return [db[1] for db in dl if db[1] != "temp"] - @reflection.cache - def get_table_names(self, connection, schema=None, **kw): + def _format_schema(self, schema, table_name): if schema is not None: qschema = self.identifier_preparer.quote_identifier(schema) - master = "%s.sqlite_master" % qschema + name = f"{qschema}.{table_name}" else: - master = "sqlite_master" - s = ("SELECT name FROM %s " "WHERE type='table' ORDER BY name") % ( - master, - ) - rs = connection.exec_driver_sql(s) - return [row[0] for row in rs] + name = table_name + return name @reflection.cache - def get_temp_table_names(self, connection, **kw): - s = ( - "SELECT name FROM sqlite_temp_master " - "WHERE type='table' ORDER BY name " - ) - rs = connection.exec_driver_sql(s) + def get_table_names(self, connection, schema=None, **kw): + main = self._format_schema(schema, "sqlite_master") + s = f"SELECT name FROM {main} WHERE type='table' ORDER BY name" + names = connection.exec_driver_sql(s).scalars().all() + return names - return [row[0] for row in rs] + @reflection.cache + def get_temp_table_names(self, connection, **kw): + main = "sqlite_temp_master" + s = f"SELECT name FROM {main} WHERE type='table' ORDER BY name" + names = connection.exec_driver_sql(s).scalars().all() + return names @reflection.cache def get_temp_view_names(self, connection, **kw): @@ -2082,11 +2082,11 @@ class SQLiteDialect(default.DefaultDialect): "SELECT name FROM sqlite_temp_master " "WHERE type='view' ORDER BY name " ) - rs = connection.exec_driver_sql(s) - - return [row[0] for row in rs] + names = connection.exec_driver_sql(s).scalars().all() + return names - def has_table(self, connection, table_name, schema=None): + @reflection.cache + def has_table(self, connection, table_name, schema=None, **kw): self._ensure_has_table_connection(connection) info = self._get_table_pragma( @@ -2099,23 +2099,16 @@ class SQLiteDialect(default.DefaultDialect): @reflection.cache def get_view_names(self, connection, schema=None, **kw): - if schema is not None: - qschema = self.identifier_preparer.quote_identifier(schema) - master = "%s.sqlite_master" % qschema - else: - master = "sqlite_master" - s = ("SELECT name FROM %s " "WHERE type='view' ORDER BY name") % ( - master, - ) - rs = connection.exec_driver_sql(s) - - return [row[0] for row in rs] + main = self._format_schema(schema, "sqlite_master") + s = f"SELECT name FROM {main} WHERE type='view' ORDER BY name" + names = connection.exec_driver_sql(s).scalars().all() + return names @reflection.cache def get_view_definition(self, connection, view_name, schema=None, **kw): if schema is not None: qschema = self.identifier_preparer.quote_identifier(schema) - master = "%s.sqlite_master" % qschema + master = f"{qschema}.sqlite_master" s = ("SELECT sql FROM %s WHERE name = ? AND type='view'") % ( master, ) @@ -2140,6 +2133,10 @@ class SQLiteDialect(default.DefaultDialect): result = rs.fetchall() if result: return result[0].sql + else: + raise exc.NoSuchTableError( + f"{schema}.{view_name}" if schema else view_name + ) @reflection.cache def get_columns(self, connection, table_name, schema=None, **kw): @@ -2186,7 +2183,14 @@ class SQLiteDialect(default.DefaultDialect): tablesql, ) ) - return columns + if columns: + return columns + elif not self.has_table(connection, table_name, schema): + raise exc.NoSuchTableError( + f"{schema}.{table_name}" if schema else table_name + ) + else: + return ReflectionDefaults.columns() def _get_column_info( self, @@ -2216,7 +2220,6 @@ class SQLiteDialect(default.DefaultDialect): "type": coltype, "nullable": nullable, "default": default, - "autoincrement": "auto", "primary_key": primary_key, } if generated: @@ -2295,13 +2298,16 @@ class SQLiteDialect(default.DefaultDialect): constraint_name = result.group(1) if result else None cols = self.get_columns(connection, table_name, schema, **kw) + # consider only pk columns. This also avoids sorting the cached + # value returned by get_columns + cols = [col for col in cols if col.get("primary_key", 0) > 0] cols.sort(key=lambda col: col.get("primary_key")) - pkeys = [] - for col in cols: - if col["primary_key"]: - pkeys.append(col["name"]) + pkeys = [col["name"] for col in cols] - return {"constrained_columns": pkeys, "name": constraint_name} + if pkeys: + return {"constrained_columns": pkeys, "name": constraint_name} + else: + return ReflectionDefaults.pk_constraint() @reflection.cache def get_foreign_keys(self, connection, table_name, schema=None, **kw): @@ -2321,12 +2327,14 @@ class SQLiteDialect(default.DefaultDialect): # original DDL. The referred columns of the foreign key # constraint are therefore the primary key of the referred # table. - referred_pk = self.get_pk_constraint( - connection, rtbl, schema=schema, **kw - ) - # note that if table doesn't exist, we still get back a record, - # just it has no columns in it - referred_columns = referred_pk["constrained_columns"] + try: + referred_pk = self.get_pk_constraint( + connection, rtbl, schema=schema, **kw + ) + referred_columns = referred_pk["constrained_columns"] + except exc.NoSuchTableError: + # ignore not existing parents + referred_columns = [] else: # note we use this list only if this is the first column # in the constraint. for subsequent columns we ignore the @@ -2378,11 +2386,11 @@ class SQLiteDialect(default.DefaultDialect): ) table_data = self._get_table_sql(connection, table_name, schema=schema) - if table_data is None: - # system tables, etc. - return [] def parse_fks(): + if table_data is None: + # system tables, etc. + return FK_PATTERN = ( r"(?:CONSTRAINT (\w+) +)?" r"FOREIGN KEY *\( *(.+?) *\) +" @@ -2453,7 +2461,10 @@ class SQLiteDialect(default.DefaultDialect): # use them as is as it's extremely difficult to parse inline # constraints fkeys.extend(keys_by_signature.values()) - return fkeys + if fkeys: + return fkeys + else: + return ReflectionDefaults.foreign_keys() def _find_cols_in_sig(self, sig): for match in re.finditer(r'(?:"(.+?)")|([a-z0-9_]+)', sig, re.I): @@ -2480,12 +2491,11 @@ class SQLiteDialect(default.DefaultDialect): table_data = self._get_table_sql( connection, table_name, schema=schema, **kw ) - if not table_data: - return [] - unique_constraints = [] def parse_uqs(): + if table_data is None: + return UNIQUE_PATTERN = r'(?:CONSTRAINT "?(.+?)"? +)?UNIQUE *\((.+?)\)' INLINE_UNIQUE_PATTERN = ( r'(?:(".+?")|(?:[\[`])?([a-z0-9_]+)(?:[\]`])?) ' @@ -2513,15 +2523,16 @@ class SQLiteDialect(default.DefaultDialect): unique_constraints.append(parsed_constraint) # NOTE: auto_index_by_sig might not be empty here, # the PRIMARY KEY may have an entry. - return unique_constraints + if unique_constraints: + return unique_constraints + else: + return ReflectionDefaults.unique_constraints() @reflection.cache def get_check_constraints(self, connection, table_name, schema=None, **kw): table_data = self._get_table_sql( connection, table_name, schema=schema, **kw ) - if not table_data: - return [] CHECK_PATTERN = r"(?:CONSTRAINT (.+) +)?" r"CHECK *\( *(.+) *\),? *" check_constraints = [] @@ -2531,7 +2542,7 @@ class SQLiteDialect(default.DefaultDialect): # necessarily makes assumptions as to how the CREATE TABLE # was emitted. - for match in re.finditer(CHECK_PATTERN, table_data, re.I): + for match in re.finditer(CHECK_PATTERN, table_data or "", re.I): name = match.group(1) if name: @@ -2539,7 +2550,10 @@ class SQLiteDialect(default.DefaultDialect): check_constraints.append({"sqltext": match.group(2), "name": name}) - return check_constraints + if check_constraints: + return check_constraints + else: + return ReflectionDefaults.check_constraints() @reflection.cache def get_indexes(self, connection, table_name, schema=None, **kw): @@ -2561,7 +2575,7 @@ class SQLiteDialect(default.DefaultDialect): # loop thru unique indexes to get the column names. for idx in list(indexes): pragma_index = self._get_table_pragma( - connection, "index_info", idx["name"] + connection, "index_info", idx["name"], schema=schema ) for row in pragma_index: @@ -2574,7 +2588,23 @@ class SQLiteDialect(default.DefaultDialect): break else: idx["column_names"].append(row[2]) - return indexes + indexes.sort(key=lambda d: d["name"] or "~") # sort None as last + if indexes: + return indexes + elif not self.has_table(connection, table_name, schema): + raise exc.NoSuchTableError( + f"{schema}.{table_name}" if schema else table_name + ) + else: + return ReflectionDefaults.indexes() + + def _is_sys_table(self, table_name): + return table_name in { + "sqlite_schema", + "sqlite_master", + "sqlite_temp_schema", + "sqlite_temp_master", + } @reflection.cache def _get_table_sql(self, connection, table_name, schema=None, **kw): @@ -2590,22 +2620,25 @@ class SQLiteDialect(default.DefaultDialect): " (SELECT * FROM %(schema)ssqlite_master UNION ALL " " SELECT * FROM %(schema)ssqlite_temp_master) " "WHERE name = ? " - "AND type = 'table'" % {"schema": schema_expr} + "AND type in ('table', 'view')" % {"schema": schema_expr} ) rs = connection.exec_driver_sql(s, (table_name,)) except exc.DBAPIError: s = ( "SELECT sql FROM %(schema)ssqlite_master " "WHERE name = ? " - "AND type = 'table'" % {"schema": schema_expr} + "AND type in ('table', 'view')" % {"schema": schema_expr} ) rs = connection.exec_driver_sql(s, (table_name,)) - return rs.scalar() + value = rs.scalar() + if value is None and not self._is_sys_table(table_name): + raise exc.NoSuchTableError(f"{schema_expr}{table_name}") + return value def _get_table_pragma(self, connection, pragma, table_name, schema=None): quote = self.identifier_preparer.quote_identifier if schema is not None: - statements = ["PRAGMA %s." % quote(schema)] + statements = [f"PRAGMA {quote(schema)}."] else: # because PRAGMA looks in all attached databases if no schema # given, need to specify "main" schema, however since we want @@ -2615,7 +2648,7 @@ class SQLiteDialect(default.DefaultDialect): qtable = quote(table_name) for statement in statements: - statement = "%s%s(%s)" % (statement, pragma, qtable) + statement = f"{statement}{pragma}({qtable})" cursor = connection.exec_driver_sql(statement) if not cursor._soft_closed: # work around SQLite issue whereby cursor.description diff --git a/lib/sqlalchemy/engine/__init__.py b/lib/sqlalchemy/engine/__init__.py index afba170759..77c2fea408 100644 --- a/lib/sqlalchemy/engine/__init__.py +++ b/lib/sqlalchemy/engine/__init__.py @@ -38,6 +38,8 @@ from .interfaces import ExecutionContext as ExecutionContext from .interfaces import TypeCompiler as TypeCompiler from .mock import create_mock_engine as create_mock_engine from .reflection import Inspector as Inspector +from .reflection import ObjectKind as ObjectKind +from .reflection import ObjectScope as ObjectScope from .result import ChunkedIteratorResult as ChunkedIteratorResult from .result import FrozenResult as FrozenResult from .result import IteratorResult as IteratorResult diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index df35e7128c..40af062529 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -45,6 +45,8 @@ from .interfaces import CacheStats from .interfaces import DBAPICursor from .interfaces import Dialect from .interfaces import ExecutionContext +from .reflection import ObjectKind +from .reflection import ObjectScope from .. import event from .. import exc from .. import pool @@ -508,15 +510,22 @@ class DefaultDialect(Dialect): """ return type_api.adapt_type(typeobj, self.colspecs) - def has_index(self, connection, table_name, index_name, schema=None): - if not self.has_table(connection, table_name, schema=schema): + def has_index(self, connection, table_name, index_name, schema=None, **kw): + if not self.has_table(connection, table_name, schema=schema, **kw): return False - for idx in self.get_indexes(connection, table_name, schema=schema): + for idx in self.get_indexes( + connection, table_name, schema=schema, **kw + ): if idx["name"] == index_name: return True else: return False + def has_schema( + self, connection: Connection, schema_name: str, **kw: Any + ) -> bool: + return schema_name in self.get_schema_names(connection, **kw) + def validate_identifier(self, ident): if len(ident) > self.max_identifier_length: raise exc.IdentifierError( @@ -769,6 +778,122 @@ class DefaultDialect(Dialect): def get_driver_connection(self, connection): return connection + def _overrides_default(self, method): + return ( + getattr(type(self), method).__code__ + is not getattr(DefaultDialect, method).__code__ + ) + + def _default_multi_reflect( + self, + single_tbl_method, + connection, + kind, + schema, + filter_names, + scope, + **kw, + ): + + names_fns = [] + temp_names_fns = [] + if ObjectKind.TABLE in kind: + names_fns.append(self.get_table_names) + temp_names_fns.append(self.get_temp_table_names) + if ObjectKind.VIEW in kind: + names_fns.append(self.get_view_names) + temp_names_fns.append(self.get_temp_view_names) + if ObjectKind.MATERIALIZED_VIEW in kind: + names_fns.append(self.get_materialized_view_names) + # no temp materialized view at the moment + # temp_names_fns.append(self.get_temp_materialized_view_names) + + unreflectable = kw.pop("unreflectable", {}) + + if ( + filter_names + and scope is ObjectScope.ANY + and kind is ObjectKind.ANY + ): + # if names are given and no qualification on type of table + # (i.e. the Table(..., autoload) case), take the names as given, + # don't run names queries. If a table does not exit + # NoSuchTableError is raised and it's skipped + + # this also suits the case for mssql where we can reflect + # individual temp tables but there's no temp_names_fn + names = filter_names + else: + names = [] + name_kw = {"schema": schema, **kw} + fns = [] + if ObjectScope.DEFAULT in scope: + fns.extend(names_fns) + if ObjectScope.TEMPORARY in scope: + fns.extend(temp_names_fns) + + for fn in fns: + try: + names.extend(fn(connection, **name_kw)) + except NotImplementedError: + pass + + if filter_names: + filter_names = set(filter_names) + + # iterate over all the tables/views and call the single table method + for table in names: + if not filter_names or table in filter_names: + key = (schema, table) + try: + yield ( + key, + single_tbl_method( + connection, table, schema=schema, **kw + ), + ) + except exc.UnreflectableTableError as err: + if key not in unreflectable: + unreflectable[key] = err + except exc.NoSuchTableError: + pass + + def get_multi_table_options(self, connection, **kw): + return self._default_multi_reflect( + self.get_table_options, connection, **kw + ) + + def get_multi_columns(self, connection, **kw): + return self._default_multi_reflect(self.get_columns, connection, **kw) + + def get_multi_pk_constraint(self, connection, **kw): + return self._default_multi_reflect( + self.get_pk_constraint, connection, **kw + ) + + def get_multi_foreign_keys(self, connection, **kw): + return self._default_multi_reflect( + self.get_foreign_keys, connection, **kw + ) + + def get_multi_indexes(self, connection, **kw): + return self._default_multi_reflect(self.get_indexes, connection, **kw) + + def get_multi_unique_constraints(self, connection, **kw): + return self._default_multi_reflect( + self.get_unique_constraints, connection, **kw + ) + + def get_multi_check_constraints(self, connection, **kw): + return self._default_multi_reflect( + self.get_check_constraints, connection, **kw + ) + + def get_multi_table_comment(self, connection, **kw): + return self._default_multi_reflect( + self.get_table_comment, connection, **kw + ) + class StrCompileDialect(DefaultDialect): diff --git a/lib/sqlalchemy/engine/interfaces.py b/lib/sqlalchemy/engine/interfaces.py index b8e85b6468..28ed03f99d 100644 --- a/lib/sqlalchemy/engine/interfaces.py +++ b/lib/sqlalchemy/engine/interfaces.py @@ -15,7 +15,9 @@ from typing import Any from typing import Awaitable from typing import Callable from typing import ClassVar +from typing import Collection from typing import Dict +from typing import Iterable from typing import List from typing import Mapping from typing import MutableMapping @@ -324,7 +326,7 @@ class ReflectedColumn(TypedDict): nullable: bool """column nullability""" - default: str + default: Optional[str] """column default expression as a SQL string""" autoincrement: NotRequired[bool] @@ -343,11 +345,11 @@ class ReflectedColumn(TypedDict): comment: NotRequired[Optional[str]] """comment for the column, if present""" - computed: NotRequired[Optional[ReflectedComputed]] + computed: NotRequired[ReflectedComputed] """indicates this column is computed at insert (possibly update) time by the database.""" - identity: NotRequired[Optional[ReflectedIdentity]] + identity: NotRequired[ReflectedIdentity] """indicates this column is an IDENTITY column""" dialect_options: NotRequired[Dict[str, Any]] @@ -390,6 +392,9 @@ class ReflectedUniqueConstraint(TypedDict): column_names: List[str] """column names which comprise the constraint""" + duplicates_index: NotRequired[Optional[str]] + "Indicates if this unique constraint duplicates an index with this name" + dialect_options: NotRequired[Dict[str, Any]] """Additional dialect-specific options detected for this reflected object""" @@ -439,7 +444,7 @@ class ReflectedForeignKeyConstraint(TypedDict): referred_columns: List[str] """referenced column names""" - dialect_options: NotRequired[Dict[str, Any]] + options: NotRequired[Dict[str, Any]] """Additional dialect-specific options detected for this reflected object""" @@ -462,9 +467,8 @@ class ReflectedIndex(TypedDict): unique: bool """whether or not the index has a unique flag""" - duplicates_constraint: NotRequired[bool] - """boolean indicating this index mirrors a unique constraint of the same - name""" + duplicates_constraint: NotRequired[Optional[str]] + "Indicates if this index mirrors a unique constraint with this name" include_columns: NotRequired[List[str]] """columns to include in the INCLUDE clause for supporting databases. @@ -472,7 +476,7 @@ class ReflectedIndex(TypedDict): .. deprecated:: 2.0 Legacy value, will be replaced with - ``d["dialect_options"][]["include"]`` + ``d["dialect_options"]["_include"]`` """ @@ -494,7 +498,7 @@ class ReflectedTableComment(TypedDict): """ - text: str + text: Optional[str] """text of the comment""" @@ -547,6 +551,7 @@ class BindTyping(Enum): VersionInfoType = Tuple[Union[int, str], ...] +TableKey = Tuple[Optional[str], str] class Dialect(EventTarget): @@ -1040,7 +1045,7 @@ class Dialect(EventTarget): raise NotImplementedError() - def initialize(self, connection: "Connection") -> None: + def initialize(self, connection: Connection) -> None: """Called during strategized creation of the dialect with a connection. @@ -1060,9 +1065,14 @@ class Dialect(EventTarget): pass + if TYPE_CHECKING: + + def _overrides_default(self, method_name: str) -> bool: + ... + def get_columns( self, - connection: "Connection", + connection: Connection, table_name: str, schema: Optional[str] = None, **kw: Any, @@ -1074,13 +1084,40 @@ class Dialect(EventTarget): information as a list of dictionaries corresponding to the :class:`.ReflectedColumn` dictionary. + This is an internal dialect method. Applications should use + :meth:`.Inspector.get_columns`. + """ + + def get_multi_columns( + self, + connection: Connection, + schema: Optional[str] = None, + filter_names: Optional[Collection[str]] = None, + **kw: Any, + ) -> Iterable[Tuple[TableKey, List[ReflectedColumn]]]: + """Return information about columns in all tables in the + given ``schema``. + + This is an internal dialect method. Applications should use + :meth:`.Inspector.get_multi_columns`. + + .. note:: The :class:`_engine.DefaultDialect` provides a default + implementation that will call the single table method for + each object returned by :meth:`Dialect.get_table_names`, + :meth:`Dialect.get_view_names` or + :meth:`Dialect.get_materialized_view_names` depending on the + provided ``kind``. Dialects that want to support a faster + implementation should implement this method. + + .. versionadded:: 2.0 + """ raise NotImplementedError() def get_pk_constraint( self, - connection: "Connection", + connection: Connection, table_name: str, schema: Optional[str] = None, **kw: Any, @@ -1093,13 +1130,41 @@ class Dialect(EventTarget): key information as a dictionary corresponding to the :class:`.ReflectedPrimaryKeyConstraint` dictionary. + This is an internal dialect method. Applications should use + :meth:`.Inspector.get_pk_constraint`. + + """ + raise NotImplementedError() + + def get_multi_pk_constraint( + self, + connection: Connection, + schema: Optional[str] = None, + filter_names: Optional[Collection[str]] = None, + **kw: Any, + ) -> Iterable[Tuple[TableKey, ReflectedPrimaryKeyConstraint]]: + """Return information about primary key constraints in + all tables in the given ``schema``. + + This is an internal dialect method. Applications should use + :meth:`.Inspector.get_multi_pk_constraint`. + + .. note:: The :class:`_engine.DefaultDialect` provides a default + implementation that will call the single table method for + each object returned by :meth:`Dialect.get_table_names`, + :meth:`Dialect.get_view_names` or + :meth:`Dialect.get_materialized_view_names` depending on the + provided ``kind``. Dialects that want to support a faster + implementation should implement this method. + + .. versionadded:: 2.0 """ raise NotImplementedError() def get_foreign_keys( self, - connection: "Connection", + connection: Connection, table_name: str, schema: Optional[str] = None, **kw: Any, @@ -1111,42 +1176,104 @@ class Dialect(EventTarget): key information as a list of dicts corresponding to the :class:`.ReflectedForeignKeyConstraint` dictionary. + This is an internal dialect method. Applications should use + :meth:`_engine.Inspector.get_foreign_keys`. + """ + + raise NotImplementedError() + + def get_multi_foreign_keys( + self, + connection: Connection, + schema: Optional[str] = None, + filter_names: Optional[Collection[str]] = None, + **kw: Any, + ) -> Iterable[Tuple[TableKey, List[ReflectedForeignKeyConstraint]]]: + """Return information about foreign_keys in all tables + in the given ``schema``. + + This is an internal dialect method. Applications should use + :meth:`_engine.Inspector.get_multi_foreign_keys`. + + .. note:: The :class:`_engine.DefaultDialect` provides a default + implementation that will call the single table method for + each object returned by :meth:`Dialect.get_table_names`, + :meth:`Dialect.get_view_names` or + :meth:`Dialect.get_materialized_view_names` depending on the + provided ``kind``. Dialects that want to support a faster + implementation should implement this method. + + .. versionadded:: 2.0 + """ raise NotImplementedError() def get_table_names( - self, connection: "Connection", schema: Optional[str] = None, **kw: Any + self, connection: Connection, schema: Optional[str] = None, **kw: Any ) -> List[str]: - """Return a list of table names for ``schema``.""" + """Return a list of table names for ``schema``. + + This is an internal dialect method. Applications should use + :meth:`_engine.Inspector.get_table_names`. + + """ raise NotImplementedError() def get_temp_table_names( - self, connection: "Connection", schema: Optional[str] = None, **kw: Any + self, connection: Connection, schema: Optional[str] = None, **kw: Any ) -> List[str]: """Return a list of temporary table names on the given connection, if supported by the underlying backend. + This is an internal dialect method. Applications should use + :meth:`_engine.Inspector.get_temp_table_names`. + """ raise NotImplementedError() def get_view_names( - self, connection: "Connection", schema: Optional[str] = None, **kw: Any + self, connection: Connection, schema: Optional[str] = None, **kw: Any + ) -> List[str]: + """Return a list of all non-materialized view names available in the + database. + + This is an internal dialect method. Applications should use + :meth:`_engine.Inspector.get_view_names`. + + :param schema: schema name to query, if not the default schema. + + """ + + raise NotImplementedError() + + def get_materialized_view_names( + self, connection: Connection, schema: Optional[str] = None, **kw: Any ) -> List[str]: - """Return a list of all view names available in the database. + """Return a list of all materialized view names available in the + database. + + This is an internal dialect method. Applications should use + :meth:`_engine.Inspector.get_materialized_view_names`. :param schema: schema name to query, if not the default schema. + + .. versionadded:: 2.0 + """ raise NotImplementedError() def get_sequence_names( - self, connection: "Connection", schema: Optional[str] = None, **kw: Any + self, connection: Connection, schema: Optional[str] = None, **kw: Any ) -> List[str]: """Return a list of all sequence names available in the database. + This is an internal dialect method. Applications should use + :meth:`_engine.Inspector.get_sequence_names`. + :param schema: schema name to query, if not the default schema. .. versionadded:: 1.4 @@ -1155,26 +1282,40 @@ class Dialect(EventTarget): raise NotImplementedError() def get_temp_view_names( - self, connection: "Connection", schema: Optional[str] = None, **kw: Any + self, connection: Connection, schema: Optional[str] = None, **kw: Any ) -> List[str]: """Return a list of temporary view names on the given connection, if supported by the underlying backend. + This is an internal dialect method. Applications should use + :meth:`_engine.Inspector.get_temp_view_names`. + """ raise NotImplementedError() + def get_schema_names(self, connection: Connection, **kw: Any) -> List[str]: + """Return a list of all schema names available in the database. + + This is an internal dialect method. Applications should use + :meth:`_engine.Inspector.get_schema_names`. + """ + raise NotImplementedError() + def get_view_definition( self, - connection: "Connection", + connection: Connection, view_name: str, schema: Optional[str] = None, **kw: Any, ) -> str: - """Return view definition. + """Return plain or materialized view definition. + + This is an internal dialect method. Applications should use + :meth:`_engine.Inspector.get_view_definition`. Given a :class:`_engine.Connection`, a string - `view_name`, and an optional string ``schema``, return the view + ``view_name``, and an optional string ``schema``, return the view definition. """ @@ -1182,7 +1323,7 @@ class Dialect(EventTarget): def get_indexes( self, - connection: "Connection", + connection: Connection, table_name: str, schema: Optional[str] = None, **kw: Any, @@ -1194,13 +1335,42 @@ class Dialect(EventTarget): information as a list of dictionaries corresponding to the :class:`.ReflectedIndex` dictionary. + This is an internal dialect method. Applications should use + :meth:`.Inspector.get_indexes`. + """ + + raise NotImplementedError() + + def get_multi_indexes( + self, + connection: Connection, + schema: Optional[str] = None, + filter_names: Optional[Collection[str]] = None, + **kw: Any, + ) -> Iterable[Tuple[TableKey, List[ReflectedIndex]]]: + """Return information about indexes in in all tables + in the given ``schema``. + + This is an internal dialect method. Applications should use + :meth:`.Inspector.get_multi_indexes`. + + .. note:: The :class:`_engine.DefaultDialect` provides a default + implementation that will call the single table method for + each object returned by :meth:`Dialect.get_table_names`, + :meth:`Dialect.get_view_names` or + :meth:`Dialect.get_materialized_view_names` depending on the + provided ``kind``. Dialects that want to support a faster + implementation should implement this method. + + .. versionadded:: 2.0 + """ raise NotImplementedError() def get_unique_constraints( self, - connection: "Connection", + connection: Connection, table_name: str, schema: Optional[str] = None, **kw: Any, @@ -1211,13 +1381,42 @@ class Dialect(EventTarget): unique constraint information as a list of dicts corresponding to the :class:`.ReflectedUniqueConstraint` dictionary. + This is an internal dialect method. Applications should use + :meth:`.Inspector.get_unique_constraints`. + """ + + raise NotImplementedError() + + def get_multi_unique_constraints( + self, + connection: Connection, + schema: Optional[str] = None, + filter_names: Optional[Collection[str]] = None, + **kw: Any, + ) -> Iterable[Tuple[TableKey, List[ReflectedUniqueConstraint]]]: + """Return information about unique constraints in all tables + in the given ``schema``. + + This is an internal dialect method. Applications should use + :meth:`.Inspector.get_multi_unique_constraints`. + + .. note:: The :class:`_engine.DefaultDialect` provides a default + implementation that will call the single table method for + each object returned by :meth:`Dialect.get_table_names`, + :meth:`Dialect.get_view_names` or + :meth:`Dialect.get_materialized_view_names` depending on the + provided ``kind``. Dialects that want to support a faster + implementation should implement this method. + + .. versionadded:: 2.0 + """ raise NotImplementedError() def get_check_constraints( self, - connection: "Connection", + connection: Connection, table_name: str, schema: Optional[str] = None, **kw: Any, @@ -1228,26 +1427,86 @@ class Dialect(EventTarget): check constraint information as a list of dicts corresponding to the :class:`.ReflectedCheckConstraint` dictionary. + This is an internal dialect method. Applications should use + :meth:`.Inspector.get_check_constraints`. + + .. versionadded:: 1.1.0 + + """ + + raise NotImplementedError() + + def get_multi_check_constraints( + self, + connection: Connection, + schema: Optional[str] = None, + filter_names: Optional[Collection[str]] = None, + **kw: Any, + ) -> Iterable[Tuple[TableKey, List[ReflectedCheckConstraint]]]: + """Return information about check constraints in all tables + in the given ``schema``. + + This is an internal dialect method. Applications should use + :meth:`.Inspector.get_multi_check_constraints`. + + .. note:: The :class:`_engine.DefaultDialect` provides a default + implementation that will call the single table method for + each object returned by :meth:`Dialect.get_table_names`, + :meth:`Dialect.get_view_names` or + :meth:`Dialect.get_materialized_view_names` depending on the + provided ``kind``. Dialects that want to support a faster + implementation should implement this method. + + .. versionadded:: 2.0 + """ raise NotImplementedError() def get_table_options( self, - connection: "Connection", + connection: Connection, table_name: str, schema: Optional[str] = None, **kw: Any, - ) -> Optional[Dict[str, Any]]: - r"""Return the "options" for the table identified by ``table_name`` - as a dictionary. + ) -> Dict[str, Any]: + """Return a dictionary of options specified when ``table_name`` + was created. + This is an internal dialect method. Applications should use + :meth:`_engine.Inspector.get_table_options`. """ - return None + raise NotImplementedError() + + def get_multi_table_options( + self, + connection: Connection, + schema: Optional[str] = None, + filter_names: Optional[Collection[str]] = None, + **kw: Any, + ) -> Iterable[Tuple[TableKey, Dict[str, Any]]]: + """Return a dictionary of options specified when the tables in the + given schema were created. + + This is an internal dialect method. Applications should use + :meth:`_engine.Inspector.get_multi_table_options`. + + .. note:: The :class:`_engine.DefaultDialect` provides a default + implementation that will call the single table method for + each object returned by :meth:`Dialect.get_table_names`, + :meth:`Dialect.get_view_names` or + :meth:`Dialect.get_materialized_view_names` depending on the + provided ``kind``. Dialects that want to support a faster + implementation should implement this method. + + .. versionadded:: 2.0 + + """ + raise NotImplementedError() def get_table_comment( self, - connection: "Connection", + connection: Connection, table_name: str, schema: Optional[str] = None, **kw: Any, @@ -1258,6 +1517,8 @@ class Dialect(EventTarget): table comment information as a dictionary corresponding to the :class:`.ReflectedTableComment` dictionary. + This is an internal dialect method. Applications should use + :meth:`.Inspector.get_table_comment`. :raise: ``NotImplementedError`` for dialects that don't support comments. @@ -1268,6 +1529,33 @@ class Dialect(EventTarget): raise NotImplementedError() + def get_multi_table_comment( + self, + connection: Connection, + schema: Optional[str] = None, + filter_names: Optional[Collection[str]] = None, + **kw: Any, + ) -> Iterable[Tuple[TableKey, ReflectedTableComment]]: + """Return information about the table comment in all tables + in the given ``schema``. + + This is an internal dialect method. Applications should use + :meth:`_engine.Inspector.get_multi_table_comment`. + + .. note:: The :class:`_engine.DefaultDialect` provides a default + implementation that will call the single table method for + each object returned by :meth:`Dialect.get_table_names`, + :meth:`Dialect.get_view_names` or + :meth:`Dialect.get_materialized_view_names` depending on the + provided ``kind``. Dialects that want to support a faster + implementation should implement this method. + + .. versionadded:: 2.0 + + """ + + raise NotImplementedError() + def normalize_name(self, name: str) -> str: """convert the given name to lowercase if it is detected as case insensitive. @@ -1290,7 +1578,7 @@ class Dialect(EventTarget): def has_table( self, - connection: "Connection", + connection: Connection, table_name: str, schema: Optional[str] = None, **kw: Any, @@ -1327,21 +1615,24 @@ class Dialect(EventTarget): def has_index( self, - connection: "Connection", + connection: Connection, table_name: str, index_name: str, schema: Optional[str] = None, + **kw: Any, ) -> bool: """Check the existence of a particular index name in the database. Given a :class:`_engine.Connection` object, a string - ``table_name`` and string index name, return True if an index of the - given name on the given table exists, false otherwise. + ``table_name`` and string index name, return ``True`` if an index of + the given name on the given table exists, ``False`` otherwise. The :class:`.DefaultDialect` implements this in terms of the :meth:`.Dialect.has_table` and :meth:`.Dialect.get_indexes` methods, however dialects can implement a more performant version. + This is an internal dialect method. Applications should use + :meth:`_engine.Inspector.has_index`. .. versionadded:: 1.4 @@ -1351,7 +1642,7 @@ class Dialect(EventTarget): def has_sequence( self, - connection: "Connection", + connection: Connection, sequence_name: str, schema: Optional[str] = None, **kw: Any, @@ -1359,13 +1650,39 @@ class Dialect(EventTarget): """Check the existence of a particular sequence in the database. Given a :class:`_engine.Connection` object and a string - `sequence_name`, return True if the given sequence exists in - the database, False otherwise. + `sequence_name`, return ``True`` if the given sequence exists in + the database, ``False`` otherwise. + + This is an internal dialect method. Applications should use + :meth:`_engine.Inspector.has_sequence`. + """ + + raise NotImplementedError() + + def has_schema( + self, connection: Connection, schema_name: str, **kw: Any + ) -> bool: + """Check the existence of a particular schema name in the database. + + Given a :class:`_engine.Connection` object, a string + ``schema_name``, return ``True`` if a schema of the + given exists, ``False`` otherwise. + + The :class:`.DefaultDialect` implements this by checking + the presence of ``schema_name`` among the schemas returned by + :meth:`.Dialect.get_schema_names`, + however dialects can implement a more performant version. + + This is an internal dialect method. Applications should use + :meth:`_engine.Inspector.has_schema`. + + .. versionadded:: 2.0 + """ raise NotImplementedError() - def _get_server_version_info(self, connection: "Connection") -> Any: + def _get_server_version_info(self, connection: Connection) -> Any: """Retrieve the server version info from the given connection. This is used by the default implementation to populate the @@ -1376,7 +1693,7 @@ class Dialect(EventTarget): raise NotImplementedError() - def _get_default_schema_name(self, connection: "Connection") -> str: + def _get_default_schema_name(self, connection: Connection) -> str: """Return the string name of the currently selected schema from the given connection. @@ -1481,7 +1798,7 @@ class Dialect(EventTarget): raise NotImplementedError() - def do_savepoint(self, connection: "Connection", name: str) -> None: + def do_savepoint(self, connection: Connection, name: str) -> None: """Create a savepoint with the given name. :param connection: a :class:`_engine.Connection`. @@ -1492,7 +1809,7 @@ class Dialect(EventTarget): raise NotImplementedError() def do_rollback_to_savepoint( - self, connection: "Connection", name: str + self, connection: Connection, name: str ) -> None: """Rollback a connection to the named savepoint. @@ -1503,9 +1820,7 @@ class Dialect(EventTarget): raise NotImplementedError() - def do_release_savepoint( - self, connection: "Connection", name: str - ) -> None: + def do_release_savepoint(self, connection: Connection, name: str) -> None: """Release the named savepoint on a connection. :param connection: a :class:`_engine.Connection`. @@ -1514,7 +1829,7 @@ class Dialect(EventTarget): raise NotImplementedError() - def do_begin_twophase(self, connection: "Connection", xid: Any) -> None: + def do_begin_twophase(self, connection: Connection, xid: Any) -> None: """Begin a two phase transaction on the given connection. :param connection: a :class:`_engine.Connection`. @@ -1524,7 +1839,7 @@ class Dialect(EventTarget): raise NotImplementedError() - def do_prepare_twophase(self, connection: "Connection", xid: Any) -> None: + def do_prepare_twophase(self, connection: Connection, xid: Any) -> None: """Prepare a two phase transaction on the given connection. :param connection: a :class:`_engine.Connection`. @@ -1536,7 +1851,7 @@ class Dialect(EventTarget): def do_rollback_twophase( self, - connection: "Connection", + connection: Connection, xid: Any, is_prepared: bool = True, recover: bool = False, @@ -1555,7 +1870,7 @@ class Dialect(EventTarget): def do_commit_twophase( self, - connection: "Connection", + connection: Connection, xid: Any, is_prepared: bool = True, recover: bool = False, @@ -1573,7 +1888,7 @@ class Dialect(EventTarget): raise NotImplementedError() - def do_recover_twophase(self, connection: "Connection") -> List[Any]: + def do_recover_twophase(self, connection: Connection) -> List[Any]: """Recover list of uncommitted prepared two phase transaction identifiers on the given connection. diff --git a/lib/sqlalchemy/engine/reflection.py b/lib/sqlalchemy/engine/reflection.py index 4fc57d5f49..32c89106b7 100644 --- a/lib/sqlalchemy/engine/reflection.py +++ b/lib/sqlalchemy/engine/reflection.py @@ -27,39 +27,148 @@ methods such as get_table_names, get_columns, etc. from __future__ import annotations import contextlib +from dataclasses import dataclass +from enum import auto +from enum import Flag +from enum import unique +from typing import Any +from typing import Callable +from typing import Collection +from typing import Dict +from typing import Generator +from typing import Iterable from typing import List from typing import Optional +from typing import Sequence +from typing import Set +from typing import Tuple +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union from .base import Connection from .base import Engine -from .interfaces import ReflectedColumn from .. import exc from .. import inspection from .. import sql from .. import util from ..sql import operators from ..sql import schema as sa_schema +from ..sql.cache_key import _ad_hoc_cache_key_from_args +from ..sql.elements import TextClause from ..sql.type_api import TypeEngine +from ..sql.visitors import InternalTraversal from ..util import topological +from ..util.typing import final + +if TYPE_CHECKING: + from .interfaces import Dialect + from .interfaces import ReflectedCheckConstraint + from .interfaces import ReflectedColumn + from .interfaces import ReflectedForeignKeyConstraint + from .interfaces import ReflectedIndex + from .interfaces import ReflectedPrimaryKeyConstraint + from .interfaces import ReflectedTableComment + from .interfaces import ReflectedUniqueConstraint + from .interfaces import TableKey + +_R = TypeVar("_R") @util.decorator -def cache(fn, self, con, *args, **kw): +def cache( + fn: Callable[..., _R], + self: Dialect, + con: Connection, + *args: Any, + **kw: Any, +) -> _R: info_cache = kw.get("info_cache", None) if info_cache is None: return fn(self, con, *args, **kw) + exclude = {"info_cache", "unreflectable"} key = ( fn.__name__, tuple(a for a in args if isinstance(a, str)), - tuple((k, v) for k, v in kw.items() if k != "info_cache"), + tuple((k, v) for k, v in kw.items() if k not in exclude), ) - ret = info_cache.get(key) + ret: _R = info_cache.get(key) if ret is None: ret = fn(self, con, *args, **kw) info_cache[key] = ret return ret +def flexi_cache( + *traverse_args: Tuple[str, InternalTraversal] +) -> Callable[[Callable[..., _R]], Callable[..., _R]]: + @util.decorator + def go( + fn: Callable[..., _R], + self: Dialect, + con: Connection, + *args: Any, + **kw: Any, + ) -> _R: + info_cache = kw.get("info_cache", None) + if info_cache is None: + return fn(self, con, *args, **kw) + key = _ad_hoc_cache_key_from_args((fn.__name__,), traverse_args, args) + ret: _R = info_cache.get(key) + if ret is None: + ret = fn(self, con, *args, **kw) + info_cache[key] = ret + return ret + + return go + + +@unique +class ObjectKind(Flag): + """Enumerator that indicates which kind of object to return when calling + the ``get_multi`` methods. + + This is a Flag enum, so custom combinations can be passed. For example, + to reflect tables and plain views ``ObjectKind.TABLE | ObjectKind.VIEW`` + may be used. + + .. note:: + Not all dialect may support all kind of object. If a dialect does + not support a particular object an empty dict is returned. + In case a dialect supports an object, but the requested method + is not applicable for the specified kind the default value + will be returned for each reflected object. For example reflecting + check constraints of view return a dict with all the views with + empty lists as values. + """ + + TABLE = auto() + "Reflect table objects" + VIEW = auto() + "Reflect plain view objects" + MATERIALIZED_VIEW = auto() + "Reflect materialized view object" + + ANY_VIEW = VIEW | MATERIALIZED_VIEW + "Reflect any kind of view objects" + ANY = TABLE | VIEW | MATERIALIZED_VIEW + "Reflect all type of objects" + + +@unique +class ObjectScope(Flag): + """Enumerator that indicates which scope to use when calling + the ``get_multi`` methods. + """ + + DEFAULT = auto() + "Include default scope" + TEMPORARY = auto() + "Include only temp scope" + ANY = DEFAULT | TEMPORARY + "Include both default and temp scope" + + @inspection._self_inspects class Inspector(inspection.Inspectable["Inspector"]): """Performs database schema inspection. @@ -85,6 +194,12 @@ class Inspector(inspection.Inspectable["Inspector"]): """ + bind: Union[Engine, Connection] + engine: Engine + _op_context_requires_connect: bool + dialect: Dialect + info_cache: Dict[Any, Any] + @util.deprecated( "1.4", "The __init__() method on :class:`_reflection.Inspector` " @@ -96,7 +211,7 @@ class Inspector(inspection.Inspectable["Inspector"]): "in order to " "acquire an :class:`_reflection.Inspector`.", ) - def __init__(self, bind): + def __init__(self, bind: Union[Engine, Connection]): """Initialize a new :class:`_reflection.Inspector`. :param bind: a :class:`~sqlalchemy.engine.Connection`, @@ -108,38 +223,51 @@ class Inspector(inspection.Inspectable["Inspector"]): :meth:`_reflection.Inspector.from_engine` """ - return self._init_legacy(bind) + self._init_legacy(bind) @classmethod - def _construct(cls, init, bind): + def _construct( + cls, init: Callable[..., Any], bind: Union[Engine, Connection] + ) -> Inspector: if hasattr(bind.dialect, "inspector"): - cls = bind.dialect.inspector + cls = bind.dialect.inspector # type: ignore[attr-defined] self = cls.__new__(cls) init(self, bind) return self - def _init_legacy(self, bind): + def _init_legacy(self, bind: Union[Engine, Connection]) -> None: if hasattr(bind, "exec_driver_sql"): - self._init_connection(bind) + self._init_connection(bind) # type: ignore[arg-type] else: - self._init_engine(bind) + self._init_engine(bind) # type: ignore[arg-type] - def _init_engine(self, engine): + def _init_engine(self, engine: Engine) -> None: self.bind = self.engine = engine engine.connect().close() self._op_context_requires_connect = True self.dialect = self.engine.dialect self.info_cache = {} - def _init_connection(self, connection): + def _init_connection(self, connection: Connection) -> None: self.bind = connection self.engine = connection.engine self._op_context_requires_connect = False self.dialect = self.engine.dialect self.info_cache = {} + def clear_cache(self) -> None: + """reset the cache for this :class:`.Inspector`. + + Inspection methods that have data cached will emit SQL queries + when next called to get new data. + + .. versionadded:: 2.0 + + """ + self.info_cache.clear() + @classmethod @util.deprecated( "1.4", @@ -152,7 +280,7 @@ class Inspector(inspection.Inspectable["Inspector"]): "in order to " "acquire an :class:`_reflection.Inspector`.", ) - def from_engine(cls, bind): + def from_engine(cls, bind: Engine) -> Inspector: """Construct a new dialect-specific Inspector object from the given engine or connection. @@ -172,15 +300,15 @@ class Inspector(inspection.Inspectable["Inspector"]): return cls._construct(cls._init_legacy, bind) @inspection._inspects(Engine) - def _engine_insp(bind): + def _engine_insp(bind: Engine) -> Inspector: # type: ignore[misc] return Inspector._construct(Inspector._init_engine, bind) @inspection._inspects(Connection) - def _connection_insp(bind): + def _connection_insp(bind: Connection) -> Inspector: # type: ignore[misc] return Inspector._construct(Inspector._init_connection, bind) @contextlib.contextmanager - def _operation_context(self): + def _operation_context(self) -> Generator[Connection, None, None]: """Return a context that optimizes for multiple operations on a single transaction. @@ -189,10 +317,11 @@ class Inspector(inspection.Inspectable["Inspector"]): :class:`_engine.Connection`. """ + conn: Connection if self._op_context_requires_connect: - conn = self.bind.connect() + conn = self.bind.connect() # type: ignore[union-attr] else: - conn = self.bind + conn = self.bind # type: ignore[assignment] try: yield conn finally: @@ -200,7 +329,7 @@ class Inspector(inspection.Inspectable["Inspector"]): conn.close() @contextlib.contextmanager - def _inspection_context(self): + def _inspection_context(self) -> Generator[Inspector, None, None]: """Return an :class:`_reflection.Inspector` from this one that will run all operations on a single connection. @@ -213,7 +342,7 @@ class Inspector(inspection.Inspectable["Inspector"]): yield sub_insp @property - def default_schema_name(self): + def default_schema_name(self) -> Optional[str]: """Return the default schema name presented by the dialect for the current engine's database user. @@ -223,30 +352,38 @@ class Inspector(inspection.Inspectable["Inspector"]): """ return self.dialect.default_schema_name - def get_schema_names(self): - """Return all schema names.""" + def get_schema_names(self, **kw: Any) -> List[str]: + r"""Return all schema names. - if hasattr(self.dialect, "get_schema_names"): - with self._operation_context() as conn: - return self.dialect.get_schema_names( - conn, info_cache=self.info_cache - ) - return [] + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + """ - def get_table_names(self, schema=None): - """Return all table names in referred to within a particular schema. + with self._operation_context() as conn: + return self.dialect.get_schema_names( + conn, info_cache=self.info_cache, **kw + ) + + def get_table_names( + self, schema: Optional[str] = None, **kw: Any + ) -> List[str]: + r"""Return all table names within a particular schema. The names are expected to be real tables only, not views. Views are instead returned using the - :meth:`_reflection.Inspector.get_view_names` - method. - + :meth:`_reflection.Inspector.get_view_names` and/or + :meth:`_reflection.Inspector.get_materialized_view_names` + methods. :param schema: Schema name. If ``schema`` is left at ``None``, the database's default schema is used, else the named schema is searched. If the database does not support named schemas, behavior is undefined if ``schema`` is not passed as ``None``. For special quoting, use :class:`.quoted_name`. + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. .. seealso:: @@ -258,43 +395,105 @@ class Inspector(inspection.Inspectable["Inspector"]): with self._operation_context() as conn: return self.dialect.get_table_names( - conn, schema, info_cache=self.info_cache + conn, schema, info_cache=self.info_cache, **kw ) - def has_table(self, table_name, schema=None): - """Return True if the backend has a table or view of the given name. + def has_table( + self, table_name: str, schema: Optional[str] = None, **kw: Any + ) -> bool: + r"""Return True if the backend has a table or view of the given name. :param table_name: name of the table to check :param schema: schema name to query, if not the default schema. + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. .. versionadded:: 1.4 - the :meth:`.Inspector.has_table` method replaces the :meth:`_engine.Engine.has_table` method. - .. versionchanged:: 2.0:: The method checks also for views. + .. versionchanged:: 2.0:: The method checks also for any type of + views (plain or materialized). In previous version this behaviour was dialect specific. New dialect suite tests were added to ensure all dialect conform with this behaviour. """ - # TODO: info_cache? with self._operation_context() as conn: - return self.dialect.has_table(conn, table_name, schema) + return self.dialect.has_table( + conn, table_name, schema, info_cache=self.info_cache, **kw + ) - def has_sequence(self, sequence_name, schema=None): - """Return True if the backend has a table of the given name. + def has_sequence( + self, sequence_name: str, schema: Optional[str] = None, **kw: Any + ) -> bool: + r"""Return True if the backend has a sequence with the given name. - :param sequence_name: name of the table to check + :param sequence_name: name of the sequence to check :param schema: schema name to query, if not the default schema. + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. .. versionadded:: 1.4 """ - # TODO: info_cache? with self._operation_context() as conn: - return self.dialect.has_sequence(conn, sequence_name, schema) + return self.dialect.has_sequence( + conn, sequence_name, schema, info_cache=self.info_cache, **kw + ) + + def has_index( + self, + table_name: str, + index_name: str, + schema: Optional[str] = None, + **kw: Any, + ) -> bool: + r"""Check the existence of a particular index name in the database. + + :param table_name: the name of the table the index belongs to + :param index_name: the name of the index to check + :param schema: schema name to query, if not the default schema. + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + .. versionadded:: 2.0 + + """ + with self._operation_context() as conn: + return self.dialect.has_index( + conn, + table_name, + index_name, + schema, + info_cache=self.info_cache, + **kw, + ) + + def has_schema(self, schema_name: str, **kw: Any) -> bool: + r"""Return True if the backend has a schema with the given name. + + :param schema_name: name of the schema to check + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + .. versionadded:: 2.0 + + """ + with self._operation_context() as conn: + return self.dialect.has_schema( + conn, schema_name, info_cache=self.info_cache, **kw + ) - def get_sorted_table_and_fkc_names(self, schema=None): - """Return dependency-sorted table and foreign key constraint names in + def get_sorted_table_and_fkc_names( + self, + schema: Optional[str] = None, + **kw: Any, + ) -> List[Tuple[Optional[str], List[Tuple[str, Optional[str]]]]]: + r"""Return dependency-sorted table and foreign key constraint names in referred to within a particular schema. This will yield 2-tuples of @@ -309,6 +508,11 @@ class Inspector(inspection.Inspectable["Inspector"]): .. versionadded:: 1.0.- + :param schema: schema name to query, if not the default schema. + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + .. seealso:: :meth:`_reflection.Inspector.get_table_names` @@ -317,24 +521,74 @@ class Inspector(inspection.Inspectable["Inspector"]): with an already-given :class:`_schema.MetaData`. """ - with self._operation_context() as conn: - tnames = self.dialect.get_table_names( - conn, schema, info_cache=self.info_cache + + return [ + ( + table_key[1] if table_key else None, + [(tname, fks) for (_, tname), fks in fk_collection], ) + for ( + table_key, + fk_collection, + ) in self.sort_tables_on_foreign_key_dependency( + consider_schemas=(schema,) + ) + ] - tuples = set() - remaining_fkcs = set() + def sort_tables_on_foreign_key_dependency( + self, + consider_schemas: Collection[Optional[str]] = (None,), + **kw: Any, + ) -> List[ + Tuple[ + Optional[Tuple[Optional[str], str]], + List[Tuple[Tuple[Optional[str], str], Optional[str]]], + ] + ]: + r"""Return dependency-sorted table and foreign key constraint names + referred to within multiple schemas. + + This method may be compared to + :meth:`.Inspector.get_sorted_table_and_fkc_names`, which + works on one schema at a time; here, the method is a generalization + that will consider multiple schemas at once including that it will + resolve for cross-schema foreign keys. + + .. versionadded:: 2.0 - fknames_for_table = {} - for tname in tnames: - fkeys = self.get_foreign_keys(tname, schema) - fknames_for_table[tname] = set([fk["name"] for fk in fkeys]) - for fkey in fkeys: - if tname != fkey["referred_table"]: - tuples.add((fkey["referred_table"], tname)) + """ + SchemaTab = Tuple[Optional[str], str] + + tuples: Set[Tuple[SchemaTab, SchemaTab]] = set() + remaining_fkcs: Set[Tuple[SchemaTab, Optional[str]]] = set() + fknames_for_table: Dict[SchemaTab, Set[Optional[str]]] = {} + tnames: List[SchemaTab] = [] + + for schname in consider_schemas: + schema_fkeys = self.get_multi_foreign_keys(schname, **kw) + tnames.extend(schema_fkeys) + for (_, tname), fkeys in schema_fkeys.items(): + fknames_for_table[(schname, tname)] = set( + [fk["name"] for fk in fkeys] + ) + for fkey in fkeys: + if ( + tname != fkey["referred_table"] + or schname != fkey["referred_schema"] + ): + tuples.add( + ( + ( + fkey["referred_schema"], + fkey["referred_table"], + ), + (schname, tname), + ) + ) try: candidate_sort = list(topological.sort(tuples, tnames)) except exc.CircularDependencyError as err: + edge: Tuple[SchemaTab, SchemaTab] for edge in err.edges: tuples.remove(edge) remaining_fkcs.update( @@ -342,16 +596,32 @@ class Inspector(inspection.Inspectable["Inspector"]): ) candidate_sort = list(topological.sort(tuples, tnames)) - return [ - (tname, fknames_for_table[tname].difference(remaining_fkcs)) - for tname in candidate_sort - ] + [(None, list(remaining_fkcs))] + ret: List[ + Tuple[Optional[SchemaTab], List[Tuple[SchemaTab, Optional[str]]]] + ] + ret = [ + ( + (schname, tname), + [ + ((schname, tname), fk) + for fk in fknames_for_table[(schname, tname)].difference( + name for _, name in remaining_fkcs + ) + ], + ) + for (schname, tname) in candidate_sort + ] + return ret + [(None, list(remaining_fkcs))] - def get_temp_table_names(self): - """Return a list of temporary table names for the current bind. + def get_temp_table_names(self, **kw: Any) -> List[str]: + r"""Return a list of temporary table names for the current bind. This method is unsupported by most dialects; currently - only SQLite implements it. + only Oracle, PostgreSQL and SQLite implements it. + + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. .. versionadded:: 1.0.0 @@ -359,28 +629,35 @@ class Inspector(inspection.Inspectable["Inspector"]): with self._operation_context() as conn: return self.dialect.get_temp_table_names( - conn, info_cache=self.info_cache + conn, info_cache=self.info_cache, **kw ) - def get_temp_view_names(self): - """Return a list of temporary view names for the current bind. + def get_temp_view_names(self, **kw: Any) -> List[str]: + r"""Return a list of temporary view names for the current bind. This method is unsupported by most dialects; currently - only SQLite implements it. + only PostgreSQL and SQLite implements it. + + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. .. versionadded:: 1.0.0 """ with self._operation_context() as conn: return self.dialect.get_temp_view_names( - conn, info_cache=self.info_cache + conn, info_cache=self.info_cache, **kw ) - def get_table_options(self, table_name, schema=None, **kw): - """Return a dictionary of options specified when the table of the + def get_table_options( + self, table_name: str, schema: Optional[str] = None, **kw: Any + ) -> Dict[str, Any]: + r"""Return a dictionary of options specified when the table of the given name was created. - This currently includes some options that apply to MySQL tables. + This currently includes some options that apply to MySQL and Oracle + tables. :param table_name: string name of the table. For special quoting, use :class:`.quoted_name`. @@ -389,60 +666,172 @@ class Inspector(inspection.Inspectable["Inspector"]): of the database connection. For special quoting, use :class:`.quoted_name`. + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + :return: a dict with the table options. The returned keys depend on the + dialect in use. Each one is prefixed with the dialect name. + """ - if hasattr(self.dialect, "get_table_options"): - with self._operation_context() as conn: - return self.dialect.get_table_options( - conn, table_name, schema, info_cache=self.info_cache, **kw - ) - return {} + with self._operation_context() as conn: + return self.dialect.get_table_options( + conn, table_name, schema, info_cache=self.info_cache, **kw + ) + + def get_multi_table_options( + self, + schema: Optional[str] = None, + filter_names: Optional[Sequence[str]] = None, + kind: ObjectKind = ObjectKind.TABLE, + scope: ObjectScope = ObjectScope.DEFAULT, + **kw: Any, + ) -> Dict[TableKey, Dict[str, Any]]: + r"""Return a dictionary of options specified when the tables in the + given schema were created. + + The tables can be filtered by passing the names to use to + ``filter_names``. + + This currently includes some options that apply to MySQL and Oracle + tables. + + :param schema: string schema name; if omitted, uses the default schema + of the database connection. For special quoting, + use :class:`.quoted_name`. + + :param filter_names: optionally return information only for the + objects listed here. + + :param kind: a :class:`.ObjectKind` that specifies the type of objects + to reflect. Defaults to ``ObjectKind.TABLE``. + + :param scope: a :class:`.ObjectScope` that specifies if options of + default, temporary or any tables should be reflected. + Defaults to ``ObjectScope.DEFAULT``. + + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + :return: a dictionary where the keys are two-tuple schema,table-name + and the values are dictionaries with the table options. + The returned keys in each dict depend on the + dialect in use. Each one is prefixed with the dialect name. + The schema is ``None`` if no schema is provided. + + .. versionadded:: 2.0 + """ + with self._operation_context() as conn: + res = self.dialect.get_multi_table_options( + conn, + schema=schema, + filter_names=filter_names, + kind=kind, + scope=scope, + info_cache=self.info_cache, + **kw, + ) + return dict(res) - def get_view_names(self, schema=None): - """Return all view names in `schema`. + def get_view_names( + self, schema: Optional[str] = None, **kw: Any + ) -> List[str]: + r"""Return all non-materialized view names in `schema`. :param schema: Optional, retrieve names from a non-default schema. For special quoting, use :class:`.quoted_name`. + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + + .. versionchanged:: 2.0 For those dialects that previously included + the names of materialized views in this list (currently PostgreSQL), + this method no longer returns the names of materialized views. + the :meth:`.Inspector.get_materialized_view_names` method should + be used instead. + + .. seealso:: + + :meth:`.Inspector.get_materialized_view_names` """ with self._operation_context() as conn: return self.dialect.get_view_names( - conn, schema, info_cache=self.info_cache + conn, schema, info_cache=self.info_cache, **kw + ) + + def get_materialized_view_names( + self, schema: Optional[str] = None, **kw: Any + ) -> List[str]: + r"""Return all materialized view names in `schema`. + + :param schema: Optional, retrieve names from a non-default schema. + For special quoting, use :class:`.quoted_name`. + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + .. versionadded:: 2.0 + + .. seealso:: + + :meth:`.Inspector.get_view_names` + + """ + + with self._operation_context() as conn: + return self.dialect.get_materialized_view_names( + conn, schema, info_cache=self.info_cache, **kw ) - def get_sequence_names(self, schema=None): - """Return all sequence names in `schema`. + def get_sequence_names( + self, schema: Optional[str] = None, **kw: Any + ) -> List[str]: + r"""Return all sequence names in `schema`. :param schema: Optional, retrieve names from a non-default schema. For special quoting, use :class:`.quoted_name`. + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. """ with self._operation_context() as conn: return self.dialect.get_sequence_names( - conn, schema, info_cache=self.info_cache + conn, schema, info_cache=self.info_cache, **kw ) - def get_view_definition(self, view_name, schema=None): - """Return definition for `view_name`. + def get_view_definition( + self, view_name: str, schema: Optional[str] = None, **kw: Any + ) -> str: + r"""Return definition for the plain or materialized view called + ``view_name``. + :param view_name: Name of the view. :param schema: Optional, retrieve names from a non-default schema. For special quoting, use :class:`.quoted_name`. + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. """ with self._operation_context() as conn: return self.dialect.get_view_definition( - conn, view_name, schema, info_cache=self.info_cache + conn, view_name, schema, info_cache=self.info_cache, **kw ) def get_columns( - self, table_name: str, schema: Optional[str] = None, **kw + self, table_name: str, schema: Optional[str] = None, **kw: Any ) -> List[ReflectedColumn]: - """Return information about columns in `table_name`. + r"""Return information about columns in ``table_name``. - Given a string `table_name` and an optional string `schema`, return - column information as a list of dicts with these keys: + Given a string ``table_name`` and an optional string ``schema``, + return column information as a list of dicts with these keys: * ``name`` - the column's name @@ -487,6 +876,10 @@ class Inspector(inspection.Inspectable["Inspector"]): of the database connection. For special quoting, use :class:`.quoted_name`. + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + :return: list of dictionaries, each representing the definition of a database column. @@ -496,17 +889,83 @@ class Inspector(inspection.Inspectable["Inspector"]): col_defs = self.dialect.get_columns( conn, table_name, schema, info_cache=self.info_cache, **kw ) - for col_def in col_defs: - # make this easy and only return instances for coltype - coltype = col_def["type"] - if not isinstance(coltype, TypeEngine): - col_def["type"] = coltype() + if col_defs: + self._instantiate_types([col_defs]) return col_defs - def get_pk_constraint(self, table_name, schema=None, **kw): - """Return information about primary key constraint on `table_name`. + def _instantiate_types( + self, data: Iterable[List[ReflectedColumn]] + ) -> None: + # make this easy and only return instances for coltype + for col_defs in data: + for col_def in col_defs: + coltype = col_def["type"] + if not isinstance(coltype, TypeEngine): + col_def["type"] = coltype() + + def get_multi_columns( + self, + schema: Optional[str] = None, + filter_names: Optional[Sequence[str]] = None, + kind: ObjectKind = ObjectKind.TABLE, + scope: ObjectScope = ObjectScope.DEFAULT, + **kw: Any, + ) -> Dict[TableKey, List[ReflectedColumn]]: + r"""Return information about columns in all objects in the given schema. + + The objects can be filtered by passing the names to use to + ``filter_names``. + + The column information is as described in + :meth:`Inspector.get_columns`. + + :param schema: string schema name; if omitted, uses the default schema + of the database connection. For special quoting, + use :class:`.quoted_name`. + + :param filter_names: optionally return information only for the + objects listed here. + + :param kind: a :class:`.ObjectKind` that specifies the type of objects + to reflect. Defaults to ``ObjectKind.TABLE``. + + :param scope: a :class:`.ObjectScope` that specifies if columns of + default, temporary or any tables should be reflected. + Defaults to ``ObjectScope.DEFAULT``. + + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + :return: a dictionary where the keys are two-tuple schema,table-name + and the values are list of dictionaries, each representing the + definition of a database column. + The schema is ``None`` if no schema is provided. + + .. versionadded:: 2.0 + """ + + with self._operation_context() as conn: + table_col_defs = dict( + self.dialect.get_multi_columns( + conn, + schema=schema, + filter_names=filter_names, + kind=kind, + scope=scope, + info_cache=self.info_cache, + **kw, + ) + ) + self._instantiate_types(table_col_defs.values()) + return table_col_defs + + def get_pk_constraint( + self, table_name: str, schema: Optional[str] = None, **kw: Any + ) -> ReflectedPrimaryKeyConstraint: + r"""Return information about primary key constraint in ``table_name``. - Given a string `table_name`, and an optional string `schema`, return + Given a string ``table_name``, and an optional string `schema`, return primary key information as a dictionary with these keys: * ``constrained_columns`` - @@ -522,16 +981,80 @@ class Inspector(inspection.Inspectable["Inspector"]): of the database connection. For special quoting, use :class:`.quoted_name`. + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + :return: a dictionary representing the definition of + a primary key constraint. + """ with self._operation_context() as conn: return self.dialect.get_pk_constraint( conn, table_name, schema, info_cache=self.info_cache, **kw ) - def get_foreign_keys(self, table_name, schema=None, **kw): - """Return information about foreign_keys in `table_name`. + def get_multi_pk_constraint( + self, + schema: Optional[str] = None, + filter_names: Optional[Sequence[str]] = None, + kind: ObjectKind = ObjectKind.TABLE, + scope: ObjectScope = ObjectScope.DEFAULT, + **kw: Any, + ) -> Dict[TableKey, ReflectedPrimaryKeyConstraint]: + r"""Return information about primary key constraints in + all tables in the given schema. + + The tables can be filtered by passing the names to use to + ``filter_names``. + + The primary key information is as described in + :meth:`Inspector.get_pk_constraint`. + + :param schema: string schema name; if omitted, uses the default schema + of the database connection. For special quoting, + use :class:`.quoted_name`. + + :param filter_names: optionally return information only for the + objects listed here. + + :param kind: a :class:`.ObjectKind` that specifies the type of objects + to reflect. Defaults to ``ObjectKind.TABLE``. + + :param scope: a :class:`.ObjectScope` that specifies if primary keys of + default, temporary or any tables should be reflected. + Defaults to ``ObjectScope.DEFAULT``. + + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + :return: a dictionary where the keys are two-tuple schema,table-name + and the values are dictionaries, each representing the + definition of a primary key constraint. + The schema is ``None`` if no schema is provided. + + .. versionadded:: 2.0 + """ + with self._operation_context() as conn: + return dict( + self.dialect.get_multi_pk_constraint( + conn, + schema=schema, + filter_names=filter_names, + kind=kind, + scope=scope, + info_cache=self.info_cache, + **kw, + ) + ) + + def get_foreign_keys( + self, table_name: str, schema: Optional[str] = None, **kw: Any + ) -> List[ReflectedForeignKeyConstraint]: + r"""Return information about foreign_keys in ``table_name``. - Given a string `table_name`, and an optional string `schema`, return + Given a string ``table_name``, and an optional string `schema`, return foreign key information as a list of dicts with these keys: * ``constrained_columns`` - @@ -557,6 +1080,13 @@ class Inspector(inspection.Inspectable["Inspector"]): of the database connection. For special quoting, use :class:`.quoted_name`. + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + :return: a list of dictionaries, each representing the + a foreign key definition. + """ with self._operation_context() as conn: @@ -564,10 +1094,68 @@ class Inspector(inspection.Inspectable["Inspector"]): conn, table_name, schema, info_cache=self.info_cache, **kw ) - def get_indexes(self, table_name, schema=None, **kw): - """Return information about indexes in `table_name`. + def get_multi_foreign_keys( + self, + schema: Optional[str] = None, + filter_names: Optional[Sequence[str]] = None, + kind: ObjectKind = ObjectKind.TABLE, + scope: ObjectScope = ObjectScope.DEFAULT, + **kw: Any, + ) -> Dict[TableKey, List[ReflectedForeignKeyConstraint]]: + r"""Return information about foreign_keys in all tables + in the given schema. + + The tables can be filtered by passing the names to use to + ``filter_names``. + + The foreign key informations as described in + :meth:`Inspector.get_foreign_keys`. + + :param schema: string schema name; if omitted, uses the default schema + of the database connection. For special quoting, + use :class:`.quoted_name`. + + :param filter_names: optionally return information only for the + objects listed here. + + :param kind: a :class:`.ObjectKind` that specifies the type of objects + to reflect. Defaults to ``ObjectKind.TABLE``. + + :param scope: a :class:`.ObjectScope` that specifies if foreign keys of + default, temporary or any tables should be reflected. + Defaults to ``ObjectScope.DEFAULT``. + + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + :return: a dictionary where the keys are two-tuple schema,table-name + and the values are list of dictionaries, each representing + a foreign key definition. + The schema is ``None`` if no schema is provided. + + .. versionadded:: 2.0 + """ + + with self._operation_context() as conn: + return dict( + self.dialect.get_multi_foreign_keys( + conn, + schema=schema, + filter_names=filter_names, + kind=kind, + scope=scope, + info_cache=self.info_cache, + **kw, + ) + ) - Given a string `table_name` and an optional string `schema`, return + def get_indexes( + self, table_name: str, schema: Optional[str] = None, **kw: Any + ) -> List[ReflectedIndex]: + r"""Return information about indexes in ``table_name``. + + Given a string ``table_name`` and an optional string `schema`, return index information as a list of dicts with these keys: * ``name`` - @@ -598,6 +1186,13 @@ class Inspector(inspection.Inspectable["Inspector"]): of the database connection. For special quoting, use :class:`.quoted_name`. + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + :return: a list of dictionaries, each representing the + definition of an index. + """ with self._operation_context() as conn: @@ -605,10 +1200,71 @@ class Inspector(inspection.Inspectable["Inspector"]): conn, table_name, schema, info_cache=self.info_cache, **kw ) - def get_unique_constraints(self, table_name, schema=None, **kw): - """Return information about unique constraints in `table_name`. + def get_multi_indexes( + self, + schema: Optional[str] = None, + filter_names: Optional[Sequence[str]] = None, + kind: ObjectKind = ObjectKind.TABLE, + scope: ObjectScope = ObjectScope.DEFAULT, + **kw: Any, + ) -> Dict[TableKey, List[ReflectedIndex]]: + r"""Return information about indexes in in all objects + in the given schema. + + The objects can be filtered by passing the names to use to + ``filter_names``. + + The foreign key information is as described in + :meth:`Inspector.get_foreign_keys`. + + The indexes information as described in + :meth:`Inspector.get_indexes`. + + :param schema: string schema name; if omitted, uses the default schema + of the database connection. For special quoting, + use :class:`.quoted_name`. + + :param filter_names: optionally return information only for the + objects listed here. + + :param kind: a :class:`.ObjectKind` that specifies the type of objects + to reflect. Defaults to ``ObjectKind.TABLE``. + + :param scope: a :class:`.ObjectScope` that specifies if indexes of + default, temporary or any tables should be reflected. + Defaults to ``ObjectScope.DEFAULT``. + + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + :return: a dictionary where the keys are two-tuple schema,table-name + and the values are list of dictionaries, each representing the + definition of an index. + The schema is ``None`` if no schema is provided. + + .. versionadded:: 2.0 + """ + + with self._operation_context() as conn: + return dict( + self.dialect.get_multi_indexes( + conn, + schema=schema, + filter_names=filter_names, + kind=kind, + scope=scope, + info_cache=self.info_cache, + **kw, + ) + ) + + def get_unique_constraints( + self, table_name: str, schema: Optional[str] = None, **kw: Any + ) -> List[ReflectedUniqueConstraint]: + r"""Return information about unique constraints in ``table_name``. - Given a string `table_name` and an optional string `schema`, return + Given a string ``table_name`` and an optional string `schema`, return unique constraint information as a list of dicts with these keys: * ``name`` - @@ -624,6 +1280,13 @@ class Inspector(inspection.Inspectable["Inspector"]): of the database connection. For special quoting, use :class:`.quoted_name`. + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + :return: a list of dictionaries, each representing the + definition of an unique constraint. + """ with self._operation_context() as conn: @@ -631,8 +1294,66 @@ class Inspector(inspection.Inspectable["Inspector"]): conn, table_name, schema, info_cache=self.info_cache, **kw ) - def get_table_comment(self, table_name, schema=None, **kw): - """Return information about the table comment for ``table_name``. + def get_multi_unique_constraints( + self, + schema: Optional[str] = None, + filter_names: Optional[Sequence[str]] = None, + kind: ObjectKind = ObjectKind.TABLE, + scope: ObjectScope = ObjectScope.DEFAULT, + **kw: Any, + ) -> Dict[TableKey, List[ReflectedUniqueConstraint]]: + r"""Return information about unique constraints in all tables + in the given schema. + + The tables can be filtered by passing the names to use to + ``filter_names``. + + The unique constraint information is as described in + :meth:`Inspector.get_unique_constraints`. + + :param schema: string schema name; if omitted, uses the default schema + of the database connection. For special quoting, + use :class:`.quoted_name`. + + :param filter_names: optionally return information only for the + objects listed here. + + :param kind: a :class:`.ObjectKind` that specifies the type of objects + to reflect. Defaults to ``ObjectKind.TABLE``. + + :param scope: a :class:`.ObjectScope` that specifies if constraints of + default, temporary or any tables should be reflected. + Defaults to ``ObjectScope.DEFAULT``. + + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + :return: a dictionary where the keys are two-tuple schema,table-name + and the values are list of dictionaries, each representing the + definition of an unique constraint. + The schema is ``None`` if no schema is provided. + + .. versionadded:: 2.0 + """ + + with self._operation_context() as conn: + return dict( + self.dialect.get_multi_unique_constraints( + conn, + schema=schema, + filter_names=filter_names, + kind=kind, + scope=scope, + info_cache=self.info_cache, + **kw, + ) + ) + + def get_table_comment( + self, table_name: str, schema: Optional[str] = None, **kw: Any + ) -> ReflectedTableComment: + r"""Return information about the table comment for ``table_name``. Given a string ``table_name`` and an optional string ``schema``, return table comment information as a dictionary with these keys: @@ -643,8 +1364,20 @@ class Inspector(inspection.Inspectable["Inspector"]): Raises ``NotImplementedError`` for a dialect that does not support comments. - .. versionadded:: 1.2 + :param table_name: string name of the table. For special quoting, + use :class:`.quoted_name`. + + :param schema: string schema name; if omitted, uses the default schema + of the database connection. For special quoting, + use :class:`.quoted_name`. + + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + :return: a dictionary, with the table comment. + .. versionadded:: 1.2 """ with self._operation_context() as conn: @@ -652,10 +1385,71 @@ class Inspector(inspection.Inspectable["Inspector"]): conn, table_name, schema, info_cache=self.info_cache, **kw ) - def get_check_constraints(self, table_name, schema=None, **kw): - """Return information about check constraints in `table_name`. + def get_multi_table_comment( + self, + schema: Optional[str] = None, + filter_names: Optional[Sequence[str]] = None, + kind: ObjectKind = ObjectKind.TABLE, + scope: ObjectScope = ObjectScope.DEFAULT, + **kw: Any, + ) -> Dict[TableKey, ReflectedTableComment]: + r"""Return information about the table comment in all objects + in the given schema. + + The objects can be filtered by passing the names to use to + ``filter_names``. + + The comment information is as described in + :meth:`Inspector.get_table_comment`. + + Raises ``NotImplementedError`` for a dialect that does not support + comments. + + :param schema: string schema name; if omitted, uses the default schema + of the database connection. For special quoting, + use :class:`.quoted_name`. + + :param filter_names: optionally return information only for the + objects listed here. + + :param kind: a :class:`.ObjectKind` that specifies the type of objects + to reflect. Defaults to ``ObjectKind.TABLE``. + + :param scope: a :class:`.ObjectScope` that specifies if comments of + default, temporary or any tables should be reflected. + Defaults to ``ObjectScope.DEFAULT``. + + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + :return: a dictionary where the keys are two-tuple schema,table-name + and the values are dictionaries, representing the + table comments. + The schema is ``None`` if no schema is provided. + + .. versionadded:: 2.0 + """ + + with self._operation_context() as conn: + return dict( + self.dialect.get_multi_table_comment( + conn, + schema=schema, + filter_names=filter_names, + kind=kind, + scope=scope, + info_cache=self.info_cache, + **kw, + ) + ) + + def get_check_constraints( + self, table_name: str, schema: Optional[str] = None, **kw: Any + ) -> List[ReflectedCheckConstraint]: + r"""Return information about check constraints in ``table_name``. - Given a string `table_name` and an optional string `schema`, return + Given a string ``table_name`` and an optional string `schema`, return check constraint information as a list of dicts with these keys: * ``name`` - @@ -677,6 +1471,13 @@ class Inspector(inspection.Inspectable["Inspector"]): of the database connection. For special quoting, use :class:`.quoted_name`. + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + :return: a list of dictionaries, each representing the + definition of a check constraints. + .. versionadded:: 1.1.0 """ @@ -686,14 +1487,71 @@ class Inspector(inspection.Inspectable["Inspector"]): conn, table_name, schema, info_cache=self.info_cache, **kw ) + def get_multi_check_constraints( + self, + schema: Optional[str] = None, + filter_names: Optional[Sequence[str]] = None, + kind: ObjectKind = ObjectKind.TABLE, + scope: ObjectScope = ObjectScope.DEFAULT, + **kw: Any, + ) -> Dict[TableKey, List[ReflectedCheckConstraint]]: + r"""Return information about check constraints in all tables + in the given schema. + + The tables can be filtered by passing the names to use to + ``filter_names``. + + The check constraint information is as described in + :meth:`Inspector.get_check_constraints`. + + :param schema: string schema name; if omitted, uses the default schema + of the database connection. For special quoting, + use :class:`.quoted_name`. + + :param filter_names: optionally return information only for the + objects listed here. + + :param kind: a :class:`.ObjectKind` that specifies the type of objects + to reflect. Defaults to ``ObjectKind.TABLE``. + + :param scope: a :class:`.ObjectScope` that specifies if constraints of + default, temporary or any tables should be reflected. + Defaults to ``ObjectScope.DEFAULT``. + + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + :return: a dictionary where the keys are two-tuple schema,table-name + and the values are list of dictionaries, each representing the + definition of a check constraints. + The schema is ``None`` if no schema is provided. + + .. versionadded:: 2.0 + """ + + with self._operation_context() as conn: + return dict( + self.dialect.get_multi_check_constraints( + conn, + schema=schema, + filter_names=filter_names, + kind=kind, + scope=scope, + info_cache=self.info_cache, + **kw, + ) + ) + def reflect_table( self, - table, - include_columns, - exclude_columns=(), - resolve_fks=True, - _extend_on=None, - ): + table: sa_schema.Table, + include_columns: Optional[Collection[str]], + exclude_columns: Collection[str] = (), + resolve_fks: bool = True, + _extend_on: Optional[Set[sa_schema.Table]] = None, + _reflect_info: Optional[_ReflectionInfo] = None, + ) -> None: """Given a :class:`_schema.Table` object, load its internal constructs based on introspection. @@ -741,21 +1599,34 @@ class Inspector(inspection.Inspectable["Inspector"]): if k in table.dialect_kwargs ) + table_key = (schema, table_name) + if _reflect_info is None or table_key not in _reflect_info.columns: + _reflect_info = self._get_reflection_info( + schema, + filter_names=[table_name], + kind=ObjectKind.ANY, + scope=ObjectScope.ANY, + _reflect_info=_reflect_info, + **table.dialect_kwargs, + ) + if table_key in _reflect_info.unreflectable: + raise _reflect_info.unreflectable[table_key] + + if table_key not in _reflect_info.columns: + raise exc.NoSuchTableError(table_name) + # reflect table options, like mysql_engine - tbl_opts = self.get_table_options( - table_name, schema, **table.dialect_kwargs - ) - if tbl_opts: - # add additional kwargs to the Table if the dialect - # returned them - table._validate_dialect_kwargs(tbl_opts) + if _reflect_info.table_options: + tbl_opts = _reflect_info.table_options.get(table_key) + if tbl_opts: + # add additional kwargs to the Table if the dialect + # returned them + table._validate_dialect_kwargs(tbl_opts) found_table = False - cols_by_orig_name = {} + cols_by_orig_name: Dict[str, sa_schema.Column[Any]] = {} - for col_d in self.get_columns( - table_name, schema, **table.dialect_kwargs - ): + for col_d in _reflect_info.columns[table_key]: found_table = True self._reflect_column( @@ -771,12 +1642,12 @@ class Inspector(inspection.Inspectable["Inspector"]): raise exc.NoSuchTableError(table_name) self._reflect_pk( - table_name, schema, table, cols_by_orig_name, exclude_columns + _reflect_info, table_key, table, cols_by_orig_name, exclude_columns ) self._reflect_fk( - table_name, - schema, + _reflect_info, + table_key, table, cols_by_orig_name, include_columns, @@ -787,8 +1658,8 @@ class Inspector(inspection.Inspectable["Inspector"]): ) self._reflect_indexes( - table_name, - schema, + _reflect_info, + table_key, table, cols_by_orig_name, include_columns, @@ -797,8 +1668,8 @@ class Inspector(inspection.Inspectable["Inspector"]): ) self._reflect_unique_constraints( - table_name, - schema, + _reflect_info, + table_key, table, cols_by_orig_name, include_columns, @@ -807,8 +1678,8 @@ class Inspector(inspection.Inspectable["Inspector"]): ) self._reflect_check_constraints( - table_name, - schema, + _reflect_info, + table_key, table, cols_by_orig_name, include_columns, @@ -817,17 +1688,27 @@ class Inspector(inspection.Inspectable["Inspector"]): ) self._reflect_table_comment( - table_name, schema, table, reflection_options + _reflect_info, + table_key, + table, + reflection_options, ) def _reflect_column( - self, table, col_d, include_columns, exclude_columns, cols_by_orig_name - ): + self, + table: sa_schema.Table, + col_d: ReflectedColumn, + include_columns: Optional[Collection[str]], + exclude_columns: Collection[str], + cols_by_orig_name: Dict[str, sa_schema.Column[Any]], + ) -> None: orig_name = col_d["name"] table.metadata.dispatch.column_reflect(self, table, col_d) - table.dispatch.column_reflect(self, table, col_d) + table.dispatch.column_reflect( # type: ignore[attr-defined] + self, table, col_d + ) # fetch name again as column_reflect is allowed to # change it @@ -840,7 +1721,7 @@ class Inspector(inspection.Inspectable["Inspector"]): coltype = col_d["type"] col_kw = dict( - (k, col_d[k]) + (k, col_d[k]) # type: ignore[literal-required] for k in [ "nullable", "autoincrement", @@ -856,15 +1737,20 @@ class Inspector(inspection.Inspectable["Inspector"]): col_kw.update(col_d["dialect_options"]) colargs = [] + default: Any if col_d.get("default") is not None: - default = col_d["default"] - if isinstance(default, sql.elements.TextClause): - default = sa_schema.DefaultClause(default, _reflected=True) - elif not isinstance(default, sa_schema.FetchedValue): + default_text = col_d["default"] + assert default_text is not None + if isinstance(default_text, TextClause): default = sa_schema.DefaultClause( - sql.text(col_d["default"]), _reflected=True + default_text, _reflected=True ) - + elif not isinstance(default_text, sa_schema.FetchedValue): + default = sa_schema.DefaultClause( + sql.text(default_text), _reflected=True + ) + else: + default = default_text colargs.append(default) if "computed" in col_d: @@ -872,11 +1758,8 @@ class Inspector(inspection.Inspectable["Inspector"]): colargs.append(computed) if "identity" in col_d: - computed = sa_schema.Identity(**col_d["identity"]) - colargs.append(computed) - - if "sequence" in col_d: - self._reflect_col_sequence(col_d, colargs) + identity = sa_schema.Identity(**col_d["identity"]) + colargs.append(identity) cols_by_orig_name[orig_name] = col = sa_schema.Column( name, coltype, *colargs, **col_kw @@ -886,23 +1769,15 @@ class Inspector(inspection.Inspectable["Inspector"]): col.primary_key = True table.append_column(col, replace_existing=True) - def _reflect_col_sequence(self, col_d, colargs): - if "sequence" in col_d: - # TODO: mssql is using this. - seq = col_d["sequence"] - sequence = sa_schema.Sequence(seq["name"], 1, 1) - if "start" in seq: - sequence.start = seq["start"] - if "increment" in seq: - sequence.increment = seq["increment"] - colargs.append(sequence) - def _reflect_pk( - self, table_name, schema, table, cols_by_orig_name, exclude_columns - ): - pk_cons = self.get_pk_constraint( - table_name, schema, **table.dialect_kwargs - ) + self, + _reflect_info: _ReflectionInfo, + table_key: TableKey, + table: sa_schema.Table, + cols_by_orig_name: Dict[str, sa_schema.Column[Any]], + exclude_columns: Collection[str], + ) -> None: + pk_cons = _reflect_info.pk_constraint.get(table_key) if pk_cons: pk_cols = [ cols_by_orig_name[pk] @@ -919,19 +1794,17 @@ class Inspector(inspection.Inspectable["Inspector"]): def _reflect_fk( self, - table_name, - schema, - table, - cols_by_orig_name, - include_columns, - exclude_columns, - resolve_fks, - _extend_on, - reflection_options, - ): - fkeys = self.get_foreign_keys( - table_name, schema, **table.dialect_kwargs - ) + _reflect_info: _ReflectionInfo, + table_key: TableKey, + table: sa_schema.Table, + cols_by_orig_name: Dict[str, sa_schema.Column[Any]], + include_columns: Optional[Collection[str]], + exclude_columns: Collection[str], + resolve_fks: bool, + _extend_on: Optional[Set[sa_schema.Table]], + reflection_options: Dict[str, Any], + ) -> None: + fkeys = _reflect_info.foreign_keys.get(table_key, []) for fkey_d in fkeys: conname = fkey_d["name"] # look for columns by orig name in cols_by_orig_name, @@ -963,6 +1836,7 @@ class Inspector(inspection.Inspectable["Inspector"]): schema=referred_schema, autoload_with=self.bind, _extend_on=_extend_on, + _reflect_info=_reflect_info, **reflection_options, ) for column in referred_columns: @@ -977,6 +1851,7 @@ class Inspector(inspection.Inspectable["Inspector"]): autoload_with=self.bind, schema=sa_schema.BLANK_SCHEMA, _extend_on=_extend_on, + _reflect_info=_reflect_info, **reflection_options, ) for column in referred_columns: @@ -1005,16 +1880,16 @@ class Inspector(inspection.Inspectable["Inspector"]): def _reflect_indexes( self, - table_name, - schema, - table, - cols_by_orig_name, - include_columns, - exclude_columns, - reflection_options, - ): + _reflect_info: _ReflectionInfo, + table_key: TableKey, + table: sa_schema.Table, + cols_by_orig_name: Dict[str, sa_schema.Column[Any]], + include_columns: Optional[Collection[str]], + exclude_columns: Collection[str], + reflection_options: Dict[str, Any], + ) -> None: # Indexes - indexes = self.get_indexes(table_name, schema) + indexes = _reflect_info.indexes.get(table_key, []) for index_d in indexes: name = index_d["name"] columns = index_d["column_names"] @@ -1034,6 +1909,7 @@ class Inspector(inspection.Inspectable["Inspector"]): continue # look for columns by orig name in cols_by_orig_name, # but support columns that are in-Python only as fallback + idx_col: Any idx_cols = [] for c in columns: try: @@ -1045,7 +1921,7 @@ class Inspector(inspection.Inspectable["Inspector"]): except KeyError: util.warn( "%s key '%s' was not located in " - "columns for table '%s'" % (flavor, c, table_name) + "columns for table '%s'" % (flavor, c, table.name) ) continue c_sorting = column_sorting.get(c, ()) @@ -1063,22 +1939,16 @@ class Inspector(inspection.Inspectable["Inspector"]): def _reflect_unique_constraints( self, - table_name, - schema, - table, - cols_by_orig_name, - include_columns, - exclude_columns, - reflection_options, - ): - + _reflect_info: _ReflectionInfo, + table_key: TableKey, + table: sa_schema.Table, + cols_by_orig_name: Dict[str, sa_schema.Column[Any]], + include_columns: Optional[Collection[str]], + exclude_columns: Collection[str], + reflection_options: Dict[str, Any], + ) -> None: + constraints = _reflect_info.unique_constraints.get(table_key, []) # Unique Constraints - try: - constraints = self.get_unique_constraints(table_name, schema) - except NotImplementedError: - # optional dialect feature - return - for const_d in constraints: conname = const_d["name"] columns = const_d["column_names"] @@ -1104,7 +1974,7 @@ class Inspector(inspection.Inspectable["Inspector"]): except KeyError: util.warn( "unique constraint key '%s' was not located in " - "columns for table '%s'" % (c, table_name) + "columns for table '%s'" % (c, table.name) ) else: constrained_cols.append(constrained_col) @@ -1114,29 +1984,166 @@ class Inspector(inspection.Inspectable["Inspector"]): def _reflect_check_constraints( self, - table_name, - schema, - table, - cols_by_orig_name, - include_columns, - exclude_columns, - reflection_options, - ): - try: - constraints = self.get_check_constraints(table_name, schema) - except NotImplementedError: - # optional dialect feature - return - + _reflect_info: _ReflectionInfo, + table_key: TableKey, + table: sa_schema.Table, + cols_by_orig_name: Dict[str, sa_schema.Column[Any]], + include_columns: Optional[Collection[str]], + exclude_columns: Collection[str], + reflection_options: Dict[str, Any], + ) -> None: + constraints = _reflect_info.check_constraints.get(table_key, []) for const_d in constraints: table.append_constraint(sa_schema.CheckConstraint(**const_d)) def _reflect_table_comment( - self, table_name, schema, table, reflection_options - ): - try: - comment_dict = self.get_table_comment(table_name, schema) - except NotImplementedError: - return + self, + _reflect_info: _ReflectionInfo, + table_key: TableKey, + table: sa_schema.Table, + reflection_options: Dict[str, Any], + ) -> None: + comment_dict = _reflect_info.table_comment.get(table_key) + if comment_dict: + table.comment = comment_dict["text"] + + def _get_reflection_info( + self, + schema: Optional[str] = None, + filter_names: Optional[Collection[str]] = None, + available: Optional[Collection[str]] = None, + _reflect_info: Optional[_ReflectionInfo] = None, + **kw: Any, + ) -> _ReflectionInfo: + kw["schema"] = schema + + if filter_names and available and len(filter_names) > 100: + fraction = len(filter_names) / len(available) + else: + fraction = None + + unreflectable: Dict[TableKey, exc.UnreflectableTableError] + kw["unreflectable"] = unreflectable = {} + + has_result: bool = True + + def run( + meth: Any, + *, + optional: bool = False, + check_filter_names_from_meth: bool = False, + ) -> Any: + nonlocal has_result + # simple heuristic to improve reflection performance if a + # dialect implements multi_reflection: + # if more than 50% of the tables in the db are in filter_names + # load all the tables, since it's most likely faster to avoid + # a filter on that many tables. + if ( + fraction is None + or fraction <= 0.5 + or not self.dialect._overrides_default(meth.__name__) + ): + _fn = filter_names + else: + _fn = None + try: + if has_result: + res = meth(filter_names=_fn, **kw) + if check_filter_names_from_meth and not res: + # method returned no result data. + # skip any future call methods + has_result = False + else: + res = {} + except NotImplementedError: + if not optional: + raise + res = {} + return res + + info = _ReflectionInfo( + columns=run( + self.get_multi_columns, check_filter_names_from_meth=True + ), + pk_constraint=run(self.get_multi_pk_constraint), + foreign_keys=run(self.get_multi_foreign_keys), + indexes=run(self.get_multi_indexes), + unique_constraints=run( + self.get_multi_unique_constraints, optional=True + ), + table_comment=run(self.get_multi_table_comment, optional=True), + check_constraints=run( + self.get_multi_check_constraints, optional=True + ), + table_options=run(self.get_multi_table_options, optional=True), + unreflectable=unreflectable, + ) + if _reflect_info: + _reflect_info.update(info) + return _reflect_info else: - table.comment = comment_dict.get("text", None) + return info + + +@final +class ReflectionDefaults: + """provides blank default values for reflection methods.""" + + @classmethod + def columns(cls) -> List[ReflectedColumn]: + return [] + + @classmethod + def pk_constraint(cls) -> ReflectedPrimaryKeyConstraint: + return { # type: ignore # pep-655 not supported + "name": None, + "constrained_columns": [], + } + + @classmethod + def foreign_keys(cls) -> List[ReflectedForeignKeyConstraint]: + return [] + + @classmethod + def indexes(cls) -> List[ReflectedIndex]: + return [] + + @classmethod + def unique_constraints(cls) -> List[ReflectedUniqueConstraint]: + return [] + + @classmethod + def check_constraints(cls) -> List[ReflectedCheckConstraint]: + return [] + + @classmethod + def table_options(cls) -> Dict[str, Any]: + return {} + + @classmethod + def table_comment(cls) -> ReflectedTableComment: + return {"text": None} + + +@dataclass +class _ReflectionInfo: + columns: Dict[TableKey, List[ReflectedColumn]] + pk_constraint: Dict[TableKey, Optional[ReflectedPrimaryKeyConstraint]] + foreign_keys: Dict[TableKey, List[ReflectedForeignKeyConstraint]] + indexes: Dict[TableKey, List[ReflectedIndex]] + # optionals + unique_constraints: Dict[TableKey, List[ReflectedUniqueConstraint]] + table_comment: Dict[TableKey, Optional[ReflectedTableComment]] + check_constraints: Dict[TableKey, List[ReflectedCheckConstraint]] + table_options: Dict[TableKey, Dict[str, Any]] + unreflectable: Dict[TableKey, exc.UnreflectableTableError] + + def update(self, other: _ReflectionInfo) -> None: + for k, v in self.__dict__.items(): + ov = getattr(other, k) + if ov is not None: + if v is None: + setattr(self, k, ov) + else: + v.update(ov) diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index 391f747726..70c01d8d3c 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -536,7 +536,7 @@ class DialectKWArgs: util.portable_instancemethod(self._kw_reg_for_dialect_cls) ) - def _validate_dialect_kwargs(self, kwargs: Any) -> None: + def _validate_dialect_kwargs(self, kwargs: Dict[str, Any]) -> None: # validate remaining kwargs that they all specify DB prefixes if not kwargs: diff --git a/lib/sqlalchemy/sql/cache_key.py b/lib/sqlalchemy/sql/cache_key.py index c16fbdae13..5922c2db01 100644 --- a/lib/sqlalchemy/sql/cache_key.py +++ b/lib/sqlalchemy/sql/cache_key.py @@ -12,6 +12,7 @@ from itertools import zip_longest import typing from typing import Any from typing import Dict +from typing import Iterable from typing import Iterator from typing import List from typing import MutableMapping @@ -546,6 +547,43 @@ class CacheKey(NamedTuple): return target_element.params(translate) +def _ad_hoc_cache_key_from_args( + tokens: Tuple[Any, ...], + traverse_args: Iterable[Tuple[str, InternalTraversal]], + args: Iterable[Any], +) -> Tuple[Any, ...]: + """a quick cache key generator used by reflection.flexi_cache.""" + bindparams: List[BindParameter[Any]] = [] + + _anon_map = anon_map() + + tup = tokens + + for (attrname, sym), arg in zip(traverse_args, args): + key = sym.name + visit_key = key.replace("dp_", "visit_") + + if arg is None: + tup += (attrname, None) + continue + + meth = getattr(_cache_key_traversal_visitor, visit_key) + if meth is CACHE_IN_PLACE: + tup += (attrname, arg) + elif meth in ( + CALL_GEN_CACHE_KEY, + STATIC_CACHE_KEY, + ANON_NAME, + PROPAGATE_ATTRS, + ): + raise NotImplementedError( + f"Haven't implemented symbol {meth} for ad-hoc key from args" + ) + else: + tup += meth(attrname, arg, None, _anon_map, bindparams) + return tup + + class _CacheKeyTraversal(HasTraversalDispatch): # very common elements are inlined into the main _get_cache_key() method # to produce a dramatic savings in Python function call overhead diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index c37b600039..1c4b3b0cee 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -38,6 +38,7 @@ import typing from typing import Any from typing import Callable from typing import cast +from typing import Collection from typing import Dict from typing import Iterable from typing import Iterator @@ -99,6 +100,7 @@ if typing.TYPE_CHECKING: from ..engine.interfaces import _ExecuteOptionsParameter from ..engine.interfaces import ExecutionContext from ..engine.mock import MockConnection + from ..engine.reflection import _ReflectionInfo from ..sql.selectable import FromClause _T = TypeVar("_T", bound="Any") @@ -493,7 +495,7 @@ class Table( keep_existing: bool = False, extend_existing: bool = False, resolve_fks: bool = True, - include_columns: Optional[Iterable[str]] = None, + include_columns: Optional[Collection[str]] = None, implicit_returning: bool = True, comment: Optional[str] = None, info: Optional[Dict[Any, Any]] = None, @@ -829,6 +831,7 @@ class Table( self.fullname = self.name self.implicit_returning = implicit_returning + _reflect_info = kw.pop("_reflect_info", None) self.comment = comment @@ -852,6 +855,7 @@ class Table( autoload_with, include_columns, _extend_on=_extend_on, + _reflect_info=_reflect_info, resolve_fks=resolve_fks, ) @@ -869,10 +873,11 @@ class Table( self, metadata: MetaData, autoload_with: Union[Engine, Connection], - include_columns: Optional[Iterable[str]], - exclude_columns: Iterable[str] = (), + include_columns: Optional[Collection[str]], + exclude_columns: Collection[str] = (), resolve_fks: bool = True, _extend_on: Optional[Set[Table]] = None, + _reflect_info: _ReflectionInfo | None = None, ) -> None: insp = inspection.inspect(autoload_with) with insp._inspection_context() as conn_insp: @@ -882,6 +887,7 @@ class Table( exclude_columns, resolve_fks, _extend_on=_extend_on, + _reflect_info=_reflect_info, ) @property @@ -924,6 +930,7 @@ class Table( autoload_replace = kwargs.pop("autoload_replace", True) schema = kwargs.pop("schema", None) _extend_on = kwargs.pop("_extend_on", None) + _reflect_info = kwargs.pop("_reflect_info", None) # these arguments are only used with _init() kwargs.pop("extend_existing", False) kwargs.pop("keep_existing", False) @@ -972,6 +979,7 @@ class Table( exclude_columns, resolve_fks, _extend_on=_extend_on, + _reflect_info=_reflect_info, ) self._extra_kwargs(**kwargs) @@ -3165,7 +3173,7 @@ class IdentityOptions: nominvalue: Optional[bool] = None, nomaxvalue: Optional[bool] = None, cycle: Optional[bool] = None, - cache: Optional[bool] = None, + cache: Optional[int] = None, order: Optional[bool] = None, ) -> None: """Construct a :class:`.IdentityOptions` object. @@ -5130,6 +5138,7 @@ class MetaData(HasSchemaAttr): sorted(self.tables.values(), key=lambda t: t.key) # type: ignore ) + @util.preload_module("sqlalchemy.engine.reflection") def reflect( self, bind: Union[Engine, Connection], @@ -5159,7 +5168,7 @@ class MetaData(HasSchemaAttr): is used, if any. :param views: - If True, also reflect views. + If True, also reflect views (materialized and plain). :param only: Optional. Load only a sub-set of available named tables. May be @@ -5225,7 +5234,7 @@ class MetaData(HasSchemaAttr): """ with inspection.inspect(bind)._inspection_context() as insp: - reflect_opts = { + reflect_opts: Any = { "autoload_with": insp, "extend_existing": extend_existing, "autoload_replace": autoload_replace, @@ -5241,15 +5250,21 @@ class MetaData(HasSchemaAttr): if schema is not None: reflect_opts["schema"] = schema + kind = util.preloaded.engine_reflection.ObjectKind.TABLE available: util.OrderedSet[str] = util.OrderedSet( insp.get_table_names(schema) ) if views: + kind = util.preloaded.engine_reflection.ObjectKind.ANY available.update(insp.get_view_names(schema)) + try: + available.update(insp.get_materialized_view_names(schema)) + except NotImplementedError: + pass if schema is not None: available_w_schema: util.OrderedSet[str] = util.OrderedSet( - ["%s.%s" % (schema, name) for name in available] + [f"{schema}.{name}" for name in available] ) else: available_w_schema = available @@ -5282,6 +5297,17 @@ class MetaData(HasSchemaAttr): for name in only if extend_existing or name not in current ] + # pass the available tables so the inspector can + # choose to ignore the filter_names + _reflect_info = insp._get_reflection_info( + schema=schema, + filter_names=load, + available=available, + kind=kind, + scope=util.preloaded.engine_reflection.ObjectScope.ANY, + **dialect_kwargs, + ) + reflect_opts["_reflect_info"] = _reflect_info for name in load: try: @@ -5489,7 +5515,7 @@ class Identity(IdentityOptions, FetchedValue, SchemaItem): nominvalue: Optional[bool] = None, nomaxvalue: Optional[bool] = None, cycle: Optional[bool] = None, - cache: Optional[bool] = None, + cache: Optional[int] = None, order: Optional[bool] = None, ) -> None: """Construct a GENERATED { ALWAYS | BY DEFAULT } AS IDENTITY DDL diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py index 9888d7c18d..9377063632 100644 --- a/lib/sqlalchemy/testing/assertions.py +++ b/lib/sqlalchemy/testing/assertions.py @@ -644,13 +644,21 @@ class AssertsCompiledSQL: class ComparesTables: - def assert_tables_equal(self, table, reflected_table, strict_types=False): + def assert_tables_equal( + self, + table, + reflected_table, + strict_types=False, + strict_constraints=True, + ): assert len(table.c) == len(reflected_table.c) for c, reflected_c in zip(table.c, reflected_table.c): eq_(c.name, reflected_c.name) assert reflected_c is reflected_table.c[c.name] - eq_(c.primary_key, reflected_c.primary_key) - eq_(c.nullable, reflected_c.nullable) + + if strict_constraints: + eq_(c.primary_key, reflected_c.primary_key) + eq_(c.nullable, reflected_c.nullable) if strict_types: msg = "Type '%s' doesn't correspond to type '%s'" @@ -664,18 +672,20 @@ class ComparesTables: if isinstance(c.type, sqltypes.String): eq_(c.type.length, reflected_c.type.length) - eq_( - {f.column.name for f in c.foreign_keys}, - {f.column.name for f in reflected_c.foreign_keys}, - ) + if strict_constraints: + eq_( + {f.column.name for f in c.foreign_keys}, + {f.column.name for f in reflected_c.foreign_keys}, + ) if c.server_default: assert isinstance( reflected_c.server_default, schema.FetchedValue ) - assert len(table.primary_key) == len(reflected_table.primary_key) - for c in table.primary_key: - assert reflected_table.primary_key.columns[c.name] is not None + if strict_constraints: + assert len(table.primary_key) == len(reflected_table.primary_key) + for c in table.primary_key: + assert reflected_table.primary_key.columns[c.name] is not None def assert_types_base(self, c1, c2): assert c1.type._compare_type_affinity( diff --git a/lib/sqlalchemy/testing/plugin/pytestplugin.py b/lib/sqlalchemy/testing/plugin/pytestplugin.py index fa7d2ca19b..cea07b3053 100644 --- a/lib/sqlalchemy/testing/plugin/pytestplugin.py +++ b/lib/sqlalchemy/testing/plugin/pytestplugin.py @@ -741,13 +741,18 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions): fn._sa_parametrize.append((argnames, pytest_params)) return fn else: + _fn_argnames = inspect.getfullargspec(fn).args[1:] if argnames is None: - _argnames = inspect.getfullargspec(fn).args[1:] + _argnames = _fn_argnames else: _argnames = re.split(r", *", argnames) if has_exclusions: - _argnames += ["_exclusions"] + existing_exl = sum( + 1 for n in _fn_argnames if n.startswith("_exclusions") + ) + current_exclusion_name = f"_exclusions_{existing_exl}" + _argnames += [current_exclusion_name] @_pytest_fn_decorator def check_exclusions(fn, *args, **kw): @@ -755,13 +760,10 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions): if _exclusions: exlu = exclusions.compound().add(*_exclusions) fn = exlu(fn) - return fn(*args[0:-1], **kw) - - def process_metadata(spec): - spec.args.append("_exclusions") + return fn(*args[:-1], **kw) fn = check_exclusions( - fn, add_positional_parameters=("_exclusions",) + fn, add_positional_parameters=(current_exclusion_name,) ) return pytest.mark.parametrize(_argnames, pytest_params)(fn) diff --git a/lib/sqlalchemy/testing/provision.py b/lib/sqlalchemy/testing/provision.py index d384377324..498d92a775 100644 --- a/lib/sqlalchemy/testing/provision.py +++ b/lib/sqlalchemy/testing/provision.py @@ -230,7 +230,39 @@ def drop_all_schema_objects(cfg, eng): drop_all_schema_objects_pre_tables(cfg, eng) + drop_views(cfg, eng) + + if config.requirements.materialized_views.enabled: + drop_materialized_views(cfg, eng) + inspector = inspect(eng) + + consider_schemas = (None,) + if config.requirements.schemas.enabled_for_config(cfg): + consider_schemas += (cfg.test_schema, cfg.test_schema_2) + util.drop_all_tables(eng, inspector, consider_schemas=consider_schemas) + + drop_all_schema_objects_post_tables(cfg, eng) + + if config.requirements.sequences.enabled_for_config(cfg): + with eng.begin() as conn: + for seq in inspector.get_sequence_names(): + conn.execute(ddl.DropSequence(schema.Sequence(seq))) + if config.requirements.schemas.enabled_for_config(cfg): + for schema_name in [cfg.test_schema, cfg.test_schema_2]: + for seq in inspector.get_sequence_names( + schema=schema_name + ): + conn.execute( + ddl.DropSequence( + schema.Sequence(seq, schema=schema_name) + ) + ) + + +def drop_views(cfg, eng): + inspector = inspect(eng) + try: view_names = inspector.get_view_names() except NotImplementedError: @@ -244,7 +276,7 @@ def drop_all_schema_objects(cfg, eng): if config.requirements.schemas.enabled_for_config(cfg): try: - view_names = inspector.get_view_names(schema="test_schema") + view_names = inspector.get_view_names(schema=cfg.test_schema) except NotImplementedError: pass else: @@ -255,32 +287,30 @@ def drop_all_schema_objects(cfg, eng): schema.Table( vname, schema.MetaData(), - schema="test_schema", + schema=cfg.test_schema, ) ) ) - util.drop_all_tables(eng, inspector) - if config.requirements.schemas.enabled_for_config(cfg): - util.drop_all_tables(eng, inspector, schema=cfg.test_schema) - util.drop_all_tables(eng, inspector, schema=cfg.test_schema_2) - drop_all_schema_objects_post_tables(cfg, eng) +def drop_materialized_views(cfg, eng): + inspector = inspect(eng) - if config.requirements.sequences.enabled_for_config(cfg): + mview_names = inspector.get_materialized_view_names() + + with eng.begin() as conn: + for vname in mview_names: + conn.exec_driver_sql(f"DROP MATERIALIZED VIEW {vname}") + + if config.requirements.schemas.enabled_for_config(cfg): + mview_names = inspector.get_materialized_view_names( + schema=cfg.test_schema + ) with eng.begin() as conn: - for seq in inspector.get_sequence_names(): - conn.execute(ddl.DropSequence(schema.Sequence(seq))) - if config.requirements.schemas.enabled_for_config(cfg): - for schema_name in [cfg.test_schema, cfg.test_schema_2]: - for seq in inspector.get_sequence_names( - schema=schema_name - ): - conn.execute( - ddl.DropSequence( - schema.Sequence(seq, schema=schema_name) - ) - ) + for vname in mview_names: + conn.exec_driver_sql( + f"DROP MATERIALIZED VIEW {cfg.test_schema}.{vname}" + ) @register.init diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py index 4f9c73cf6e..038f6e9bd1 100644 --- a/lib/sqlalchemy/testing/requirements.py +++ b/lib/sqlalchemy/testing/requirements.py @@ -64,6 +64,25 @@ class SuiteRequirements(Requirements): return exclusions.open() + @property + def foreign_keys_reflect_as_index(self): + """Target database creates an index that's reflected for + foreign keys.""" + + return exclusions.closed() + + @property + def unique_index_reflect_as_unique_constraints(self): + """Target database reflects unique indexes as unique constrains.""" + + return exclusions.closed() + + @property + def unique_constraints_reflect_as_index(self): + """Target database reflects unique constraints as indexes.""" + + return exclusions.closed() + @property def table_value_constructor(self): """Database / dialect supports a query like:: @@ -628,6 +647,12 @@ class SuiteRequirements(Requirements): def schema_reflection(self): return self.schemas + @property + def schema_create_delete(self): + """target database supports schema create and dropped with + 'CREATE SCHEMA' and 'DROP SCHEMA'""" + return exclusions.closed() + @property def primary_key_constraint_reflection(self): return exclusions.open() @@ -692,6 +717,12 @@ class SuiteRequirements(Requirements): """target database supports CREATE INDEX with per-column ASC/DESC.""" return exclusions.open() + @property + def reflect_indexes_with_ascdesc(self): + """target database supports reflecting INDEX with per-column + ASC/DESC.""" + return exclusions.open() + @property def indexes_with_expressions(self): """target database supports CREATE INDEX against SQL expressions.""" @@ -1567,3 +1598,18 @@ class SuiteRequirements(Requirements): def json_deserializer_binary(self): "indicates if the json_deserializer function is called with bytes" return exclusions.closed() + + @property + def reflect_table_options(self): + """Target database must support reflecting table_options.""" + return exclusions.closed() + + @property + def materialized_views(self): + """Target database must support MATERIALIZED VIEWs.""" + return exclusions.closed() + + @property + def materialized_views_reflect_pk(self): + """Target database reflect MATERIALIZED VIEWs pks.""" + return exclusions.closed() diff --git a/lib/sqlalchemy/testing/schema.py b/lib/sqlalchemy/testing/schema.py index e4a92a7320..46cbf4759e 100644 --- a/lib/sqlalchemy/testing/schema.py +++ b/lib/sqlalchemy/testing/schema.py @@ -23,7 +23,7 @@ __all__ = ["Table", "Column"] table_options = {} -def Table(*args, **kw): +def Table(*args, **kw) -> schema.Table: """A schema.Table wrapper/hook for dialect-specific tweaks.""" test_opts = {k: kw.pop(k) for k in list(kw) if k.startswith("test_")} @@ -134,6 +134,19 @@ class eq_type_affinity: return self.target._type_affinity is not other._type_affinity +class eq_compile_type: + """similar to eq_type_affinity but uses compile""" + + def __init__(self, target): + self.target = target + + def __eq__(self, other): + return self.target == other.compile() + + def __ne__(self, other): + return self.target != other.compile() + + class eq_clause_element: """Helper to compare SQL structures based on compare()""" diff --git a/lib/sqlalchemy/testing/suite/test_reflection.py b/lib/sqlalchemy/testing/suite/test_reflection.py index b09b962275..7b8e2aa8bd 100644 --- a/lib/sqlalchemy/testing/suite/test_reflection.py +++ b/lib/sqlalchemy/testing/suite/test_reflection.py @@ -7,6 +7,8 @@ import sqlalchemy as sa from .. import config from .. import engines from .. import eq_ +from .. import expect_raises +from .. import expect_raises_message from .. import expect_warnings from .. import fixtures from .. import is_ @@ -24,12 +26,19 @@ from ... import MetaData from ... import String from ... import testing from ... import types as sql_types +from ...engine import Inspector +from ...engine import ObjectKind +from ...engine import ObjectScope +from ...exc import NoSuchTableError +from ...exc import UnreflectableTableError from ...schema import DDL from ...schema import Index from ...sql.elements import quoted_name from ...sql.schema import BLANK_SCHEMA +from ...testing import ComparesTables from ...testing import is_false from ...testing import is_true +from ...testing import mock metadata, users = None, None @@ -61,6 +70,19 @@ class HasTableTest(fixtures.TablesTest): is_false(config.db.dialect.has_table(conn, "test_table_s")) is_false(config.db.dialect.has_table(conn, "nonexistent_table")) + def test_has_table_cache(self, metadata): + insp = inspect(config.db) + is_true(insp.has_table("test_table")) + nt = Table("new_table", metadata, Column("col", Integer)) + is_false(insp.has_table("new_table")) + nt.create(config.db) + try: + is_false(insp.has_table("new_table")) + insp.clear_cache() + is_true(insp.has_table("new_table")) + finally: + nt.drop(config.db) + @testing.requires.schemas def test_has_table_schema(self): with config.db.begin() as conn: @@ -117,6 +139,7 @@ class HasIndexTest(fixtures.TablesTest): metadata, Column("id", Integer, primary_key=True), Column("data", String(50)), + Column("data2", String(50)), ) Index("my_idx", tt.c.data) @@ -130,40 +153,56 @@ class HasIndexTest(fixtures.TablesTest): ) Index("my_idx_s", tt.c.data) - def test_has_index(self): - with config.db.begin() as conn: - assert config.db.dialect.has_index(conn, "test_table", "my_idx") - assert not config.db.dialect.has_index( - conn, "test_table", "my_idx_s" - ) - assert not config.db.dialect.has_index( - conn, "nonexistent_table", "my_idx" - ) - assert not config.db.dialect.has_index( - conn, "test_table", "nonexistent_idx" - ) + kind = testing.combinations("dialect", "inspector", argnames="kind") + + def _has_index(self, kind, conn): + if kind == "dialect": + return lambda *a, **k: config.db.dialect.has_index(conn, *a, **k) + else: + return inspect(conn).has_index + + @kind + def test_has_index(self, kind, connection, metadata): + meth = self._has_index(kind, connection) + assert meth("test_table", "my_idx") + assert not meth("test_table", "my_idx_s") + assert not meth("nonexistent_table", "my_idx") + assert not meth("test_table", "nonexistent_idx") + + assert not meth("test_table", "my_idx_2") + assert not meth("test_table_2", "my_idx_3") + idx = Index("my_idx_2", self.tables.test_table.c.data2) + tbl = Table( + "test_table_2", + metadata, + Column("foo", Integer), + Index("my_idx_3", "foo"), + ) + idx.create(connection) + tbl.create(connection) + try: + if kind == "inspector": + assert not meth("test_table", "my_idx_2") + assert not meth("test_table_2", "my_idx_3") + meth.__self__.clear_cache() + assert meth("test_table", "my_idx_2") is True + assert meth("test_table_2", "my_idx_3") is True + finally: + tbl.drop(connection) + idx.drop(connection) @testing.requires.schemas - def test_has_index_schema(self): - with config.db.begin() as conn: - assert config.db.dialect.has_index( - conn, "test_table", "my_idx_s", schema=config.test_schema - ) - assert not config.db.dialect.has_index( - conn, "test_table", "my_idx", schema=config.test_schema - ) - assert not config.db.dialect.has_index( - conn, - "nonexistent_table", - "my_idx_s", - schema=config.test_schema, - ) - assert not config.db.dialect.has_index( - conn, - "test_table", - "nonexistent_idx_s", - schema=config.test_schema, - ) + @kind + def test_has_index_schema(self, kind, connection): + meth = self._has_index(kind, connection) + assert meth("test_table", "my_idx_s", schema=config.test_schema) + assert not meth("test_table", "my_idx", schema=config.test_schema) + assert not meth( + "nonexistent_table", "my_idx_s", schema=config.test_schema + ) + assert not meth( + "test_table", "nonexistent_idx_s", schema=config.test_schema + ) class QuotedNameArgumentTest(fixtures.TablesTest): @@ -264,7 +303,12 @@ class QuotedNameArgumentTest(fixtures.TablesTest): def test_get_table_options(self, name): insp = inspect(config.db) - insp.get_table_options(name) + if testing.requires.reflect_table_options.enabled: + res = insp.get_table_options(name) + is_true(isinstance(res, dict)) + else: + with expect_raises(NotImplementedError): + res = insp.get_table_options(name) @quote_fixtures @testing.requires.view_column_reflection @@ -311,7 +355,37 @@ class QuotedNameArgumentTest(fixtures.TablesTest): assert insp.get_check_constraints(name) -class ComponentReflectionTest(fixtures.TablesTest): +def _multi_combination(fn): + schema = testing.combinations( + None, + ( + lambda: config.test_schema, + testing.requires.schemas, + ), + argnames="schema", + ) + scope = testing.combinations( + ObjectScope.DEFAULT, + ObjectScope.TEMPORARY, + ObjectScope.ANY, + argnames="scope", + ) + kind = testing.combinations( + ObjectKind.TABLE, + ObjectKind.VIEW, + ObjectKind.MATERIALIZED_VIEW, + ObjectKind.ANY, + ObjectKind.ANY_VIEW, + ObjectKind.TABLE | ObjectKind.VIEW, + ObjectKind.TABLE | ObjectKind.MATERIALIZED_VIEW, + argnames="kind", + ) + filter_names = testing.combinations(True, False, argnames="use_filter") + + return schema(scope(kind(filter_names(fn)))) + + +class ComponentReflectionTest(ComparesTables, fixtures.TablesTest): run_inserts = run_deletes = None __backend__ = True @@ -354,6 +428,7 @@ class ComponentReflectionTest(fixtures.TablesTest): "%susers.user_id" % schema_prefix, name="user_id_fk" ), ), + sa.CheckConstraint("test2 > 0", name="test2_gt_zero"), schema=schema, test_needs_fk=True, ) @@ -364,6 +439,8 @@ class ComponentReflectionTest(fixtures.TablesTest): Column("user_id", sa.INT, primary_key=True), Column("test1", sa.CHAR(5), nullable=False), Column("test2", sa.Float(), nullable=False), + Column("parent_user_id", sa.Integer), + sa.CheckConstraint("test2 > 0", name="test2_gt_zero"), schema=schema, test_needs_fk=True, ) @@ -375,9 +452,19 @@ class ComponentReflectionTest(fixtures.TablesTest): Column( "address_id", sa.Integer, - sa.ForeignKey("%semail_addresses.address_id" % schema_prefix), + sa.ForeignKey( + "%semail_addresses.address_id" % schema_prefix, + name="email_add_id_fg", + ), + ), + Column("data", sa.String(30), unique=True), + sa.CheckConstraint( + "address_id > 0 AND address_id < 1000", + name="address_id_gt_zero", + ), + sa.UniqueConstraint( + "address_id", "dingaling_id", name="zz_dingalings_multiple" ), - Column("data", sa.String(30)), schema=schema, test_needs_fk=True, ) @@ -388,7 +475,7 @@ class ComponentReflectionTest(fixtures.TablesTest): Column( "remote_user_id", sa.Integer, sa.ForeignKey(users.c.user_id) ), - Column("email_address", sa.String(20)), + Column("email_address", sa.String(20), index=True), sa.PrimaryKeyConstraint("address_id", name="email_ad_pk"), schema=schema, test_needs_fk=True, @@ -406,6 +493,12 @@ class ComponentReflectionTest(fixtures.TablesTest): schema=schema, comment=r"""the test % ' " \ table comment""", ) + Table( + "no_constraints", + metadata, + Column("data", sa.String(20)), + schema=schema, + ) if testing.requires.cross_schema_fk_reflection.enabled: if schema is None: @@ -449,7 +542,10 @@ class ComponentReflectionTest(fixtures.TablesTest): ) if testing.requires.index_reflection.enabled: - cls.define_index(metadata, users) + Index("users_t_idx", users.c.test1, users.c.test2, unique=True) + Index( + "users_all_idx", users.c.user_id, users.c.test2, users.c.test1 + ) if not schema: # test_needs_fk is at the moment to force MySQL InnoDB @@ -468,7 +564,10 @@ class ComponentReflectionTest(fixtures.TablesTest): test_needs_fk=True, ) - if testing.requires.indexes_with_ascdesc.enabled: + if ( + testing.requires.indexes_with_ascdesc.enabled + and testing.requires.reflect_indexes_with_ascdesc.enabled + ): Index("noncol_idx_nopk", noncol_idx_test_nopk.c.q.desc()) Index("noncol_idx_pk", noncol_idx_test_pk.c.q.desc()) @@ -477,12 +576,16 @@ class ComponentReflectionTest(fixtures.TablesTest): if not schema and testing.requires.temp_table_reflection.enabled: cls.define_temp_tables(metadata) + @classmethod + def temp_table_name(cls): + return get_temp_table_name( + config, config.db, f"user_tmp_{config.ident}" + ) + @classmethod def define_temp_tables(cls, metadata): kw = temp_table_keyword_args(config, config.db) - table_name = get_temp_table_name( - config, config.db, "user_tmp_%s" % config.ident - ) + table_name = cls.temp_table_name() user_tmp = Table( table_name, metadata, @@ -495,7 +598,7 @@ class ComponentReflectionTest(fixtures.TablesTest): # unique constraints created against temp tables in different # databases. # https://www.arbinada.com/en/node/1645 - sa.UniqueConstraint("name", name="user_tmp_uq_%s" % config.ident), + sa.UniqueConstraint("name", name=f"user_tmp_uq_{config.ident}"), sa.Index("user_tmp_ix", "foo"), **kw, ) @@ -513,33 +616,636 @@ class ComponentReflectionTest(fixtures.TablesTest): ) event.listen(user_tmp, "before_drop", DDL("drop view user_tmp_v")) - @classmethod - def define_index(cls, metadata, users): - Index("users_t_idx", users.c.test1, users.c.test2) - Index("users_all_idx", users.c.user_id, users.c.test2, users.c.test1) - @classmethod def define_views(cls, metadata, schema): - for table_name in ("users", "email_addresses"): + if testing.requires.materialized_views.enabled: + materialized = {"dingalings"} + else: + materialized = set() + for table_name in ("users", "email_addresses", "dingalings"): fullname = table_name if schema: - fullname = "%s.%s" % (schema, table_name) + fullname = f"{schema}.{table_name}" view_name = fullname + "_v" - query = "CREATE VIEW %s AS SELECT * FROM %s" % ( - view_name, - fullname, + prefix = "MATERIALIZED " if table_name in materialized else "" + query = ( + f"CREATE {prefix}VIEW {view_name} AS SELECT * FROM {fullname}" ) event.listen(metadata, "after_create", DDL(query)) + if table_name in materialized: + index_name = "mat_index" + if schema and testing.against("oracle"): + index_name = f"{schema}.{index_name}" + idx = f"CREATE INDEX {index_name} ON {view_name}(data)" + event.listen(metadata, "after_create", DDL(idx)) event.listen( - metadata, "before_drop", DDL("DROP VIEW %s" % view_name) + metadata, "before_drop", DDL(f"DROP {prefix}VIEW {view_name}") + ) + + def _resolve_kind(self, kind, tables, views, materialized): + res = {} + if ObjectKind.TABLE in kind: + res.update(tables) + if ObjectKind.VIEW in kind: + res.update(views) + if ObjectKind.MATERIALIZED_VIEW in kind: + res.update(materialized) + return res + + def _resolve_views(self, views, materialized): + if not testing.requires.view_column_reflection.enabled: + materialized.clear() + views.clear() + elif not testing.requires.materialized_views.enabled: + views.update(materialized) + materialized.clear() + + def _resolve_names(self, schema, scope, filter_names, values): + scope_filter = lambda _: True # noqa: E731 + if scope is ObjectScope.DEFAULT: + scope_filter = lambda k: "tmp" not in k[1] # noqa: E731 + if scope is ObjectScope.TEMPORARY: + scope_filter = lambda k: "tmp" in k[1] # noqa: E731 + + removed = { + None: {"remote_table", "remote_table_2"}, + testing.config.test_schema: { + "local_table", + "noncol_idx_test_nopk", + "noncol_idx_test_pk", + "user_tmp_v", + self.temp_table_name(), + }, + } + if not testing.requires.cross_schema_fk_reflection.enabled: + removed[None].add("local_table") + removed[testing.config.test_schema].update( + ["remote_table", "remote_table_2"] + ) + if not testing.requires.index_reflection.enabled: + removed[None].update( + ["noncol_idx_test_nopk", "noncol_idx_test_pk"] ) + if ( + not testing.requires.temp_table_reflection.enabled + or not testing.requires.temp_table_names.enabled + ): + removed[None].update(["user_tmp_v", self.temp_table_name()]) + if not testing.requires.temporary_views.enabled: + removed[None].update(["user_tmp_v"]) + + res = { + k: v + for k, v in values.items() + if scope_filter(k) + and k[1] not in removed[schema] + and (not filter_names or k[1] in filter_names) + } + return res + + def exp_options( + self, + schema=None, + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + filter_names=None, + ): + materialized = {(schema, "dingalings_v"): mock.ANY} + views = { + (schema, "email_addresses_v"): mock.ANY, + (schema, "users_v"): mock.ANY, + (schema, "user_tmp_v"): mock.ANY, + } + self._resolve_views(views, materialized) + tables = { + (schema, "users"): mock.ANY, + (schema, "dingalings"): mock.ANY, + (schema, "email_addresses"): mock.ANY, + (schema, "comment_test"): mock.ANY, + (schema, "no_constraints"): mock.ANY, + (schema, "local_table"): mock.ANY, + (schema, "remote_table"): mock.ANY, + (schema, "remote_table_2"): mock.ANY, + (schema, "noncol_idx_test_nopk"): mock.ANY, + (schema, "noncol_idx_test_pk"): mock.ANY, + (schema, self.temp_table_name()): mock.ANY, + } + res = self._resolve_kind(kind, tables, views, materialized) + res = self._resolve_names(schema, scope, filter_names, res) + return res + + def exp_comments( + self, + schema=None, + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + filter_names=None, + ): + empty = {"text": None} + materialized = {(schema, "dingalings_v"): empty} + views = { + (schema, "email_addresses_v"): empty, + (schema, "users_v"): empty, + (schema, "user_tmp_v"): empty, + } + self._resolve_views(views, materialized) + tables = { + (schema, "users"): empty, + (schema, "dingalings"): empty, + (schema, "email_addresses"): empty, + (schema, "comment_test"): { + "text": r"""the test % ' " \ table comment""" + }, + (schema, "no_constraints"): empty, + (schema, "local_table"): empty, + (schema, "remote_table"): empty, + (schema, "remote_table_2"): empty, + (schema, "noncol_idx_test_nopk"): empty, + (schema, "noncol_idx_test_pk"): empty, + (schema, self.temp_table_name()): empty, + } + res = self._resolve_kind(kind, tables, views, materialized) + res = self._resolve_names(schema, scope, filter_names, res) + return res + + def exp_columns( + self, + schema=None, + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + filter_names=None, + ): + def col( + name, auto=False, default=mock.ANY, comment=None, nullable=True + ): + res = { + "name": name, + "autoincrement": auto, + "type": mock.ANY, + "default": default, + "comment": comment, + "nullable": nullable, + } + if auto == "omit": + res.pop("autoincrement") + return res + + def pk(name, **kw): + kw = {"auto": True, "default": mock.ANY, "nullable": False, **kw} + return col(name, **kw) + + materialized = { + (schema, "dingalings_v"): [ + col("dingaling_id", auto="omit", nullable=mock.ANY), + col("address_id"), + col("data"), + ] + } + views = { + (schema, "email_addresses_v"): [ + col("address_id", auto="omit", nullable=mock.ANY), + col("remote_user_id"), + col("email_address"), + ], + (schema, "users_v"): [ + col("user_id", auto="omit", nullable=mock.ANY), + col("test1", nullable=mock.ANY), + col("test2", nullable=mock.ANY), + col("parent_user_id"), + ], + (schema, "user_tmp_v"): [ + col("id", auto="omit", nullable=mock.ANY), + col("name"), + col("foo"), + ], + } + self._resolve_views(views, materialized) + tables = { + (schema, "users"): [ + pk("user_id"), + col("test1", nullable=False), + col("test2", nullable=False), + col("parent_user_id"), + ], + (schema, "dingalings"): [ + pk("dingaling_id"), + col("address_id"), + col("data"), + ], + (schema, "email_addresses"): [ + pk("address_id"), + col("remote_user_id"), + col("email_address"), + ], + (schema, "comment_test"): [ + pk("id", comment="id comment"), + col("data", comment="data % comment"), + col( + "d2", + comment=r"""Comment types type speedily ' " \ '' Fun!""", + ), + ], + (schema, "no_constraints"): [col("data")], + (schema, "local_table"): [pk("id"), col("data"), col("remote_id")], + (schema, "remote_table"): [pk("id"), col("local_id"), col("data")], + (schema, "remote_table_2"): [pk("id"), col("data")], + (schema, "noncol_idx_test_nopk"): [col("q")], + (schema, "noncol_idx_test_pk"): [pk("id"), col("q")], + (schema, self.temp_table_name()): [ + pk("id"), + col("name"), + col("foo"), + ], + } + res = self._resolve_kind(kind, tables, views, materialized) + res = self._resolve_names(schema, scope, filter_names, res) + return res + + @property + def _required_column_keys(self): + return {"name", "type", "nullable", "default"} + + def exp_pks( + self, + schema=None, + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + filter_names=None, + ): + def pk(*cols, name=mock.ANY): + return {"constrained_columns": list(cols), "name": name} + + empty = pk(name=None) + if testing.requires.materialized_views_reflect_pk.enabled: + materialized = {(schema, "dingalings_v"): pk("dingaling_id")} + else: + materialized = {(schema, "dingalings_v"): empty} + views = { + (schema, "email_addresses_v"): empty, + (schema, "users_v"): empty, + (schema, "user_tmp_v"): empty, + } + self._resolve_views(views, materialized) + tables = { + (schema, "users"): pk("user_id"), + (schema, "dingalings"): pk("dingaling_id"), + (schema, "email_addresses"): pk("address_id", name="email_ad_pk"), + (schema, "comment_test"): pk("id"), + (schema, "no_constraints"): empty, + (schema, "local_table"): pk("id"), + (schema, "remote_table"): pk("id"), + (schema, "remote_table_2"): pk("id"), + (schema, "noncol_idx_test_nopk"): empty, + (schema, "noncol_idx_test_pk"): pk("id"), + (schema, self.temp_table_name()): pk("id"), + } + if not testing.requires.reflects_pk_names.enabled: + for val in tables.values(): + if val["name"] is not None: + val["name"] = mock.ANY + res = self._resolve_kind(kind, tables, views, materialized) + res = self._resolve_names(schema, scope, filter_names, res) + return res + + @property + def _required_pk_keys(self): + return {"name", "constrained_columns"} + + def exp_fks( + self, + schema=None, + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + filter_names=None, + ): + class tt: + def __eq__(self, other): + return ( + other is None + or config.db.dialect.default_schema_name == other + ) + + def fk(cols, ref_col, ref_table, ref_schema=schema, name=mock.ANY): + return { + "constrained_columns": cols, + "referred_columns": ref_col, + "name": name, + "options": mock.ANY, + "referred_schema": ref_schema + if ref_schema is not None + else tt(), + "referred_table": ref_table, + } + + materialized = {(schema, "dingalings_v"): []} + views = { + (schema, "email_addresses_v"): [], + (schema, "users_v"): [], + (schema, "user_tmp_v"): [], + } + self._resolve_views(views, materialized) + tables = { + (schema, "users"): [ + fk(["parent_user_id"], ["user_id"], "users", name="user_id_fk") + ], + (schema, "dingalings"): [ + fk( + ["address_id"], + ["address_id"], + "email_addresses", + name="email_add_id_fg", + ) + ], + (schema, "email_addresses"): [ + fk(["remote_user_id"], ["user_id"], "users") + ], + (schema, "comment_test"): [], + (schema, "no_constraints"): [], + (schema, "local_table"): [ + fk( + ["remote_id"], + ["id"], + "remote_table_2", + ref_schema=config.test_schema, + ) + ], + (schema, "remote_table"): [ + fk(["local_id"], ["id"], "local_table", ref_schema=None) + ], + (schema, "remote_table_2"): [], + (schema, "noncol_idx_test_nopk"): [], + (schema, "noncol_idx_test_pk"): [], + (schema, self.temp_table_name()): [], + } + if not testing.requires.self_referential_foreign_keys.enabled: + tables[(schema, "users")].clear() + if not testing.requires.named_constraints.enabled: + for vals in tables.values(): + for val in vals: + if val["name"] is not mock.ANY: + val["name"] = mock.ANY + + res = self._resolve_kind(kind, tables, views, materialized) + res = self._resolve_names(schema, scope, filter_names, res) + return res + + @property + def _required_fk_keys(self): + return { + "name", + "constrained_columns", + "referred_schema", + "referred_table", + "referred_columns", + } + + def exp_indexes( + self, + schema=None, + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + filter_names=None, + ): + def idx( + *cols, + name, + unique=False, + column_sorting=None, + duplicates=False, + fk=False, + ): + fk_req = testing.requires.foreign_keys_reflect_as_index + dup_req = testing.requires.unique_constraints_reflect_as_index + if (fk and not fk_req.enabled) or ( + duplicates and not dup_req.enabled + ): + return () + res = { + "unique": unique, + "column_names": list(cols), + "name": name, + "dialect_options": mock.ANY, + "include_columns": [], + } + if column_sorting: + res["column_sorting"] = {"q": ("desc",)} + if duplicates: + res["duplicates_constraint"] = name + return [res] + + materialized = {(schema, "dingalings_v"): []} + views = { + (schema, "email_addresses_v"): [], + (schema, "users_v"): [], + (schema, "user_tmp_v"): [], + } + self._resolve_views(views, materialized) + if materialized: + materialized[(schema, "dingalings_v")].extend( + idx("data", name="mat_index") + ) + tables = { + (schema, "users"): [ + *idx("parent_user_id", name="user_id_fk", fk=True), + *idx("user_id", "test2", "test1", name="users_all_idx"), + *idx("test1", "test2", name="users_t_idx", unique=True), + ], + (schema, "dingalings"): [ + *idx("data", name=mock.ANY, unique=True, duplicates=True), + *idx( + "address_id", + "dingaling_id", + name="zz_dingalings_multiple", + unique=True, + duplicates=True, + ), + ], + (schema, "email_addresses"): [ + *idx("email_address", name=mock.ANY), + *idx("remote_user_id", name=mock.ANY, fk=True), + ], + (schema, "comment_test"): [], + (schema, "no_constraints"): [], + (schema, "local_table"): [ + *idx("remote_id", name=mock.ANY, fk=True) + ], + (schema, "remote_table"): [ + *idx("local_id", name=mock.ANY, fk=True) + ], + (schema, "remote_table_2"): [], + (schema, "noncol_idx_test_nopk"): [ + *idx( + "q", + name="noncol_idx_nopk", + column_sorting={"q": ("desc",)}, + ) + ], + (schema, "noncol_idx_test_pk"): [ + *idx( + "q", name="noncol_idx_pk", column_sorting={"q": ("desc",)} + ) + ], + (schema, self.temp_table_name()): [ + *idx("foo", name="user_tmp_ix"), + *idx( + "name", + name=f"user_tmp_uq_{config.ident}", + duplicates=True, + unique=True, + ), + ], + } + if ( + not testing.requires.indexes_with_ascdesc.enabled + or not testing.requires.reflect_indexes_with_ascdesc.enabled + ): + tables[(schema, "noncol_idx_test_nopk")].clear() + tables[(schema, "noncol_idx_test_pk")].clear() + res = self._resolve_kind(kind, tables, views, materialized) + res = self._resolve_names(schema, scope, filter_names, res) + return res + + @property + def _required_index_keys(self): + return {"name", "column_names", "unique"} + + def exp_ucs( + self, + schema=None, + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + filter_names=None, + all_=False, + ): + def uc(*cols, name, duplicates_index=None, is_index=False): + req = testing.requires.unique_index_reflect_as_unique_constraints + if is_index and not req.enabled: + return () + res = { + "column_names": list(cols), + "name": name, + } + if duplicates_index: + res["duplicates_index"] = duplicates_index + return [res] + + materialized = {(schema, "dingalings_v"): []} + views = { + (schema, "email_addresses_v"): [], + (schema, "users_v"): [], + (schema, "user_tmp_v"): [], + } + self._resolve_views(views, materialized) + tables = { + (schema, "users"): [ + *uc( + "test1", + "test2", + name="users_t_idx", + duplicates_index="users_t_idx", + is_index=True, + ) + ], + (schema, "dingalings"): [ + *uc("data", name=mock.ANY, duplicates_index=mock.ANY), + *uc( + "address_id", + "dingaling_id", + name="zz_dingalings_multiple", + duplicates_index="zz_dingalings_multiple", + ), + ], + (schema, "email_addresses"): [], + (schema, "comment_test"): [], + (schema, "no_constraints"): [], + (schema, "local_table"): [], + (schema, "remote_table"): [], + (schema, "remote_table_2"): [], + (schema, "noncol_idx_test_nopk"): [], + (schema, "noncol_idx_test_pk"): [], + (schema, self.temp_table_name()): [ + *uc("name", name=f"user_tmp_uq_{config.ident}") + ], + } + if all_: + return {**materialized, **views, **tables} + else: + res = self._resolve_kind(kind, tables, views, materialized) + res = self._resolve_names(schema, scope, filter_names, res) + return res + + @property + def _required_unique_cst_keys(self): + return {"name", "column_names"} + + def exp_ccs( + self, + schema=None, + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + filter_names=None, + ): + class tt(str): + def __eq__(self, other): + res = ( + other.lower() + .replace("(", "") + .replace(")", "") + .replace("`", "") + ) + return self in res + + def cc(text, name): + return {"sqltext": tt(text), "name": name} + + # print({1: "test2 > (0)::double precision"} == {1: tt("test2 > 0")}) + # assert 0 + materialized = {(schema, "dingalings_v"): []} + views = { + (schema, "email_addresses_v"): [], + (schema, "users_v"): [], + (schema, "user_tmp_v"): [], + } + self._resolve_views(views, materialized) + tables = { + (schema, "users"): [cc("test2 > 0", "test2_gt_zero")], + (schema, "dingalings"): [ + cc( + "address_id > 0 and address_id < 1000", + name="address_id_gt_zero", + ), + ], + (schema, "email_addresses"): [], + (schema, "comment_test"): [], + (schema, "no_constraints"): [], + (schema, "local_table"): [], + (schema, "remote_table"): [], + (schema, "remote_table_2"): [], + (schema, "noncol_idx_test_nopk"): [], + (schema, "noncol_idx_test_pk"): [], + (schema, self.temp_table_name()): [], + } + res = self._resolve_kind(kind, tables, views, materialized) + res = self._resolve_names(schema, scope, filter_names, res) + return res + + @property + def _required_cc_keys(self): + return {"name", "sqltext"} @testing.requires.schema_reflection - def test_get_schema_names(self): - insp = inspect(self.bind) + def test_get_schema_names(self, connection): + insp = inspect(connection) - self.assert_(testing.config.test_schema in insp.get_schema_names()) + is_true(testing.config.test_schema in insp.get_schema_names()) + + @testing.requires.schema_reflection + def test_has_schema(self, connection): + insp = inspect(connection) + + is_true(insp.has_schema(testing.config.test_schema)) + is_false(insp.has_schema("sa_fake_schema_foo")) @testing.requires.schema_reflection def test_get_schema_names_w_translate_map(self, connection): @@ -553,7 +1259,37 @@ class ComponentReflectionTest(fixtures.TablesTest): ) insp = inspect(connection) - self.assert_(testing.config.test_schema in insp.get_schema_names()) + is_true(testing.config.test_schema in insp.get_schema_names()) + + @testing.requires.schema_reflection + def test_has_schema_w_translate_map(self, connection): + connection = connection.execution_options( + schema_translate_map={ + "foo": "bar", + BLANK_SCHEMA: testing.config.test_schema, + } + ) + insp = inspect(connection) + + is_true(insp.has_schema(testing.config.test_schema)) + is_false(insp.has_schema("sa_fake_schema_foo")) + + @testing.requires.schema_reflection + @testing.requires.schema_create_delete + def test_schema_cache(self, connection): + insp = inspect(connection) + + is_false("foo_bar" in insp.get_schema_names()) + is_false(insp.has_schema("foo_bar")) + connection.execute(DDL("CREATE SCHEMA foo_bar")) + try: + is_false("foo_bar" in insp.get_schema_names()) + is_false(insp.has_schema("foo_bar")) + insp.clear_cache() + is_true("foo_bar" in insp.get_schema_names()) + is_true(insp.has_schema("foo_bar")) + finally: + connection.execute(DDL("DROP SCHEMA foo_bar")) @testing.requires.schema_reflection def test_dialect_initialize(self): @@ -562,113 +1298,115 @@ class ComponentReflectionTest(fixtures.TablesTest): assert hasattr(engine.dialect, "default_schema_name") @testing.requires.schema_reflection - def test_get_default_schema_name(self): - insp = inspect(self.bind) - eq_(insp.default_schema_name, self.bind.dialect.default_schema_name) + def test_get_default_schema_name(self, connection): + insp = inspect(connection) + eq_(insp.default_schema_name, connection.dialect.default_schema_name) - @testing.requires.foreign_key_constraint_reflection @testing.combinations( - (None, True, False, False), - (None, True, False, True, testing.requires.schemas), - ("foreign_key", True, False, False), - (None, False, True, False), - (None, False, True, True, testing.requires.schemas), - (None, True, True, False), - (None, True, True, True, testing.requires.schemas), - argnames="order_by,include_plain,include_views,use_schema", + None, + ("foreign_key", testing.requires.foreign_key_constraint_reflection), + argnames="order_by", ) - def test_get_table_names( - self, connection, order_by, include_plain, include_views, use_schema - ): + @testing.combinations( + (True, testing.requires.schemas), False, argnames="use_schema" + ) + def test_get_table_names(self, connection, order_by, use_schema): if use_schema: schema = config.test_schema else: schema = None - _ignore_tables = [ + _ignore_tables = { "comment_test", "noncol_idx_test_pk", "noncol_idx_test_nopk", "local_table", "remote_table", "remote_table_2", - ] + "no_constraints", + } insp = inspect(connection) - if include_views: - table_names = insp.get_view_names(schema) - table_names.sort() - answer = ["email_addresses_v", "users_v"] - eq_(sorted(table_names), answer) + if order_by: + tables = [ + rec[0] + for rec in insp.get_sorted_table_and_fkc_names(schema) + if rec[0] + ] + else: + tables = insp.get_table_names(schema) + table_names = [t for t in tables if t not in _ignore_tables] - if include_plain: - if order_by: - tables = [ - rec[0] - for rec in insp.get_sorted_table_and_fkc_names(schema) - if rec[0] - ] - else: - tables = insp.get_table_names(schema) - table_names = [t for t in tables if t not in _ignore_tables] + if order_by == "foreign_key": + answer = ["users", "email_addresses", "dingalings"] + eq_(table_names, answer) + else: + answer = ["dingalings", "email_addresses", "users"] + eq_(sorted(table_names), answer) - if order_by == "foreign_key": - answer = ["users", "email_addresses", "dingalings"] - eq_(table_names, answer) - else: - answer = ["dingalings", "email_addresses", "users"] - eq_(sorted(table_names), answer) + @testing.combinations( + (True, testing.requires.schemas), False, argnames="use_schema" + ) + def test_get_view_names(self, connection, use_schema): + insp = inspect(connection) + if use_schema: + schema = config.test_schema + else: + schema = None + table_names = insp.get_view_names(schema) + if testing.requires.materialized_views.enabled: + eq_(sorted(table_names), ["email_addresses_v", "users_v"]) + eq_(insp.get_materialized_view_names(schema), ["dingalings_v"]) + else: + answer = ["dingalings_v", "email_addresses_v", "users_v"] + eq_(sorted(table_names), answer) @testing.requires.temp_table_names - def test_get_temp_table_names(self): - insp = inspect(self.bind) + def test_get_temp_table_names(self, connection): + insp = inspect(connection) temp_table_names = insp.get_temp_table_names() - eq_(sorted(temp_table_names), ["user_tmp_%s" % config.ident]) + eq_(sorted(temp_table_names), [f"user_tmp_{config.ident}"]) @testing.requires.view_reflection - @testing.requires.temp_table_names @testing.requires.temporary_views - def test_get_temp_view_names(self): - insp = inspect(self.bind) + def test_get_temp_view_names(self, connection): + insp = inspect(connection) temp_table_names = insp.get_temp_view_names() eq_(sorted(temp_table_names), ["user_tmp_v"]) @testing.requires.comment_reflection - def test_get_comments(self): - self._test_get_comments() + def test_get_comments(self, connection): + self._test_get_comments(connection) @testing.requires.comment_reflection @testing.requires.schemas - def test_get_comments_with_schema(self): - self._test_get_comments(testing.config.test_schema) - - def _test_get_comments(self, schema=None): - insp = inspect(self.bind) + def test_get_comments_with_schema(self, connection): + self._test_get_comments(connection, testing.config.test_schema) + def _test_get_comments(self, connection, schema=None): + insp = inspect(connection) + exp = self.exp_comments(schema=schema) eq_( insp.get_table_comment("comment_test", schema=schema), - {"text": r"""the test % ' " \ table comment"""}, + exp[(schema, "comment_test")], ) - eq_(insp.get_table_comment("users", schema=schema), {"text": None}) + eq_( + insp.get_table_comment("users", schema=schema), + exp[(schema, "users")], + ) eq_( - [ - {"name": rec["name"], "comment": rec["comment"]} - for rec in insp.get_columns("comment_test", schema=schema) - ], - [ - {"comment": "id comment", "name": "id"}, - {"comment": "data % comment", "name": "data"}, - { - "comment": ( - r"""Comment types type speedily ' " \ '' Fun!""" - ), - "name": "d2", - }, - ], + insp.get_table_comment("comment_test", schema=schema), + exp[(schema, "comment_test")], + ) + + no_cst = self.tables.no_constraints.name + eq_( + insp.get_table_comment(no_cst, schema=schema), + exp[(schema, no_cst)], ) @testing.combinations( @@ -691,7 +1429,7 @@ class ComponentReflectionTest(fixtures.TablesTest): users, addresses = (self.tables.users, self.tables.email_addresses) if use_views: - table_names = ["users_v", "email_addresses_v"] + table_names = ["users_v", "email_addresses_v", "dingalings_v"] else: table_names = ["users", "email_addresses"] @@ -699,7 +1437,7 @@ class ComponentReflectionTest(fixtures.TablesTest): for table_name, table in zip(table_names, (users, addresses)): schema_name = schema cols = insp.get_columns(table_name, schema=schema_name) - self.assert_(len(cols) > 0, len(cols)) + is_true(len(cols) > 0, len(cols)) # should be in order @@ -721,7 +1459,7 @@ class ComponentReflectionTest(fixtures.TablesTest): # assert that the desired type and return type share # a base within one of the generic types. - self.assert_( + is_true( len( set(ctype.__mro__) .intersection(ctype_def.__mro__) @@ -745,15 +1483,29 @@ class ComponentReflectionTest(fixtures.TablesTest): if not col.primary_key: assert cols[i]["default"] is None + # The case of a table with no column + # is tested below in TableNoColumnsTest + @testing.requires.temp_table_reflection - def test_get_temp_table_columns(self): - table_name = get_temp_table_name( - config, self.bind, "user_tmp_%s" % config.ident + def test_reflect_table_temp_table(self, connection): + + table_name = self.temp_table_name() + user_tmp = self.tables[table_name] + + reflected_user_tmp = Table( + table_name, MetaData(), autoload_with=connection ) + self.assert_tables_equal( + user_tmp, reflected_user_tmp, strict_constraints=False + ) + + @testing.requires.temp_table_reflection + def test_get_temp_table_columns(self, connection): + table_name = self.temp_table_name() user_tmp = self.tables[table_name] - insp = inspect(self.bind) + insp = inspect(connection) cols = insp.get_columns(table_name) - self.assert_(len(cols) > 0, len(cols)) + is_true(len(cols) > 0, len(cols)) for i, col in enumerate(user_tmp.columns): eq_(col.name, cols[i]["name"]) @@ -761,8 +1513,8 @@ class ComponentReflectionTest(fixtures.TablesTest): @testing.requires.temp_table_reflection @testing.requires.view_column_reflection @testing.requires.temporary_views - def test_get_temp_view_columns(self): - insp = inspect(self.bind) + def test_get_temp_view_columns(self, connection): + insp = inspect(connection) cols = insp.get_columns("user_tmp_v") eq_([col["name"] for col in cols], ["id", "name", "foo"]) @@ -778,18 +1530,27 @@ class ComponentReflectionTest(fixtures.TablesTest): users, addresses = self.tables.users, self.tables.email_addresses insp = inspect(connection) + exp = self.exp_pks(schema=schema) users_cons = insp.get_pk_constraint(users.name, schema=schema) - users_pkeys = users_cons["constrained_columns"] - eq_(users_pkeys, ["user_id"]) + self._check_list( + [users_cons], [exp[(schema, users.name)]], self._required_pk_keys + ) addr_cons = insp.get_pk_constraint(addresses.name, schema=schema) - addr_pkeys = addr_cons["constrained_columns"] - eq_(addr_pkeys, ["address_id"]) + exp_cols = exp[(schema, addresses.name)]["constrained_columns"] + eq_(addr_cons["constrained_columns"], exp_cols) with testing.requires.reflects_pk_names.fail_if(): eq_(addr_cons["name"], "email_ad_pk") + no_cst = self.tables.no_constraints.name + self._check_list( + [insp.get_pk_constraint(no_cst, schema=schema)], + [exp[(schema, no_cst)]], + self._required_pk_keys, + ) + @testing.combinations( (False,), (True, testing.requires.schemas), argnames="use_schema" ) @@ -815,31 +1576,33 @@ class ComponentReflectionTest(fixtures.TablesTest): eq_(fkey1["referred_schema"], expected_schema) eq_(fkey1["referred_table"], users.name) eq_(fkey1["referred_columns"], ["user_id"]) - if testing.requires.self_referential_foreign_keys.enabled: - eq_(fkey1["constrained_columns"], ["parent_user_id"]) + eq_(fkey1["constrained_columns"], ["parent_user_id"]) # addresses addr_fkeys = insp.get_foreign_keys(addresses.name, schema=schema) fkey1 = addr_fkeys[0] with testing.requires.implicitly_named_constraints.fail_if(): - self.assert_(fkey1["name"] is not None) + is_true(fkey1["name"] is not None) eq_(fkey1["referred_schema"], expected_schema) eq_(fkey1["referred_table"], users.name) eq_(fkey1["referred_columns"], ["user_id"]) eq_(fkey1["constrained_columns"], ["remote_user_id"]) + no_cst = self.tables.no_constraints.name + eq_(insp.get_foreign_keys(no_cst, schema=schema), []) + @testing.requires.cross_schema_fk_reflection @testing.requires.schemas - def test_get_inter_schema_foreign_keys(self): + def test_get_inter_schema_foreign_keys(self, connection): local_table, remote_table, remote_table_2 = self.tables( - "%s.local_table" % self.bind.dialect.default_schema_name, + "%s.local_table" % connection.dialect.default_schema_name, "%s.remote_table" % testing.config.test_schema, "%s.remote_table_2" % testing.config.test_schema, ) - insp = inspect(self.bind) + insp = inspect(connection) local_fkeys = insp.get_foreign_keys(local_table.name) eq_(len(local_fkeys), 1) @@ -857,25 +1620,21 @@ class ComponentReflectionTest(fixtures.TablesTest): fkey2 = remote_fkeys[0] - assert fkey2["referred_schema"] in ( - None, - self.bind.dialect.default_schema_name, + is_true( + fkey2["referred_schema"] + in ( + None, + connection.dialect.default_schema_name, + ) ) eq_(fkey2["referred_table"], local_table.name) eq_(fkey2["referred_columns"], ["id"]) eq_(fkey2["constrained_columns"], ["local_id"]) - def _assert_insp_indexes(self, indexes, expected_indexes): - index_names = [d["name"] for d in indexes] - for e_index in expected_indexes: - assert e_index["name"] in index_names - index = indexes[index_names.index(e_index["name"])] - for key in e_index: - eq_(e_index[key], index[key]) - @testing.combinations( (False,), (True, testing.requires.schemas), argnames="use_schema" ) + @testing.requires.index_reflection def test_get_indexes(self, connection, use_schema): if use_schema: @@ -885,21 +1644,19 @@ class ComponentReflectionTest(fixtures.TablesTest): # The database may decide to create indexes for foreign keys, etc. # so there may be more indexes than expected. - insp = inspect(self.bind) + insp = inspect(connection) indexes = insp.get_indexes("users", schema=schema) - expected_indexes = [ - { - "unique": False, - "column_names": ["test1", "test2"], - "name": "users_t_idx", - }, - { - "unique": False, - "column_names": ["user_id", "test2", "test1"], - "name": "users_all_idx", - }, - ] - self._assert_insp_indexes(indexes, expected_indexes) + exp = self.exp_indexes(schema=schema) + self._check_list( + indexes, exp[(schema, "users")], self._required_index_keys + ) + + no_cst = self.tables.no_constraints.name + self._check_list( + insp.get_indexes(no_cst, schema=schema), + exp[(schema, no_cst)], + self._required_index_keys, + ) @testing.combinations( ("noncol_idx_test_nopk", "noncol_idx_nopk"), @@ -908,15 +1665,15 @@ class ComponentReflectionTest(fixtures.TablesTest): ) @testing.requires.index_reflection @testing.requires.indexes_with_ascdesc + @testing.requires.reflect_indexes_with_ascdesc def test_get_noncol_index(self, connection, tname, ixname): insp = inspect(connection) indexes = insp.get_indexes(tname) - # reflecting an index that has "x DESC" in it as the column. # the DB may or may not give us "x", but make sure we get the index # back, it has a name, it's connected to the table. - expected_indexes = [{"unique": False, "name": ixname}] - self._assert_insp_indexes(indexes, expected_indexes) + expected_indexes = self.exp_indexes()[(None, tname)] + self._check_list(indexes, expected_indexes, self._required_index_keys) t = Table(tname, MetaData(), autoload_with=connection) eq_(len(t.indexes), 1) @@ -925,29 +1682,17 @@ class ComponentReflectionTest(fixtures.TablesTest): @testing.requires.temp_table_reflection @testing.requires.unique_constraint_reflection - def test_get_temp_table_unique_constraints(self): - insp = inspect(self.bind) - reflected = insp.get_unique_constraints("user_tmp_%s" % config.ident) - for refl in reflected: - # Different dialects handle duplicate index and constraints - # differently, so ignore this flag - refl.pop("duplicates_index", None) - eq_( - reflected, - [ - { - "column_names": ["name"], - "name": "user_tmp_uq_%s" % config.ident, - } - ], - ) + def test_get_temp_table_unique_constraints(self, connection): + insp = inspect(connection) + name = self.temp_table_name() + reflected = insp.get_unique_constraints(name) + exp = self.exp_ucs(all_=True)[(None, name)] + self._check_list(reflected, exp, self._required_index_keys) @testing.requires.temp_table_reflect_indexes - def test_get_temp_table_indexes(self): - insp = inspect(self.bind) - table_name = get_temp_table_name( - config, config.db, "user_tmp_%s" % config.ident - ) + def test_get_temp_table_indexes(self, connection): + insp = inspect(connection) + table_name = self.temp_table_name() indexes = insp.get_indexes(table_name) for ind in indexes: ind.pop("dialect_options", None) @@ -1005,9 +1750,9 @@ class ComponentReflectionTest(fixtures.TablesTest): ) table.create(connection) - inspector = inspect(connection) + insp = inspect(connection) reflected = sorted( - inspector.get_unique_constraints("testtbl", schema=schema), + insp.get_unique_constraints("testtbl", schema=schema), key=operator.itemgetter("name"), ) @@ -1047,6 +1792,9 @@ class ComponentReflectionTest(fixtures.TablesTest): eq_(names_that_duplicate_index, idx_names) eq_(uq_names, set()) + no_cst = self.tables.no_constraints.name + eq_(insp.get_unique_constraints(no_cst, schema=schema), []) + @testing.requires.view_reflection @testing.combinations( (False,), (True, testing.requires.schemas), argnames="use_schema" @@ -1056,32 +1804,21 @@ class ComponentReflectionTest(fixtures.TablesTest): schema = config.test_schema else: schema = None - view_name1 = "users_v" - view_name2 = "email_addresses_v" insp = inspect(connection) - v1 = insp.get_view_definition(view_name1, schema=schema) - self.assert_(v1) - v2 = insp.get_view_definition(view_name2, schema=schema) - self.assert_(v2) + for view in ["users_v", "email_addresses_v", "dingalings_v"]: + v = insp.get_view_definition(view, schema=schema) + is_true(bool(v)) - # why is this here if it's PG specific ? - @testing.combinations( - ("users", False), - ("users", True, testing.requires.schemas), - argnames="table_name,use_schema", - ) - @testing.only_on("postgresql", "PG specific feature") - def test_get_table_oid(self, connection, table_name, use_schema): - if use_schema: - schema = config.test_schema - else: - schema = None + @testing.requires.view_reflection + def test_get_view_definition_does_not_exist(self, connection): insp = inspect(connection) - oid = insp.get_table_oid(table_name, schema) - self.assert_(isinstance(oid, int)) + with expect_raises(NoSuchTableError): + insp.get_view_definition("view_does_not_exist") + with expect_raises(NoSuchTableError): + insp.get_view_definition("users") # a table @testing.requires.table_reflection - def test_autoincrement_col(self): + def test_autoincrement_col(self, connection): """test that 'autoincrement' is reflected according to sqla's policy. Don't mark this test as unsupported for any backend ! @@ -1094,7 +1831,7 @@ class ComponentReflectionTest(fixtures.TablesTest): """ - insp = inspect(self.bind) + insp = inspect(connection) for tname, cname in [ ("users", "user_id"), @@ -1105,6 +1842,330 @@ class ComponentReflectionTest(fixtures.TablesTest): id_ = {c["name"]: c for c in cols}[cname] assert id_.get("autoincrement", True) + @testing.combinations( + (True, testing.requires.schemas), (False,), argnames="use_schema" + ) + def test_get_table_options(self, use_schema): + insp = inspect(config.db) + schema = config.test_schema if use_schema else None + + if testing.requires.reflect_table_options.enabled: + res = insp.get_table_options("users", schema=schema) + is_true(isinstance(res, dict)) + # NOTE: can't really create a table with no option + res = insp.get_table_options("no_constraints", schema=schema) + is_true(isinstance(res, dict)) + else: + with expect_raises(NotImplementedError): + res = insp.get_table_options("users", schema=schema) + + @testing.combinations((True, testing.requires.schemas), False) + def test_multi_get_table_options(self, use_schema): + insp = inspect(config.db) + if testing.requires.reflect_table_options.enabled: + schema = config.test_schema if use_schema else None + res = insp.get_multi_table_options(schema=schema) + + exp = { + (schema, table): insp.get_table_options(table, schema=schema) + for table in insp.get_table_names(schema=schema) + } + eq_(res, exp) + else: + with expect_raises(NotImplementedError): + res = insp.get_multi_table_options() + + @testing.fixture + def get_multi_exp(self, connection): + def provide_fixture( + schema, scope, kind, use_filter, single_reflect_fn, exp_method + ): + insp = inspect(connection) + # call the reflection function at least once to avoid + # "Unexpected success" errors if the result is actually empty + # and NotImplementedError is not raised + single_reflect_fn(insp, "email_addresses") + kw = {"scope": scope, "kind": kind} + if schema: + schema = schema() + + filter_names = [] + + if ObjectKind.TABLE in kind: + filter_names.extend( + ["comment_test", "users", "does-not-exist"] + ) + if ObjectKind.VIEW in kind: + filter_names.extend(["email_addresses_v", "does-not-exist"]) + if ObjectKind.MATERIALIZED_VIEW in kind: + filter_names.extend(["dingalings_v", "does-not-exist"]) + + if schema: + kw["schema"] = schema + if use_filter: + kw["filter_names"] = filter_names + + exp = exp_method( + schema=schema, + scope=scope, + kind=kind, + filter_names=kw.get("filter_names"), + ) + kws = [kw] + if scope == ObjectScope.DEFAULT: + nkw = kw.copy() + nkw.pop("scope") + kws.append(nkw) + if kind == ObjectKind.TABLE: + nkw = kw.copy() + nkw.pop("kind") + kws.append(nkw) + + return inspect(connection), kws, exp + + return provide_fixture + + @testing.requires.reflect_table_options + @_multi_combination + def test_multi_get_table_options_tables( + self, get_multi_exp, schema, scope, kind, use_filter + ): + insp, kws, exp = get_multi_exp( + schema, + scope, + kind, + use_filter, + Inspector.get_table_options, + self.exp_options, + ) + for kw in kws: + insp.clear_cache() + result = insp.get_multi_table_options(**kw) + eq_(result, exp) + + @testing.requires.comment_reflection + @_multi_combination + def test_get_multi_table_comment( + self, get_multi_exp, schema, scope, kind, use_filter + ): + insp, kws, exp = get_multi_exp( + schema, + scope, + kind, + use_filter, + Inspector.get_table_comment, + self.exp_comments, + ) + for kw in kws: + insp.clear_cache() + eq_(insp.get_multi_table_comment(**kw), exp) + + def _check_list(self, result, exp, req_keys=None, msg=None): + if req_keys is None: + eq_(result, exp, msg) + else: + eq_(len(result), len(exp), msg) + for r, e in zip(result, exp): + for k in set(r) | set(e): + if k in req_keys or (k in r and k in e): + eq_(r[k], e[k], f"{msg} - {k} - {r}") + + def _check_table_dict(self, result, exp, req_keys=None, make_lists=False): + eq_(set(result.keys()), set(exp.keys())) + for k in result: + r, e = result[k], exp[k] + if make_lists: + r, e = [r], [e] + self._check_list(r, e, req_keys, k) + + @_multi_combination + def test_get_multi_columns( + self, get_multi_exp, schema, scope, kind, use_filter + ): + insp, kws, exp = get_multi_exp( + schema, + scope, + kind, + use_filter, + Inspector.get_columns, + self.exp_columns, + ) + + for kw in kws: + insp.clear_cache() + result = insp.get_multi_columns(**kw) + self._check_table_dict(result, exp, self._required_column_keys) + + @testing.requires.primary_key_constraint_reflection + @_multi_combination + def test_get_multi_pk_constraint( + self, get_multi_exp, schema, scope, kind, use_filter + ): + insp, kws, exp = get_multi_exp( + schema, + scope, + kind, + use_filter, + Inspector.get_pk_constraint, + self.exp_pks, + ) + for kw in kws: + insp.clear_cache() + result = insp.get_multi_pk_constraint(**kw) + self._check_table_dict( + result, exp, self._required_pk_keys, make_lists=True + ) + + def _adjust_sort(self, result, expected, key): + if not testing.requires.implicitly_named_constraints.enabled: + for obj in [result, expected]: + for val in obj.values(): + if len(val) > 1 and any( + v.get("name") in (None, mock.ANY) for v in val + ): + val.sort(key=key) + + @testing.requires.foreign_key_constraint_reflection + @_multi_combination + def test_get_multi_foreign_keys( + self, get_multi_exp, schema, scope, kind, use_filter + ): + insp, kws, exp = get_multi_exp( + schema, + scope, + kind, + use_filter, + Inspector.get_foreign_keys, + self.exp_fks, + ) + for kw in kws: + insp.clear_cache() + result = insp.get_multi_foreign_keys(**kw) + self._adjust_sort( + result, exp, lambda d: tuple(d["constrained_columns"]) + ) + self._check_table_dict(result, exp, self._required_fk_keys) + + @testing.requires.index_reflection + @_multi_combination + def test_get_multi_indexes( + self, get_multi_exp, schema, scope, kind, use_filter + ): + insp, kws, exp = get_multi_exp( + schema, + scope, + kind, + use_filter, + Inspector.get_indexes, + self.exp_indexes, + ) + for kw in kws: + insp.clear_cache() + result = insp.get_multi_indexes(**kw) + self._check_table_dict(result, exp, self._required_index_keys) + + @testing.requires.unique_constraint_reflection + @_multi_combination + def test_get_multi_unique_constraints( + self, get_multi_exp, schema, scope, kind, use_filter + ): + insp, kws, exp = get_multi_exp( + schema, + scope, + kind, + use_filter, + Inspector.get_unique_constraints, + self.exp_ucs, + ) + for kw in kws: + insp.clear_cache() + result = insp.get_multi_unique_constraints(**kw) + self._adjust_sort(result, exp, lambda d: tuple(d["column_names"])) + self._check_table_dict(result, exp, self._required_unique_cst_keys) + + @testing.requires.check_constraint_reflection + @_multi_combination + def test_get_multi_check_constraints( + self, get_multi_exp, schema, scope, kind, use_filter + ): + insp, kws, exp = get_multi_exp( + schema, + scope, + kind, + use_filter, + Inspector.get_check_constraints, + self.exp_ccs, + ) + for kw in kws: + insp.clear_cache() + result = insp.get_multi_check_constraints(**kw) + self._adjust_sort(result, exp, lambda d: tuple(d["sqltext"])) + self._check_table_dict(result, exp, self._required_cc_keys) + + @testing.combinations( + ("get_table_options", testing.requires.reflect_table_options), + "get_columns", + ( + "get_pk_constraint", + testing.requires.primary_key_constraint_reflection, + ), + ( + "get_foreign_keys", + testing.requires.foreign_key_constraint_reflection, + ), + ("get_indexes", testing.requires.index_reflection), + ( + "get_unique_constraints", + testing.requires.unique_constraint_reflection, + ), + ( + "get_check_constraints", + testing.requires.check_constraint_reflection, + ), + ("get_table_comment", testing.requires.comment_reflection), + argnames="method", + ) + def test_not_existing_table(self, method, connection): + insp = inspect(connection) + meth = getattr(insp, method) + with expect_raises(NoSuchTableError): + meth("table_does_not_exists") + + def test_unreflectable(self, connection): + mc = Inspector.get_multi_columns + + def patched(*a, **k): + ur = k.setdefault("unreflectable", {}) + ur[(None, "some_table")] = UnreflectableTableError("err") + return mc(*a, **k) + + with mock.patch.object(Inspector, "get_multi_columns", patched): + with expect_raises_message(UnreflectableTableError, "err"): + inspect(connection).reflect_table( + Table("some_table", MetaData()), None + ) + + @testing.combinations(True, False, argnames="use_schema") + @testing.combinations( + (True, testing.requires.views), False, argnames="views" + ) + def test_metadata(self, connection, use_schema, views): + m = MetaData() + schema = config.test_schema if use_schema else None + m.reflect(connection, schema=schema, views=views, resolve_fks=False) + + insp = inspect(connection) + tables = insp.get_table_names(schema) + if views: + tables += insp.get_view_names(schema) + try: + tables += insp.get_materialized_view_names(schema) + except NotImplementedError: + pass + if schema: + tables = [f"{schema}.{t}" for t in tables] + eq_(sorted(m.tables), sorted(tables)) + class TableNoColumnsTest(fixtures.TestBase): __requires__ = ("reflect_tables_no_columns",) @@ -1117,9 +2178,6 @@ class TableNoColumnsTest(fixtures.TestBase): @testing.fixture def view_no_columns(self, connection, metadata): - Table("empty", metadata) - metadata.create_all(connection) - Table("empty", metadata) event.listen( metadata, @@ -1134,31 +2192,32 @@ class TableNoColumnsTest(fixtures.TestBase): ) metadata.create_all(connection) - @testing.requires.reflect_tables_no_columns def test_reflect_table_no_columns(self, connection, table_no_columns): t2 = Table("empty", MetaData(), autoload_with=connection) eq_(list(t2.c), []) - @testing.requires.reflect_tables_no_columns def test_get_columns_table_no_columns(self, connection, table_no_columns): - eq_(inspect(connection).get_columns("empty"), []) + insp = inspect(connection) + eq_(insp.get_columns("empty"), []) + multi = insp.get_multi_columns() + eq_(multi, {(None, "empty"): []}) - @testing.requires.reflect_tables_no_columns def test_reflect_incl_table_no_columns(self, connection, table_no_columns): m = MetaData() m.reflect(connection) assert set(m.tables).intersection(["empty"]) @testing.requires.views - @testing.requires.reflect_tables_no_columns def test_reflect_view_no_columns(self, connection, view_no_columns): t2 = Table("empty_v", MetaData(), autoload_with=connection) eq_(list(t2.c), []) @testing.requires.views - @testing.requires.reflect_tables_no_columns def test_get_columns_view_no_columns(self, connection, view_no_columns): - eq_(inspect(connection).get_columns("empty_v"), []) + insp = inspect(connection) + eq_(insp.get_columns("empty_v"), []) + multi = insp.get_multi_columns(kind=ObjectKind.VIEW) + eq_(multi, {(None, "empty_v"): []}) class ComponentReflectionTestExtra(fixtures.TestBase): @@ -1185,12 +2244,18 @@ class ComponentReflectionTestExtra(fixtures.TestBase): ), schema=schema, ) + Table( + "no_constraints", + metadata, + Column("data", sa.String(20)), + schema=schema, + ) metadata.create_all(connection) - inspector = inspect(connection) + insp = inspect(connection) reflected = sorted( - inspector.get_check_constraints("sa_cc", schema=schema), + insp.get_check_constraints("sa_cc", schema=schema), key=operator.itemgetter("name"), ) @@ -1213,6 +2278,8 @@ class ComponentReflectionTestExtra(fixtures.TestBase): {"name": "cc1", "sqltext": "a > 1 and a < 5"}, ], ) + no_cst = "no_constraints" + eq_(insp.get_check_constraints(no_cst, schema=schema), []) @testing.requires.indexes_with_expressions def test_reflect_expression_based_indexes(self, metadata, connection): @@ -1642,7 +2709,8 @@ class IdentityReflectionTest(fixtures.TablesTest): if col["name"] == "normal": is_false("identity" in col) elif col["name"] == "id1": - is_true(col["autoincrement"] in (True, "auto")) + if "autoincrement" in col: + is_true(col["autoincrement"]) eq_(col["default"], None) is_true("identity" in col) self.check( @@ -1659,7 +2727,8 @@ class IdentityReflectionTest(fixtures.TablesTest): approx=True, ) elif col["name"] == "id2": - is_true(col["autoincrement"] in (True, "auto")) + if "autoincrement" in col: + is_true(col["autoincrement"]) eq_(col["default"], None) is_true("identity" in col) self.check( @@ -1685,7 +2754,8 @@ class IdentityReflectionTest(fixtures.TablesTest): if col["name"] == "normal": is_false("identity" in col) elif col["name"] == "id1": - is_true(col["autoincrement"] in (True, "auto")) + if "autoincrement" in col: + is_true(col["autoincrement"]) eq_(col["default"], None) is_true("identity" in col) self.check( @@ -1735,16 +2805,16 @@ class CompositeKeyReflectionTest(fixtures.TablesTest): ) @testing.requires.primary_key_constraint_reflection - def test_pk_column_order(self): + def test_pk_column_order(self, connection): # test for issue #5661 - insp = inspect(self.bind) + insp = inspect(connection) primary_key = insp.get_pk_constraint(self.tables.tb1.name) eq_(primary_key.get("constrained_columns"), ["name", "id", "attr"]) @testing.requires.foreign_key_constraint_reflection - def test_fk_column_order(self): + def test_fk_column_order(self, connection): # test for issue #5661 - insp = inspect(self.bind) + insp = inspect(connection) foreign_keys = insp.get_foreign_keys(self.tables.tb2.name) eq_(len(foreign_keys), 1) fkey1 = foreign_keys[0] diff --git a/lib/sqlalchemy/testing/suite/test_sequence.py b/lib/sqlalchemy/testing/suite/test_sequence.py index eae0519921..e15fad6425 100644 --- a/lib/sqlalchemy/testing/suite/test_sequence.py +++ b/lib/sqlalchemy/testing/suite/test_sequence.py @@ -194,16 +194,23 @@ class HasSequenceTest(fixtures.TablesTest): ) def test_has_sequence(self, connection): - eq_( - inspect(connection).has_sequence("user_id_seq"), - True, - ) + eq_(inspect(connection).has_sequence("user_id_seq"), True) + + def test_has_sequence_cache(self, connection, metadata): + insp = inspect(connection) + eq_(insp.has_sequence("user_id_seq"), True) + ss = Sequence("new_seq", metadata=metadata) + eq_(insp.has_sequence("new_seq"), False) + ss.create(connection) + try: + eq_(insp.has_sequence("new_seq"), False) + insp.clear_cache() + eq_(insp.has_sequence("new_seq"), True) + finally: + ss.drop(connection) def test_has_sequence_other_object(self, connection): - eq_( - inspect(connection).has_sequence("user_id_table"), - False, - ) + eq_(inspect(connection).has_sequence("user_id_table"), False) @testing.requires.schemas def test_has_sequence_schema(self, connection): @@ -215,10 +222,7 @@ class HasSequenceTest(fixtures.TablesTest): ) def test_has_sequence_neg(self, connection): - eq_( - inspect(connection).has_sequence("some_sequence"), - False, - ) + eq_(inspect(connection).has_sequence("some_sequence"), False) @testing.requires.schemas def test_has_sequence_schemas_neg(self, connection): @@ -240,10 +244,7 @@ class HasSequenceTest(fixtures.TablesTest): @testing.requires.schemas def test_has_sequence_remote_not_in_default(self, connection): - eq_( - inspect(connection).has_sequence("schema_seq"), - False, - ) + eq_(inspect(connection).has_sequence("schema_seq"), False) def test_get_sequence_names(self, connection): exp = {"other_seq", "user_id_seq"} diff --git a/lib/sqlalchemy/testing/util.py b/lib/sqlalchemy/testing/util.py index 0070b4d670..6fd42af702 100644 --- a/lib/sqlalchemy/testing/util.py +++ b/lib/sqlalchemy/testing/util.py @@ -393,36 +393,55 @@ def drop_all_tables_from_metadata(metadata, engine_or_connection): go(engine_or_connection) -def drop_all_tables(engine, inspector, schema=None, include_names=None): +def drop_all_tables( + engine, + inspector, + schema=None, + consider_schemas=(None,), + include_names=None, +): if include_names is not None: include_names = set(include_names) + if schema is not None: + assert consider_schemas == ( + None, + ), "consider_schemas and schema are mutually exclusive" + consider_schemas = (schema,) + with engine.begin() as conn: - for tname, fkcs in reversed( - inspector.get_sorted_table_and_fkc_names(schema=schema) + for table_key, fkcs in reversed( + inspector.sort_tables_on_foreign_key_dependency( + consider_schemas=consider_schemas + ) ): - if tname: - if include_names is not None and tname not in include_names: + if table_key: + if ( + include_names is not None + and table_key[1] not in include_names + ): continue conn.execute( - DropTable(Table(tname, MetaData(), schema=schema)) + DropTable( + Table(table_key[1], MetaData(), schema=table_key[0]) + ) ) elif fkcs: if not engine.dialect.supports_alter: continue - for tname, fkc in fkcs: + for t_key, fkc in fkcs: if ( include_names is not None - and tname not in include_names + and t_key[1] not in include_names ): continue tb = Table( - tname, + t_key[1], MetaData(), Column("x", Integer), Column("y", Integer), - schema=schema, + schema=t_key[0], ) conn.execute( DropConstraint( diff --git a/lib/sqlalchemy/util/topological.py b/lib/sqlalchemy/util/topological.py index 24e478b573..620e3bbb71 100644 --- a/lib/sqlalchemy/util/topological.py +++ b/lib/sqlalchemy/util/topological.py @@ -10,6 +10,7 @@ from __future__ import annotations from typing import Any +from typing import Collection from typing import DefaultDict from typing import Iterable from typing import Iterator @@ -27,7 +28,7 @@ __all__ = ["sort", "sort_as_subsets", "find_cycles"] def sort_as_subsets( - tuples: Iterable[Tuple[_T, _T]], allitems: Iterable[_T] + tuples: Collection[Tuple[_T, _T]], allitems: Collection[_T] ) -> Iterator[Sequence[_T]]: edges: DefaultDict[_T, Set[_T]] = util.defaultdict(set) @@ -56,8 +57,8 @@ def sort_as_subsets( def sort( - tuples: Iterable[Tuple[_T, _T]], - allitems: Iterable[_T], + tuples: Collection[Tuple[_T, _T]], + allitems: Collection[_T], deterministic_order: bool = True, ) -> Iterator[_T]: """sort the given list of items by dependency. @@ -76,8 +77,7 @@ def sort( def find_cycles( - tuples: Iterable[Tuple[_T, _T]], - allitems: Iterable[_T], + tuples: Iterable[Tuple[_T, _T]], allitems: Iterable[_T] ) -> Set[_T]: # adapted from: # https://neopythonic.blogspot.com/2009/01/detecting-cycles-in-directed-graph.html diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py index eb625e06e2..4e76554c73 100644 --- a/lib/sqlalchemy/util/typing.py +++ b/lib/sqlalchemy/util/typing.py @@ -78,11 +78,13 @@ if typing.TYPE_CHECKING or compat.py38: from typing import Protocol as Protocol from typing import TypedDict as TypedDict from typing import Final as Final + from typing import final as final else: from typing_extensions import Literal as Literal # noqa: F401 from typing_extensions import Protocol as Protocol # noqa: F401 from typing_extensions import TypedDict as TypedDict # noqa: F401 from typing_extensions import Final as Final # noqa: F401 + from typing_extensions import final as final # noqa: F401 typing_get_args = get_args typing_get_origin = get_origin diff --git a/pyproject.toml b/pyproject.toml index 812d60e915..3807b9c103 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,17 +60,4 @@ module = [ warn_unused_ignores = false strict = true -[[tool.mypy.overrides]] - -##################################################################### -# interim list of modules that need some level of type checking to -# pass -module = [ - - "sqlalchemy.engine.reflection", - -] - -ignore_errors = true -warn_unused_ignores = false diff --git a/setup.cfg b/setup.cfg index 1b3a6bcc68..f24a978280 100644 --- a/setup.cfg +++ b/setup.cfg @@ -128,6 +128,9 @@ profile_file = test/profiles.txt # create public database link test_link connect to scott identified by tiger # using 'xe'; oracle_db_link = test_link +# create public database link test_link2 connect to test_schema identified by tiger +# using 'xe'; +oracle_db_link2 = test_link2 # host name of a postgres database that has the postgres_fdw extension. # to create this run: @@ -162,8 +165,8 @@ mariadb = mariadb+mysqldb://scott:tiger@127.0.0.1:3306/test mariadb_connector = mariadb+mariadbconnector://scott:tiger@127.0.0.1:3306/test mssql = mssql+pyodbc://scott:tiger^5HHH@mssql2017:1433/test?driver=ODBC+Driver+13+for+SQL+Server mssql_pymssql = mssql+pymssql://scott:tiger@ms_2008 -docker_mssql = mssql+pymssql://scott:tiger^5HHH@127.0.0.1:1433/test +docker_mssql = mssql+pyodbc://scott:tiger^5HHH@127.0.0.1:1433/test?driver=ODBC+Driver+17+for+SQL+Server oracle = oracle+cx_oracle://scott:tiger@oracle18c/xe cxoracle = oracle+cx_oracle://scott:tiger@oracle18c/xe -oracle_oracledb = oracle+oracledb://scott:tiger@oracle18c/xe oracledb = oracle+oracledb://scott:tiger@oracle18c/xe +docker_oracle = oracle+cx_oracle://scott:tiger@127.0.0.1:1521/?service_name=XEPDB1 \ No newline at end of file diff --git a/test/dialect/mysql/test_reflection.py b/test/dialect/mysql/test_reflection.py index f414c9c376..846001347c 100644 --- a/test/dialect/mysql/test_reflection.py +++ b/test/dialect/mysql/test_reflection.py @@ -33,9 +33,9 @@ from sqlalchemy import UniqueConstraint from sqlalchemy.dialects.mysql import base as mysql from sqlalchemy.dialects.mysql import reflection as _reflection from sqlalchemy.schema import CreateIndex -from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import eq_ +from sqlalchemy.testing import expect_raises_message from sqlalchemy.testing import expect_warnings from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ @@ -558,6 +558,10 @@ class ReflectionTest(fixtures.TestBase, AssertsCompiledSQL): ) def test_skip_not_describable(self, metadata, connection): + """This test is the only one that test the _default_multi_reflect + behaviour with UnreflectableTableError + """ + @event.listens_for(metadata, "before_drop") def cleanup(*arg, **kw): with testing.db.begin() as conn: @@ -579,14 +583,10 @@ class ReflectionTest(fixtures.TestBase, AssertsCompiledSQL): m.reflect(views=True, bind=conn) eq_(m.tables["test_t2"].name, "test_t2") - assert_raises_message( - exc.UnreflectableTableError, - "references invalid table", - Table, - "test_v", - MetaData(), - autoload_with=conn, - ) + with expect_raises_message( + exc.UnreflectableTableError, "references invalid table" + ): + Table("test_v", MetaData(), autoload_with=conn) @testing.exclude("mysql", "<", (5, 0, 0), "no information_schema support") def test_system_views(self): diff --git a/test/dialect/oracle/test_reflection.py b/test/dialect/oracle/test_reflection.py index bf76dca43d..53eb94df30 100644 --- a/test/dialect/oracle/test_reflection.py +++ b/test/dialect/oracle/test_reflection.py @@ -27,14 +27,18 @@ from sqlalchemy.dialects.oracle.base import BINARY_FLOAT from sqlalchemy.dialects.oracle.base import DOUBLE_PRECISION from sqlalchemy.dialects.oracle.base import NUMBER from sqlalchemy.dialects.oracle.base import REAL +from sqlalchemy.engine import ObjectKind from sqlalchemy.testing import assert_warns from sqlalchemy.testing import AssertsCompiledSQL +from sqlalchemy.testing import config from sqlalchemy.testing import eq_ +from sqlalchemy.testing import expect_raises from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ from sqlalchemy.testing import is_true from sqlalchemy.testing.engines import testing_engine from sqlalchemy.testing.schema import Column +from sqlalchemy.testing.schema import eq_compile_type from sqlalchemy.testing.schema import Table @@ -384,6 +388,7 @@ class SystemTableTablenamesTest(fixtures.TestBase): __backend__ = True def setup_test(self): + with testing.db.begin() as conn: conn.exec_driver_sql("create table my_table (id integer)") conn.exec_driver_sql( @@ -417,6 +422,14 @@ class SystemTableTablenamesTest(fixtures.TestBase): set(["my_table", "foo_table"]), ) + def test_reflect_system_table(self): + meta = MetaData() + t = Table("foo_table", meta, autoload_with=testing.db) + assert t.columns.keys() == ["id"] + + t = Table("my_temp_table", meta, autoload_with=testing.db) + assert t.columns.keys() == ["id"] + class DontReflectIOTTest(fixtures.TestBase): """test that index overflow tables aren't included in @@ -509,6 +522,228 @@ class TableReflectionTest(fixtures.TestBase): tbl = Table("test_compress", m2, autoload_with=connection) assert tbl.dialect_options["oracle"]["compress"] == "OLTP" + def test_reflect_hidden_column(self): + with testing.db.begin() as conn: + conn.exec_driver_sql( + "CREATE TABLE my_table(id integer, hide integer INVISIBLE)" + ) + + try: + insp = inspect(conn) + cols = insp.get_columns("my_table") + assert len(cols) == 1 + assert cols[0]["name"] == "id" + finally: + conn.exec_driver_sql("DROP TABLE my_table") + + +class ViewReflectionTest(fixtures.TestBase): + __only_on__ = "oracle" + __backend__ = True + + @classmethod + def setup_test_class(cls): + sql = """ + CREATE TABLE tbl ( + id INTEGER PRIMARY KEY, + data INTEGER + ); + + CREATE VIEW tbl_plain_v AS + SELECT id, data FROM tbl WHERE id > 100; + + -- comments on plain views are created with "comment on table" + -- because why not.. + COMMENT ON TABLE tbl_plain_v IS 'view comment'; + + CREATE MATERIALIZED VIEW tbl_v AS + SELECT id, data FROM tbl WHERE id > 42; + + COMMENT ON MATERIALIZED VIEW tbl_v IS 'my mat view comment'; + + CREATE MATERIALIZED VIEW tbl_v2 AS + SELECT id, data FROM tbl WHERE id < 42; + + COMMENT ON MATERIALIZED VIEW tbl_v2 IS 'my other mat view comment'; + + CREATE SYNONYM view_syn FOR tbl_plain_v; + CREATE SYNONYM %(test_schema)s.ts_v_s FOR tbl_plain_v; + + CREATE VIEW %(test_schema)s.schema_view AS + SELECT 1 AS value FROM dual; + + COMMENT ON TABLE %(test_schema)s.schema_view IS 'schema view comment'; + CREATE SYNONYM syn_schema_view FOR %(test_schema)s.schema_view; + """ + if testing.requires.oracle_test_dblink.enabled: + cls.dblink = config.file_config.get( + "sqla_testing", "oracle_db_link" + ) + sql += """ + CREATE SYNONYM syn_link FOR tbl_plain_v@%(link)s; + """ % { + "link": cls.dblink + } + with testing.db.begin() as conn: + for stmt in ( + sql % {"test_schema": testing.config.test_schema} + ).split(";"): + if stmt.strip(): + conn.exec_driver_sql(stmt) + + @classmethod + def teardown_test_class(cls): + sql = """ + DROP MATERIALIZED VIEW tbl_v; + DROP MATERIALIZED VIEW tbl_v2; + DROP VIEW tbl_plain_v; + DROP TABLE tbl; + DROP VIEW %(test_schema)s.schema_view; + DROP SYNONYM view_syn; + DROP SYNONYM %(test_schema)s.ts_v_s; + DROP SYNONYM syn_schema_view; + """ + if testing.requires.oracle_test_dblink.enabled: + sql += """ + DROP SYNONYM syn_link; + """ + with testing.db.begin() as conn: + for stmt in ( + sql % {"test_schema": testing.config.test_schema} + ).split(";"): + if stmt.strip(): + conn.exec_driver_sql(stmt) + + def test_get_names(self, connection): + insp = inspect(connection) + eq_(insp.get_table_names(), ["tbl"]) + eq_(insp.get_view_names(), ["tbl_plain_v"]) + eq_(insp.get_materialized_view_names(), ["tbl_v", "tbl_v2"]) + eq_( + insp.get_view_names(schema=testing.config.test_schema), + ["schema_view"], + ) + + def test_get_table_comment_on_view(self, connection): + insp = inspect(connection) + eq_(insp.get_table_comment("tbl_v"), {"text": "my mat view comment"}) + eq_(insp.get_table_comment("tbl_plain_v"), {"text": "view comment"}) + + def test_get_multi_view_comment(self, connection): + insp = inspect(connection) + plain = {(None, "tbl_plain_v"): {"text": "view comment"}} + mat = { + (None, "tbl_v"): {"text": "my mat view comment"}, + (None, "tbl_v2"): {"text": "my other mat view comment"}, + } + eq_(insp.get_multi_table_comment(kind=ObjectKind.VIEW), plain) + eq_( + insp.get_multi_table_comment(kind=ObjectKind.MATERIALIZED_VIEW), + mat, + ) + eq_( + insp.get_multi_table_comment(kind=ObjectKind.ANY_VIEW), + {**plain, **mat}, + ) + ts = testing.config.test_schema + eq_( + insp.get_multi_table_comment(kind=ObjectKind.ANY_VIEW, schema=ts), + {(ts, "schema_view"): {"text": "schema view comment"}}, + ) + eq_(insp.get_multi_table_comment(), {(None, "tbl"): {"text": None}}) + + def test_get_table_comment_synonym(self, connection): + insp = inspect(connection) + eq_( + insp.get_table_comment("view_syn", oracle_resolve_synonyms=True), + {"text": "view comment"}, + ) + eq_( + insp.get_table_comment( + "syn_schema_view", oracle_resolve_synonyms=True + ), + {"text": "schema view comment"}, + ) + eq_( + insp.get_table_comment( + "ts_v_s", + oracle_resolve_synonyms=True, + schema=testing.config.test_schema, + ), + {"text": "view comment"}, + ) + + def test_get_multi_view_comment_synonym(self, connection): + insp = inspect(connection) + exp = { + (None, "view_syn"): {"text": "view comment"}, + (None, "syn_schema_view"): {"text": "schema view comment"}, + } + if testing.requires.oracle_test_dblink.enabled: + exp[(None, "syn_link")] = {"text": "view comment"} + eq_( + insp.get_multi_table_comment( + oracle_resolve_synonyms=True, kind=ObjectKind.ANY_VIEW + ), + exp, + ) + ts = testing.config.test_schema + eq_( + insp.get_multi_table_comment( + oracle_resolve_synonyms=True, + schema=ts, + kind=ObjectKind.ANY_VIEW, + ), + {(ts, "ts_v_s"): {"text": "view comment"}}, + ) + + def test_get_view_definition(self, connection): + insp = inspect(connection) + eq_( + insp.get_view_definition("tbl_plain_v"), + "SELECT id, data FROM tbl WHERE id > 100", + ) + eq_( + insp.get_view_definition("tbl_v"), + "SELECT id, data FROM tbl WHERE id > 42", + ) + with expect_raises(exc.NoSuchTableError): + eq_(insp.get_view_definition("view_syn"), None) + eq_( + insp.get_view_definition("view_syn", oracle_resolve_synonyms=True), + "SELECT id, data FROM tbl WHERE id > 100", + ) + eq_( + insp.get_view_definition( + "syn_schema_view", oracle_resolve_synonyms=True + ), + "SELECT 1 AS value FROM dual", + ) + eq_( + insp.get_view_definition( + "ts_v_s", + oracle_resolve_synonyms=True, + schema=testing.config.test_schema, + ), + "SELECT id, data FROM tbl WHERE id > 100", + ) + + @testing.requires.oracle_test_dblink + def test_get_view_definition_dblink(self, connection): + insp = inspect(connection) + eq_( + insp.get_view_definition("syn_link", oracle_resolve_synonyms=True), + "SELECT id, data FROM tbl WHERE id > 100", + ) + eq_( + insp.get_view_definition("tbl_plain_v", dblink=self.dblink), + "SELECT id, data FROM tbl WHERE id > 100", + ) + eq_( + insp.get_view_definition("tbl_v", dblink=self.dblink), + "SELECT id, data FROM tbl WHERE id > 42", + ) + class RoundTripIndexTest(fixtures.TestBase): __only_on__ = "oracle" @@ -722,8 +957,6 @@ class DBLinkReflectionTest(fixtures.TestBase): @classmethod def setup_test_class(cls): - from sqlalchemy.testing import config - cls.dblink = config.file_config.get("sqla_testing", "oracle_db_link") # note that the synonym here is still not totally functional @@ -863,3 +1096,320 @@ class IdentityReflectionTest(fixtures.TablesTest): exp = common.copy() exp["order"] = True eq_(col["identity"], exp) + + +class AdditionalReflectionTests(fixtures.TestBase): + __only_on__ = "oracle" + __backend__ = True + + @classmethod + def setup_test_class(cls): + # currently assuming full DBA privs for the user. + # don't really know how else to go here unless + # we connect as the other user. + + sql = """ +CREATE TABLE %(schema)sparent( + id INTEGER, + data VARCHAR2(50), + CONSTRAINT parent_pk_%(schema_id)s PRIMARY KEY (id) +); +CREATE TABLE %(schema)smy_table( + id INTEGER, + name VARCHAR2(125), + related INTEGER, + data%(schema_id)s NUMBER NOT NULL, + CONSTRAINT my_table_pk_%(schema_id)s PRIMARY KEY (id), + CONSTRAINT my_table_fk_%(schema_id)s FOREIGN KEY(related) + REFERENCES %(schema)sparent(id), + CONSTRAINT my_table_check_%(schema_id)s CHECK (data%(schema_id)s > 42), + CONSTRAINT data_unique%(schema_id)s UNIQUE (data%(schema_id)s) +); +CREATE INDEX my_table_index_%(schema_id)s on %(schema)smy_table (id, name); +COMMENT ON TABLE %(schema)smy_table IS 'my table comment %(schema_id)s'; +COMMENT ON COLUMN %(schema)smy_table.name IS +'my table.name comment %(schema_id)s'; +""" + + with testing.db.begin() as conn: + for schema in ("", testing.config.test_schema): + dd = { + "schema": f"{schema}." if schema else "", + "schema_id": "sch" if schema else "", + } + for stmt in (sql % dd).split(";"): + if stmt.strip(): + conn.exec_driver_sql(stmt) + + @classmethod + def teardown_test_class(cls): + sql = """ +drop table %(schema)smy_table; +drop table %(schema)sparent; +""" + with testing.db.begin() as conn: + for schema in ("", testing.config.test_schema): + dd = {"schema": f"{schema}." if schema else ""} + for stmt in (sql % dd).split(";"): + if stmt.strip(): + try: + conn.exec_driver_sql(stmt) + except: + pass + + def setup_test(self): + self.dblink = config.file_config.get("sqla_testing", "oracle_db_link") + self.dblink2 = config.file_config.get( + "sqla_testing", "oracle_db_link2" + ) + self.columns = {} + self.indexes = {} + self.primary_keys = {} + self.comments = {} + self.uniques = {} + self.checks = {} + self.foreign_keys = {} + self.options = {} + self.allDicts = [ + self.columns, + self.indexes, + self.primary_keys, + self.comments, + self.uniques, + self.checks, + self.foreign_keys, + self.options, + ] + for schema in (None, testing.config.test_schema): + suffix = "sch" if schema else "" + + self.columns[schema] = { + (schema, "my_table"): [ + { + "name": "id", + "nullable": False, + "type": eq_compile_type("INTEGER"), + "default": None, + "comment": None, + }, + { + "name": "name", + "nullable": True, + "type": eq_compile_type("VARCHAR(125)"), + "default": None, + "comment": f"my table.name comment {suffix}", + }, + { + "name": "related", + "nullable": True, + "type": eq_compile_type("INTEGER"), + "default": None, + "comment": None, + }, + { + "name": f"data{suffix}", + "nullable": False, + "type": eq_compile_type("NUMBER"), + "default": None, + "comment": None, + }, + ], + (schema, "parent"): [ + { + "name": "id", + "nullable": False, + "type": eq_compile_type("INTEGER"), + "default": None, + "comment": None, + }, + { + "name": "data", + "nullable": True, + "type": eq_compile_type("VARCHAR(50)"), + "default": None, + "comment": None, + }, + ], + } + self.indexes[schema] = { + (schema, "my_table"): [ + { + "name": f"data_unique{suffix}", + "column_names": [f"data{suffix}"], + "dialect_options": {}, + "unique": True, + }, + { + "name": f"my_table_index_{suffix}", + "column_names": ["id", "name"], + "dialect_options": {}, + "unique": False, + }, + ], + (schema, "parent"): [], + } + self.primary_keys[schema] = { + (schema, "my_table"): { + "name": f"my_table_pk_{suffix}", + "constrained_columns": ["id"], + }, + (schema, "parent"): { + "name": f"parent_pk_{suffix}", + "constrained_columns": ["id"], + }, + } + self.comments[schema] = { + (schema, "my_table"): {"text": f"my table comment {suffix}"}, + (schema, "parent"): {"text": None}, + } + self.foreign_keys[schema] = { + (schema, "my_table"): [ + { + "name": f"my_table_fk_{suffix}", + "constrained_columns": ["related"], + "referred_schema": schema, + "referred_table": "parent", + "referred_columns": ["id"], + "options": {}, + } + ], + (schema, "parent"): [], + } + self.checks[schema] = { + (schema, "my_table"): [ + { + "name": f"my_table_check_{suffix}", + "sqltext": f"data{suffix} > 42", + } + ], + (schema, "parent"): [], + } + self.uniques[schema] = { + (schema, "my_table"): [ + { + "name": f"data_unique{suffix}", + "column_names": [f"data{suffix}"], + "duplicates_index": f"data_unique{suffix}", + } + ], + (schema, "parent"): [], + } + self.options[schema] = { + (schema, "my_table"): {}, + (schema, "parent"): {}, + } + + def test_tables(self, connection): + insp = inspect(connection) + + eq_(sorted(insp.get_table_names()), ["my_table", "parent"]) + + def _check_reflection(self, conn, schema, res_schema=False, **kw): + if res_schema is False: + res_schema = schema + insp = inspect(conn) + eq_( + insp.get_multi_columns(schema=schema, **kw), + self.columns[res_schema], + ) + eq_( + insp.get_multi_indexes(schema=schema, **kw), + self.indexes[res_schema], + ) + eq_( + insp.get_multi_pk_constraint(schema=schema, **kw), + self.primary_keys[res_schema], + ) + eq_( + insp.get_multi_table_comment(schema=schema, **kw), + self.comments[res_schema], + ) + eq_( + insp.get_multi_foreign_keys(schema=schema, **kw), + self.foreign_keys[res_schema], + ) + eq_( + insp.get_multi_check_constraints(schema=schema, **kw), + self.checks[res_schema], + ) + eq_( + insp.get_multi_unique_constraints(schema=schema, **kw), + self.uniques[res_schema], + ) + eq_( + insp.get_multi_table_options(schema=schema, **kw), + self.options[res_schema], + ) + + @testing.combinations(True, False, argnames="schema") + def test_schema_translate_map(self, connection, schema): + schema = testing.config.test_schema if schema else None + c = connection.execution_options( + schema_translate_map={ + None: "foo", + testing.config.test_schema: "bar", + } + ) + self._check_reflection(c, schema) + + @testing.requires.oracle_test_dblink + def test_db_link(self, connection): + self._check_reflection(connection, schema=None, dblink=self.dblink) + self._check_reflection( + connection, + schema=testing.config.test_schema, + dblink=self.dblink, + ) + + def test_no_synonyms(self, connection): + # oracle_resolve_synonyms is ignored if there are no matching synonym + self._check_reflection( + connection, schema=None, oracle_resolve_synonyms=True + ) + connection.exec_driver_sql("CREATE SYNONYM tmp FOR parent") + for dict_ in self.allDicts: + dict_["tmp"] = {(None, "parent"): dict_[None][(None, "parent")]} + try: + self._check_reflection( + connection, + schema=None, + res_schema="tmp", + oracle_resolve_synonyms=True, + filter_names=["parent"], + ) + finally: + connection.exec_driver_sql("DROP SYNONYM tmp") + + @testing.requires.oracle_test_dblink + @testing.requires.oracle_test_dblink2 + def test_multi_dblink_synonyms(self, connection): + # oracle_resolve_synonyms handles multiple dblink at once + connection.exec_driver_sql( + f"CREATE SYNONYM s1 FOR my_table@{self.dblink}" + ) + connection.exec_driver_sql( + f"CREATE SYNONYM s2 FOR {testing.config.test_schema}." + f"my_table@{self.dblink2}" + ) + connection.exec_driver_sql("CREATE SYNONYM s3 FOR parent") + for dict_ in self.allDicts: + dict_["tmp"] = { + (None, "s1"): dict_[None][(None, "my_table")], + (None, "s2"): dict_[testing.config.test_schema][ + (testing.config.test_schema, "my_table") + ], + (None, "s3"): dict_[None][(None, "parent")], + } + fk = self.foreign_keys["tmp"][(None, "s1")][0] + fk["referred_table"] = "s3" + try: + self._check_reflection( + connection, + schema=None, + res_schema="tmp", + oracle_resolve_synonyms=True, + ) + finally: + connection.exec_driver_sql("DROP SYNONYM s1") + connection.exec_driver_sql("DROP SYNONYM s2") + connection.exec_driver_sql("DROP SYNONYM s3") diff --git a/test/dialect/postgresql/test_dialect.py b/test/dialect/postgresql/test_dialect.py index 0093eb5ba7..d55aa82039 100644 --- a/test/dialect/postgresql/test_dialect.py +++ b/test/dialect/postgresql/test_dialect.py @@ -887,19 +887,12 @@ class MiscBackendTest( ) @testing.combinations( - ((8, 1), False, False), - ((8, 1), None, False), - ((11, 5), True, False), - ((11, 5), False, True), + (True, False), + (False, True), ) - def test_backslash_escapes_detection( - self, version, explicit_setting, expected - ): + def test_backslash_escapes_detection(self, explicit_setting, expected): engine = engines.testing_engine() - def _server_version(conn): - return version - if explicit_setting is not None: @event.listens_for(engine, "connect", insert=True) @@ -912,11 +905,8 @@ class MiscBackendTest( ) dbapi_connection.commit() - with mock.patch.object( - engine.dialect, "_get_server_version_info", _server_version - ): - with engine.connect(): - eq_(engine.dialect._backslash_escapes, expected) + with engine.connect(): + eq_(engine.dialect._backslash_escapes, expected) def test_dbapi_autocommit_attribute(self): """all the supported DBAPIs have an .autocommit attribute. make diff --git a/test/dialect/postgresql/test_reflection.py b/test/dialect/postgresql/test_reflection.py index cbb1809e4d..00e5dc5b9c 100644 --- a/test/dialect/postgresql/test_reflection.py +++ b/test/dialect/postgresql/test_reflection.py @@ -27,17 +27,21 @@ from sqlalchemy.dialects.postgresql import base as postgresql from sqlalchemy.dialects.postgresql import ExcludeConstraint from sqlalchemy.dialects.postgresql import INTEGER from sqlalchemy.dialects.postgresql import INTERVAL +from sqlalchemy.dialects.postgresql import pg_catalog from sqlalchemy.dialects.postgresql import TSRANGE +from sqlalchemy.engine import ObjectKind +from sqlalchemy.engine import ObjectScope from sqlalchemy.schema import CreateIndex from sqlalchemy.sql.schema import CheckConstraint from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import fixtures from sqlalchemy.testing import mock -from sqlalchemy.testing.assertions import assert_raises from sqlalchemy.testing.assertions import assert_warns from sqlalchemy.testing.assertions import AssertsExecutionResults from sqlalchemy.testing.assertions import eq_ +from sqlalchemy.testing.assertions import expect_raises from sqlalchemy.testing.assertions import is_ +from sqlalchemy.testing.assertions import is_false from sqlalchemy.testing.assertions import is_true @@ -231,15 +235,34 @@ class MaterializedViewReflectionTest( connection.execute(target.insert(), {"id": 89, "data": "d1"}) materialized_view = sa.DDL( - "CREATE MATERIALIZED VIEW test_mview AS " "SELECT * FROM testtable" + "CREATE MATERIALIZED VIEW test_mview AS SELECT * FROM testtable" ) plain_view = sa.DDL( - "CREATE VIEW test_regview AS " "SELECT * FROM testtable" + "CREATE VIEW test_regview AS SELECT data FROM testtable" ) sa.event.listen(testtable, "after_create", plain_view) sa.event.listen(testtable, "after_create", materialized_view) + sa.event.listen( + testtable, + "after_create", + sa.DDL("COMMENT ON VIEW test_regview IS 'regular view comment'"), + ) + sa.event.listen( + testtable, + "after_create", + sa.DDL( + "COMMENT ON MATERIALIZED VIEW test_mview " + "IS 'materialized view comment'" + ), + ) + sa.event.listen( + testtable, + "after_create", + sa.DDL("CREATE INDEX mat_index ON test_mview(data DESC)"), + ) + sa.event.listen( testtable, "before_drop", @@ -249,6 +272,12 @@ class MaterializedViewReflectionTest( testtable, "before_drop", sa.DDL("DROP VIEW test_regview") ) + def test_has_type(self, connection): + insp = inspect(connection) + is_true(insp.has_type("test_mview")) + is_true(insp.has_type("test_regview")) + is_true(insp.has_type("testtable")) + def test_mview_is_reflected(self, connection): metadata = MetaData() table = Table("test_mview", metadata, autoload_with=connection) @@ -265,49 +294,99 @@ class MaterializedViewReflectionTest( def test_get_view_names(self, inspect_fixture): insp, conn = inspect_fixture - eq_(set(insp.get_view_names()), set(["test_regview", "test_mview"])) + eq_(set(insp.get_view_names()), set(["test_regview"])) - def test_get_view_names_plain(self, connection): + def test_get_materialized_view_names(self, inspect_fixture): + insp, conn = inspect_fixture + eq_(set(insp.get_materialized_view_names()), set(["test_mview"])) + + def test_get_view_names_reflection_cache_ok(self, connection): insp = inspect(connection) + eq_(set(insp.get_view_names()), set(["test_regview"])) eq_( - set(insp.get_view_names(include=("plain",))), set(["test_regview"]) + set(insp.get_materialized_view_names()), + set(["test_mview"]), + ) + eq_( + set(insp.get_view_names()).union( + insp.get_materialized_view_names() + ), + set(["test_regview", "test_mview"]), ) - def test_get_view_names_plain_string(self, connection): + def test_get_view_definition(self, connection): insp = inspect(connection) - eq_(set(insp.get_view_names(include="plain")), set(["test_regview"])) - def test_get_view_names_materialized(self, connection): - insp = inspect(connection) + def normalize(definition): + return re.sub(r"[\n\t ]+", " ", definition.strip()) + eq_( - set(insp.get_view_names(include=("materialized",))), - set(["test_mview"]), + normalize(insp.get_view_definition("test_mview")), + "SELECT testtable.id, testtable.data FROM testtable;", + ) + eq_( + normalize(insp.get_view_definition("test_regview")), + "SELECT testtable.data FROM testtable;", ) - def test_get_view_names_reflection_cache_ok(self, connection): + def test_get_view_comment(self, connection): insp = inspect(connection) eq_( - set(insp.get_view_names(include=("plain",))), set(["test_regview"]) + insp.get_table_comment("test_regview"), + {"text": "regular view comment"}, ) eq_( - set(insp.get_view_names(include=("materialized",))), - set(["test_mview"]), + insp.get_table_comment("test_mview"), + {"text": "materialized view comment"}, ) - eq_(set(insp.get_view_names()), set(["test_regview", "test_mview"])) - def test_get_view_names_empty(self, connection): + def test_get_multi_view_comment(self, connection): insp = inspect(connection) - assert_raises(ValueError, insp.get_view_names, include=()) + eq_( + insp.get_multi_table_comment(), + {(None, "testtable"): {"text": None}}, + ) + plain = {(None, "test_regview"): {"text": "regular view comment"}} + mat = {(None, "test_mview"): {"text": "materialized view comment"}} + eq_(insp.get_multi_table_comment(kind=ObjectKind.VIEW), plain) + eq_( + insp.get_multi_table_comment(kind=ObjectKind.MATERIALIZED_VIEW), + mat, + ) + eq_( + insp.get_multi_table_comment(kind=ObjectKind.ANY_VIEW), + {**plain, **mat}, + ) + eq_( + insp.get_multi_table_comment( + kind=ObjectKind.ANY_VIEW, scope=ObjectScope.TEMPORARY + ), + {}, + ) - def test_get_view_definition(self, connection): + def test_get_multi_view_indexes(self, connection): insp = inspect(connection) + eq_(insp.get_multi_indexes(), {(None, "testtable"): []}) + + exp = { + "name": "mat_index", + "unique": False, + "column_names": ["data"], + "column_sorting": {"data": ("desc",)}, + } + if connection.dialect.server_version_info >= (11, 0): + exp["include_columns"] = [] + exp["dialect_options"] = {"postgresql_include": []} + plain = {(None, "test_regview"): []} + mat = {(None, "test_mview"): [exp]} + eq_(insp.get_multi_indexes(kind=ObjectKind.VIEW), plain) + eq_(insp.get_multi_indexes(kind=ObjectKind.MATERIALIZED_VIEW), mat) + eq_(insp.get_multi_indexes(kind=ObjectKind.ANY_VIEW), {**plain, **mat}) eq_( - re.sub( - r"[\n\t ]+", - " ", - insp.get_view_definition("test_mview").strip(), + insp.get_multi_indexes( + kind=ObjectKind.ANY_VIEW, scope=ObjectScope.TEMPORARY ), - "SELECT testtable.id, testtable.data FROM testtable;", + {}, ) @@ -993,9 +1072,9 @@ class ReflectionTest( go, [ "Skipped unsupported reflection of " - "expression-based index idx1", + "expression-based index idx1 of table party", "Skipped unsupported reflection of " - "expression-based index idx3", + "expression-based index idx3 of table party", ], ) @@ -1016,7 +1095,7 @@ class ReflectionTest( metadata.create_all(connection) - ind = connection.dialect.get_indexes(connection, t1, None) + ind = connection.dialect.get_indexes(connection, t1.name, None) partial_definitions = [] for ix in ind: @@ -1337,6 +1416,9 @@ class ReflectionTest( } ], ) + is_true(inspector.has_type("mood", "test_schema")) + is_true(inspector.has_type("mood", "*")) + is_false(inspector.has_type("mood")) def test_inspect_enums(self, metadata, inspect_fixture): @@ -1345,30 +1427,49 @@ class ReflectionTest( enum_type = postgresql.ENUM( "cat", "dog", "rat", name="pet", metadata=metadata ) + enum_type.create(conn) + conn.commit() - with conn.begin(): - enum_type.create(conn) - - eq_( - inspector.get_enums(), - [ - { - "visible": True, - "labels": ["cat", "dog", "rat"], - "name": "pet", - "schema": "public", - } - ], - ) - - def test_get_table_oid(self, metadata, inspect_fixture): - - inspector, conn = inspect_fixture + res = [ + { + "visible": True, + "labels": ["cat", "dog", "rat"], + "name": "pet", + "schema": "public", + } + ] + eq_(inspector.get_enums(), res) + is_true(inspector.has_type("pet", "*")) + is_true(inspector.has_type("pet")) + is_false(inspector.has_type("pet", "test_schema")) + + enum_type.drop(conn) + conn.commit() + eq_(inspector.get_enums(), res) + is_true(inspector.has_type("pet")) + inspector.clear_cache() + eq_(inspector.get_enums(), []) + is_false(inspector.has_type("pet")) + + def test_get_table_oid(self, metadata, connection): + Table("t1", metadata, Column("col", Integer)) + Table("t1", metadata, Column("col", Integer), schema="test_schema") + metadata.create_all(connection) + insp = inspect(connection) + oid = insp.get_table_oid("t1") + oid_schema = insp.get_table_oid("t1", schema="test_schema") + is_true(isinstance(oid, int)) + is_true(isinstance(oid_schema, int)) + is_true(oid != oid_schema) - with conn.begin(): - Table("some_table", metadata, Column("q", Integer)).create(conn) + with expect_raises(exc.NoSuchTableError): + insp.get_table_oid("does_not_exist") - assert inspector.get_table_oid("some_table") is not None + metadata.tables["t1"].drop(connection) + eq_(insp.get_table_oid("t1"), oid) + insp.clear_cache() + with expect_raises(exc.NoSuchTableError): + insp.get_table_oid("t1") def test_inspect_enums_case_sensitive(self, metadata, connection): sa.event.listen( @@ -1707,77 +1808,146 @@ class ReflectionTest( ) def test_reflect_check_warning(self): - rows = [("some name", "NOTCHECK foobar")] + rows = [("foo", "some name", "NOTCHECK foobar")] conn = mock.Mock( execute=lambda *arg, **kw: mock.MagicMock( fetchall=lambda: rows, __iter__=lambda self: iter(rows) ) ) - with mock.patch.object( - testing.db.dialect, "get_table_oid", lambda *arg, **kw: 1 + with testing.expect_warnings( + "Could not parse CHECK constraint text: 'NOTCHECK foobar'" ): - with testing.expect_warnings( - "Could not parse CHECK constraint text: 'NOTCHECK foobar'" - ): - testing.db.dialect.get_check_constraints(conn, "foo") + testing.db.dialect.get_check_constraints(conn, "foo") def test_reflect_extra_newlines(self): rows = [ - ("some name", "CHECK (\n(a \nIS\n NOT\n\n NULL\n)\n)"), - ("some other name", "CHECK ((b\nIS\nNOT\nNULL))"), - ("some CRLF name", "CHECK ((c\r\n\r\nIS\r\nNOT\r\nNULL))"), - ("some name", "CHECK (c != 'hi\nim a name\n')"), + ("foo", "some name", "CHECK (\n(a \nIS\n NOT\n\n NULL\n)\n)"), + ("foo", "some other name", "CHECK ((b\nIS\nNOT\nNULL))"), + ("foo", "some CRLF name", "CHECK ((c\r\n\r\nIS\r\nNOT\r\nNULL))"), + ("foo", "some name", "CHECK (c != 'hi\nim a name\n')"), ] conn = mock.Mock( execute=lambda *arg, **kw: mock.MagicMock( fetchall=lambda: rows, __iter__=lambda self: iter(rows) ) ) - with mock.patch.object( - testing.db.dialect, "get_table_oid", lambda *arg, **kw: 1 - ): - check_constraints = testing.db.dialect.get_check_constraints( - conn, "foo" - ) - eq_( - check_constraints, - [ - { - "name": "some name", - "sqltext": "a \nIS\n NOT\n\n NULL\n", - }, - {"name": "some other name", "sqltext": "b\nIS\nNOT\nNULL"}, - { - "name": "some CRLF name", - "sqltext": "c\r\n\r\nIS\r\nNOT\r\nNULL", - }, - {"name": "some name", "sqltext": "c != 'hi\nim a name\n'"}, - ], - ) + check_constraints = testing.db.dialect.get_check_constraints( + conn, "foo" + ) + eq_( + check_constraints, + [ + { + "name": "some name", + "sqltext": "a \nIS\n NOT\n\n NULL\n", + }, + {"name": "some other name", "sqltext": "b\nIS\nNOT\nNULL"}, + { + "name": "some CRLF name", + "sqltext": "c\r\n\r\nIS\r\nNOT\r\nNULL", + }, + {"name": "some name", "sqltext": "c != 'hi\nim a name\n'"}, + ], + ) def test_reflect_with_not_valid_check_constraint(self): - rows = [("some name", "CHECK ((a IS NOT NULL)) NOT VALID")] + rows = [("foo", "some name", "CHECK ((a IS NOT NULL)) NOT VALID")] conn = mock.Mock( execute=lambda *arg, **kw: mock.MagicMock( fetchall=lambda: rows, __iter__=lambda self: iter(rows) ) ) - with mock.patch.object( - testing.db.dialect, "get_table_oid", lambda *arg, **kw: 1 - ): - check_constraints = testing.db.dialect.get_check_constraints( - conn, "foo" + check_constraints = testing.db.dialect.get_check_constraints( + conn, "foo" + ) + eq_( + check_constraints, + [ + { + "name": "some name", + "sqltext": "a IS NOT NULL", + "dialect_options": {"not_valid": True}, + } + ], + ) + + def _apply_stm(self, connection, use_map): + if use_map: + return connection.execution_options( + schema_translate_map={ + None: "foo", + testing.config.test_schema: "bar", + } ) - eq_( - check_constraints, - [ - { - "name": "some name", - "sqltext": "a IS NOT NULL", - "dialect_options": {"not_valid": True}, - } - ], + else: + return connection + + @testing.combinations(True, False, argnames="use_map") + @testing.combinations(True, False, argnames="schema") + def test_schema_translate_map(self, metadata, connection, use_map, schema): + schema = testing.config.test_schema if schema else None + Table( + "foo", + metadata, + Column("id", Integer, primary_key=True), + Column("a", Integer, index=True), + Column( + "b", + ForeignKey(f"{schema}.foo.id" if schema else "foo.id"), + unique=True, + ), + CheckConstraint("a>10", name="foo_check"), + comment="comm", + schema=schema, + ) + metadata.create_all(connection) + if use_map: + connection = connection.execution_options( + schema_translate_map={ + None: "foo", + testing.config.test_schema: "bar", + } ) + insp = inspect(connection) + eq_( + [c["name"] for c in insp.get_columns("foo", schema=schema)], + ["id", "a", "b"], + ) + eq_( + [ + i["column_names"] + for i in insp.get_indexes("foo", schema=schema) + ], + [["b"], ["a"]], + ) + eq_( + insp.get_pk_constraint("foo", schema=schema)[ + "constrained_columns" + ], + ["id"], + ) + eq_(insp.get_table_comment("foo", schema=schema), {"text": "comm"}) + eq_( + [ + f["constrained_columns"] + for f in insp.get_foreign_keys("foo", schema=schema) + ], + [["b"]], + ) + eq_( + [ + c["name"] + for c in insp.get_check_constraints("foo", schema=schema) + ], + ["foo_check"], + ) + eq_( + [ + u["column_names"] + for u in insp.get_unique_constraints("foo", schema=schema) + ], + [["b"]], + ) class CustomTypeReflectionTest(fixtures.TestBase): @@ -1804,9 +1974,23 @@ class CustomTypeReflectionTest(fixtures.TestBase): ("my_custom_type(ARG1)", ("ARG1", None)), ("my_custom_type(ARG1, ARG2)", ("ARG1", "ARG2")), ]: - column_info = dialect._get_column_info( - "colname", sch, None, False, {}, {}, "public", None, "", None + row_dict = { + "name": "colname", + "table_name": "tblname", + "format_type": sch, + "default": None, + "not_null": False, + "comment": None, + "generated": "", + "identity_options": None, + } + column_info = dialect._get_columns_info( + [row_dict], {}, {}, "public" ) + assert ("public", "tblname") in column_info + column_info = column_info[("public", "tblname")] + assert len(column_info) == 1 + column_info = column_info[0] assert isinstance(column_info["type"], self.CustomType) eq_(column_info["type"].arg1, args[0]) eq_(column_info["type"].arg2, args[1]) @@ -1951,3 +2135,64 @@ class IdentityReflectionTest(fixtures.TablesTest): exp = default.copy() exp.update(maxvalue=2**15 - 1) eq_(col["identity"], exp) + + +class TestReflectDifficultColTypes(fixtures.TablesTest): + __only_on__ = "postgresql" + __backend__ = True + + def define_tables(metadata): + Table( + "sample_table", + metadata, + Column("c1", Integer, primary_key=True), + Column("c2", Integer, unique=True), + Column("c3", Integer), + Index("sample_table_index", "c2", "c3"), + ) + + def check_int_list(self, row, key): + value = row[key] + is_true(isinstance(value, list)) + is_true(len(value) > 0) + is_true(all(isinstance(v, int) for v in value)) + + def test_pg_index(self, connection): + insp = inspect(connection) + + pgc_oid = insp.get_table_oid("sample_table") + cols = [ + col + for col in pg_catalog.pg_index.c + if testing.db.dialect.server_version_info + >= col.info.get("server_version", (0,)) + ] + + stmt = sa.select(*cols).filter_by(indrelid=pgc_oid) + rows = connection.execute(stmt).mappings().all() + is_true(len(rows) > 0) + cols = [ + col + for col in ["indkey", "indoption", "indclass", "indcollation"] + if testing.db.dialect.server_version_info + >= pg_catalog.pg_index.c[col].info.get("server_version", (0,)) + ] + for row in rows: + for col in cols: + self.check_int_list(row, col) + + def test_pg_constraint(self, connection): + insp = inspect(connection) + + pgc_oid = insp.get_table_oid("sample_table") + cols = [ + col + for col in pg_catalog.pg_constraint.c + if testing.db.dialect.server_version_info + >= col.info.get("server_version", (0,)) + ] + stmt = sa.select(*cols).filter_by(conrelid=pgc_oid) + rows = connection.execute(stmt).mappings().all() + is_true(len(rows) > 0) + for row in rows: + self.check_int_list(row, "conkey") diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index 266263d5fb..8b6532ce5d 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -451,11 +451,16 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): asserter.assert_( # check for table RegexSQL( - "select relname from pg_class c join pg_namespace.*", + "SELECT pg_catalog.pg_class.relname FROM pg_catalog." + "pg_class JOIN pg_catalog.pg_namespace.*", dialect="postgresql", ), # check for enum, just once - RegexSQL(r".*SELECT EXISTS ", dialect="postgresql"), + RegexSQL( + r"SELECT pg_catalog.pg_type.typname .* WHERE " + "pg_catalog.pg_type.typname = ", + dialect="postgresql", + ), RegexSQL("CREATE TYPE myenum AS ENUM .*", dialect="postgresql"), RegexSQL(r"CREATE TABLE t .*", dialect="postgresql"), ) @@ -465,11 +470,16 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): asserter.assert_( RegexSQL( - "select relname from pg_class c join pg_namespace.*", + "SELECT pg_catalog.pg_class.relname FROM pg_catalog." + "pg_class JOIN pg_catalog.pg_namespace.*", dialect="postgresql", ), RegexSQL("DROP TABLE t", dialect="postgresql"), - RegexSQL(r".*SELECT EXISTS ", dialect="postgresql"), + RegexSQL( + r"SELECT pg_catalog.pg_type.typname .* WHERE " + "pg_catalog.pg_type.typname = ", + dialect="postgresql", + ), RegexSQL("DROP TYPE myenum", dialect="postgresql"), ) @@ -691,23 +701,6 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): connection, "fourfivesixtype" ) - def test_no_support(self, testing_engine): - def server_version_info(self): - return (8, 2) - - e = testing_engine() - dialect = e.dialect - dialect._get_server_version_info = server_version_info - - assert dialect.supports_native_enum - e.connect() - assert not dialect.supports_native_enum - - # initialize is called again on new pool - e.dispose() - e.connect() - assert not dialect.supports_native_enum - def test_reflection(self, metadata, connection): etype = Enum( "four", "five", "six", name="fourfivesixtype", metadata=metadata diff --git a/test/dialect/test_sqlite.py b/test/dialect/test_sqlite.py index fb43319982..ed9d67612d 100644 --- a/test/dialect/test_sqlite.py +++ b/test/dialect/test_sqlite.py @@ -50,6 +50,7 @@ from sqlalchemy.testing import combinations from sqlalchemy.testing import config from sqlalchemy.testing import engines from sqlalchemy.testing import eq_ +from sqlalchemy.testing import expect_raises from sqlalchemy.testing import expect_warnings from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ @@ -856,27 +857,15 @@ class AttachedDBTest(fixtures.TestBase): ["foo", "bar"], ) - eq_( - [ - d["name"] - for d in insp.get_columns("nonexistent", schema="test_schema") - ], - [], - ) - eq_( - [ - d["name"] - for d in insp.get_columns("another_created", schema=None) - ], - [], - ) - eq_( - [ - d["name"] - for d in insp.get_columns("local_only", schema="test_schema") - ], - [], - ) + with expect_raises(exc.NoSuchTableError): + insp.get_columns("nonexistent", schema="test_schema") + + with expect_raises(exc.NoSuchTableError): + insp.get_columns("another_created", schema=None) + + with expect_raises(exc.NoSuchTableError): + insp.get_columns("local_only", schema="test_schema") + eq_([d["name"] for d in insp.get_columns("local_only")], ["q", "p"]) def test_table_names_present(self): diff --git a/test/engine/test_reflection.py b/test/engine/test_reflection.py index 76099e8637..2f6c06aceb 100644 --- a/test/engine/test_reflection.py +++ b/test/engine/test_reflection.py @@ -2,6 +2,7 @@ import unicodedata import sqlalchemy as sa from sqlalchemy import Computed +from sqlalchemy import Connection from sqlalchemy import DefaultClause from sqlalchemy import event from sqlalchemy import FetchedValue @@ -17,6 +18,7 @@ from sqlalchemy import sql from sqlalchemy import String from sqlalchemy import testing from sqlalchemy import UniqueConstraint +from sqlalchemy.engine import Inspector from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import AssertsCompiledSQL @@ -1254,12 +1256,13 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): m2 = MetaData() t2 = Table("x", m2, autoload_with=connection) - ck = [ + cks = [ const for const in t2.constraints if isinstance(const, sa.CheckConstraint) - ][0] - + ] + eq_(len(cks), 1) + ck = cks[0] eq_regex(ck.sqltext.text, r"[\(`]*q[\)`]* > 10") eq_(ck.name, "ck1") @@ -1268,11 +1271,17 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): sa.Index("x_ix", t.c.a, t.c.b) metadata.create_all(connection) - def mock_get_columns(self, connection, table_name, **kw): - return [{"name": "b", "type": Integer, "primary_key": False}] + gri = Inspector._get_reflection_info + + def mock_gri(self, *a, **kw): + res = gri(self, *a, **kw) + res.columns[(None, "x")] = [ + col for col in res.columns[(None, "x")] if col["name"] == "b" + ] + return res with testing.mock.patch.object( - connection.dialect, "get_columns", mock_get_columns + Inspector, "_get_reflection_info", mock_gri ): m = MetaData() with testing.expect_warnings( @@ -1409,38 +1418,49 @@ class CreateDropTest(fixtures.TablesTest): eq_(ua, ["users", "email_addresses"]) eq_(oi, ["orders", "items"]) - def test_checkfirst(self, connection): + def test_checkfirst(self, connection: Connection) -> None: insp = inspect(connection) + users = self.tables.users is_false(insp.has_table("users")) users.create(connection) + insp.clear_cache() is_true(insp.has_table("users")) users.create(connection, checkfirst=True) users.drop(connection) users.drop(connection, checkfirst=True) + insp.clear_cache() is_false(insp.has_table("users")) users.create(connection, checkfirst=True) users.drop(connection) - def test_createdrop(self, connection): + def test_createdrop(self, connection: Connection) -> None: insp = inspect(connection) metadata = self.tables_test_metadata + assert metadata is not None metadata.create_all(connection) is_true(insp.has_table("items")) is_true(insp.has_table("email_addresses")) metadata.create_all(connection) + insp.clear_cache() is_true(insp.has_table("items")) metadata.drop_all(connection) + insp.clear_cache() is_false(insp.has_table("items")) is_false(insp.has_table("email_addresses")) metadata.drop_all(connection) + insp.clear_cache() is_false(insp.has_table("items")) - def test_tablenames(self, connection): + def test_has_table_and_table_names(self, connection): + """establish that has_table and get_table_names are consistent w/ + each other with regard to caching + + """ metadata = self.tables_test_metadata metadata.create_all(bind=connection) insp = inspect(connection) @@ -1448,6 +1468,19 @@ class CreateDropTest(fixtures.TablesTest): # ensure all tables we created are in the list. is_true(set(insp.get_table_names()).issuperset(metadata.tables)) + assert insp.has_table("items") + assert "items" in insp.get_table_names() + + self.tables.items.drop(connection) + + # cached + assert insp.has_table("items") + assert "items" in insp.get_table_names() + + insp = inspect(connection) + assert not insp.has_table("items") + assert "items" not in insp.get_table_names() + class SchemaManipulationTest(fixtures.TestBase): __backend__ = True @@ -1602,13 +1635,7 @@ class SchemaTest(fixtures.TestBase): __backend__ = True @testing.requires.schemas - @testing.requires.cross_schema_fk_reflection def test_has_schema(self): - if not hasattr(testing.db.dialect, "has_schema"): - testing.config.skip_test( - "dialect %s doesn't have a has_schema method" - % testing.db.dialect.name - ) with testing.db.connect() as conn: eq_( testing.db.dialect.has_schema( diff --git a/test/perf/many_table_reflection.py b/test/perf/many_table_reflection.py new file mode 100644 index 0000000000..8749df5c2c --- /dev/null +++ b/test/perf/many_table_reflection.py @@ -0,0 +1,617 @@ +from argparse import ArgumentDefaultsHelpFormatter +from argparse import ArgumentParser +from collections import defaultdict +from contextlib import contextmanager +from functools import wraps +from pprint import pprint +import random +import time + +import sqlalchemy as sa +from sqlalchemy.engine import Inspector + +types = (sa.Integer, sa.BigInteger, sa.String(200), sa.DateTime) +USE_CONNECTION = False + + +def generate_table(meta: sa.MetaData, min_cols, max_cols, dialect_name): + col_number = random.randint(min_cols, max_cols) + table_num = len(meta.tables) + add_identity = random.random() > 0.90 + identity = sa.Identity( + always=random.randint(0, 1), + start=random.randint(1, 100), + increment=random.randint(1, 7), + ) + is_mssql = dialect_name == "mssql" + cols = [] + for i in range(col_number - (0 if is_mssql else add_identity)): + args = [] + if random.random() < 0.95 or table_num == 0: + if is_mssql and add_identity and i == 0: + args.append(sa.Integer) + args.append(identity) + else: + args.append(random.choice(types)) + else: + args.append( + sa.ForeignKey(f"table_{table_num-1}.table_{table_num-1}_col_1") + ) + cols.append( + sa.Column( + f"table_{table_num}_col_{i+1}", + *args, + primary_key=i == 0, + comment=f"primary key of table_{table_num}" + if i == 0 + else None, + index=random.random() > 0.9 and i > 0, + unique=random.random() > 0.95 and i > 0, + ) + ) + if add_identity and not is_mssql: + cols.append( + sa.Column( + f"table_{table_num}_col_{col_number}", + sa.Integer, + identity, + ) + ) + args = () + if table_num % 3 == 0: + # mysql can't do check constraint on PK col + args = (sa.CheckConstraint(cols[1].is_not(None)),) + return sa.Table( + f"table_{table_num}", + meta, + *cols, + *args, + comment=f"comment for table_{table_num}" if table_num % 2 else None, + ) + + +def generate_meta(schema_name, table_number, min_cols, max_cols, dialect_name): + meta = sa.MetaData(schema=schema_name) + log = defaultdict(int) + for _ in range(table_number): + t = generate_table(meta, min_cols, max_cols, dialect_name) + log["tables"] += 1 + log["columns"] += len(t.columns) + log["index"] += len(t.indexes) + log["check_con"] += len( + [c for c in t.constraints if isinstance(c, sa.CheckConstraint)] + ) + log["foreign_keys_con"] += len( + [ + c + for c in t.constraints + if isinstance(c, sa.ForeignKeyConstraint) + ] + ) + log["unique_con"] += len( + [c for c in t.constraints if isinstance(c, sa.UniqueConstraint)] + ) + log["identity"] += len([c for c in t.columns if c.identity]) + + print("Meta info", dict(log)) + return meta + + +def log(fn): + @wraps(fn) + def wrap(*a, **kw): + print("Running ", fn.__name__, "...", flush=True, end="") + try: + r = fn(*a, **kw) + except NotImplementedError: + print(" [not implemented]", flush=True) + r = None + else: + print("... done", flush=True) + return r + + return wrap + + +tests = {} + + +def define_test(fn): + name: str = fn.__name__ + if name.startswith("reflect_"): + name = name[8:] + tests[name] = wfn = log(fn) + return wfn + + +@log +def create_tables(engine, meta): + tables = list(meta.tables.values()) + for i in range(0, len(tables), 500): + meta.create_all(engine, tables[i : i + 500]) + + +@log +def drop_tables(engine, meta, schema_name, table_names: list): + tables = list(meta.tables.values())[::-1] + for i in range(0, len(tables), 500): + meta.drop_all(engine, tables[i : i + 500]) + + remaining = sa.inspect(engine).get_table_names(schema=schema_name) + suffix = "" + if engine.dialect.name.startswith("postgres"): + suffix = "CASCADE" + + remaining = sorted( + remaining, key=lambda tn: int(tn.partition("_")[2]), reverse=True + ) + with engine.connect() as conn: + for i, tn in enumerate(remaining): + if engine.dialect.requires_name_normalize: + name = engine.dialect.denormalize_name(tn) + else: + name = tn + if schema_name: + conn.execute( + sa.schema.DDL( + f'DROP TABLE {schema_name}."{name}" {suffix}' + ) + ) + else: + conn.execute(sa.schema.DDL(f'DROP TABLE "{name}" {suffix}')) + if i % 500 == 0: + conn.commit() + conn.commit() + + +@log +def reflect_tables(engine, schema_name): + ref_meta = sa.MetaData(schema=schema_name) + ref_meta.reflect(engine) + + +def verify_dict(multi, single, str_compare=False): + if single is None or multi is None: + return + if single != multi: + keys = set(single) | set(multi) + diff = [] + for key in sorted(keys): + se, me = single.get(key), multi.get(key) + if str(se) != str(me) if str_compare else se != me: + diff.append((key, single.get(key), multi.get(key))) + if diff: + print("\nfound different result:") + pprint(diff) + + +def _single_test( + singe_fn_name, + multi_fn_name, + engine, + schema_name, + table_names, + timing, + mode, +): + single = None + if "single" in mode: + singe_fn = getattr(Inspector, singe_fn_name) + + def go(bind): + insp = sa.inspect(bind) + single = {} + with timing(singe_fn.__name__): + for t in table_names: + single[(schema_name, t)] = singe_fn( + insp, t, schema=schema_name + ) + return single + + if USE_CONNECTION: + with engine.connect() as c: + single = go(c) + else: + single = go(engine) + + multi = None + if "multi" in mode: + insp = sa.inspect(engine) + multi_fn = getattr(Inspector, multi_fn_name) + with timing(multi_fn.__name__): + multi = multi_fn(insp, schema=schema_name) + return (multi, single) + + +@define_test +def reflect_columns( + engine, schema_name, table_names, timing, mode, ignore_diff +): + multi, single = _single_test( + "get_columns", + "get_multi_columns", + engine, + schema_name, + table_names, + timing, + mode, + ) + if not ignore_diff: + verify_dict(multi, single, str_compare=True) + + +@define_test +def reflect_table_options( + engine, schema_name, table_names, timing, mode, ignore_diff +): + multi, single = _single_test( + "get_table_options", + "get_multi_table_options", + engine, + schema_name, + table_names, + timing, + mode, + ) + if not ignore_diff: + verify_dict(multi, single) + + +@define_test +def reflect_pk(engine, schema_name, table_names, timing, mode, ignore_diff): + multi, single = _single_test( + "get_pk_constraint", + "get_multi_pk_constraint", + engine, + schema_name, + table_names, + timing, + mode, + ) + if not ignore_diff: + verify_dict(multi, single) + + +@define_test +def reflect_comment( + engine, schema_name, table_names, timing, mode, ignore_diff +): + multi, single = _single_test( + "get_table_comment", + "get_multi_table_comment", + engine, + schema_name, + table_names, + timing, + mode, + ) + if not ignore_diff: + verify_dict(multi, single) + + +@define_test +def reflect_whole_tables( + engine, schema_name, table_names, timing, mode, ignore_diff +): + single = None + meta = sa.MetaData(schema=schema_name) + + if "single" in mode: + + def go(bind): + single = {} + with timing("Table_autoload_with"): + for name in table_names: + single[(None, name)] = sa.Table( + name, meta, autoload_with=bind + ) + return single + + if USE_CONNECTION: + with engine.connect() as c: + single = go(c) + else: + single = go(engine) + + multi_meta = sa.MetaData(schema=schema_name) + if "multi" in mode: + with timing("MetaData_reflect"): + multi_meta.reflect(engine, only=table_names) + return (multi_meta, single) + + +@define_test +def reflect_check_constraints( + engine, schema_name, table_names, timing, mode, ignore_diff +): + multi, single = _single_test( + "get_check_constraints", + "get_multi_check_constraints", + engine, + schema_name, + table_names, + timing, + mode, + ) + if not ignore_diff: + verify_dict(multi, single) + + +@define_test +def reflect_indexes( + engine, schema_name, table_names, timing, mode, ignore_diff +): + multi, single = _single_test( + "get_indexes", + "get_multi_indexes", + engine, + schema_name, + table_names, + timing, + mode, + ) + if not ignore_diff: + verify_dict(multi, single) + + +@define_test +def reflect_foreign_keys( + engine, schema_name, table_names, timing, mode, ignore_diff +): + multi, single = _single_test( + "get_foreign_keys", + "get_multi_foreign_keys", + engine, + schema_name, + table_names, + timing, + mode, + ) + if not ignore_diff: + verify_dict(multi, single) + + +@define_test +def reflect_unique_constraints( + engine, schema_name, table_names, timing, mode, ignore_diff +): + multi, single = _single_test( + "get_unique_constraints", + "get_multi_unique_constraints", + engine, + schema_name, + table_names, + timing, + mode, + ) + if not ignore_diff: + verify_dict(multi, single) + + +def _apply_events(engine): + queries = defaultdict(list) + + now = 0 + + @sa.event.listens_for(engine, "before_cursor_execute") + def before_cursor_execute( + conn, cursor, statement, parameters, context, executemany + ): + + nonlocal now + now = time.time() + + @sa.event.listens_for(engine, "after_cursor_execute") + def after_cursor_execute( + conn, cursor, statement, parameters, context, executemany + ): + total = time.time() - now + + if context and context.compiled: + statement_str = context.compiled.string + else: + statement_str = statement + queries[statement_str].append(total) + + return queries + + +def _print_query_stats(queries): + number_of_queries = sum( + len(query_times) for query_times in queries.values() + ) + print("-" * 50) + q_list = list(queries.items()) + q_list.sort(key=lambda rec: -sum(rec[1])) + total = sum([sum(t) for _, t in q_list]) + print(f"total number of queries: {number_of_queries}. Total time {total}") + print("-" * 50) + + for stmt, times in q_list: + total_t = sum(times) + max_t = max(times) + min_t = min(times) + avg_t = total_t / len(times) + times.sort() + median_t = times[len(times) // 2] + + print( + f"Query times: {total_t=}, {max_t=}, {min_t=}, {avg_t=}, " + f"{median_t=} Number of calls: {len(times)}" + ) + print(stmt.strip(), "\n") + + +def main(db, schema_name, table_number, min_cols, max_cols, args): + timing = timer() + if args.pool_class: + engine = sa.create_engine( + db, echo=args.echo, poolclass=getattr(sa.pool, args.pool_class) + ) + else: + engine = sa.create_engine(db, echo=args.echo) + + if engine.name == "oracle": + # clear out oracle caches so that we get the real-world time the + # queries would normally take for scripts that aren't run repeatedly + with engine.connect() as conn: + # https://stackoverflow.com/questions/2147456/how-to-clear-all-cached-items-in-oracle + conn.exec_driver_sql("alter system flush buffer_cache") + conn.exec_driver_sql("alter system flush shared_pool") + if not args.no_create: + print( + f"Generating {table_number} using engine {engine} in " + f"schema {schema_name or 'default'}", + ) + meta = sa.MetaData() + table_names = [] + stats = {} + try: + if not args.no_create: + with timing("populate-meta"): + meta = generate_meta( + schema_name, table_number, min_cols, max_cols, engine.name + ) + with timing("create-tables"): + create_tables(engine, meta) + + with timing("get_table_names"): + with engine.connect() as conn: + table_names = engine.dialect.get_table_names( + conn, schema=schema_name + ) + print( + f"Reflected table number {len(table_names)} in " + f"schema {schema_name or 'default'}" + ) + mode = {"single", "multi"} + if args.multi_only: + mode.discard("single") + if args.single_only: + mode.discard("multi") + + if args.sqlstats: + print("starting stats for subsequent tests") + stats = _apply_events(engine) + for test_name, test_fn in tests.items(): + if test_name in args.test or "all" in args.test: + test_fn( + engine, + schema_name, + table_names, + timing, + mode, + args.ignore_diff, + ) + + if args.reflect: + with timing("reflect-tables"): + reflect_tables(engine, schema_name) + finally: + # copy stats to new dict + if args.sqlstats: + stats = dict(stats) + try: + if not args.no_drop: + with timing("drop-tables"): + drop_tables(engine, meta, schema_name, table_names) + finally: + pprint(timing.timing, sort_dicts=False) + if args.sqlstats: + _print_query_stats(stats) + + +def timer(): + timing = {} + + @contextmanager + def track_time(name): + s = time.time() + yield + timing[name] = time.time() - s + + track_time.timing = timing + return track_time + + +if __name__ == "__main__": + parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) + parser.add_argument( + "--db", help="Database url", default="sqlite:///many-table.db" + ) + parser.add_argument( + "--schema-name", + help="optional schema name", + type=str, + default=None, + ) + parser.add_argument( + "--table-number", + help="Number of table to generate.", + type=int, + default=250, + ) + parser.add_argument( + "--min-cols", + help="Min number of column per table.", + type=int, + default=15, + ) + parser.add_argument( + "--max-cols", + help="Max number of column per table.", + type=int, + default=250, + ) + parser.add_argument( + "--no-create", help="Do not run create tables", action="store_true" + ) + parser.add_argument( + "--no-drop", help="Do not run drop tables", action="store_true" + ) + parser.add_argument("--reflect", help="Run reflect", action="store_true") + parser.add_argument( + "--test", + help="Run these tests. 'all' runs all tests", + nargs="+", + choices=tuple(tests) + ("all", "none"), + default=["all"], + ) + parser.add_argument( + "--sqlstats", + help="count and time individual queries", + action="store_true", + ) + parser.add_argument( + "--multi-only", help="Only run multi table tests", action="store_true" + ) + parser.add_argument( + "--single-only", + help="Only run single table tests", + action="store_true", + ) + parser.add_argument( + "--echo", action="store_true", help="Enable echo on the engine" + ) + parser.add_argument( + "--ignore-diff", + action="store_true", + help="Ignores differences in the single/multi reflections", + ) + parser.add_argument( + "--single-inspect-conn", + action="store_true", + help="Uses inspect on a connection instead of on the engine when " + "using single reflections. Mainly for sqlite.", + ) + parser.add_argument("--pool-class", help="The pool class to use") + + args = parser.parse_args() + min_cols = args.min_cols + max_cols = args.max_cols + USE_CONNECTION = args.single_inspect_conn + assert min_cols <= max_cols and min_cols >= 1 + assert not (args.multi_only and args.single_only) + main( + args.db, args.schema_name, args.table_number, min_cols, max_cols, args + ) diff --git a/test/requirements.py b/test/requirements.py index 2d0876158d..6bea3ddc9e 100644 --- a/test/requirements.py +++ b/test/requirements.py @@ -75,6 +75,18 @@ class DefaultRequirements(SuiteRequirements): return skip_if(no_support("sqlite", "not supported by database")) + @property + def foreign_keys_reflect_as_index(self): + return only_on(["mysql", "mariadb"]) + + @property + def unique_index_reflect_as_unique_constraints(self): + return only_on(["mysql", "mariadb"]) + + @property + def unique_constraints_reflect_as_index(self): + return only_on(["mysql", "mariadb", "oracle", "postgresql", "mssql"]) + @property def foreign_key_constraint_name_reflection(self): return fails_if( @@ -83,6 +95,10 @@ class DefaultRequirements(SuiteRequirements): and not self._mariadb_105(config) ) + @property + def reflect_indexes_with_ascdesc(self): + return fails_if(["oracle"]) + @property def table_ddl_if_exists(self): """target platform supports IF NOT EXISTS / IF EXISTS for tables.""" @@ -507,6 +523,12 @@ class DefaultRequirements(SuiteRequirements): return exclusions.open() + @property + def schema_create_delete(self): + """target database supports schema create and dropped with + 'CREATE SCHEMA' and 'DROP SCHEMA'""" + return exclusions.skip_if(["sqlite", "oracle"]) + @property def cross_schema_fk_reflection(self): """target system must support reflection of inter-schema foreign @@ -547,11 +569,13 @@ class DefaultRequirements(SuiteRequirements): @property def check_constraint_reflection(self): - return fails_on_everything_except( - "postgresql", - "sqlite", - "oracle", - self._mysql_and_check_constraints_exist, + return only_on( + [ + "postgresql", + "sqlite", + "oracle", + self._mysql_and_check_constraints_exist, + ] ) @property @@ -562,7 +586,9 @@ class DefaultRequirements(SuiteRequirements): def temp_table_names(self): """target dialect supports listing of temporary table names""" - return only_on(["sqlite", "oracle"]) + skip_if(self._sqlite_file_db) + return only_on(["sqlite", "oracle", "postgresql"]) + skip_if( + self._sqlite_file_db + ) @property def temporary_views(self): @@ -792,8 +818,7 @@ class DefaultRequirements(SuiteRequirements): @property def views(self): """Target database must support VIEWs.""" - - return skip_if("drizzle", "no VIEW support") + return exclusions.open() @property def empty_strings_varchar(self): @@ -1336,14 +1361,28 @@ class DefaultRequirements(SuiteRequirements): ) ) + def _has_oracle_test_dblink(self, key): + def check(config): + assert config.db.dialect.name == "oracle" + name = config.file_config.get("sqla_testing", key) + if not name: + return False + with config.db.connect() as conn: + links = config.db.dialect._list_dblinks(conn) + return config.db.dialect.normalize_name(name) in links + + return only_on(["oracle"]) + only_if( + check, + f"{key} option not specified in config or dblink not found in db", + ) + @property def oracle_test_dblink(self): - return skip_if( - lambda config: not config.file_config.has_option( - "sqla_testing", "oracle_db_link" - ), - "oracle_db_link option not specified in config", - ) + return self._has_oracle_test_dblink("oracle_db_link") + + @property + def oracle_test_dblink2(self): + return self._has_oracle_test_dblink("oracle_db_link2") @property def postgresql_test_dblink(self): @@ -1781,6 +1820,19 @@ class DefaultRequirements(SuiteRequirements): return only_on(["mssql"]) + only_if(check) + @property + def reflect_table_options(self): + return only_on(["mysql", "mariadb", "oracle"]) + + @property + def materialized_views(self): + """Target database must support MATERIALIZED VIEWs.""" + return only_on(["postgresql", "oracle"]) + + @property + def materialized_views_reflect_pk(self): + return only_on(["oracle"]) + @property def uuid_data_type(self): """Return databases that support the UUID datatype.""" -- 2.47.2