>> sqlplus system/tiger@//localhost/XEPDB1 <<EOF
CREATE USER test_schema IDENTIFIED BY tiger;
GRANT DBA TO SCOTT;
+ GRANT CREATE TABLE TO scott;
+ GRANT CREATE TABLE TO test_schema;
GRANT UNLIMITED TABLESPACE TO scott;
GRANT UNLIMITED TABLESPACE TO test_schema;
+ GRANT CREATE SESSION TO test_schema;
+ CREATE PUBLIC DATABASE LINK test_link CONNECT TO scott IDENTIFIED BY tiger USING 'XEPDB1';
+ CREATE PUBLIC DATABASE LINK test_link2 CONNECT TO test_schema IDENTIFIED BY tiger USING 'XEPDB1';
EOF
# To stop the container. It will also remove it.
--- /dev/null
+.. change::
+ :tags: performance, schema
+ :tickets: 4379
+
+ Rearchitected the schema reflection API to allow some dialects to make use
+ of high performing batch queries to reflect the schemas of many tables at
+ once using much fewer queries. The new performance features are targeted
+ first at the PostgreSQL and Oracle backends, and may be applied to any
+ dialect that makes use of SELECT queries against system catalog tables to
+ reflect tables (currently this omits the MySQL and SQLite dialects which
+ instead make use of parsing the "CREATE TABLE" statement, however these
+ dialects do not have a pre-existing performance issue with reflection. MS
+ SQL Server is still a TODO).
+
+ The new API is backwards compatible with the previous system, and should
+ require no changes to third party dialects to retain compatibility;
+ third party dialects can also opt into the new system by implementing
+ batched queries for schema reflection.
+
+ Along with this change is an updated reflection API that is fully
+ :pep:`484` typed, features many new methods and some changes.
+
+.. change::
+ :tags: bug, schema
+
+ For SQLAlchemy-included dialects for SQLite, PostgreSQL, MySQL/MariaDB,
+ Oracle, and SQL Server, the :meth:`.Inspector.has_table`,
+ :meth:`.Inspector.has_sequence`, :meth:`.Inspector.has_index`,
+ :meth:`.Inspector.get_table_names` and
+ :meth:`.Inspector.get_sequence_names` now all behave consistently in terms
+ of caching: they all fully cache their result after being called the first
+ time for a particular :class:`.Inspector` object. Programs that create or
+ drop tables/sequences while calling upon the same :class:`.Inspector`
+ object will not receive updated status after the state of the database has
+ changed. A call to :meth:`.Inspector.clear_cache` or a new
+ :class:`.Inspector` should be used when DDL changes are to be executed.
+ Previously, the :meth:`.Inspector.has_table`,
+ :meth:`.Inspector.has_sequence` methods did not implement caching nor did
+ the :class:`.Inspector` support caching for these methods, while the
+ :meth:`.Inspector.get_table_names` and
+ :meth:`.Inspector.get_sequence_names` methods were, leading to inconsistent
+ results between the two types of method.
+
+ Behavior for third party dialects is dependent on whether or not they
+ implement the "reflection cache" decorator for the dialect-level
+ implementation of these methods.
+
+.. change::
+ :tags: change, schema
+ :tickets: 4379
+
+ Improvements to the :class:`.Inspector` object:
+
+ * added a method
+ :meth:`.Inspector.has_schema` that returns if a schema
+ is present in the target database
+ * added a method :meth:`.Inspector.has_index` that returns if a table has
+ a particular index.
+ * Inspection methods such as :meth:`.Inspector.get_columns` that work
+ on a single table at a time should now all consistently
+ raise :class:`_exc.NoSuchTableError` if a
+ table or view is not found; this change is specific to individual
+ dialects, so may not be the case for existing third-party dialects.
+ * Separated the handling of "views" and "materialized views", as in
+ real world use cases, these two constructs make use of different DDL
+ for CREATE and DROP; this includes that there are now separate
+ :meth:`.Inspector.get_view_names` and
+ :meth:`.Inspector.get_materialized_view_names` methods.
+
--- /dev/null
+.. change::
+ :tags: change, oracle
+ :tickets:`4379`
+
+ Materialized views on oracle are now reflected as views.
+ On previous versions of SQLAlchemy the views were returned among
+ the table names, not among the view names. As a side effect of
+ this change they are not reflected by default by
+ :meth:`_sql.MetaData.reflect`, unless ``views=True`` is set.
+ To get a list of materialized views, use the new
+ inspection method :meth:`.Inspector.get_materialized_view_names`.
--- /dev/null
+.. change::
+ :tags: change, postgresql
+
+ SQLAlchemy now requires PostgreSQL version 9 or greater.
+ Older versions may still work in some limited use cases.
{opensql}BEGIN (implicit)
PRAGMA main.table_...info("some_table")
[raw sql] ()
- SELECT sql FROM (SELECT * FROM sqlite_master UNION ALL SELECT * FROM sqlite_temp_master) WHERE name = ? AND type = 'table'
+ SELECT sql FROM (SELECT * FROM sqlite_master UNION ALL SELECT * FROM sqlite_temp_master) WHERE name = ? AND type in ('table', 'view')
[raw sql] ('some_table',)
PRAGMA main.foreign_key_list("some_table")
...
from ...engine import cursor as _cursor
from ...engine import default
from ...engine import reflection
+from ...engine.reflection import ReflectionDefaults
from ...sql import coercions
from ...sql import compiler
from ...sql import elements
return self.schema_name
@_db_plus_owner
- def has_table(self, connection, tablename, dbname, owner, schema):
+ def has_table(self, connection, tablename, dbname, owner, schema, **kw):
self._ensure_has_table_connection(connection)
- if tablename.startswith("#"): # temporary table
- # mssql does not support temporary views
- # SQL Error [4103] [S0001]: "#v": Temporary views are not allowed
- tables = ischema.mssql_temp_table_columns
- s = sql.select(tables.c.table_name).where(
- tables.c.table_name.like(
- self._temp_table_name_like_pattern(tablename)
- )
- )
-
- # #7168: fetch all (not just first match) in case some other #temp
- # table with the same name happens to appear first
- table_names = connection.execute(s).scalars().fetchall()
- # #6910: verify it's not a temp table from another session
- for table_name in table_names:
- if bool(
- connection.scalar(
- text("SELECT object_id(:table_name)"),
- {"table_name": "tempdb.dbo.[{}]".format(table_name)},
- )
- ):
- return True
- else:
- return False
- else:
- tables = ischema.tables
-
- s = sql.select(tables.c.table_name).where(
- sql.and_(
- sql.or_(
- tables.c.table_type == "BASE TABLE",
- tables.c.table_type == "VIEW",
- ),
- tables.c.table_name == tablename,
- )
- )
-
- if owner:
- s = s.where(tables.c.table_schema == owner)
-
- c = connection.execute(s)
-
- return c.first() is not None
+ return self._internal_has_table(connection, tablename, owner, **kw)
+ @reflection.cache
@_db_plus_owner
- def has_sequence(self, connection, sequencename, dbname, owner, schema):
+ def has_sequence(
+ self, connection, sequencename, dbname, owner, schema, **kw
+ ):
sequences = ischema.sequences
s = sql.select(sequences.c.sequence_name).where(
view_names = [r[0] for r in connection.execute(s)]
return view_names
+ @reflection.cache
+ def _internal_has_table(self, connection, tablename, owner, **kw):
+ if tablename.startswith("#"): # temporary table
+ # mssql does not support temporary views
+ # SQL Error [4103] [S0001]: "#v": Temporary views are not allowed
+ tables = ischema.mssql_temp_table_columns
+
+ s = sql.select(tables.c.table_name).where(
+ tables.c.table_name.like(
+ self._temp_table_name_like_pattern(tablename)
+ )
+ )
+
+ # #7168: fetch all (not just first match) in case some other #temp
+ # table with the same name happens to appear first
+ table_names = connection.scalars(s).all()
+ # #6910: verify it's not a temp table from another session
+ for table_name in table_names:
+ if bool(
+ connection.scalar(
+ text("SELECT object_id(:table_name)"),
+ {"table_name": "tempdb.dbo.[{}]".format(table_name)},
+ )
+ ):
+ return True
+ else:
+ return False
+ else:
+ tables = ischema.tables
+
+ s = sql.select(tables.c.table_name).where(
+ sql.and_(
+ sql.or_(
+ tables.c.table_type == "BASE TABLE",
+ tables.c.table_type == "VIEW",
+ ),
+ tables.c.table_name == tablename,
+ )
+ )
+
+ if owner:
+ s = s.where(tables.c.table_schema == owner)
+
+ c = connection.execute(s)
+
+ return c.first() is not None
+
+ def _default_or_error(self, connection, tablename, owner, method, **kw):
+ # TODO: try to avoid having to run a separate query here
+ if self._internal_has_table(connection, tablename, owner, **kw):
+ return method()
+ else:
+ raise exc.NoSuchTableError(f"{owner}.{tablename}")
+
@reflection.cache
@_db_plus_owner
def get_indexes(self, connection, tablename, dbname, owner, schema, **kw):
rp = connection.execution_options(future_result=True).execute(
sql.text(
"select ind.index_id, ind.is_unique, ind.name, "
- "%s "
+ f"{filter_definition} "
"from sys.indexes as ind join sys.tables as tab on "
"ind.object_id=tab.object_id "
"join sys.schemas as sch on sch.schema_id=tab.schema_id "
"where tab.name = :tabname "
"and sch.name=:schname "
- "and ind.is_primary_key=0 and ind.type != 0"
- % filter_definition
+ "and ind.is_primary_key=0 and ind.type != 0 "
+ "order by ind.name "
)
.bindparams(
sql.bindparam("tabname", tablename, ischema.CoerceUnicode()),
"mssql_include"
] = index_info["include_columns"]
- return list(indexes.values())
+ if indexes:
+ return list(indexes.values())
+ else:
+ return self._default_or_error(
+ connection, tablename, owner, ReflectionDefaults.indexes, **kw
+ )
@reflection.cache
@_db_plus_owner
def get_view_definition(
self, connection, viewname, dbname, owner, schema, **kw
):
- rp = connection.execute(
+ view_def = connection.execute(
sql.text(
- "select definition from sys.sql_modules as mod, "
- "sys.views as views, "
- "sys.schemas as sch"
- " where "
- "mod.object_id=views.object_id and "
- "views.schema_id=sch.schema_id and "
- "views.name=:viewname and sch.name=:schname"
+ "select mod.definition "
+ "from sys.sql_modules as mod "
+ "join sys.views as views on mod.object_id = views.object_id "
+ "join sys.schemas as sch on views.schema_id = sch.schema_id "
+ "where views.name=:viewname and sch.name=:schname"
).bindparams(
sql.bindparam("viewname", viewname, ischema.CoerceUnicode()),
sql.bindparam("schname", owner, ischema.CoerceUnicode()),
)
- )
-
- if rp:
- view_def = rp.scalar()
+ ).scalar()
+ if view_def:
return view_def
+ else:
+ raise exc.NoSuchTableError(f"{owner}.{viewname}")
def _temp_table_name_like_pattern(self, tablename):
# LIKE uses '%' to match zero or more characters and '_' to match any
cols.append(cdict)
- return cols
+ if cols:
+ return cols
+ else:
+ return self._default_or_error(
+ connection, tablename, owner, ReflectionDefaults.columns, **kw
+ )
@reflection.cache
@_db_plus_owner
pkeys.append(row["COLUMN_NAME"])
if constraint_name is None:
constraint_name = row[C.c.constraint_name.name]
- return {"constrained_columns": pkeys, "name": constraint_name}
+ if pkeys:
+ return {"constrained_columns": pkeys, "name": constraint_name}
+ else:
+ return self._default_or_error(
+ connection,
+ tablename,
+ owner,
+ ReflectionDefaults.pk_constraint,
+ **kw,
+ )
@reflection.cache
@_db_plus_owner
fkeys = util.defaultdict(fkey_rec)
- for r in connection.execute(s).fetchall():
+ for r in connection.execute(s).all():
(
_, # constraint schema
rfknm,
local_cols.append(scol)
remote_cols.append(rcol)
- return list(fkeys.values())
+ if fkeys:
+ return list(fkeys.values())
+ else:
+ return self._default_or_error(
+ connection,
+ tablename,
+ owner,
+ ReflectionDefaults.foreign_keys,
+ **kw,
+ )
from ... import util
from ...engine import default
from ...engine import reflection
+from ...engine.reflection import ReflectionDefaults
from ...sql import coercions
from ...sql import compiler
from ...sql import elements
def _get_default_schema_name(self, connection):
return connection.exec_driver_sql("SELECT DATABASE()").scalar()
- def has_table(self, connection, table_name, schema=None):
+ @reflection.cache
+ def has_table(self, connection, table_name, schema=None, **kw):
self._ensure_has_table_connection(connection)
if schema is None:
)
return bool(rs.scalar())
- def has_sequence(self, connection, sequence_name, schema=None):
+ @reflection.cache
+ def has_sequence(self, connection, sequence_name, schema=None, **kw):
if not self.supports_sequences:
self._sequences_not_supported()
if not schema:
parsed_state = self._parsed_state_or_create(
connection, table_name, schema, **kw
)
- return parsed_state.table_options
+ if parsed_state.table_options:
+ return parsed_state.table_options
+ else:
+ return ReflectionDefaults.table_options()
@reflection.cache
def get_columns(self, connection, table_name, schema=None, **kw):
parsed_state = self._parsed_state_or_create(
connection, table_name, schema, **kw
)
- return parsed_state.columns
+ if parsed_state.columns:
+ return parsed_state.columns
+ else:
+ return ReflectionDefaults.columns()
@reflection.cache
def get_pk_constraint(self, connection, table_name, schema=None, **kw):
# There can be only one.
cols = [s[0] for s in key["columns"]]
return {"constrained_columns": cols, "name": None}
- return {"constrained_columns": [], "name": None}
+ return ReflectionDefaults.pk_constraint()
@reflection.cache
def get_foreign_keys(self, connection, table_name, schema=None, **kw):
if self._needs_correct_for_88718_96365:
self._correct_for_mysql_bugs_88718_96365(fkeys, connection)
- return fkeys
+ return fkeys if fkeys else ReflectionDefaults.foreign_keys()
def _correct_for_mysql_bugs_88718_96365(self, fkeys, connection):
# Foreign key is always in lower case (MySQL 8.0)
connection, table_name, schema, **kw
)
- return [
+ cks = [
{"name": spec["name"], "sqltext": spec["sqltext"]}
for spec in parsed_state.ck_constraints
]
+ return cks if cks else ReflectionDefaults.check_constraints()
@reflection.cache
def get_table_comment(self, connection, table_name, schema=None, **kw):
parsed_state = self._parsed_state_or_create(
connection, table_name, schema, **kw
)
- return {
- "text": parsed_state.table_options.get(
- "%s_comment" % self.name, None
- )
- }
+ comment = parsed_state.table_options.get(f"{self.name}_comment", None)
+ if comment is not None:
+ return {"text": comment}
+ else:
+ return ReflectionDefaults.table_comment()
@reflection.cache
def get_indexes(self, connection, table_name, schema=None, **kw):
if flavor:
index_d["type"] = flavor
indexes.append(index_d)
- return indexes
+ indexes.sort(key=lambda d: d["name"] or "~") # sort None as last
+ return indexes if indexes else ReflectionDefaults.indexes()
@reflection.cache
def get_unique_constraints(
connection, table_name, schema, **kw
)
- return [
+ ucs = [
{
"name": key["name"],
"column_names": [col[0] for col in key["columns"]],
for key in parsed_state.keys
if key["type"] == "UNIQUE"
]
+ ucs.sort(key=lambda d: d["name"] or "~") # sort None as last
+ if ucs:
+ return ucs
+ else:
+ return ReflectionDefaults.unique_constraints()
@reflection.cache
def get_view_definition(self, connection, view_name, schema=None, **kw):
sql = self._show_create_table(
connection, None, charset, full_name=full_name
)
+ if sql.upper().startswith("CREATE TABLE"):
+ # it's a table, not a view
+ raise exc.NoSuchTableError(full_name)
return sql
def _parsed_state_or_create(
""" # noqa
-from itertools import groupby
+from __future__ import annotations
+
+from collections import defaultdict
+from functools import lru_cache
+from functools import wraps
import re
+from . import dictionary
+from .types import _OracleBoolean
+from .types import _OracleDate
+from .types import BFILE
+from .types import BINARY_DOUBLE
+from .types import BINARY_FLOAT
+from .types import DATE
+from .types import FLOAT
+from .types import INTERVAL
+from .types import LONG
+from .types import NCLOB
+from .types import NUMBER
+from .types import NVARCHAR2 # noqa
+from .types import OracleRaw # noqa
+from .types import RAW
+from .types import ROWID # noqa
+from .types import VARCHAR2 # noqa
from ... import Computed
from ... import exc
from ... import schema as sa_schema
from ... import sql
from ... import util
from ...engine import default
+from ...engine import ObjectKind
+from ...engine import ObjectScope
from ...engine import reflection
+from ...engine.reflection import ReflectionDefaults
+from ...sql import and_
+from ...sql import bindparam
from ...sql import compiler
from ...sql import expression
+from ...sql import func
+from ...sql import null
+from ...sql import or_
+from ...sql import select
from ...sql import sqltypes
from ...sql import util as sql_util
from ...sql import visitors
+from ...sql.visitors import InternalTraversal
from ...types import BLOB
from ...types import CHAR
from ...types import CLOB
)
-class RAW(sqltypes._Binary):
- __visit_name__ = "RAW"
-
-
-OracleRaw = RAW
-
-
-class NCLOB(sqltypes.Text):
- __visit_name__ = "NCLOB"
-
-
-class VARCHAR2(VARCHAR):
- __visit_name__ = "VARCHAR2"
-
-
-NVARCHAR2 = NVARCHAR
-
-
-class NUMBER(sqltypes.Numeric, sqltypes.Integer):
- __visit_name__ = "NUMBER"
-
- def __init__(self, precision=None, scale=None, asdecimal=None):
- if asdecimal is None:
- asdecimal = bool(scale and scale > 0)
-
- super(NUMBER, self).__init__(
- precision=precision, scale=scale, asdecimal=asdecimal
- )
-
- def adapt(self, impltype):
- ret = super(NUMBER, self).adapt(impltype)
- # leave a hint for the DBAPI handler
- ret._is_oracle_number = True
- return ret
-
- @property
- def _type_affinity(self):
- if bool(self.scale and self.scale > 0):
- return sqltypes.Numeric
- else:
- return sqltypes.Integer
-
-
-class FLOAT(sqltypes.FLOAT):
- """Oracle FLOAT.
-
- This is the same as :class:`_sqltypes.FLOAT` except that
- an Oracle-specific :paramref:`_oracle.FLOAT.binary_precision`
- parameter is accepted, and
- the :paramref:`_sqltypes.Float.precision` parameter is not accepted.
-
- Oracle FLOAT types indicate precision in terms of "binary precision", which
- defaults to 126. For a REAL type, the value is 63. This parameter does not
- cleanly map to a specific number of decimal places but is roughly
- equivalent to the desired number of decimal places divided by 0.3103.
-
- .. versionadded:: 2.0
-
- """
-
- __visit_name__ = "FLOAT"
-
- def __init__(
- self,
- binary_precision=None,
- asdecimal=False,
- decimal_return_scale=None,
- ):
- r"""
- Construct a FLOAT
-
- :param binary_precision: Oracle binary precision value to be rendered
- in DDL. This may be approximated to the number of decimal characters
- using the formula "decimal precision = 0.30103 * binary precision".
- The default value used by Oracle for FLOAT / DOUBLE PRECISION is 126.
-
- :param asdecimal: See :paramref:`_sqltypes.Float.asdecimal`
-
- :param decimal_return_scale: See
- :paramref:`_sqltypes.Float.decimal_return_scale`
-
- """
- super().__init__(
- asdecimal=asdecimal, decimal_return_scale=decimal_return_scale
- )
- self.binary_precision = binary_precision
-
-
-class BINARY_DOUBLE(sqltypes.Float):
- __visit_name__ = "BINARY_DOUBLE"
-
-
-class BINARY_FLOAT(sqltypes.Float):
- __visit_name__ = "BINARY_FLOAT"
-
-
-class BFILE(sqltypes.LargeBinary):
- __visit_name__ = "BFILE"
-
-
-class LONG(sqltypes.Text):
- __visit_name__ = "LONG"
-
-
-class _OracleDateLiteralRender:
- def _literal_processor_datetime(self, dialect):
- def process(value):
- if value is not None:
- if getattr(value, "microsecond", None):
- value = (
- f"""TO_TIMESTAMP"""
- f"""('{value.isoformat().replace("T", " ")}', """
- """'YYYY-MM-DD HH24:MI:SS.FF')"""
- )
- else:
- value = (
- f"""TO_DATE"""
- f"""('{value.isoformat().replace("T", " ")}', """
- """'YYYY-MM-DD HH24:MI:SS')"""
- )
- return value
-
- return process
-
- def _literal_processor_date(self, dialect):
- def process(value):
- if value is not None:
- if getattr(value, "microsecond", None):
- value = (
- f"""TO_TIMESTAMP"""
- f"""('{value.isoformat().split("T")[0]}', """
- """'YYYY-MM-DD')"""
- )
- else:
- value = (
- f"""TO_DATE"""
- f"""('{value.isoformat().split("T")[0]}', """
- """'YYYY-MM-DD')"""
- )
- return value
-
- return process
-
-
-class DATE(_OracleDateLiteralRender, sqltypes.DateTime):
- """Provide the oracle DATE type.
-
- This type has no special Python behavior, except that it subclasses
- :class:`_types.DateTime`; this is to suit the fact that the Oracle
- ``DATE`` type supports a time value.
-
- .. versionadded:: 0.9.4
-
- """
-
- __visit_name__ = "DATE"
-
- def literal_processor(self, dialect):
- return self._literal_processor_datetime(dialect)
-
- def _compare_type_affinity(self, other):
- return other._type_affinity in (sqltypes.DateTime, sqltypes.Date)
-
-
-class _OracleDate(_OracleDateLiteralRender, sqltypes.Date):
- def literal_processor(self, dialect):
- return self._literal_processor_date(dialect)
-
-
-class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval):
- __visit_name__ = "INTERVAL"
-
- def __init__(self, day_precision=None, second_precision=None):
- """Construct an INTERVAL.
-
- Note that only DAY TO SECOND intervals are currently supported.
- This is due to a lack of support for YEAR TO MONTH intervals
- within available DBAPIs.
-
- :param day_precision: the day precision value. this is the number of
- digits to store for the day field. Defaults to "2"
- :param second_precision: the second precision value. this is the
- number of digits to store for the fractional seconds field.
- Defaults to "6".
-
- """
- self.day_precision = day_precision
- self.second_precision = second_precision
-
- @classmethod
- def _adapt_from_generic_interval(cls, interval):
- return INTERVAL(
- day_precision=interval.day_precision,
- second_precision=interval.second_precision,
- )
-
- @property
- def _type_affinity(self):
- return sqltypes.Interval
-
- def as_generic(self, allow_nulltype=False):
- return sqltypes.Interval(
- native=True,
- second_precision=self.second_precision,
- day_precision=self.day_precision,
- )
-
-
-class ROWID(sqltypes.TypeEngine):
- """Oracle ROWID type.
-
- When used in a cast() or similar, generates ROWID.
-
- """
-
- __visit_name__ = "ROWID"
-
-
-class _OracleBoolean(sqltypes.Boolean):
- def get_dbapi_type(self, dbapi):
- return dbapi.NUMBER
-
-
colspecs = {
sqltypes.Boolean: _OracleBoolean,
sqltypes.Interval: INTERVAL,
type_,
)
+ def pre_exec(self):
+ if self.statement and "_oracle_dblink" in self.execution_options:
+ self.statement = self.statement.replace(
+ dictionary.DB_LINK_PLACEHOLDER,
+ self.execution_options["_oracle_dblink"],
+ )
+
class OracleDialect(default.DefaultDialect):
name = "oracle"
# it may work also on versions before the 18
return self.server_version_info and self.server_version_info >= (18,)
+ @property
+ def _supports_except_all(self):
+ return self.server_version_info and self.server_version_info >= (21,)
+
def do_release_savepoint(self, connection, name):
# Oracle does not support RELEASE SAVEPOINT
pass
except:
return "READ COMMITTED"
- def has_table(self, connection, table_name, schema=None):
+ def _execute_reflection(
+ self, connection, query, dblink, returns_long, params=None
+ ):
+ if dblink and not dblink.startswith("@"):
+ dblink = f"@{dblink}"
+ execution_options = {
+ # handle db links
+ "_oracle_dblink": dblink or "",
+ # override any schema translate map
+ "schema_translate_map": None,
+ }
+
+ if dblink and returns_long:
+ # Oracle seems to error with
+ # "ORA-00997: illegal use of LONG datatype" when returning
+ # LONG columns via a dblink in a query with bind params
+ # This type seems to be very hard to cast into something else
+ # so it seems easier to just use bind param in this case
+ def visit_bindparam(bindparam):
+ bindparam.literal_execute = True
+
+ query = visitors.cloned_traverse(
+ query, {}, {"bindparam": visit_bindparam}
+ )
+ return connection.execute(
+ query, params, execution_options=execution_options
+ )
+
+ @util.memoized_property
+ def _has_table_query(self):
+ # materialized views are returned by all_tables
+ tables = (
+ select(
+ dictionary.all_tables.c.table_name,
+ dictionary.all_tables.c.owner,
+ )
+ .union_all(
+ select(
+ dictionary.all_views.c.view_name.label("table_name"),
+ dictionary.all_views.c.owner,
+ )
+ )
+ .subquery("tables_and_views")
+ )
+
+ query = select(tables.c.table_name).where(
+ tables.c.table_name == bindparam("table_name"),
+ tables.c.owner == bindparam("owner"),
+ )
+ return query
+
+ @reflection.cache
+ def has_table(
+ self, connection, table_name, schema=None, dblink=None, **kw
+ ):
+ """Supported kw arguments are: ``dblink`` to reflect via a db link."""
self._ensure_has_table_connection(connection)
if not schema:
schema = self.default_schema_name
- cursor = connection.execute(
- sql.text(
- """SELECT table_name FROM all_tables
- WHERE table_name = CAST(:name AS VARCHAR2(128))
- AND owner = CAST(:schema_name AS VARCHAR2(128))
- UNION ALL
- SELECT view_name FROM all_views
- WHERE view_name = CAST(:name AS VARCHAR2(128))
- AND owner = CAST(:schema_name AS VARCHAR2(128))
- """
- ),
- dict(
- name=self.denormalize_name(table_name),
- schema_name=self.denormalize_name(schema),
- ),
+ params = {
+ "table_name": self.denormalize_name(table_name),
+ "owner": self.denormalize_name(schema),
+ }
+ cursor = self._execute_reflection(
+ connection,
+ self._has_table_query,
+ dblink,
+ returns_long=False,
+ params=params,
)
- return cursor.first() is not None
+ return bool(cursor.scalar())
- def has_sequence(self, connection, sequence_name, schema=None):
+ @reflection.cache
+ def has_sequence(
+ self, connection, sequence_name, schema=None, dblink=None, **kw
+ ):
+ """Supported kw arguments are: ``dblink`` to reflect via a db link."""
if not schema:
schema = self.default_schema_name
- cursor = connection.execute(
- sql.text(
- "SELECT sequence_name FROM all_sequences "
- "WHERE sequence_name = :name AND "
- "sequence_owner = :schema_name"
- ),
- dict(
- name=self.denormalize_name(sequence_name),
- schema_name=self.denormalize_name(schema),
- ),
+
+ query = select(dictionary.all_sequences.c.sequence_name).where(
+ dictionary.all_sequences.c.sequence_name
+ == self.denormalize_name(sequence_name),
+ dictionary.all_sequences.c.sequence_owner
+ == self.denormalize_name(schema),
)
- return cursor.first() is not None
+
+ cursor = self._execute_reflection(
+ connection, query, dblink, returns_long=False
+ )
+ return bool(cursor.scalar())
def _get_default_schema_name(self, connection):
return self.normalize_name(
).scalar()
)
- def _resolve_synonym(
- self,
- connection,
- desired_owner=None,
- desired_synonym=None,
- desired_table=None,
+ @reflection.flexi_cache(
+ ("schema", InternalTraversal.dp_string),
+ ("filter_names", InternalTraversal.dp_string_list),
+ ("dblink", InternalTraversal.dp_string),
+ )
+ def _get_synonyms(self, connection, schema, filter_names, dblink, **kw):
+ owner = self.denormalize_name(schema or self.default_schema_name)
+
+ has_filter_names, params = self._prepare_filter_names(filter_names)
+ query = select(
+ dictionary.all_synonyms.c.synonym_name,
+ dictionary.all_synonyms.c.table_name,
+ dictionary.all_synonyms.c.table_owner,
+ dictionary.all_synonyms.c.db_link,
+ ).where(dictionary.all_synonyms.c.owner == owner)
+ if has_filter_names:
+ query = query.where(
+ dictionary.all_synonyms.c.synonym_name.in_(
+ params["filter_names"]
+ )
+ )
+ result = self._execute_reflection(
+ connection, query, dblink, returns_long=False
+ ).mappings()
+ return result.all()
+
+ @lru_cache()
+ def _all_objects_query(
+ self, owner, scope, kind, has_filter_names, has_mat_views
):
- """search for a local synonym matching the given desired owner/name.
-
- if desired_owner is None, attempts to locate a distinct owner.
-
- returns the actual name, owner, dblink name, and synonym name if
- found.
- """
-
- q = (
- "SELECT owner, table_owner, table_name, db_link, "
- "synonym_name FROM all_synonyms WHERE "
+ query = (
+ select(dictionary.all_objects.c.object_name)
+ .select_from(dictionary.all_objects)
+ .where(dictionary.all_objects.c.owner == owner)
)
- clauses = []
- params = {}
- if desired_synonym:
- clauses.append(
- "synonym_name = CAST(:synonym_name AS VARCHAR2(128))"
+
+ # NOTE: materialized views are listed in all_objects twice;
+ # once as MATERIALIZE VIEW and once as TABLE
+ if kind is ObjectKind.ANY:
+ # materilaized view are listed also as tables so there is no
+ # need to add them to the in_.
+ query = query.where(
+ dictionary.all_objects.c.object_type.in_(("TABLE", "VIEW"))
)
- params["synonym_name"] = desired_synonym
- if desired_owner:
- clauses.append("owner = CAST(:desired_owner AS VARCHAR2(128))")
- params["desired_owner"] = desired_owner
- if desired_table:
- clauses.append("table_name = CAST(:tname AS VARCHAR2(128))")
- params["tname"] = desired_table
-
- q += " AND ".join(clauses)
-
- result = connection.execution_options(future_result=True).execute(
- sql.text(q), params
- )
- if desired_owner:
- row = result.mappings().first()
- if row:
- return (
- row["table_name"],
- row["table_owner"],
- row["db_link"],
- row["synonym_name"],
- )
- else:
- return None, None, None, None
else:
- rows = result.mappings().all()
- if len(rows) > 1:
- raise AssertionError(
- "There are multiple tables visible to the schema, you "
- "must specify owner"
- )
- elif len(rows) == 1:
- row = rows[0]
- return (
- row["table_name"],
- row["table_owner"],
- row["db_link"],
- row["synonym_name"],
- )
- else:
- return None, None, None, None
+ object_type = []
+ if ObjectKind.VIEW in kind:
+ object_type.append("VIEW")
+ if (
+ ObjectKind.MATERIALIZED_VIEW in kind
+ and ObjectKind.TABLE not in kind
+ ):
+ # materilaized view are listed also as tables so there is no
+ # need to add them to the in_ if also selecting tables.
+ object_type.append("MATERIALIZED VIEW")
+ if ObjectKind.TABLE in kind:
+ object_type.append("TABLE")
+ if has_mat_views and ObjectKind.MATERIALIZED_VIEW not in kind:
+ # materialized view are listed also as tables,
+ # so they need to be filtered out
+ # EXCEPT ALL / MINUS profiles as faster than using
+ # NOT EXISTS or NOT IN with a subquery, but it's in
+ # general faster to get the mat view names and exclude
+ # them only when needed
+ query = query.where(
+ dictionary.all_objects.c.object_name.not_in(
+ bindparam("mat_views")
+ )
+ )
+ query = query.where(
+ dictionary.all_objects.c.object_type.in_(object_type)
+ )
- @reflection.cache
- def _prepare_reflection_args(
- self,
- connection,
- table_name,
- schema=None,
- resolve_synonyms=False,
- dblink="",
- **kw,
- ):
+ # handles scope
+ if scope is ObjectScope.DEFAULT:
+ query = query.where(dictionary.all_objects.c.temporary == "N")
+ elif scope is ObjectScope.TEMPORARY:
+ query = query.where(dictionary.all_objects.c.temporary == "Y")
- if resolve_synonyms:
- actual_name, owner, dblink, synonym = self._resolve_synonym(
- connection,
- desired_owner=self.denormalize_name(schema),
- desired_synonym=self.denormalize_name(table_name),
+ if has_filter_names:
+ query = query.where(
+ dictionary.all_objects.c.object_name.in_(
+ bindparam("filter_names")
+ )
)
- else:
- actual_name, owner, dblink, synonym = None, None, None, None
- if not actual_name:
- actual_name = self.denormalize_name(table_name)
-
- if dblink:
- # using user_db_links here since all_db_links appears
- # to have more restricted permissions.
- # https://docs.oracle.com/cd/B28359_01/server.111/b28310/ds_admin005.htm
- # will need to hear from more users if we are doing
- # the right thing here. See [ticket:2619]
- owner = connection.scalar(
- sql.text(
- "SELECT username FROM user_db_links " "WHERE db_link=:link"
- ),
- dict(link=dblink),
+ return query
+
+ @reflection.flexi_cache(
+ ("schema", InternalTraversal.dp_string),
+ ("scope", InternalTraversal.dp_plain_obj),
+ ("kind", InternalTraversal.dp_plain_obj),
+ ("filter_names", InternalTraversal.dp_string_list),
+ ("dblink", InternalTraversal.dp_string),
+ )
+ def _get_all_objects(
+ self, connection, schema, scope, kind, filter_names, dblink, **kw
+ ):
+ owner = self.denormalize_name(schema or self.default_schema_name)
+
+ has_filter_names, params = self._prepare_filter_names(filter_names)
+ has_mat_views = False
+ if (
+ ObjectKind.TABLE in kind
+ and ObjectKind.MATERIALIZED_VIEW not in kind
+ ):
+ # see note in _all_objects_query
+ mat_views = self.get_materialized_view_names(
+ connection, schema, dblink, _normalize=False, **kw
)
- dblink = "@" + dblink
- elif not owner:
- owner = self.denormalize_name(schema or self.default_schema_name)
+ if mat_views:
+ params["mat_views"] = mat_views
+ has_mat_views = True
+
+ query = self._all_objects_query(
+ owner, scope, kind, has_filter_names, has_mat_views
+ )
- return (actual_name, owner, dblink or "", synonym)
+ result = self._execute_reflection(
+ connection, query, dblink, returns_long=False, params=params
+ ).scalars()
- @reflection.cache
- def get_schema_names(self, connection, **kw):
- s = "SELECT username FROM all_users ORDER BY username"
- cursor = connection.exec_driver_sql(s)
- return [self.normalize_name(row[0]) for row in cursor]
+ return result.all()
+
+ def _handle_synonyms_decorator(fn):
+ @wraps(fn)
+ def wrapper(self, *args, **kwargs):
+ return self._handle_synonyms(fn, *args, **kwargs)
+
+ return wrapper
+
+ def _handle_synonyms(self, fn, connection, *args, **kwargs):
+ if not kwargs.get("oracle_resolve_synonyms", False):
+ return fn(self, connection, *args, **kwargs)
+
+ original_kw = kwargs.copy()
+ schema = kwargs.pop("schema", None)
+ result = self._get_synonyms(
+ connection,
+ schema=schema,
+ filter_names=kwargs.pop("filter_names", None),
+ dblink=kwargs.pop("dblink", None),
+ info_cache=kwargs.get("info_cache", None),
+ )
+
+ dblinks_owners = defaultdict(dict)
+ for row in result:
+ key = row["db_link"], row["table_owner"]
+ tn = self.normalize_name(row["table_name"])
+ dblinks_owners[key][tn] = row["synonym_name"]
+
+ if not dblinks_owners:
+ # No synonym, do the plain thing
+ return fn(self, connection, *args, **original_kw)
+
+ data = {}
+ for (dblink, table_owner), mapping in dblinks_owners.items():
+ call_kw = {
+ **original_kw,
+ "schema": table_owner,
+ "dblink": self.normalize_name(dblink),
+ "filter_names": mapping.keys(),
+ }
+ call_result = fn(self, connection, *args, **call_kw)
+ for (_, tn), value in call_result:
+ synonym_name = self.normalize_name(mapping[tn])
+ data[(schema, synonym_name)] = value
+ return data.items()
@reflection.cache
- def get_table_names(self, connection, schema=None, **kw):
- schema = self.denormalize_name(schema or self.default_schema_name)
+ def get_schema_names(self, connection, dblink=None, **kw):
+ """Supported kw arguments are: ``dblink`` to reflect via a db link."""
+ query = select(dictionary.all_users.c.username).order_by(
+ dictionary.all_users.c.username
+ )
+ result = self._execute_reflection(
+ connection, query, dblink, returns_long=False
+ ).scalars()
+ return [self.normalize_name(row) for row in result]
+ @reflection.cache
+ def get_table_names(self, connection, schema=None, dblink=None, **kw):
+ """Supported kw arguments are: ``dblink`` to reflect via a db link."""
# note that table_names() isn't loading DBLINKed or synonym'ed tables
if schema is None:
schema = self.default_schema_name
- sql_str = "SELECT table_name FROM all_tables WHERE "
+ den_schema = self.denormalize_name(schema)
+ if kw.get("oracle_resolve_synonyms", False):
+ tables = (
+ select(
+ dictionary.all_tables.c.table_name,
+ dictionary.all_tables.c.owner,
+ dictionary.all_tables.c.iot_name,
+ dictionary.all_tables.c.duration,
+ dictionary.all_tables.c.tablespace_name,
+ )
+ .union_all(
+ select(
+ dictionary.all_synonyms.c.synonym_name.label(
+ "table_name"
+ ),
+ dictionary.all_synonyms.c.owner,
+ dictionary.all_tables.c.iot_name,
+ dictionary.all_tables.c.duration,
+ dictionary.all_tables.c.tablespace_name,
+ )
+ .select_from(dictionary.all_tables)
+ .join(
+ dictionary.all_synonyms,
+ and_(
+ dictionary.all_tables.c.table_name
+ == dictionary.all_synonyms.c.table_name,
+ dictionary.all_tables.c.owner
+ == func.coalesce(
+ dictionary.all_synonyms.c.table_owner,
+ dictionary.all_synonyms.c.owner,
+ ),
+ ),
+ )
+ )
+ .subquery("available_tables")
+ )
+ else:
+ tables = dictionary.all_tables
+
+ query = select(tables.c.table_name)
if self.exclude_tablespaces:
- sql_str += (
- "nvl(tablespace_name, 'no tablespace') "
- "NOT IN (%s) AND "
- % (", ".join(["'%s'" % ts for ts in self.exclude_tablespaces]))
+ query = query.where(
+ func.coalesce(
+ tables.c.tablespace_name, "no tablespace"
+ ).not_in(self.exclude_tablespaces)
)
- sql_str += (
- "OWNER = :owner " "AND IOT_NAME IS NULL " "AND DURATION IS NULL"
+ query = query.where(
+ tables.c.owner == den_schema,
+ tables.c.iot_name.is_(null()),
+ tables.c.duration.is_(null()),
)
- cursor = connection.execute(sql.text(sql_str), dict(owner=schema))
- return [self.normalize_name(row[0]) for row in cursor]
+ # remove materialized views
+ mat_query = select(
+ dictionary.all_mviews.c.mview_name.label("table_name")
+ ).where(dictionary.all_mviews.c.owner == den_schema)
+
+ query = (
+ query.except_all(mat_query)
+ if self._supports_except_all
+ else query.except_(mat_query)
+ )
+
+ result = self._execute_reflection(
+ connection, query, dblink, returns_long=False
+ ).scalars()
+ return [self.normalize_name(row) for row in result]
@reflection.cache
- def get_temp_table_names(self, connection, **kw):
+ def get_temp_table_names(self, connection, dblink=None, **kw):
+ """Supported kw arguments are: ``dblink`` to reflect via a db link."""
schema = self.denormalize_name(self.default_schema_name)
- sql_str = "SELECT table_name FROM all_tables WHERE "
+ query = select(dictionary.all_tables.c.table_name)
if self.exclude_tablespaces:
- sql_str += (
- "nvl(tablespace_name, 'no tablespace') "
- "NOT IN (%s) AND "
- % (", ".join(["'%s'" % ts for ts in self.exclude_tablespaces]))
+ query = query.where(
+ func.coalesce(
+ dictionary.all_tables.c.tablespace_name, "no tablespace"
+ ).not_in(self.exclude_tablespaces)
)
- sql_str += (
- "OWNER = :owner "
- "AND IOT_NAME IS NULL "
- "AND DURATION IS NOT NULL"
+ query = query.where(
+ dictionary.all_tables.c.owner == schema,
+ dictionary.all_tables.c.iot_name.is_(null()),
+ dictionary.all_tables.c.duration.is_not(null()),
)
- cursor = connection.execute(sql.text(sql_str), dict(owner=schema))
- return [self.normalize_name(row[0]) for row in cursor]
+ result = self._execute_reflection(
+ connection, query, dblink, returns_long=False
+ ).scalars()
+ return [self.normalize_name(row) for row in result]
@reflection.cache
- def get_view_names(self, connection, schema=None, **kw):
- schema = self.denormalize_name(schema or self.default_schema_name)
- s = sql.text("SELECT view_name FROM all_views WHERE owner = :owner")
- cursor = connection.execute(
- s, dict(owner=self.denormalize_name(schema))
+ def get_materialized_view_names(
+ self, connection, schema=None, dblink=None, _normalize=True, **kw
+ ):
+ """Supported kw arguments are: ``dblink`` to reflect via a db link."""
+ if not schema:
+ schema = self.default_schema_name
+
+ query = select(dictionary.all_mviews.c.mview_name).where(
+ dictionary.all_mviews.c.owner == self.denormalize_name(schema)
)
- return [self.normalize_name(row[0]) for row in cursor]
+ result = self._execute_reflection(
+ connection, query, dblink, returns_long=False
+ ).scalars()
+ if _normalize:
+ return [self.normalize_name(row) for row in result]
+ else:
+ return result.all()
@reflection.cache
- def get_sequence_names(self, connection, schema=None, **kw):
+ def get_view_names(self, connection, schema=None, dblink=None, **kw):
+ """Supported kw arguments are: ``dblink`` to reflect via a db link."""
if not schema:
schema = self.default_schema_name
- cursor = connection.execute(
- sql.text(
- "SELECT sequence_name FROM all_sequences "
- "WHERE sequence_owner = :schema_name"
- ),
- dict(schema_name=self.denormalize_name(schema)),
+
+ query = select(dictionary.all_views.c.view_name).where(
+ dictionary.all_views.c.owner == self.denormalize_name(schema)
)
- return [self.normalize_name(row[0]) for row in cursor]
+ result = self._execute_reflection(
+ connection, query, dblink, returns_long=False
+ ).scalars()
+ return [self.normalize_name(row) for row in result]
@reflection.cache
- def get_table_options(self, connection, table_name, schema=None, **kw):
- options = {}
+ def get_sequence_names(self, connection, schema=None, dblink=None, **kw):
+ """Supported kw arguments are: ``dblink`` to reflect via a db link."""
+ if not schema:
+ schema = self.default_schema_name
+ query = select(dictionary.all_sequences.c.sequence_name).where(
+ dictionary.all_sequences.c.sequence_owner
+ == self.denormalize_name(schema)
+ )
- resolve_synonyms = kw.get("oracle_resolve_synonyms", False)
- dblink = kw.get("dblink", "")
- info_cache = kw.get("info_cache")
+ result = self._execute_reflection(
+ connection, query, dblink, returns_long=False
+ ).scalars()
+ return [self.normalize_name(row) for row in result]
- (table_name, schema, dblink, synonym) = self._prepare_reflection_args(
+ def _value_or_raise(self, data, table, schema):
+ table = self.normalize_name(str(table))
+ try:
+ return dict(data)[(schema, table)]
+ except KeyError:
+ raise exc.NoSuchTableError(
+ f"{schema}.{table}" if schema else table
+ ) from None
+
+ def _prepare_filter_names(self, filter_names):
+ if filter_names:
+ fn = [self.denormalize_name(name) for name in filter_names]
+ return True, {"filter_names": fn}
+ else:
+ return False, {}
+
+ @reflection.cache
+ def get_table_options(self, connection, table_name, schema=None, **kw):
+ """Supported kw arguments are: ``dblink`` to reflect via a db link;
+ ``oracle_resolve_synonyms`` to resolve names to synonyms
+ """
+ data = self.get_multi_table_options(
connection,
- table_name,
- schema,
- resolve_synonyms,
- dblink,
- info_cache=info_cache,
+ schema=schema,
+ filter_names=[table_name],
+ scope=ObjectScope.ANY,
+ kind=ObjectKind.ANY,
+ **kw,
)
+ return self._value_or_raise(data, table_name, schema)
- params = {"table_name": table_name}
+ @lru_cache()
+ def _table_options_query(
+ self, owner, scope, kind, has_filter_names, has_mat_views
+ ):
+ query = select(
+ dictionary.all_tables.c.table_name,
+ dictionary.all_tables.c.compression,
+ dictionary.all_tables.c.compress_for,
+ ).where(dictionary.all_tables.c.owner == owner)
+ if has_filter_names:
+ query = query.where(
+ dictionary.all_tables.c.table_name.in_(
+ bindparam("filter_names")
+ )
+ )
+ if scope is ObjectScope.DEFAULT:
+ query = query.where(dictionary.all_tables.c.duration.is_(null()))
+ elif scope is ObjectScope.TEMPORARY:
+ query = query.where(
+ dictionary.all_tables.c.duration.is_not(null())
+ )
- columns = ["table_name"]
- if self._supports_table_compression:
- columns.append("compression")
- if self._supports_table_compress_for:
- columns.append("compress_for")
+ if (
+ has_mat_views
+ and ObjectKind.TABLE in kind
+ and ObjectKind.MATERIALIZED_VIEW not in kind
+ ):
+ # cant use EXCEPT ALL / MINUS here because we don't have an
+ # excludable row vs. the query above
+ # outerjoin + where null works better on oracle 21 but 11 does
+ # not like it at all. this is the next best thing
+
+ query = query.where(
+ dictionary.all_tables.c.table_name.not_in(
+ bindparam("mat_views")
+ )
+ )
+ elif (
+ ObjectKind.TABLE not in kind
+ and ObjectKind.MATERIALIZED_VIEW in kind
+ ):
+ query = query.where(
+ dictionary.all_tables.c.table_name.in_(bindparam("mat_views"))
+ )
+ return query
- text = (
- "SELECT %(columns)s "
- "FROM ALL_TABLES%(dblink)s "
- "WHERE table_name = CAST(:table_name AS VARCHAR(128))"
- )
+ @_handle_synonyms_decorator
+ def get_multi_table_options(
+ self,
+ connection,
+ *,
+ schema,
+ filter_names,
+ scope,
+ kind,
+ dblink=None,
+ **kw,
+ ):
+ """Supported kw arguments are: ``dblink`` to reflect via a db link;
+ ``oracle_resolve_synonyms`` to resolve names to synonyms
+ """
+ owner = self.denormalize_name(schema or self.default_schema_name)
- if schema is not None:
- params["owner"] = schema
- text += " AND owner = CAST(:owner AS VARCHAR(128)) "
- text = text % {"dblink": dblink, "columns": ", ".join(columns)}
+ has_filter_names, params = self._prepare_filter_names(filter_names)
+ has_mat_views = False
- result = connection.execute(sql.text(text), params)
+ if (
+ ObjectKind.TABLE in kind
+ and ObjectKind.MATERIALIZED_VIEW not in kind
+ ):
+ # see note in _table_options_query
+ mat_views = self.get_materialized_view_names(
+ connection, schema, dblink, _normalize=False, **kw
+ )
+ if mat_views:
+ params["mat_views"] = mat_views
+ has_mat_views = True
+ elif (
+ ObjectKind.TABLE not in kind
+ and ObjectKind.MATERIALIZED_VIEW in kind
+ ):
+ mat_views = self.get_materialized_view_names(
+ connection, schema, dblink, _normalize=False, **kw
+ )
+ params["mat_views"] = mat_views
- enabled = dict(DISABLED=False, ENABLED=True)
+ options = {}
+ default = ReflectionDefaults.table_options
- row = result.first()
- if row:
- if "compression" in row._fields and enabled.get(
- row.compression, False
- ):
- if "compress_for" in row._fields:
- options["oracle_compress"] = row.compress_for
+ if ObjectKind.TABLE in kind or ObjectKind.MATERIALIZED_VIEW in kind:
+ query = self._table_options_query(
+ owner, scope, kind, has_filter_names, has_mat_views
+ )
+ result = self._execute_reflection(
+ connection, query, dblink, returns_long=False, params=params
+ )
+
+ for table, compression, compress_for in result:
+ if compression == "ENABLED":
+ data = {"oracle_compress": compress_for}
else:
- options["oracle_compress"] = True
+ data = default()
+ options[(schema, self.normalize_name(table))] = data
+ if ObjectKind.VIEW in kind and ObjectScope.DEFAULT in scope:
+ # add the views (no temporary views)
+ for view in self.get_view_names(connection, schema, dblink, **kw):
+ if not filter_names or view in filter_names:
+ options[(schema, view)] = default()
- return options
+ return options.items()
@reflection.cache
def get_columns(self, connection, table_name, schema=None, **kw):
+ """Supported kw arguments are: ``dblink`` to reflect via a db link;
+ ``oracle_resolve_synonyms`` to resolve names to synonyms
"""
- kw arguments can be:
+ data = self.get_multi_columns(
+ connection,
+ schema=schema,
+ filter_names=[table_name],
+ scope=ObjectScope.ANY,
+ kind=ObjectKind.ANY,
+ **kw,
+ )
+ return self._value_or_raise(data, table_name, schema)
+
+ def _run_batches(
+ self, connection, query, dblink, returns_long, mappings, all_objects
+ ):
+ each_batch = 500
+ batches = list(all_objects)
+ while batches:
+ batch = batches[0:each_batch]
+ batches[0:each_batch] = []
+
+ result = self._execute_reflection(
+ connection,
+ query,
+ dblink,
+ returns_long=returns_long,
+ params={"all_objects": batch},
+ )
+ if mappings:
+ yield from result.mappings()
+ else:
+ yield from result
+
+ @lru_cache()
+ def _column_query(self, owner):
+ all_cols = dictionary.all_tab_cols
+ all_comments = dictionary.all_col_comments
+ all_ids = dictionary.all_tab_identity_cols
- oracle_resolve_synonyms
+ if self.server_version_info >= (12,):
+ add_cols = (
+ all_cols.c.default_on_null,
+ sql.case(
+ (all_ids.c.table_name.is_(None), sql.null()),
+ else_=all_ids.c.generation_type
+ + ","
+ + all_ids.c.identity_options,
+ ).label("identity_options"),
+ )
+ join_identity_cols = True
+ else:
+ add_cols = (
+ sql.null().label("default_on_null"),
+ sql.null().label("identity_options"),
+ )
+ join_identity_cols = False
+
+ # NOTE: on oracle cannot create tables/views without columns and
+ # a table cannot have all column hidden:
+ # ORA-54039: table must have at least one column that is not invisible
+ # all_tab_cols returns data for tables/views/mat-views.
+ # all_tab_cols does not return recycled tables
+
+ query = (
+ select(
+ all_cols.c.table_name,
+ all_cols.c.column_name,
+ all_cols.c.data_type,
+ all_cols.c.char_length,
+ all_cols.c.data_precision,
+ all_cols.c.data_scale,
+ all_cols.c.nullable,
+ all_cols.c.data_default,
+ all_comments.c.comments,
+ all_cols.c.virtual_column,
+ *add_cols,
+ ).select_from(all_cols)
+ # NOTE: all_col_comments has a row for each column even if no
+ # comment is present, so a join could be performed, but there
+ # seems to be no difference compared to an outer join
+ .outerjoin(
+ all_comments,
+ and_(
+ all_cols.c.table_name == all_comments.c.table_name,
+ all_cols.c.column_name == all_comments.c.column_name,
+ all_cols.c.owner == all_comments.c.owner,
+ ),
+ )
+ )
+ if join_identity_cols:
+ query = query.outerjoin(
+ all_ids,
+ and_(
+ all_cols.c.table_name == all_ids.c.table_name,
+ all_cols.c.column_name == all_ids.c.column_name,
+ all_cols.c.owner == all_ids.c.owner,
+ ),
+ )
- dblink
+ query = query.where(
+ all_cols.c.table_name.in_(bindparam("all_objects")),
+ all_cols.c.hidden_column == "NO",
+ all_cols.c.owner == owner,
+ ).order_by(all_cols.c.table_name, all_cols.c.column_id)
+ return query
+ @_handle_synonyms_decorator
+ def get_multi_columns(
+ self,
+ connection,
+ *,
+ schema,
+ filter_names,
+ scope,
+ kind,
+ dblink=None,
+ **kw,
+ ):
+ """Supported kw arguments are: ``dblink`` to reflect via a db link;
+ ``oracle_resolve_synonyms`` to resolve names to synonyms
"""
+ owner = self.denormalize_name(schema or self.default_schema_name)
+ query = self._column_query(owner)
- resolve_synonyms = kw.get("oracle_resolve_synonyms", False)
- dblink = kw.get("dblink", "")
- info_cache = kw.get("info_cache")
+ if (
+ filter_names
+ and kind is ObjectKind.ANY
+ and scope is ObjectScope.ANY
+ ):
+ all_objects = [self.denormalize_name(n) for n in filter_names]
+ else:
+ all_objects = self._get_all_objects(
+ connection, schema, scope, kind, filter_names, dblink, **kw
+ )
- (table_name, schema, dblink, synonym) = self._prepare_reflection_args(
+ columns = defaultdict(list)
+
+ # all_tab_cols.data_default is LONG
+ result = self._run_batches(
connection,
- table_name,
- schema,
- resolve_synonyms,
+ query,
dblink,
- info_cache=info_cache,
+ returns_long=True,
+ mappings=True,
+ all_objects=all_objects,
)
- columns = []
- if self._supports_char_length:
- char_length_col = "char_length"
- else:
- char_length_col = "data_length"
- if self.server_version_info >= (12,):
- identity_cols = """\
- col.default_on_null,
- (
- SELECT id.generation_type || ',' || id.IDENTITY_OPTIONS
- FROM ALL_TAB_IDENTITY_COLS%(dblink)s id
- WHERE col.table_name = id.table_name
- AND col.column_name = id.column_name
- AND col.owner = id.owner
- ) AS identity_options""" % {
- "dblink": dblink
- }
- else:
- identity_cols = "NULL as default_on_null, NULL as identity_options"
-
- params = {"table_name": table_name}
-
- text = """
- SELECT
- col.column_name,
- col.data_type,
- col.%(char_length_col)s,
- col.data_precision,
- col.data_scale,
- col.nullable,
- col.data_default,
- com.comments,
- col.virtual_column,
- %(identity_cols)s
- FROM all_tab_cols%(dblink)s col
- LEFT JOIN all_col_comments%(dblink)s com
- ON col.table_name = com.table_name
- AND col.column_name = com.column_name
- AND col.owner = com.owner
- WHERE col.table_name = CAST(:table_name AS VARCHAR2(128))
- AND col.hidden_column = 'NO'
- """
- if schema is not None:
- params["owner"] = schema
- text += " AND col.owner = :owner "
- text += " ORDER BY col.column_id"
- text = text % {
- "dblink": dblink,
- "char_length_col": char_length_col,
- "identity_cols": identity_cols,
- }
-
- c = connection.execute(sql.text(text), params)
-
- for row in c:
- colname = self.normalize_name(row[0])
- orig_colname = row[0]
- coltype = row[1]
- length = row[2]
- precision = row[3]
- scale = row[4]
- nullable = row[5] == "Y"
- default = row[6]
- comment = row[7]
- generated = row[8]
- default_on_nul = row[9]
- identity_options = row[10]
+ for row_dict in result:
+ table_name = self.normalize_name(row_dict["table_name"])
+ orig_colname = row_dict["column_name"]
+ colname = self.normalize_name(orig_colname)
+ coltype = row_dict["data_type"]
+ precision = row_dict["data_precision"]
if coltype == "NUMBER":
+ scale = row_dict["data_scale"]
if precision is None and scale == 0:
coltype = INTEGER()
else:
coltype = FLOAT(binary_precision=precision)
elif coltype in ("VARCHAR2", "NVARCHAR2", "CHAR", "NCHAR"):
- coltype = self.ischema_names.get(coltype)(length)
+ coltype = self.ischema_names.get(coltype)(
+ row_dict["char_length"]
+ )
elif "WITH TIME ZONE" in coltype:
coltype = TIMESTAMP(timezone=True)
else:
)
coltype = sqltypes.NULLTYPE
- if generated == "YES":
+ default = row_dict["data_default"]
+ if row_dict["virtual_column"] == "YES":
computed = dict(sqltext=default)
default = None
else:
computed = None
+ identity_options = row_dict["identity_options"]
if identity_options is not None:
identity = self._parse_identity_options(
- identity_options, default_on_nul
+ identity_options, row_dict["default_on_null"]
)
default = None
else:
cdict = {
"name": colname,
"type": coltype,
- "nullable": nullable,
+ "nullable": row_dict["nullable"] == "Y",
"default": default,
- "autoincrement": "auto",
- "comment": comment,
+ "comment": row_dict["comments"],
}
if orig_colname.lower() == orig_colname:
cdict["quote"] = True
if identity is not None:
cdict["identity"] = identity
- columns.append(cdict)
- return columns
+ columns[(schema, table_name)].append(cdict)
- def _parse_identity_options(self, identity_options, default_on_nul):
+ # NOTE: default not needed since all tables have columns
+ # default = ReflectionDefaults.columns
+ # return (
+ # (key, value if value else default())
+ # for key, value in columns.items()
+ # )
+ return columns.items()
+
+ def _parse_identity_options(self, identity_options, default_on_null):
# identity_options is a string that starts with 'ALWAYS,' or
# 'BY DEFAULT,' and continues with
# START WITH: 1, INCREMENT BY: 1, MAX_VALUE: 123, MIN_VALUE: 1,
parts = [p.strip() for p in identity_options.split(",")]
identity = {
"always": parts[0] == "ALWAYS",
- "on_null": default_on_nul == "YES",
+ "on_null": default_on_null == "YES",
}
for part in parts[1:]:
return identity
@reflection.cache
- def get_table_comment(
+ def get_table_comment(self, connection, table_name, schema=None, **kw):
+ """Supported kw arguments are: ``dblink`` to reflect via a db link;
+ ``oracle_resolve_synonyms`` to resolve names to synonyms
+ """
+ data = self.get_multi_table_comment(
+ connection,
+ schema=schema,
+ filter_names=[table_name],
+ scope=ObjectScope.ANY,
+ kind=ObjectKind.ANY,
+ **kw,
+ )
+ return self._value_or_raise(data, table_name, schema)
+
+ @lru_cache()
+ def _comment_query(self, owner, scope, kind, has_filter_names):
+ # NOTE: all_tab_comments / all_mview_comments have a row for all
+ # object even if they don't have comments
+ queries = []
+ if ObjectKind.TABLE in kind or ObjectKind.VIEW in kind:
+ # all_tab_comments returns also plain views
+ tbl_view = select(
+ dictionary.all_tab_comments.c.table_name,
+ dictionary.all_tab_comments.c.comments,
+ ).where(
+ dictionary.all_tab_comments.c.owner == owner,
+ dictionary.all_tab_comments.c.table_name.not_like("BIN$%"),
+ )
+ if ObjectKind.VIEW not in kind:
+ tbl_view = tbl_view.where(
+ dictionary.all_tab_comments.c.table_type == "TABLE"
+ )
+ elif ObjectKind.TABLE not in kind:
+ tbl_view = tbl_view.where(
+ dictionary.all_tab_comments.c.table_type == "VIEW"
+ )
+ queries.append(tbl_view)
+ if ObjectKind.MATERIALIZED_VIEW in kind:
+ mat_view = select(
+ dictionary.all_mview_comments.c.mview_name.label("table_name"),
+ dictionary.all_mview_comments.c.comments,
+ ).where(
+ dictionary.all_mview_comments.c.owner == owner,
+ dictionary.all_mview_comments.c.mview_name.not_like("BIN$%"),
+ )
+ queries.append(mat_view)
+ if len(queries) == 1:
+ query = queries[0]
+ else:
+ union = sql.union_all(*queries).subquery("tables_and_views")
+ query = select(union.c.table_name, union.c.comments)
+
+ name_col = query.selected_columns.table_name
+
+ if scope in (ObjectScope.DEFAULT, ObjectScope.TEMPORARY):
+ temp = "Y" if scope is ObjectScope.TEMPORARY else "N"
+ # need distinct since materialized view are listed also
+ # as tables in all_objects
+ query = query.distinct().join(
+ dictionary.all_objects,
+ and_(
+ dictionary.all_objects.c.owner == owner,
+ dictionary.all_objects.c.object_name == name_col,
+ dictionary.all_objects.c.temporary == temp,
+ ),
+ )
+ if has_filter_names:
+ query = query.where(name_col.in_(bindparam("filter_names")))
+ return query
+
+ @_handle_synonyms_decorator
+ def get_multi_table_comment(
self,
connection,
- table_name,
- schema=None,
- resolve_synonyms=False,
- dblink="",
+ *,
+ schema,
+ filter_names,
+ scope,
+ kind,
+ dblink=None,
**kw,
):
-
- info_cache = kw.get("info_cache")
- (table_name, schema, dblink, synonym) = self._prepare_reflection_args(
- connection,
- table_name,
- schema,
- resolve_synonyms,
- dblink,
- info_cache=info_cache,
- )
-
- if not schema:
- schema = self.default_schema_name
-
- COMMENT_SQL = """
- SELECT comments
- FROM all_tab_comments
- WHERE table_name = CAST(:table_name AS VARCHAR(128))
- AND owner = CAST(:schema_name AS VARCHAR(128))
+ """Supported kw arguments are: ``dblink`` to reflect via a db link;
+ ``oracle_resolve_synonyms`` to resolve names to synonyms
"""
+ owner = self.denormalize_name(schema or self.default_schema_name)
+ has_filter_names, params = self._prepare_filter_names(filter_names)
+ query = self._comment_query(owner, scope, kind, has_filter_names)
- c = connection.execute(
- sql.text(COMMENT_SQL),
- dict(table_name=table_name, schema_name=schema),
+ result = self._execute_reflection(
+ connection, query, dblink, returns_long=False, params=params
+ )
+ default = ReflectionDefaults.table_comment
+ # materialized views by default seem to have a comment like
+ # "snapshot table for snapshot owner.mat_view_name"
+ ignore_mat_view = "snapshot table for snapshot "
+ return (
+ (
+ (schema, self.normalize_name(table)),
+ {"text": comment}
+ if comment is not None
+ and not comment.startswith(ignore_mat_view)
+ else default(),
+ )
+ for table, comment in result
)
- return {"text": c.scalar()}
@reflection.cache
- def get_indexes(
- self,
- connection,
- table_name,
- schema=None,
- resolve_synonyms=False,
- dblink="",
- **kw,
- ):
-
- info_cache = kw.get("info_cache")
- (table_name, schema, dblink, synonym) = self._prepare_reflection_args(
+ def get_indexes(self, connection, table_name, schema=None, **kw):
+ """Supported kw arguments are: ``dblink`` to reflect via a db link;
+ ``oracle_resolve_synonyms`` to resolve names to synonyms
+ """
+ data = self.get_multi_indexes(
connection,
- table_name,
- schema,
- resolve_synonyms,
- dblink,
- info_cache=info_cache,
+ schema=schema,
+ filter_names=[table_name],
+ scope=ObjectScope.ANY,
+ kind=ObjectKind.ANY,
+ **kw,
)
- indexes = []
-
- params = {"table_name": table_name}
- text = (
- "SELECT a.index_name, a.column_name, "
- "\nb.index_type, b.uniqueness, b.compression, b.prefix_length "
- "\nFROM ALL_IND_COLUMNS%(dblink)s a, "
- "\nALL_INDEXES%(dblink)s b "
- "\nWHERE "
- "\na.index_name = b.index_name "
- "\nAND a.table_owner = b.table_owner "
- "\nAND a.table_name = b.table_name "
- "\nAND a.table_name = CAST(:table_name AS VARCHAR(128))"
+ return self._value_or_raise(data, table_name, schema)
+
+ @lru_cache()
+ def _index_query(self, owner):
+ return (
+ select(
+ dictionary.all_ind_columns.c.table_name,
+ dictionary.all_ind_columns.c.index_name,
+ dictionary.all_ind_columns.c.column_name,
+ dictionary.all_indexes.c.index_type,
+ dictionary.all_indexes.c.uniqueness,
+ dictionary.all_indexes.c.compression,
+ dictionary.all_indexes.c.prefix_length,
+ )
+ .select_from(dictionary.all_ind_columns)
+ .join(
+ dictionary.all_indexes,
+ sql.and_(
+ dictionary.all_ind_columns.c.index_name
+ == dictionary.all_indexes.c.index_name,
+ dictionary.all_ind_columns.c.table_owner
+ == dictionary.all_indexes.c.table_owner,
+ # NOTE: this condition on table_name is not required
+ # but it improves the query performance noticeably
+ dictionary.all_ind_columns.c.table_name
+ == dictionary.all_indexes.c.table_name,
+ ),
+ )
+ .where(
+ dictionary.all_ind_columns.c.table_owner == owner,
+ dictionary.all_ind_columns.c.table_name.in_(
+ bindparam("all_objects")
+ ),
+ )
+ .order_by(
+ dictionary.all_ind_columns.c.index_name,
+ dictionary.all_ind_columns.c.column_position,
+ )
)
- if schema is not None:
- params["schema"] = schema
- text += "AND a.table_owner = :schema "
+ @reflection.flexi_cache(
+ ("schema", InternalTraversal.dp_string),
+ ("dblink", InternalTraversal.dp_string),
+ ("all_objects", InternalTraversal.dp_string_list),
+ )
+ def _get_indexes_rows(self, connection, schema, dblink, all_objects, **kw):
+ owner = self.denormalize_name(schema or self.default_schema_name)
- text += "ORDER BY a.index_name, a.column_position"
+ query = self._index_query(owner)
- text = text % {"dblink": dblink}
+ pks = {
+ row_dict["constraint_name"]
+ for row_dict in self._get_all_constraint_rows(
+ connection, schema, dblink, all_objects, **kw
+ )
+ if row_dict["constraint_type"] == "P"
+ }
- q = sql.text(text)
- rp = connection.execute(q, params)
- indexes = []
- last_index_name = None
- pk_constraint = self.get_pk_constraint(
+ result = self._run_batches(
connection,
- table_name,
- schema,
- resolve_synonyms=resolve_synonyms,
- dblink=dblink,
- info_cache=kw.get("info_cache"),
+ query,
+ dblink,
+ returns_long=False,
+ mappings=True,
+ all_objects=all_objects,
)
- uniqueness = dict(NONUNIQUE=False, UNIQUE=True)
- enabled = dict(DISABLED=False, ENABLED=True)
+ return [
+ row_dict
+ for row_dict in result
+ if row_dict["index_name"] not in pks
+ ]
- oracle_sys_col = re.compile(r"SYS_NC\d+\$", re.IGNORECASE)
+ @_handle_synonyms_decorator
+ def get_multi_indexes(
+ self,
+ connection,
+ *,
+ schema,
+ filter_names,
+ scope,
+ kind,
+ dblink=None,
+ **kw,
+ ):
+ """Supported kw arguments are: ``dblink`` to reflect via a db link;
+ ``oracle_resolve_synonyms`` to resolve names to synonyms
+ """
+ all_objects = self._get_all_objects(
+ connection, schema, scope, kind, filter_names, dblink, **kw
+ )
- index = None
- for rset in rp:
- index_name_normalized = self.normalize_name(rset.index_name)
+ uniqueness = {"NONUNIQUE": False, "UNIQUE": True}
+ enabled = {"DISABLED": False, "ENABLED": True}
+ is_bitmap = {"BITMAP", "FUNCTION-BASED BITMAP"}
- # skip primary key index. This is refined as of
- # [ticket:5421]. Note that ALL_INDEXES.GENERATED will by "Y"
- # if the name of this index was generated by Oracle, however
- # if a named primary key constraint was created then this flag
- # is false.
- if (
- pk_constraint
- and index_name_normalized == pk_constraint["name"]
- ):
- continue
+ oracle_sys_col = re.compile(r"SYS_NC\d+\$", re.IGNORECASE)
- if rset.index_name != last_index_name:
- index = dict(
- name=index_name_normalized,
- column_names=[],
- dialect_options={},
- )
- indexes.append(index)
- index["unique"] = uniqueness.get(rset.uniqueness, False)
+ indexes = defaultdict(dict)
+
+ for row_dict in self._get_indexes_rows(
+ connection, schema, dblink, all_objects, **kw
+ ):
+ index_name = self.normalize_name(row_dict["index_name"])
+ table_name = self.normalize_name(row_dict["table_name"])
+ table_indexes = indexes[(schema, table_name)]
+
+ if index_name not in table_indexes:
+ table_indexes[index_name] = index_dict = {
+ "name": index_name,
+ "column_names": [],
+ "dialect_options": {},
+ "unique": uniqueness.get(row_dict["uniqueness"], False),
+ }
+ do = index_dict["dialect_options"]
+ if row_dict["index_type"] in is_bitmap:
+ do["oracle_bitmap"] = True
+ if enabled.get(row_dict["compression"], False):
+ do["oracle_compress"] = row_dict["prefix_length"]
- if rset.index_type in ("BITMAP", "FUNCTION-BASED BITMAP"):
- index["dialect_options"]["oracle_bitmap"] = True
- if enabled.get(rset.compression, False):
- index["dialect_options"][
- "oracle_compress"
- ] = rset.prefix_length
+ else:
+ index_dict = table_indexes[index_name]
# filter out Oracle SYS_NC names. could also do an outer join
- # to the all_tab_columns table and check for real col names there.
- if not oracle_sys_col.match(rset.column_name):
- index["column_names"].append(
- self.normalize_name(rset.column_name)
+ # to the all_tab_columns table and check for real col names
+ # there.
+ if not oracle_sys_col.match(row_dict["column_name"]):
+ index_dict["column_names"].append(
+ self.normalize_name(row_dict["column_name"])
)
- last_index_name = rset.index_name
- return indexes
+ default = ReflectionDefaults.indexes
- @reflection.cache
- def _get_constraint_data(
- self, connection, table_name, schema=None, dblink="", **kw
- ):
-
- params = {"table_name": table_name}
-
- text = (
- "SELECT"
- "\nac.constraint_name," # 0
- "\nac.constraint_type," # 1
- "\nloc.column_name AS local_column," # 2
- "\nrem.table_name AS remote_table," # 3
- "\nrem.column_name AS remote_column," # 4
- "\nrem.owner AS remote_owner," # 5
- "\nloc.position as loc_pos," # 6
- "\nrem.position as rem_pos," # 7
- "\nac.search_condition," # 8
- "\nac.delete_rule" # 9
- "\nFROM all_constraints%(dblink)s ac,"
- "\nall_cons_columns%(dblink)s loc,"
- "\nall_cons_columns%(dblink)s rem"
- "\nWHERE ac.table_name = CAST(:table_name AS VARCHAR2(128))"
- "\nAND ac.constraint_type IN ('R','P', 'U', 'C')"
- )
-
- if schema is not None:
- params["owner"] = schema
- text += "\nAND ac.owner = CAST(:owner AS VARCHAR2(128))"
-
- text += (
- "\nAND ac.owner = loc.owner"
- "\nAND ac.constraint_name = loc.constraint_name"
- "\nAND ac.r_owner = rem.owner(+)"
- "\nAND ac.r_constraint_name = rem.constraint_name(+)"
- "\nAND (rem.position IS NULL or loc.position=rem.position)"
- "\nORDER BY ac.constraint_name, loc.position"
+ return (
+ (key, list(indexes[key].values()) if key in indexes else default())
+ for key in (
+ (schema, self.normalize_name(obj_name))
+ for obj_name in all_objects
+ )
)
- text = text % {"dblink": dblink}
- rp = connection.execute(sql.text(text), params)
- constraint_data = rp.fetchall()
- return constraint_data
-
@reflection.cache
def get_pk_constraint(self, connection, table_name, schema=None, **kw):
- resolve_synonyms = kw.get("oracle_resolve_synonyms", False)
- dblink = kw.get("dblink", "")
- info_cache = kw.get("info_cache")
-
- (table_name, schema, dblink, synonym) = self._prepare_reflection_args(
+ """Supported kw arguments are: ``dblink`` to reflect via a db link;
+ ``oracle_resolve_synonyms`` to resolve names to synonyms
+ """
+ data = self.get_multi_pk_constraint(
connection,
- table_name,
- schema,
- resolve_synonyms,
- dblink,
- info_cache=info_cache,
+ schema=schema,
+ filter_names=[table_name],
+ scope=ObjectScope.ANY,
+ kind=ObjectKind.ANY,
+ **kw,
)
- pkeys = []
- constraint_name = None
- constraint_data = self._get_constraint_data(
- connection,
- table_name,
- schema,
- dblink,
- info_cache=kw.get("info_cache"),
+ return self._value_or_raise(data, table_name, schema)
+
+ @lru_cache()
+ def _constraint_query(self, owner):
+ local = dictionary.all_cons_columns.alias("local")
+ remote = dictionary.all_cons_columns.alias("remote")
+ return (
+ select(
+ dictionary.all_constraints.c.table_name,
+ dictionary.all_constraints.c.constraint_type,
+ dictionary.all_constraints.c.constraint_name,
+ local.c.column_name.label("local_column"),
+ remote.c.table_name.label("remote_table"),
+ remote.c.column_name.label("remote_column"),
+ remote.c.owner.label("remote_owner"),
+ dictionary.all_constraints.c.search_condition,
+ dictionary.all_constraints.c.delete_rule,
+ )
+ .select_from(dictionary.all_constraints)
+ .join(
+ local,
+ and_(
+ local.c.owner == dictionary.all_constraints.c.owner,
+ dictionary.all_constraints.c.constraint_name
+ == local.c.constraint_name,
+ ),
+ )
+ .outerjoin(
+ remote,
+ and_(
+ dictionary.all_constraints.c.r_owner == remote.c.owner,
+ dictionary.all_constraints.c.r_constraint_name
+ == remote.c.constraint_name,
+ or_(
+ remote.c.position.is_(sql.null()),
+ local.c.position == remote.c.position,
+ ),
+ ),
+ )
+ .where(
+ dictionary.all_constraints.c.owner == owner,
+ dictionary.all_constraints.c.table_name.in_(
+ bindparam("all_objects")
+ ),
+ dictionary.all_constraints.c.constraint_type.in_(
+ ("R", "P", "U", "C")
+ ),
+ )
+ .order_by(
+ dictionary.all_constraints.c.constraint_name, local.c.position
+ )
)
- for row in constraint_data:
- (
- cons_name,
- cons_type,
- local_column,
- remote_table,
- remote_column,
- remote_owner,
- ) = row[0:2] + tuple([self.normalize_name(x) for x in row[2:6]])
- if cons_type == "P":
- if constraint_name is None:
- constraint_name = self.normalize_name(cons_name)
- pkeys.append(local_column)
- return {"constrained_columns": pkeys, "name": constraint_name}
+ @reflection.flexi_cache(
+ ("schema", InternalTraversal.dp_string),
+ ("dblink", InternalTraversal.dp_string),
+ ("all_objects", InternalTraversal.dp_string_list),
+ )
+ def _get_all_constraint_rows(
+ self, connection, schema, dblink, all_objects, **kw
+ ):
+ owner = self.denormalize_name(schema or self.default_schema_name)
+ query = self._constraint_query(owner)
- @reflection.cache
- def get_foreign_keys(self, connection, table_name, schema=None, **kw):
+ # since the result is cached a list must be created
+ values = list(
+ self._run_batches(
+ connection,
+ query,
+ dblink,
+ returns_long=False,
+ mappings=True,
+ all_objects=all_objects,
+ )
+ )
+ return values
+
+ @_handle_synonyms_decorator
+ def get_multi_pk_constraint(
+ self,
+ connection,
+ *,
+ scope,
+ schema,
+ filter_names,
+ kind,
+ dblink=None,
+ **kw,
+ ):
+ """Supported kw arguments are: ``dblink`` to reflect via a db link;
+ ``oracle_resolve_synonyms`` to resolve names to synonyms
"""
+ all_objects = self._get_all_objects(
+ connection, schema, scope, kind, filter_names, dblink, **kw
+ )
- kw arguments can be:
+ primary_keys = defaultdict(dict)
+ default = ReflectionDefaults.pk_constraint
- oracle_resolve_synonyms
+ for row_dict in self._get_all_constraint_rows(
+ connection, schema, dblink, all_objects, **kw
+ ):
+ if row_dict["constraint_type"] != "P":
+ continue
+ table_name = self.normalize_name(row_dict["table_name"])
+ constraint_name = self.normalize_name(row_dict["constraint_name"])
+ column_name = self.normalize_name(row_dict["local_column"])
+
+ table_pk = primary_keys[(schema, table_name)]
+ if not table_pk:
+ table_pk["name"] = constraint_name
+ table_pk["constrained_columns"] = [column_name]
+ else:
+ table_pk["constrained_columns"].append(column_name)
- dblink
+ return (
+ (key, primary_keys[key] if key in primary_keys else default())
+ for key in (
+ (schema, self.normalize_name(obj_name))
+ for obj_name in all_objects
+ )
+ )
+ @reflection.cache
+ def get_foreign_keys(
+ self,
+ connection,
+ table_name,
+ schema=None,
+ **kw,
+ ):
+ """Supported kw arguments are: ``dblink`` to reflect via a db link;
+ ``oracle_resolve_synonyms`` to resolve names to synonyms
"""
- requested_schema = schema # to check later on
- resolve_synonyms = kw.get("oracle_resolve_synonyms", False)
- dblink = kw.get("dblink", "")
- info_cache = kw.get("info_cache")
-
- (table_name, schema, dblink, synonym) = self._prepare_reflection_args(
+ data = self.get_multi_foreign_keys(
connection,
- table_name,
- schema,
- resolve_synonyms,
- dblink,
- info_cache=info_cache,
+ schema=schema,
+ filter_names=[table_name],
+ scope=ObjectScope.ANY,
+ kind=ObjectKind.ANY,
+ **kw,
)
+ return self._value_or_raise(data, table_name, schema)
- constraint_data = self._get_constraint_data(
- connection,
- table_name,
- schema,
- dblink,
- info_cache=kw.get("info_cache"),
+ @_handle_synonyms_decorator
+ def get_multi_foreign_keys(
+ self,
+ connection,
+ *,
+ scope,
+ schema,
+ filter_names,
+ kind,
+ dblink=None,
+ **kw,
+ ):
+ """Supported kw arguments are: ``dblink`` to reflect via a db link;
+ ``oracle_resolve_synonyms`` to resolve names to synonyms
+ """
+ all_objects = self._get_all_objects(
+ connection, schema, scope, kind, filter_names, dblink, **kw
)
- def fkey_rec():
- return {
- "name": None,
- "constrained_columns": [],
- "referred_schema": None,
- "referred_table": None,
- "referred_columns": [],
- "options": {},
- }
+ resolve_synonyms = kw.get("oracle_resolve_synonyms", False)
- fkeys = util.defaultdict(fkey_rec)
+ owner = self.denormalize_name(schema or self.default_schema_name)
- for row in constraint_data:
- (
- cons_name,
- cons_type,
- local_column,
- remote_table,
- remote_column,
- remote_owner,
- ) = row[0:2] + tuple([self.normalize_name(x) for x in row[2:6]])
-
- cons_name = self.normalize_name(cons_name)
-
- if cons_type == "R":
- if remote_table is None:
- # ticket 363
- util.warn(
- (
- "Got 'None' querying 'table_name' from "
- "all_cons_columns%(dblink)s - does the user have "
- "proper rights to the table?"
- )
- % {"dblink": dblink}
- )
- continue
+ all_remote_owners = set()
+ fkeys = defaultdict(dict)
+
+ for row_dict in self._get_all_constraint_rows(
+ connection, schema, dblink, all_objects, **kw
+ ):
+ if row_dict["constraint_type"] != "R":
+ continue
+
+ table_name = self.normalize_name(row_dict["table_name"])
+ constraint_name = self.normalize_name(row_dict["constraint_name"])
+ table_fkey = fkeys[(schema, table_name)]
+
+ assert constraint_name is not None
- rec = fkeys[cons_name]
- rec["name"] = cons_name
- local_cols, remote_cols = (
- rec["constrained_columns"],
- rec["referred_columns"],
+ local_column = self.normalize_name(row_dict["local_column"])
+ remote_table = self.normalize_name(row_dict["remote_table"])
+ remote_column = self.normalize_name(row_dict["remote_column"])
+ remote_owner_orig = row_dict["remote_owner"]
+ remote_owner = self.normalize_name(remote_owner_orig)
+ if remote_owner_orig is not None:
+ all_remote_owners.add(remote_owner_orig)
+
+ if remote_table is None:
+ # ticket 363
+ if dblink and not dblink.startswith("@"):
+ dblink = f"@{dblink}"
+ util.warn(
+ "Got 'None' querying 'table_name' from "
+ f"all_cons_columns{dblink or ''} - does the user have "
+ "proper rights to the table?"
)
+ continue
- if not rec["referred_table"]:
- if resolve_synonyms:
- (
- ref_remote_name,
- ref_remote_owner,
- ref_dblink,
- ref_synonym,
- ) = self._resolve_synonym(
- connection,
- desired_owner=self.denormalize_name(remote_owner),
- desired_table=self.denormalize_name(remote_table),
- )
- if ref_synonym:
- remote_table = self.normalize_name(ref_synonym)
- remote_owner = self.normalize_name(
- ref_remote_owner
- )
+ if constraint_name not in table_fkey:
+ table_fkey[constraint_name] = fkey = {
+ "name": constraint_name,
+ "constrained_columns": [],
+ "referred_schema": None,
+ "referred_table": remote_table,
+ "referred_columns": [],
+ "options": {},
+ }
- rec["referred_table"] = remote_table
+ if resolve_synonyms:
+ # will be removed below
+ fkey["_ref_schema"] = remote_owner
- if (
- requested_schema is not None
- or self.denormalize_name(remote_owner) != schema
- ):
- rec["referred_schema"] = remote_owner
+ if schema is not None or remote_owner_orig != owner:
+ fkey["referred_schema"] = remote_owner
+
+ delete_rule = row_dict["delete_rule"]
+ if delete_rule != "NO ACTION":
+ fkey["options"]["ondelete"] = delete_rule
+
+ else:
+ fkey = table_fkey[constraint_name]
+
+ fkey["constrained_columns"].append(local_column)
+ fkey["referred_columns"].append(remote_column)
+
+ if resolve_synonyms and all_remote_owners:
+ query = select(
+ dictionary.all_synonyms.c.owner,
+ dictionary.all_synonyms.c.table_name,
+ dictionary.all_synonyms.c.table_owner,
+ dictionary.all_synonyms.c.synonym_name,
+ ).where(dictionary.all_synonyms.c.owner.in_(all_remote_owners))
+
+ result = self._execute_reflection(
+ connection, query, dblink, returns_long=False
+ ).mappings()
- if row[9] != "NO ACTION":
- rec["options"]["ondelete"] = row[9]
+ remote_owners_lut = {}
+ for row in result:
+ synonym_owner = self.normalize_name(row["owner"])
+ table_name = self.normalize_name(row["table_name"])
- local_cols.append(local_column)
- remote_cols.append(remote_column)
+ remote_owners_lut[(synonym_owner, table_name)] = (
+ row["table_owner"],
+ row["synonym_name"],
+ )
+
+ empty = (None, None)
+ for table_fkeys in fkeys.values():
+ for table_fkey in table_fkeys.values():
+ key = (
+ table_fkey.pop("_ref_schema"),
+ table_fkey["referred_table"],
+ )
+ remote_owner, syn_name = remote_owners_lut.get(key, empty)
+ if syn_name:
+ sn = self.normalize_name(syn_name)
+ table_fkey["referred_table"] = sn
+ if schema is not None or remote_owner != owner:
+ ro = self.normalize_name(remote_owner)
+ table_fkey["referred_schema"] = ro
+ else:
+ table_fkey["referred_schema"] = None
+ default = ReflectionDefaults.foreign_keys
- return list(fkeys.values())
+ return (
+ (key, list(fkeys[key].values()) if key in fkeys else default())
+ for key in (
+ (schema, self.normalize_name(obj_name))
+ for obj_name in all_objects
+ )
+ )
@reflection.cache
def get_unique_constraints(
self, connection, table_name, schema=None, **kw
):
- resolve_synonyms = kw.get("oracle_resolve_synonyms", False)
- dblink = kw.get("dblink", "")
- info_cache = kw.get("info_cache")
-
- (table_name, schema, dblink, synonym) = self._prepare_reflection_args(
+ """Supported kw arguments are: ``dblink`` to reflect via a db link;
+ ``oracle_resolve_synonyms`` to resolve names to synonyms
+ """
+ data = self.get_multi_unique_constraints(
connection,
- table_name,
- schema,
- resolve_synonyms,
- dblink,
- info_cache=info_cache,
+ schema=schema,
+ filter_names=[table_name],
+ scope=ObjectScope.ANY,
+ kind=ObjectKind.ANY,
+ **kw,
)
+ return self._value_or_raise(data, table_name, schema)
- constraint_data = self._get_constraint_data(
- connection,
- table_name,
- schema,
- dblink,
- info_cache=kw.get("info_cache"),
+ @_handle_synonyms_decorator
+ def get_multi_unique_constraints(
+ self,
+ connection,
+ *,
+ scope,
+ schema,
+ filter_names,
+ kind,
+ dblink=None,
+ **kw,
+ ):
+ """Supported kw arguments are: ``dblink`` to reflect via a db link;
+ ``oracle_resolve_synonyms`` to resolve names to synonyms
+ """
+ all_objects = self._get_all_objects(
+ connection, schema, scope, kind, filter_names, dblink, **kw
)
- unique_keys = filter(lambda x: x[1] == "U", constraint_data)
- uniques_group = groupby(unique_keys, lambda x: x[0])
+ unique_cons = defaultdict(dict)
index_names = {
- ix["name"]
- for ix in self.get_indexes(connection, table_name, schema=schema)
+ row_dict["index_name"]
+ for row_dict in self._get_indexes_rows(
+ connection, schema, dblink, all_objects, **kw
+ )
}
- return [
- {
- "name": name,
- "column_names": cols,
- "duplicates_index": name if name in index_names else None,
- }
- for name, cols in [
- [
- self.normalize_name(i[0]),
- [self.normalize_name(x[2]) for x in i[1]],
- ]
- for i in uniques_group
- ]
- ]
+
+ for row_dict in self._get_all_constraint_rows(
+ connection, schema, dblink, all_objects, **kw
+ ):
+ if row_dict["constraint_type"] != "U":
+ continue
+ table_name = self.normalize_name(row_dict["table_name"])
+ constraint_name_orig = row_dict["constraint_name"]
+ constraint_name = self.normalize_name(constraint_name_orig)
+ column_name = self.normalize_name(row_dict["local_column"])
+ table_uc = unique_cons[(schema, table_name)]
+
+ assert constraint_name is not None
+
+ if constraint_name not in table_uc:
+ table_uc[constraint_name] = uc = {
+ "name": constraint_name,
+ "column_names": [],
+ "duplicates_index": constraint_name
+ if constraint_name_orig in index_names
+ else None,
+ }
+ else:
+ uc = table_uc[constraint_name]
+
+ uc["column_names"].append(column_name)
+
+ default = ReflectionDefaults.unique_constraints
+
+ return (
+ (
+ key,
+ list(unique_cons[key].values())
+ if key in unique_cons
+ else default(),
+ )
+ for key in (
+ (schema, self.normalize_name(obj_name))
+ for obj_name in all_objects
+ )
+ )
@reflection.cache
def get_view_definition(
connection,
view_name,
schema=None,
- resolve_synonyms=False,
- dblink="",
+ dblink=None,
**kw,
):
- info_cache = kw.get("info_cache")
- (view_name, schema, dblink, synonym) = self._prepare_reflection_args(
- connection,
- view_name,
- schema,
- resolve_synonyms,
- dblink,
- info_cache=info_cache,
+ """Supported kw arguments are: ``dblink`` to reflect via a db link;
+ ``oracle_resolve_synonyms`` to resolve names to synonyms
+ """
+ if kw.get("oracle_resolve_synonyms", False):
+ synonyms = self._get_synonyms(
+ connection, schema, filter_names=[view_name], dblink=dblink
+ )
+ if synonyms:
+ assert len(synonyms) == 1
+ row_dict = synonyms[0]
+ dblink = self.normalize_name(row_dict["db_link"])
+ schema = row_dict["table_owner"]
+ view_name = row_dict["table_name"]
+
+ name = self.denormalize_name(view_name)
+ owner = self.denormalize_name(schema or self.default_schema_name)
+ query = (
+ select(dictionary.all_views.c.text)
+ .where(
+ dictionary.all_views.c.view_name == name,
+ dictionary.all_views.c.owner == owner,
+ )
+ .union_all(
+ select(dictionary.all_mviews.c.query).where(
+ dictionary.all_mviews.c.mview_name == name,
+ dictionary.all_mviews.c.owner == owner,
+ )
+ )
)
- params = {"view_name": view_name}
- text = "SELECT text FROM all_views WHERE view_name=:view_name"
-
- if schema is not None:
- text += " AND owner = :schema"
- params["schema"] = schema
-
- rp = connection.execute(sql.text(text), params).scalar()
- if rp:
- return rp
+ rp = self._execute_reflection(
+ connection, query, dblink, returns_long=False
+ ).scalar()
+ if rp is None:
+ raise exc.NoSuchTableError(
+ f"{schema}.{view_name}" if schema else view_name
+ )
else:
- return None
+ return rp
@reflection.cache
def get_check_constraints(
self, connection, table_name, schema=None, include_all=False, **kw
):
- resolve_synonyms = kw.get("oracle_resolve_synonyms", False)
- dblink = kw.get("dblink", "")
- info_cache = kw.get("info_cache")
-
- (table_name, schema, dblink, synonym) = self._prepare_reflection_args(
+ """Supported kw arguments are: ``dblink`` to reflect via a db link;
+ ``oracle_resolve_synonyms`` to resolve names to synonyms
+ """
+ data = self.get_multi_check_constraints(
connection,
- table_name,
- schema,
- resolve_synonyms,
- dblink,
- info_cache=info_cache,
+ schema=schema,
+ filter_names=[table_name],
+ scope=ObjectScope.ANY,
+ include_all=include_all,
+ kind=ObjectKind.ANY,
+ **kw,
)
+ return self._value_or_raise(data, table_name, schema)
- constraint_data = self._get_constraint_data(
- connection,
- table_name,
- schema,
- dblink,
- info_cache=kw.get("info_cache"),
+ @_handle_synonyms_decorator
+ def get_multi_check_constraints(
+ self,
+ connection,
+ *,
+ schema,
+ filter_names,
+ dblink=None,
+ scope,
+ kind,
+ include_all=False,
+ **kw,
+ ):
+ """Supported kw arguments are: ``dblink`` to reflect via a db link;
+ ``oracle_resolve_synonyms`` to resolve names to synonyms
+ """
+ all_objects = self._get_all_objects(
+ connection, schema, scope, kind, filter_names, dblink, **kw
)
- check_constraints = filter(lambda x: x[1] == "C", constraint_data)
+ not_null = re.compile(r"..+?. IS NOT NULL$")
- return [
- {"name": self.normalize_name(cons[0]), "sqltext": cons[8]}
- for cons in check_constraints
- if include_all or not re.match(r"..+?. IS NOT NULL$", cons[8])
- ]
+ check_constraints = defaultdict(list)
+
+ for row_dict in self._get_all_constraint_rows(
+ connection, schema, dblink, all_objects, **kw
+ ):
+ if row_dict["constraint_type"] != "C":
+ continue
+ table_name = self.normalize_name(row_dict["table_name"])
+ constraint_name = self.normalize_name(row_dict["constraint_name"])
+ search_condition = row_dict["search_condition"]
+
+ table_checks = check_constraints[(schema, table_name)]
+ if constraint_name is not None and (
+ include_all or not not_null.match(search_condition)
+ ):
+ table_checks.append(
+ {"name": constraint_name, "sqltext": search_condition}
+ )
+
+ default = ReflectionDefaults.check_constraints
+
+ return (
+ (
+ key,
+ check_constraints[key]
+ if key in check_constraints
+ else default(),
+ )
+ for key in (
+ (schema, self.normalize_name(obj_name))
+ for obj_name in all_objects
+ )
+ )
+
+ def _list_dblinks(self, connection, dblink=None):
+ query = select(dictionary.all_db_links.c.db_link)
+ links = self._execute_reflection(
+ connection, query, dblink, returns_long=False
+ ).scalars()
+ return [self.normalize_name(link) for link in links]
class _OuterJoinColumn(sql.ClauseElement):
from .base import OracleCompiler
from .base import OracleDialect
from .base import OracleExecutionContext
+from .types import _OracleDateLiteralRender
from ... import exc
from ... import util
from ...engine import cursor as _cursor
return process
-class _CXOracleTIMESTAMP(oracle._OracleDateLiteralRender, sqltypes.TIMESTAMP):
+class _CXOracleTIMESTAMP(_OracleDateLiteralRender, sqltypes.TIMESTAMP):
def literal_processor(self, dialect):
return self._literal_processor_datetime(dialect)
return None
def pre_exec(self):
+ super().pre_exec()
if not getattr(self.compiled, "_oracle_cx_sql_compiler", False):
return
--- /dev/null
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+# mypy: ignore-errors
+
+from .types import DATE
+from .types import LONG
+from .types import NUMBER
+from .types import RAW
+from .types import VARCHAR2
+from ... import Column
+from ... import MetaData
+from ... import Table
+from ... import table
+from ...sql.sqltypes import CHAR
+
+# constants
+DB_LINK_PLACEHOLDER = "__$sa_dblink$__"
+# tables
+dual = table("dual")
+dictionary_meta = MetaData()
+
+# NOTE: all the dictionary_meta are aliases because oracle does not like
+# using the full table@dblink for every column in query, and complains with
+# ORA-00960: ambiguous column naming in select list
+all_tables = Table(
+ "all_tables" + DB_LINK_PLACEHOLDER,
+ dictionary_meta,
+ Column("owner", VARCHAR2(128), nullable=False),
+ Column("table_name", VARCHAR2(128), nullable=False),
+ Column("tablespace_name", VARCHAR2(30)),
+ Column("cluster_name", VARCHAR2(128)),
+ Column("iot_name", VARCHAR2(128)),
+ Column("status", VARCHAR2(8)),
+ Column("pct_free", NUMBER),
+ Column("pct_used", NUMBER),
+ Column("ini_trans", NUMBER),
+ Column("max_trans", NUMBER),
+ Column("initial_extent", NUMBER),
+ Column("next_extent", NUMBER),
+ Column("min_extents", NUMBER),
+ Column("max_extents", NUMBER),
+ Column("pct_increase", NUMBER),
+ Column("freelists", NUMBER),
+ Column("freelist_groups", NUMBER),
+ Column("logging", VARCHAR2(3)),
+ Column("backed_up", VARCHAR2(1)),
+ Column("num_rows", NUMBER),
+ Column("blocks", NUMBER),
+ Column("empty_blocks", NUMBER),
+ Column("avg_space", NUMBER),
+ Column("chain_cnt", NUMBER),
+ Column("avg_row_len", NUMBER),
+ Column("avg_space_freelist_blocks", NUMBER),
+ Column("num_freelist_blocks", NUMBER),
+ Column("degree", VARCHAR2(10)),
+ Column("instances", VARCHAR2(10)),
+ Column("cache", VARCHAR2(5)),
+ Column("table_lock", VARCHAR2(8)),
+ Column("sample_size", NUMBER),
+ Column("last_analyzed", DATE),
+ Column("partitioned", VARCHAR2(3)),
+ Column("iot_type", VARCHAR2(12)),
+ Column("temporary", VARCHAR2(1)),
+ Column("secondary", VARCHAR2(1)),
+ Column("nested", VARCHAR2(3)),
+ Column("buffer_pool", VARCHAR2(7)),
+ Column("flash_cache", VARCHAR2(7)),
+ Column("cell_flash_cache", VARCHAR2(7)),
+ Column("row_movement", VARCHAR2(8)),
+ Column("global_stats", VARCHAR2(3)),
+ Column("user_stats", VARCHAR2(3)),
+ Column("duration", VARCHAR2(15)),
+ Column("skip_corrupt", VARCHAR2(8)),
+ Column("monitoring", VARCHAR2(3)),
+ Column("cluster_owner", VARCHAR2(128)),
+ Column("dependencies", VARCHAR2(8)),
+ Column("compression", VARCHAR2(8)),
+ Column("compress_for", VARCHAR2(30)),
+ Column("dropped", VARCHAR2(3)),
+ Column("read_only", VARCHAR2(3)),
+ Column("segment_created", VARCHAR2(3)),
+ Column("result_cache", VARCHAR2(7)),
+ Column("clustering", VARCHAR2(3)),
+ Column("activity_tracking", VARCHAR2(23)),
+ Column("dml_timestamp", VARCHAR2(25)),
+ Column("has_identity", VARCHAR2(3)),
+ Column("container_data", VARCHAR2(3)),
+ Column("inmemory", VARCHAR2(8)),
+ Column("inmemory_priority", VARCHAR2(8)),
+ Column("inmemory_distribute", VARCHAR2(15)),
+ Column("inmemory_compression", VARCHAR2(17)),
+ Column("inmemory_duplicate", VARCHAR2(13)),
+ Column("default_collation", VARCHAR2(100)),
+ Column("duplicated", VARCHAR2(1)),
+ Column("sharded", VARCHAR2(1)),
+ Column("externally_sharded", VARCHAR2(1)),
+ Column("externally_duplicated", VARCHAR2(1)),
+ Column("external", VARCHAR2(3)),
+ Column("hybrid", VARCHAR2(3)),
+ Column("cellmemory", VARCHAR2(24)),
+ Column("containers_default", VARCHAR2(3)),
+ Column("container_map", VARCHAR2(3)),
+ Column("extended_data_link", VARCHAR2(3)),
+ Column("extended_data_link_map", VARCHAR2(3)),
+ Column("inmemory_service", VARCHAR2(12)),
+ Column("inmemory_service_name", VARCHAR2(1000)),
+ Column("container_map_object", VARCHAR2(3)),
+ Column("memoptimize_read", VARCHAR2(8)),
+ Column("memoptimize_write", VARCHAR2(8)),
+ Column("has_sensitive_column", VARCHAR2(3)),
+ Column("admit_null", VARCHAR2(3)),
+ Column("data_link_dml_enabled", VARCHAR2(3)),
+ Column("logical_replication", VARCHAR2(8)),
+).alias("a_tables")
+
+all_views = Table(
+ "all_views" + DB_LINK_PLACEHOLDER,
+ dictionary_meta,
+ Column("owner", VARCHAR2(128), nullable=False),
+ Column("view_name", VARCHAR2(128), nullable=False),
+ Column("text_length", NUMBER),
+ Column("text", LONG),
+ Column("text_vc", VARCHAR2(4000)),
+ Column("type_text_length", NUMBER),
+ Column("type_text", VARCHAR2(4000)),
+ Column("oid_text_length", NUMBER),
+ Column("oid_text", VARCHAR2(4000)),
+ Column("view_type_owner", VARCHAR2(128)),
+ Column("view_type", VARCHAR2(128)),
+ Column("superview_name", VARCHAR2(128)),
+ Column("editioning_view", VARCHAR2(1)),
+ Column("read_only", VARCHAR2(1)),
+ Column("container_data", VARCHAR2(1)),
+ Column("bequeath", VARCHAR2(12)),
+ Column("origin_con_id", VARCHAR2(256)),
+ Column("default_collation", VARCHAR2(100)),
+ Column("containers_default", VARCHAR2(3)),
+ Column("container_map", VARCHAR2(3)),
+ Column("extended_data_link", VARCHAR2(3)),
+ Column("extended_data_link_map", VARCHAR2(3)),
+ Column("has_sensitive_column", VARCHAR2(3)),
+ Column("admit_null", VARCHAR2(3)),
+ Column("pdb_local_only", VARCHAR2(3)),
+).alias("a_views")
+
+all_sequences = Table(
+ "all_sequences" + DB_LINK_PLACEHOLDER,
+ dictionary_meta,
+ Column("sequence_owner", VARCHAR2(128), nullable=False),
+ Column("sequence_name", VARCHAR2(128), nullable=False),
+ Column("min_value", NUMBER),
+ Column("max_value", NUMBER),
+ Column("increment_by", NUMBER, nullable=False),
+ Column("cycle_flag", VARCHAR2(1)),
+ Column("order_flag", VARCHAR2(1)),
+ Column("cache_size", NUMBER, nullable=False),
+ Column("last_number", NUMBER, nullable=False),
+ Column("scale_flag", VARCHAR2(1)),
+ Column("extend_flag", VARCHAR2(1)),
+ Column("sharded_flag", VARCHAR2(1)),
+ Column("session_flag", VARCHAR2(1)),
+ Column("keep_value", VARCHAR2(1)),
+).alias("a_sequences")
+
+all_users = Table(
+ "all_users" + DB_LINK_PLACEHOLDER,
+ dictionary_meta,
+ Column("username", VARCHAR2(128), nullable=False),
+ Column("user_id", NUMBER, nullable=False),
+ Column("created", DATE, nullable=False),
+ Column("common", VARCHAR2(3)),
+ Column("oracle_maintained", VARCHAR2(1)),
+ Column("inherited", VARCHAR2(3)),
+ Column("default_collation", VARCHAR2(100)),
+ Column("implicit", VARCHAR2(3)),
+ Column("all_shard", VARCHAR2(3)),
+ Column("external_shard", VARCHAR2(3)),
+).alias("a_users")
+
+all_mviews = Table(
+ "all_mviews" + DB_LINK_PLACEHOLDER,
+ dictionary_meta,
+ Column("owner", VARCHAR2(128), nullable=False),
+ Column("mview_name", VARCHAR2(128), nullable=False),
+ Column("container_name", VARCHAR2(128), nullable=False),
+ Column("query", LONG),
+ Column("query_len", NUMBER(38)),
+ Column("updatable", VARCHAR2(1)),
+ Column("update_log", VARCHAR2(128)),
+ Column("master_rollback_seg", VARCHAR2(128)),
+ Column("master_link", VARCHAR2(128)),
+ Column("rewrite_enabled", VARCHAR2(1)),
+ Column("rewrite_capability", VARCHAR2(9)),
+ Column("refresh_mode", VARCHAR2(6)),
+ Column("refresh_method", VARCHAR2(8)),
+ Column("build_mode", VARCHAR2(9)),
+ Column("fast_refreshable", VARCHAR2(18)),
+ Column("last_refresh_type", VARCHAR2(8)),
+ Column("last_refresh_date", DATE),
+ Column("last_refresh_end_time", DATE),
+ Column("staleness", VARCHAR2(19)),
+ Column("after_fast_refresh", VARCHAR2(19)),
+ Column("unknown_prebuilt", VARCHAR2(1)),
+ Column("unknown_plsql_func", VARCHAR2(1)),
+ Column("unknown_external_table", VARCHAR2(1)),
+ Column("unknown_consider_fresh", VARCHAR2(1)),
+ Column("unknown_import", VARCHAR2(1)),
+ Column("unknown_trusted_fd", VARCHAR2(1)),
+ Column("compile_state", VARCHAR2(19)),
+ Column("use_no_index", VARCHAR2(1)),
+ Column("stale_since", DATE),
+ Column("num_pct_tables", NUMBER),
+ Column("num_fresh_pct_regions", NUMBER),
+ Column("num_stale_pct_regions", NUMBER),
+ Column("segment_created", VARCHAR2(3)),
+ Column("evaluation_edition", VARCHAR2(128)),
+ Column("unusable_before", VARCHAR2(128)),
+ Column("unusable_beginning", VARCHAR2(128)),
+ Column("default_collation", VARCHAR2(100)),
+ Column("on_query_computation", VARCHAR2(1)),
+ Column("auto", VARCHAR2(3)),
+).alias("a_mviews")
+
+all_tab_identity_cols = Table(
+ "all_tab_identity_cols" + DB_LINK_PLACEHOLDER,
+ dictionary_meta,
+ Column("owner", VARCHAR2(128), nullable=False),
+ Column("table_name", VARCHAR2(128), nullable=False),
+ Column("column_name", VARCHAR2(128), nullable=False),
+ Column("generation_type", VARCHAR2(10)),
+ Column("sequence_name", VARCHAR2(128), nullable=False),
+ Column("identity_options", VARCHAR2(298)),
+).alias("a_tab_identity_cols")
+
+all_tab_cols = Table(
+ "all_tab_cols" + DB_LINK_PLACEHOLDER,
+ dictionary_meta,
+ Column("owner", VARCHAR2(128), nullable=False),
+ Column("table_name", VARCHAR2(128), nullable=False),
+ Column("column_name", VARCHAR2(128), nullable=False),
+ Column("data_type", VARCHAR2(128)),
+ Column("data_type_mod", VARCHAR2(3)),
+ Column("data_type_owner", VARCHAR2(128)),
+ Column("data_length", NUMBER, nullable=False),
+ Column("data_precision", NUMBER),
+ Column("data_scale", NUMBER),
+ Column("nullable", VARCHAR2(1)),
+ Column("column_id", NUMBER),
+ Column("default_length", NUMBER),
+ Column("data_default", LONG),
+ Column("num_distinct", NUMBER),
+ Column("low_value", RAW(1000)),
+ Column("high_value", RAW(1000)),
+ Column("density", NUMBER),
+ Column("num_nulls", NUMBER),
+ Column("num_buckets", NUMBER),
+ Column("last_analyzed", DATE),
+ Column("sample_size", NUMBER),
+ Column("character_set_name", VARCHAR2(44)),
+ Column("char_col_decl_length", NUMBER),
+ Column("global_stats", VARCHAR2(3)),
+ Column("user_stats", VARCHAR2(3)),
+ Column("avg_col_len", NUMBER),
+ Column("char_length", NUMBER),
+ Column("char_used", VARCHAR2(1)),
+ Column("v80_fmt_image", VARCHAR2(3)),
+ Column("data_upgraded", VARCHAR2(3)),
+ Column("hidden_column", VARCHAR2(3)),
+ Column("virtual_column", VARCHAR2(3)),
+ Column("segment_column_id", NUMBER),
+ Column("internal_column_id", NUMBER, nullable=False),
+ Column("histogram", VARCHAR2(15)),
+ Column("qualified_col_name", VARCHAR2(4000)),
+ Column("user_generated", VARCHAR2(3)),
+ Column("default_on_null", VARCHAR2(3)),
+ Column("identity_column", VARCHAR2(3)),
+ Column("evaluation_edition", VARCHAR2(128)),
+ Column("unusable_before", VARCHAR2(128)),
+ Column("unusable_beginning", VARCHAR2(128)),
+ Column("collation", VARCHAR2(100)),
+ Column("collated_column_id", NUMBER),
+).alias("a_tab_cols")
+
+all_tab_comments = Table(
+ "all_tab_comments" + DB_LINK_PLACEHOLDER,
+ dictionary_meta,
+ Column("owner", VARCHAR2(128), nullable=False),
+ Column("table_name", VARCHAR2(128), nullable=False),
+ Column("table_type", VARCHAR2(11)),
+ Column("comments", VARCHAR2(4000)),
+ Column("origin_con_id", NUMBER),
+).alias("a_tab_comments")
+
+all_col_comments = Table(
+ "all_col_comments" + DB_LINK_PLACEHOLDER,
+ dictionary_meta,
+ Column("owner", VARCHAR2(128), nullable=False),
+ Column("table_name", VARCHAR2(128), nullable=False),
+ Column("column_name", VARCHAR2(128), nullable=False),
+ Column("comments", VARCHAR2(4000)),
+ Column("origin_con_id", NUMBER),
+).alias("a_col_comments")
+
+all_mview_comments = Table(
+ "all_mview_comments" + DB_LINK_PLACEHOLDER,
+ dictionary_meta,
+ Column("owner", VARCHAR2(128), nullable=False),
+ Column("mview_name", VARCHAR2(128), nullable=False),
+ Column("comments", VARCHAR2(4000)),
+).alias("a_mview_comments")
+
+all_ind_columns = Table(
+ "all_ind_columns" + DB_LINK_PLACEHOLDER,
+ dictionary_meta,
+ Column("index_owner", VARCHAR2(128), nullable=False),
+ Column("index_name", VARCHAR2(128), nullable=False),
+ Column("table_owner", VARCHAR2(128), nullable=False),
+ Column("table_name", VARCHAR2(128), nullable=False),
+ Column("column_name", VARCHAR2(4000)),
+ Column("column_position", NUMBER, nullable=False),
+ Column("column_length", NUMBER, nullable=False),
+ Column("char_length", NUMBER),
+ Column("descend", VARCHAR2(4)),
+ Column("collated_column_id", NUMBER),
+).alias("a_ind_columns")
+
+all_indexes = Table(
+ "all_indexes" + DB_LINK_PLACEHOLDER,
+ dictionary_meta,
+ Column("owner", VARCHAR2(128), nullable=False),
+ Column("index_name", VARCHAR2(128), nullable=False),
+ Column("index_type", VARCHAR2(27)),
+ Column("table_owner", VARCHAR2(128), nullable=False),
+ Column("table_name", VARCHAR2(128), nullable=False),
+ Column("table_type", CHAR(11)),
+ Column("uniqueness", VARCHAR2(9)),
+ Column("compression", VARCHAR2(13)),
+ Column("prefix_length", NUMBER),
+ Column("tablespace_name", VARCHAR2(30)),
+ Column("ini_trans", NUMBER),
+ Column("max_trans", NUMBER),
+ Column("initial_extent", NUMBER),
+ Column("next_extent", NUMBER),
+ Column("min_extents", NUMBER),
+ Column("max_extents", NUMBER),
+ Column("pct_increase", NUMBER),
+ Column("pct_threshold", NUMBER),
+ Column("include_column", NUMBER),
+ Column("freelists", NUMBER),
+ Column("freelist_groups", NUMBER),
+ Column("pct_free", NUMBER),
+ Column("logging", VARCHAR2(3)),
+ Column("blevel", NUMBER),
+ Column("leaf_blocks", NUMBER),
+ Column("distinct_keys", NUMBER),
+ Column("avg_leaf_blocks_per_key", NUMBER),
+ Column("avg_data_blocks_per_key", NUMBER),
+ Column("clustering_factor", NUMBER),
+ Column("status", VARCHAR2(8)),
+ Column("num_rows", NUMBER),
+ Column("sample_size", NUMBER),
+ Column("last_analyzed", DATE),
+ Column("degree", VARCHAR2(40)),
+ Column("instances", VARCHAR2(40)),
+ Column("partitioned", VARCHAR2(3)),
+ Column("temporary", VARCHAR2(1)),
+ Column("generated", VARCHAR2(1)),
+ Column("secondary", VARCHAR2(1)),
+ Column("buffer_pool", VARCHAR2(7)),
+ Column("flash_cache", VARCHAR2(7)),
+ Column("cell_flash_cache", VARCHAR2(7)),
+ Column("user_stats", VARCHAR2(3)),
+ Column("duration", VARCHAR2(15)),
+ Column("pct_direct_access", NUMBER),
+ Column("ityp_owner", VARCHAR2(128)),
+ Column("ityp_name", VARCHAR2(128)),
+ Column("parameters", VARCHAR2(1000)),
+ Column("global_stats", VARCHAR2(3)),
+ Column("domidx_status", VARCHAR2(12)),
+ Column("domidx_opstatus", VARCHAR2(6)),
+ Column("funcidx_status", VARCHAR2(8)),
+ Column("join_index", VARCHAR2(3)),
+ Column("iot_redundant_pkey_elim", VARCHAR2(3)),
+ Column("dropped", VARCHAR2(3)),
+ Column("visibility", VARCHAR2(9)),
+ Column("domidx_management", VARCHAR2(14)),
+ Column("segment_created", VARCHAR2(3)),
+ Column("orphaned_entries", VARCHAR2(3)),
+ Column("indexing", VARCHAR2(7)),
+ Column("auto", VARCHAR2(3)),
+).alias("a_indexes")
+
+all_constraints = Table(
+ "all_constraints" + DB_LINK_PLACEHOLDER,
+ dictionary_meta,
+ Column("owner", VARCHAR2(128)),
+ Column("constraint_name", VARCHAR2(128)),
+ Column("constraint_type", VARCHAR2(1)),
+ Column("table_name", VARCHAR2(128)),
+ Column("search_condition", LONG),
+ Column("search_condition_vc", VARCHAR2(4000)),
+ Column("r_owner", VARCHAR2(128)),
+ Column("r_constraint_name", VARCHAR2(128)),
+ Column("delete_rule", VARCHAR2(9)),
+ Column("status", VARCHAR2(8)),
+ Column("deferrable", VARCHAR2(14)),
+ Column("deferred", VARCHAR2(9)),
+ Column("validated", VARCHAR2(13)),
+ Column("generated", VARCHAR2(14)),
+ Column("bad", VARCHAR2(3)),
+ Column("rely", VARCHAR2(4)),
+ Column("last_change", DATE),
+ Column("index_owner", VARCHAR2(128)),
+ Column("index_name", VARCHAR2(128)),
+ Column("invalid", VARCHAR2(7)),
+ Column("view_related", VARCHAR2(14)),
+ Column("origin_con_id", VARCHAR2(256)),
+).alias("a_constraints")
+
+all_cons_columns = Table(
+ "all_cons_columns" + DB_LINK_PLACEHOLDER,
+ dictionary_meta,
+ Column("owner", VARCHAR2(128), nullable=False),
+ Column("constraint_name", VARCHAR2(128), nullable=False),
+ Column("table_name", VARCHAR2(128), nullable=False),
+ Column("column_name", VARCHAR2(4000)),
+ Column("position", NUMBER),
+).alias("a_cons_columns")
+
+# TODO figure out if it's still relevant, since there is no mention from here
+# https://docs.oracle.com/en/database/oracle/oracle-database/21/refrn/ALL_DB_LINKS.html
+# original note:
+# using user_db_links here since all_db_links appears
+# to have more restricted permissions.
+# https://docs.oracle.com/cd/B28359_01/server.111/b28310/ds_admin005.htm
+# will need to hear from more users if we are doing
+# the right thing here. See [ticket:2619]
+all_db_links = Table(
+ "all_db_links" + DB_LINK_PLACEHOLDER,
+ dictionary_meta,
+ Column("owner", VARCHAR2(128), nullable=False),
+ Column("db_link", VARCHAR2(128), nullable=False),
+ Column("username", VARCHAR2(128)),
+ Column("host", VARCHAR2(2000)),
+ Column("created", DATE, nullable=False),
+ Column("hidden", VARCHAR2(3)),
+ Column("shard_internal", VARCHAR2(3)),
+ Column("valid", VARCHAR2(3)),
+ Column("intra_cdb", VARCHAR2(3)),
+).alias("a_db_links")
+
+all_synonyms = Table(
+ "all_synonyms" + DB_LINK_PLACEHOLDER,
+ dictionary_meta,
+ Column("owner", VARCHAR2(128)),
+ Column("synonym_name", VARCHAR2(128)),
+ Column("table_owner", VARCHAR2(128)),
+ Column("table_name", VARCHAR2(128)),
+ Column("db_link", VARCHAR2(128)),
+ Column("origin_con_id", VARCHAR2(256)),
+).alias("a_synonyms")
+
+all_objects = Table(
+ "all_objects" + DB_LINK_PLACEHOLDER,
+ dictionary_meta,
+ Column("owner", VARCHAR2(128), nullable=False),
+ Column("object_name", VARCHAR2(128), nullable=False),
+ Column("subobject_name", VARCHAR2(128)),
+ Column("object_id", NUMBER, nullable=False),
+ Column("data_object_id", NUMBER),
+ Column("object_type", VARCHAR2(23)),
+ Column("created", DATE, nullable=False),
+ Column("last_ddl_time", DATE, nullable=False),
+ Column("timestamp", VARCHAR2(19)),
+ Column("status", VARCHAR2(7)),
+ Column("temporary", VARCHAR2(1)),
+ Column("generated", VARCHAR2(1)),
+ Column("secondary", VARCHAR2(1)),
+ Column("namespace", NUMBER, nullable=False),
+ Column("edition_name", VARCHAR2(128)),
+ Column("sharing", VARCHAR2(13)),
+ Column("editionable", VARCHAR2(1)),
+ Column("oracle_maintained", VARCHAR2(1)),
+ Column("application", VARCHAR2(1)),
+ Column("default_collation", VARCHAR2(100)),
+ Column("duplicated", VARCHAR2(1)),
+ Column("sharded", VARCHAR2(1)),
+ Column("created_appid", NUMBER),
+ Column("created_vsnid", NUMBER),
+ Column("modified_appid", NUMBER),
+ Column("modified_vsnid", NUMBER),
+).alias("a_objects")
from ... import create_engine
from ... import exc
+from ... import inspect
from ...engine import url as sa_url
from ...testing.provision import configure_follower
from ...testing.provision import create_db
+from ...testing.provision import drop_all_schema_objects_post_tables
+from ...testing.provision import drop_all_schema_objects_pre_tables
from ...testing.provision import drop_db
from ...testing.provision import follower_url_from_main
from ...testing.provision import log
conn.exec_driver_sql("grant unlimited tablespace to %s" % ident)
conn.exec_driver_sql("grant unlimited tablespace to %s_ts1" % ident)
conn.exec_driver_sql("grant unlimited tablespace to %s_ts2" % ident)
+ # these are needed to create materialized views
+ conn.exec_driver_sql("grant create table to %s" % ident)
+ conn.exec_driver_sql("grant create table to %s_ts1" % ident)
+ conn.exec_driver_sql("grant create table to %s_ts2" % ident)
@configure_follower.for_db("oracle")
return False
+@drop_all_schema_objects_pre_tables.for_db("oracle")
+def _ora_drop_all_schema_objects_pre_tables(cfg, eng):
+ _purge_recyclebin(eng)
+ _purge_recyclebin(eng, cfg.test_schema)
+
+
+@drop_all_schema_objects_post_tables.for_db("oracle")
+def _ora_drop_all_schema_objects_post_tables(cfg, eng):
+
+ with eng.begin() as conn:
+ for syn in conn.dialect._get_synonyms(conn, None, None, None):
+ conn.exec_driver_sql(f"drop synonym {syn['synonym_name']}")
+
+ for syn in conn.dialect._get_synonyms(
+ conn, cfg.test_schema, None, None
+ ):
+ conn.exec_driver_sql(
+ f"drop synonym {cfg.test_schema}.{syn['synonym_name']}"
+ )
+
+ for tmp_table in inspect(conn).get_temp_table_names():
+ conn.exec_driver_sql(f"drop table {tmp_table}")
+
+
@drop_db.for_db("oracle")
def _oracle_drop_db(cfg, eng, ident):
with eng.begin() as conn:
@stop_test_class_outside_fixtures.for_db("oracle")
-def stop_test_class_outside_fixtures(config, db, cls):
+def _ora_stop_test_class_outside_fixtures(config, db, cls):
try:
- with db.begin() as conn:
- # run magic command to get rid of identity sequences
- # https://floo.bar/2019/11/29/drop-the-underlying-sequence-of-an-identity-column/ # noqa: E501
- conn.exec_driver_sql("purge recyclebin")
+ _purge_recyclebin(db)
except exc.DatabaseError as err:
log.warning("purge recyclebin command failed: %s", err)
_all_conns.clear()
+def _purge_recyclebin(eng, schema=None):
+ with eng.begin() as conn:
+ if schema is None:
+ # run magic command to get rid of identity sequences
+ # https://floo.bar/2019/11/29/drop-the-underlying-sequence-of-an-identity-column/ # noqa: E501
+ conn.exec_driver_sql("purge recyclebin")
+ else:
+ # per user: https://community.oracle.com/tech/developers/discussion/2255402/how-to-clear-dba-recyclebin-for-a-particular-user # noqa: E501
+ for owner, object_name, type_ in conn.exec_driver_sql(
+ "select owner, object_name,type from "
+ "dba_recyclebin where owner=:schema and type='TABLE'",
+ {"schema": conn.dialect.denormalize_name(schema)},
+ ).all():
+ conn.exec_driver_sql(f'purge {type_} {owner}."{object_name}"')
+
+
_all_conns = set()
--- /dev/null
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+# mypy: ignore-errors
+
+from ...sql import sqltypes
+from ...types import NVARCHAR
+from ...types import VARCHAR
+
+
+class RAW(sqltypes._Binary):
+ __visit_name__ = "RAW"
+
+
+OracleRaw = RAW
+
+
+class NCLOB(sqltypes.Text):
+ __visit_name__ = "NCLOB"
+
+
+class VARCHAR2(VARCHAR):
+ __visit_name__ = "VARCHAR2"
+
+
+NVARCHAR2 = NVARCHAR
+
+
+class NUMBER(sqltypes.Numeric, sqltypes.Integer):
+ __visit_name__ = "NUMBER"
+
+ def __init__(self, precision=None, scale=None, asdecimal=None):
+ if asdecimal is None:
+ asdecimal = bool(scale and scale > 0)
+
+ super(NUMBER, self).__init__(
+ precision=precision, scale=scale, asdecimal=asdecimal
+ )
+
+ def adapt(self, impltype):
+ ret = super(NUMBER, self).adapt(impltype)
+ # leave a hint for the DBAPI handler
+ ret._is_oracle_number = True
+ return ret
+
+ @property
+ def _type_affinity(self):
+ if bool(self.scale and self.scale > 0):
+ return sqltypes.Numeric
+ else:
+ return sqltypes.Integer
+
+
+class FLOAT(sqltypes.FLOAT):
+ """Oracle FLOAT.
+
+ This is the same as :class:`_sqltypes.FLOAT` except that
+ an Oracle-specific :paramref:`_oracle.FLOAT.binary_precision`
+ parameter is accepted, and
+ the :paramref:`_sqltypes.Float.precision` parameter is not accepted.
+
+ Oracle FLOAT types indicate precision in terms of "binary precision", which
+ defaults to 126. For a REAL type, the value is 63. This parameter does not
+ cleanly map to a specific number of decimal places but is roughly
+ equivalent to the desired number of decimal places divided by 0.3103.
+
+ .. versionadded:: 2.0
+
+ """
+
+ __visit_name__ = "FLOAT"
+
+ def __init__(
+ self,
+ binary_precision=None,
+ asdecimal=False,
+ decimal_return_scale=None,
+ ):
+ r"""
+ Construct a FLOAT
+
+ :param binary_precision: Oracle binary precision value to be rendered
+ in DDL. This may be approximated to the number of decimal characters
+ using the formula "decimal precision = 0.30103 * binary precision".
+ The default value used by Oracle for FLOAT / DOUBLE PRECISION is 126.
+
+ :param asdecimal: See :paramref:`_sqltypes.Float.asdecimal`
+
+ :param decimal_return_scale: See
+ :paramref:`_sqltypes.Float.decimal_return_scale`
+
+ """
+ super().__init__(
+ asdecimal=asdecimal, decimal_return_scale=decimal_return_scale
+ )
+ self.binary_precision = binary_precision
+
+
+class BINARY_DOUBLE(sqltypes.Float):
+ __visit_name__ = "BINARY_DOUBLE"
+
+
+class BINARY_FLOAT(sqltypes.Float):
+ __visit_name__ = "BINARY_FLOAT"
+
+
+class BFILE(sqltypes.LargeBinary):
+ __visit_name__ = "BFILE"
+
+
+class LONG(sqltypes.Text):
+ __visit_name__ = "LONG"
+
+
+class _OracleDateLiteralRender:
+ def _literal_processor_datetime(self, dialect):
+ def process(value):
+ if value is not None:
+ if getattr(value, "microsecond", None):
+ value = (
+ f"""TO_TIMESTAMP"""
+ f"""('{value.isoformat().replace("T", " ")}', """
+ """'YYYY-MM-DD HH24:MI:SS.FF')"""
+ )
+ else:
+ value = (
+ f"""TO_DATE"""
+ f"""('{value.isoformat().replace("T", " ")}', """
+ """'YYYY-MM-DD HH24:MI:SS')"""
+ )
+ return value
+
+ return process
+
+ def _literal_processor_date(self, dialect):
+ def process(value):
+ if value is not None:
+ if getattr(value, "microsecond", None):
+ value = (
+ f"""TO_TIMESTAMP"""
+ f"""('{value.isoformat().split("T")[0]}', """
+ """'YYYY-MM-DD')"""
+ )
+ else:
+ value = (
+ f"""TO_DATE"""
+ f"""('{value.isoformat().split("T")[0]}', """
+ """'YYYY-MM-DD')"""
+ )
+ return value
+
+ return process
+
+
+class DATE(_OracleDateLiteralRender, sqltypes.DateTime):
+ """Provide the oracle DATE type.
+
+ This type has no special Python behavior, except that it subclasses
+ :class:`_types.DateTime`; this is to suit the fact that the Oracle
+ ``DATE`` type supports a time value.
+
+ .. versionadded:: 0.9.4
+
+ """
+
+ __visit_name__ = "DATE"
+
+ def literal_processor(self, dialect):
+ return self._literal_processor_datetime(dialect)
+
+ def _compare_type_affinity(self, other):
+ return other._type_affinity in (sqltypes.DateTime, sqltypes.Date)
+
+
+class _OracleDate(_OracleDateLiteralRender, sqltypes.Date):
+ def literal_processor(self, dialect):
+ return self._literal_processor_date(dialect)
+
+
+class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval):
+ __visit_name__ = "INTERVAL"
+
+ def __init__(self, day_precision=None, second_precision=None):
+ """Construct an INTERVAL.
+
+ Note that only DAY TO SECOND intervals are currently supported.
+ This is due to a lack of support for YEAR TO MONTH intervals
+ within available DBAPIs.
+
+ :param day_precision: the day precision value. this is the number of
+ digits to store for the day field. Defaults to "2"
+ :param second_precision: the second precision value. this is the
+ number of digits to store for the fractional seconds field.
+ Defaults to "6".
+
+ """
+ self.day_precision = day_precision
+ self.second_precision = second_precision
+
+ @classmethod
+ def _adapt_from_generic_interval(cls, interval):
+ return INTERVAL(
+ day_precision=interval.day_precision,
+ second_precision=interval.second_precision,
+ )
+
+ @property
+ def _type_affinity(self):
+ return sqltypes.Interval
+
+ def as_generic(self, allow_nulltype=False):
+ return sqltypes.Interval(
+ native=True,
+ second_precision=self.second_precision,
+ day_precision=self.day_precision,
+ )
+
+
+class ROWID(sqltypes.TypeEngine):
+ """Oracle ROWID type.
+
+ When used in a cast() or similar, generates ROWID.
+
+ """
+
+ __visit_name__ = "ROWID"
+
+
+class _OracleBoolean(sqltypes.Boolean):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.NUMBER
from .array import ARRAY
from .array import array
from .base import BIGINT
-from .base import BIT
from .base import BOOLEAN
-from .base import BYTEA
from .base import CHAR
-from .base import CIDR
-from .base import CreateEnumType
from .base import DATE
from .base import DOUBLE_PRECISION
-from .base import DropEnumType
-from .base import ENUM
from .base import FLOAT
-from .base import INET
from .base import INTEGER
-from .base import INTERVAL
-from .base import MACADDR
-from .base import MONEY
from .base import NUMERIC
-from .base import OID
from .base import REAL
-from .base import REGCLASS
from .base import SMALLINT
from .base import TEXT
-from .base import TIME
-from .base import TIMESTAMP
-from .base import TSVECTOR
from .base import UUID
from .base import VARCHAR
from .dml import Insert
from .ranges import NUMRANGE
from .ranges import TSRANGE
from .ranges import TSTZRANGE
-from ...util import compat
+from .types import BIT
+from .types import BYTEA
+from .types import CIDR
+from .types import CreateEnumType
+from .types import DropEnumType
+from .types import ENUM
+from .types import INET
+from .types import INTERVAL
+from .types import MACADDR
+from .types import MONEY
+from .types import OID
+from .types import REGCLASS
+from .types import TIME
+from .types import TIMESTAMP
+from .types import TSVECTOR
# Alias psycopg also as psycopg_async
psycopg_async = type(
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
import decimal
from .base import PGDialect
from .base import PGExecutionContext
from .hstore import HSTORE
+from .pg_catalog import _SpaceVector
+from .pg_catalog import INT2VECTOR
+from .pg_catalog import OIDVECTOR
from ... import exc
from ... import types as sqltypes
from ... import util
render_bind_cast = True
+class _PsycopgINT2VECTOR(_SpaceVector, INT2VECTOR):
+ pass
+
+
+class _PsycopgOIDVECTOR(_SpaceVector, OIDVECTOR):
+ pass
+
+
class _PGExecutionContext_common_psycopg(PGExecutionContext):
def create_server_side_cursor(self):
# use server-side cursors:
sqltypes.Numeric: _PsycopgNumeric,
HSTORE: _PsycopgHStore,
sqltypes.ARRAY: _PsycopgARRAY,
+ INT2VECTOR: _PsycopgINT2VECTOR,
+ OIDVECTOR: _PsycopgOIDVECTOR,
},
)
render_bind_cast = True
+class AsyncpgCHAR(sqltypes.CHAR):
+ render_bind_cast = True
+
+
class PGExecutionContext_asyncpg(PGExecutionContext):
def handle_dbapi_exception(self, e):
if isinstance(
sqltypes.Enum: AsyncPgEnum,
OID: AsyncpgOID,
REGCLASS: AsyncpgREGCLASS,
+ sqltypes.CHAR: AsyncpgCHAR,
},
)
is_async = True
:name: PostgreSQL
:full_support: 9.6, 10, 11, 12, 13, 14
:normal_support: 9.6+
- :best_effort: 8+
+ :best_effort: 9+
.. _postgresql_sequences:
from __future__ import annotations
from collections import defaultdict
-import datetime as dt
+from functools import lru_cache
import re
-from typing import Any
from . import array as _array
from . import dml
from . import hstore as _hstore
from . import json as _json
+from . import pg_catalog
from . import ranges as _ranges
+from .types import _DECIMAL_TYPES # noqa
+from .types import _FLOAT_TYPES # noqa
+from .types import _INT_TYPES # noqa
+from .types import BIT
+from .types import BYTEA
+from .types import CIDR
+from .types import CreateEnumType # noqa
+from .types import DropEnumType # noqa
+from .types import ENUM
+from .types import INET
+from .types import INTERVAL
+from .types import MACADDR
+from .types import MONEY
+from .types import OID
+from .types import PGBit # noqa
+from .types import PGCidr # noqa
+from .types import PGInet # noqa
+from .types import PGInterval # noqa
+from .types import PGMacAddr # noqa
+from .types import PGUuid
+from .types import REGCLASS
+from .types import TIME
+from .types import TIMESTAMP
+from .types import TSVECTOR
from ... import exc
from ... import schema
+from ... import select
from ... import sql
from ... import util
from ...engine import characteristics
from ...engine import default
from ...engine import interfaces
+from ...engine import ObjectKind
+from ...engine import ObjectScope
from ...engine import reflection
+from ...engine.reflection import ReflectionDefaults
+from ...sql import bindparam
from ...sql import coercions
from ...sql import compiler
from ...sql import elements
from ...sql import roles
from ...sql import sqltypes
from ...sql import util as sql_util
-from ...sql.ddl import InvokeDDLBase
+from ...sql.visitors import InternalTraversal
from ...types import BIGINT
from ...types import BOOLEAN
from ...types import CHAR
]
)
-_DECIMAL_TYPES = (1231, 1700)
-_FLOAT_TYPES = (700, 701, 1021, 1022)
-_INT_TYPES = (20, 21, 23, 26, 1005, 1007, 1016)
-
-
-class PGUuid(UUID):
- render_bind_cast = True
- render_literal_cast = True
-
-
-class BYTEA(sqltypes.LargeBinary[bytes]):
- __visit_name__ = "BYTEA"
-
-
-class INET(sqltypes.TypeEngine[str]):
- __visit_name__ = "INET"
-
-
-PGInet = INET
-
-
-class CIDR(sqltypes.TypeEngine[str]):
- __visit_name__ = "CIDR"
-
-
-PGCidr = CIDR
-
-
-class MACADDR(sqltypes.TypeEngine[str]):
- __visit_name__ = "MACADDR"
-
-
-PGMacAddr = MACADDR
-
-
-class MONEY(sqltypes.TypeEngine[str]):
-
- r"""Provide the PostgreSQL MONEY type.
-
- Depending on driver, result rows using this type may return a
- string value which includes currency symbols.
-
- For this reason, it may be preferable to provide conversion to a
- numerically-based currency datatype using :class:`_types.TypeDecorator`::
-
- import re
- import decimal
- from sqlalchemy import TypeDecorator
-
- class NumericMoney(TypeDecorator):
- impl = MONEY
-
- def process_result_value(self, value: Any, dialect: Any) -> None:
- if value is not None:
- # adjust this for the currency and numeric
- m = re.match(r"\$([\d.]+)", value)
- if m:
- value = decimal.Decimal(m.group(1))
- return value
-
- Alternatively, the conversion may be applied as a CAST using
- the :meth:`_types.TypeDecorator.column_expression` method as follows::
-
- import decimal
- from sqlalchemy import cast
- from sqlalchemy import TypeDecorator
-
- class NumericMoney(TypeDecorator):
- impl = MONEY
-
- def column_expression(self, column: Any):
- return cast(column, Numeric())
-
- .. versionadded:: 1.2
-
- """
-
- __visit_name__ = "MONEY"
-
-
-class OID(sqltypes.TypeEngine[int]):
-
- """Provide the PostgreSQL OID type.
-
- .. versionadded:: 0.9.5
-
- """
-
- __visit_name__ = "OID"
-
-
-class REGCLASS(sqltypes.TypeEngine[str]):
-
- """Provide the PostgreSQL REGCLASS type.
-
- .. versionadded:: 1.2.7
-
- """
-
- __visit_name__ = "REGCLASS"
-
-
-class TIMESTAMP(sqltypes.TIMESTAMP):
- def __init__(self, timezone=False, precision=None):
- super(TIMESTAMP, self).__init__(timezone=timezone)
- self.precision = precision
-
-
-class TIME(sqltypes.TIME):
- def __init__(self, timezone=False, precision=None):
- super(TIME, self).__init__(timezone=timezone)
- self.precision = precision
-
-
-class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval):
-
- """PostgreSQL INTERVAL type."""
-
- __visit_name__ = "INTERVAL"
- native = True
-
- def __init__(self, precision=None, fields=None):
- """Construct an INTERVAL.
-
- :param precision: optional integer precision value
- :param fields: string fields specifier. allows storage of fields
- to be limited, such as ``"YEAR"``, ``"MONTH"``, ``"DAY TO HOUR"``,
- etc.
-
- .. versionadded:: 1.2
-
- """
- self.precision = precision
- self.fields = fields
-
- @classmethod
- def adapt_emulated_to_native(cls, interval, **kw):
- return INTERVAL(precision=interval.second_precision)
-
- @property
- def _type_affinity(self):
- return sqltypes.Interval
-
- def as_generic(self, allow_nulltype=False):
- return sqltypes.Interval(native=True, second_precision=self.precision)
-
- @property
- def python_type(self):
- return dt.timedelta
-
-
-PGInterval = INTERVAL
-
-
-class BIT(sqltypes.TypeEngine[int]):
- __visit_name__ = "BIT"
-
- def __init__(self, length=None, varying=False):
- if not varying:
- # BIT without VARYING defaults to length 1
- self.length = length or 1
- else:
- # but BIT VARYING can be unlimited-length, so no default
- self.length = length
- self.varying = varying
-
-
-PGBit = BIT
-
-
-class TSVECTOR(sqltypes.TypeEngine[Any]):
-
- """The :class:`_postgresql.TSVECTOR` type implements the PostgreSQL
- text search type TSVECTOR.
-
- It can be used to do full text queries on natural language
- documents.
-
- .. versionadded:: 0.9.0
-
- .. seealso::
-
- :ref:`postgresql_match`
-
- """
-
- __visit_name__ = "TSVECTOR"
-
-
-class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum):
-
- """PostgreSQL ENUM type.
-
- This is a subclass of :class:`_types.Enum` which includes
- support for PG's ``CREATE TYPE`` and ``DROP TYPE``.
-
- When the builtin type :class:`_types.Enum` is used and the
- :paramref:`.Enum.native_enum` flag is left at its default of
- True, the PostgreSQL backend will use a :class:`_postgresql.ENUM`
- type as the implementation, so the special create/drop rules
- will be used.
-
- The create/drop behavior of ENUM is necessarily intricate, due to the
- awkward relationship the ENUM type has in relationship to the
- parent table, in that it may be "owned" by just a single table, or
- may be shared among many tables.
-
- When using :class:`_types.Enum` or :class:`_postgresql.ENUM`
- in an "inline" fashion, the ``CREATE TYPE`` and ``DROP TYPE`` is emitted
- corresponding to when the :meth:`_schema.Table.create` and
- :meth:`_schema.Table.drop`
- methods are called::
-
- table = Table('sometable', metadata,
- Column('some_enum', ENUM('a', 'b', 'c', name='myenum'))
- )
-
- table.create(engine) # will emit CREATE ENUM and CREATE TABLE
- table.drop(engine) # will emit DROP TABLE and DROP ENUM
-
- To use a common enumerated type between multiple tables, the best
- practice is to declare the :class:`_types.Enum` or
- :class:`_postgresql.ENUM` independently, and associate it with the
- :class:`_schema.MetaData` object itself::
-
- my_enum = ENUM('a', 'b', 'c', name='myenum', metadata=metadata)
-
- t1 = Table('sometable_one', metadata,
- Column('some_enum', myenum)
- )
-
- t2 = Table('sometable_two', metadata,
- Column('some_enum', myenum)
- )
-
- When this pattern is used, care must still be taken at the level
- of individual table creates. Emitting CREATE TABLE without also
- specifying ``checkfirst=True`` will still cause issues::
-
- t1.create(engine) # will fail: no such type 'myenum'
-
- If we specify ``checkfirst=True``, the individual table-level create
- operation will check for the ``ENUM`` and create if not exists::
-
- # will check if enum exists, and emit CREATE TYPE if not
- t1.create(engine, checkfirst=True)
-
- When using a metadata-level ENUM type, the type will always be created
- and dropped if either the metadata-wide create/drop is called::
-
- metadata.create_all(engine) # will emit CREATE TYPE
- metadata.drop_all(engine) # will emit DROP TYPE
-
- The type can also be created and dropped directly::
-
- my_enum.create(engine)
- my_enum.drop(engine)
-
- .. versionchanged:: 1.0.0 The PostgreSQL :class:`_postgresql.ENUM` type
- now behaves more strictly with regards to CREATE/DROP. A metadata-level
- ENUM type will only be created and dropped at the metadata level,
- not the table level, with the exception of
- ``table.create(checkfirst=True)``.
- The ``table.drop()`` call will now emit a DROP TYPE for a table-level
- enumerated type.
-
- """
-
- native_enum = True
-
- def __init__(self, *enums, **kw):
- """Construct an :class:`_postgresql.ENUM`.
-
- Arguments are the same as that of
- :class:`_types.Enum`, but also including
- the following parameters.
-
- :param create_type: Defaults to True.
- Indicates that ``CREATE TYPE`` should be
- emitted, after optionally checking for the
- presence of the type, when the parent
- table is being created; and additionally
- that ``DROP TYPE`` is called when the table
- is dropped. When ``False``, no check
- will be performed and no ``CREATE TYPE``
- or ``DROP TYPE`` is emitted, unless
- :meth:`~.postgresql.ENUM.create`
- or :meth:`~.postgresql.ENUM.drop`
- are called directly.
- Setting to ``False`` is helpful
- when invoking a creation scheme to a SQL file
- without access to the actual database -
- the :meth:`~.postgresql.ENUM.create` and
- :meth:`~.postgresql.ENUM.drop` methods can
- be used to emit SQL to a target bind.
-
- """
- native_enum = kw.pop("native_enum", None)
- if native_enum is False:
- util.warn(
- "the native_enum flag does not apply to the "
- "sqlalchemy.dialects.postgresql.ENUM datatype; this type "
- "always refers to ENUM. Use sqlalchemy.types.Enum for "
- "non-native enum."
- )
- self.create_type = kw.pop("create_type", True)
- super(ENUM, self).__init__(*enums, **kw)
-
- @classmethod
- def adapt_emulated_to_native(cls, impl, **kw):
- """Produce a PostgreSQL native :class:`_postgresql.ENUM` from plain
- :class:`.Enum`.
-
- """
- kw.setdefault("validate_strings", impl.validate_strings)
- kw.setdefault("name", impl.name)
- kw.setdefault("schema", impl.schema)
- kw.setdefault("inherit_schema", impl.inherit_schema)
- kw.setdefault("metadata", impl.metadata)
- kw.setdefault("_create_events", False)
- kw.setdefault("values_callable", impl.values_callable)
- kw.setdefault("omit_aliases", impl._omit_aliases)
- return cls(**kw)
-
- def create(self, bind=None, checkfirst=True):
- """Emit ``CREATE TYPE`` for this
- :class:`_postgresql.ENUM`.
-
- If the underlying dialect does not support
- PostgreSQL CREATE TYPE, no action is taken.
-
- :param bind: a connectable :class:`_engine.Engine`,
- :class:`_engine.Connection`, or similar object to emit
- SQL.
- :param checkfirst: if ``True``, a query against
- the PG catalog will be first performed to see
- if the type does not exist already before
- creating.
-
- """
- if not bind.dialect.supports_native_enum:
- return
-
- bind._run_ddl_visitor(self.EnumGenerator, self, checkfirst=checkfirst)
-
- def drop(self, bind=None, checkfirst=True):
- """Emit ``DROP TYPE`` for this
- :class:`_postgresql.ENUM`.
-
- If the underlying dialect does not support
- PostgreSQL DROP TYPE, no action is taken.
-
- :param bind: a connectable :class:`_engine.Engine`,
- :class:`_engine.Connection`, or similar object to emit
- SQL.
- :param checkfirst: if ``True``, a query against
- the PG catalog will be first performed to see
- if the type actually exists before dropping.
-
- """
- if not bind.dialect.supports_native_enum:
- return
-
- bind._run_ddl_visitor(self.EnumDropper, self, checkfirst=checkfirst)
-
- class EnumGenerator(InvokeDDLBase):
- def __init__(self, dialect, connection, checkfirst=False, **kwargs):
- super(ENUM.EnumGenerator, self).__init__(connection, **kwargs)
- self.checkfirst = checkfirst
-
- def _can_create_enum(self, enum):
- if not self.checkfirst:
- return True
-
- effective_schema = self.connection.schema_for_object(enum)
-
- return not self.connection.dialect.has_type(
- self.connection, enum.name, schema=effective_schema
- )
-
- def visit_enum(self, enum):
- if not self._can_create_enum(enum):
- return
-
- self.connection.execute(CreateEnumType(enum))
-
- class EnumDropper(InvokeDDLBase):
- def __init__(self, dialect, connection, checkfirst=False, **kwargs):
- super(ENUM.EnumDropper, self).__init__(connection, **kwargs)
- self.checkfirst = checkfirst
-
- def _can_drop_enum(self, enum):
- if not self.checkfirst:
- return True
-
- effective_schema = self.connection.schema_for_object(enum)
-
- return self.connection.dialect.has_type(
- self.connection, enum.name, schema=effective_schema
- )
-
- def visit_enum(self, enum):
- if not self._can_drop_enum(enum):
- return
-
- self.connection.execute(DropEnumType(enum))
-
- def get_dbapi_type(self, dbapi):
- """dont return dbapi.STRING for ENUM in PostgreSQL, since that's
- a different type"""
-
- return None
-
- def _check_for_name_in_memos(self, checkfirst, kw):
- """Look in the 'ddl runner' for 'memos', then
- note our name in that collection.
-
- This to ensure a particular named enum is operated
- upon only once within any kind of create/drop
- sequence without relying upon "checkfirst".
-
- """
- if not self.create_type:
- return True
- if "_ddl_runner" in kw:
- ddl_runner = kw["_ddl_runner"]
- if "_pg_enums" in ddl_runner.memo:
- pg_enums = ddl_runner.memo["_pg_enums"]
- else:
- pg_enums = ddl_runner.memo["_pg_enums"] = set()
- present = (self.schema, self.name) in pg_enums
- pg_enums.add((self.schema, self.name))
- return present
- else:
- return False
-
- def _on_table_create(self, target, bind, checkfirst=False, **kw):
- if (
- checkfirst
- or (
- not self.metadata
- and not kw.get("_is_metadata_operation", False)
- )
- ) and not self._check_for_name_in_memos(checkfirst, kw):
- self.create(bind=bind, checkfirst=checkfirst)
-
- def _on_table_drop(self, target, bind, checkfirst=False, **kw):
- if (
- not self.metadata
- and not kw.get("_is_metadata_operation", False)
- and not self._check_for_name_in_memos(checkfirst, kw)
- ):
- self.drop(bind=bind, checkfirst=checkfirst)
-
- def _on_metadata_create(self, target, bind, checkfirst=False, **kw):
- if not self._check_for_name_in_memos(checkfirst, kw):
- self.create(bind=bind, checkfirst=checkfirst)
-
- def _on_metadata_drop(self, target, bind, checkfirst=False, **kw):
- if not self._check_for_name_in_memos(checkfirst, kw):
- self.drop(bind=bind, checkfirst=checkfirst)
-
-
colspecs = {
sqltypes.ARRAY: _array.ARRAY,
sqltypes.Interval: INTERVAL,
class PGInspector(reflection.Inspector):
+ dialect: PGDialect
+
def get_table_oid(self, table_name, schema=None):
- """Return the OID for the given table name."""
+ """Return the OID for the given table name.
+
+ :param table_name: string name of the table. For special quoting,
+ use :class:`.quoted_name`.
+
+ :param schema: string schema name; if omitted, uses the default schema
+ of the database connection. For special quoting,
+ use :class:`.quoted_name`.
+
+ """
with self._operation_context() as conn:
return self.dialect.get_table_oid(
.. versionadded:: 1.0.0
"""
- schema = schema or self.default_schema_name
with self._operation_context() as conn:
- return self.dialect._load_enums(conn, schema)
+ return self.dialect._load_enums(
+ conn, schema, info_cache=self.info_cache
+ )
def get_foreign_table_names(self, schema=None):
"""Return a list of FOREIGN TABLE names.
.. versionadded:: 1.0.0
"""
- schema = schema or self.default_schema_name
with self._operation_context() as conn:
- return self.dialect._get_foreign_table_names(conn, schema)
-
- def get_view_names(self, schema=None, include=("plain", "materialized")):
- """Return all view names in `schema`.
+ return self.dialect._get_foreign_table_names(
+ conn, schema, info_cache=self.info_cache
+ )
- :param schema: Optional, retrieve names from a non-default schema.
- For special quoting, use :class:`.quoted_name`.
+ def has_type(self, type_name, schema=None, **kw):
+ """Return if the database has the specified type in the provided
+ schema.
- :param include: specify which types of views to return. Passed
- as a string value (for a single type) or a tuple (for any number
- of types). Defaults to ``('plain', 'materialized')``.
+ :param type_name: the type to check.
+ :param schema: schema name. If None, the default schema
+ (typically 'public') is used. May also be set to '*' to
+ check in all schemas.
- .. versionadded:: 1.1
+ .. versionadded:: 2.0
"""
-
with self._operation_context() as conn:
- return self.dialect.get_view_names(
- conn, schema, info_cache=self.info_cache, include=include
+ return self.dialect.has_type(
+ conn, type_name, schema, info_cache=self.info_cache
)
-class CreateEnumType(schema._CreateDropBase):
- __visit_name__ = "create_enum_type"
-
-
-class DropEnumType(schema._CreateDropBase):
- __visit_name__ = "drop_enum_type"
-
-
class PGExecutionContext(default.DefaultExecutionContext):
def fire_sequence(self, seq, type_):
return self._execute_scalar(
def initialize(self, connection):
super(PGDialect, self).initialize(connection)
- if self.server_version_info <= (8, 2):
- self.delete_returning = (
- self.update_returning
- ) = self.insert_returning = False
-
- self.supports_native_enum = self.server_version_info >= (8, 3)
- if not self.supports_native_enum:
- self.colspecs = self.colspecs.copy()
- # pop base Enum type
- self.colspecs.pop(sqltypes.Enum, None)
- # psycopg2, others may have placed ENUM here as well
- self.colspecs.pop(ENUM, None)
-
# https://www.postgresql.org/docs/9.3/static/release-9-2.html#AEN116689
self.supports_smallserial = self.server_version_info >= (9, 2)
- if self.server_version_info < (8, 2):
- self._backslash_escapes = False
- else:
- # ensure this query is not emitted on server version < 8.2
- # as it will fail
- std_string = connection.exec_driver_sql(
- "show standard_conforming_strings"
- ).scalar()
- self._backslash_escapes = std_string == "off"
-
- self._supports_create_index_concurrently = (
- self.server_version_info >= (8, 2)
- )
+ std_string = connection.exec_driver_sql(
+ "show standard_conforming_strings"
+ ).scalar()
+ self._backslash_escapes = std_string == "off"
+
self._supports_drop_index_concurrently = self.server_version_info >= (
9,
2,
self.do_commit(connection.connection)
def do_recover_twophase(self, connection):
- resultset = connection.execute(
+ return connection.scalars(
sql.text("SELECT gid FROM pg_prepared_xacts")
- )
- return [row[0] for row in resultset]
+ ).all()
def _get_default_schema_name(self, connection):
return connection.exec_driver_sql("select current_schema()").scalar()
- def has_schema(self, connection, schema):
- query = (
- "select nspname from pg_namespace " "where lower(nspname)=:schema"
- )
- cursor = connection.execute(
- sql.text(query).bindparams(
- sql.bindparam(
- "schema",
- str(schema.lower()),
- type_=sqltypes.Unicode,
- )
- )
+ @reflection.cache
+ def has_schema(self, connection, schema, **kw):
+ query = select(pg_catalog.pg_namespace.c.nspname).where(
+ pg_catalog.pg_namespace.c.nspname == schema
)
+ return bool(connection.scalar(query))
- return bool(cursor.first())
-
- def has_table(self, connection, table_name, schema=None):
- self._ensure_has_table_connection(connection)
- # seems like case gets folded in pg_class...
+ def _pg_class_filter_scope_schema(
+ self, query, schema, scope, pg_class_table=None
+ ):
+ if pg_class_table is None:
+ pg_class_table = pg_catalog.pg_class
+ query = query.join(
+ pg_catalog.pg_namespace,
+ pg_catalog.pg_namespace.c.oid == pg_class_table.c.relnamespace,
+ )
+ if scope is ObjectScope.DEFAULT:
+ query = query.where(pg_class_table.c.relpersistence != "t")
+ elif scope is ObjectScope.TEMPORARY:
+ query = query.where(pg_class_table.c.relpersistence == "t")
if schema is None:
- cursor = connection.execute(
- sql.text(
- "select relname from pg_class c join pg_namespace n on "
- "n.oid=c.relnamespace where "
- "pg_catalog.pg_table_is_visible(c.oid) "
- "and relname=:name"
- ).bindparams(
- sql.bindparam(
- "name",
- str(table_name),
- type_=sqltypes.Unicode,
- )
- )
+ query = query.where(
+ pg_catalog.pg_table_is_visible(pg_class_table.c.oid),
+ # ignore pg_catalog schema
+ pg_catalog.pg_namespace.c.nspname != "pg_catalog",
)
else:
- cursor = connection.execute(
- sql.text(
- "select relname from pg_class c join pg_namespace n on "
- "n.oid=c.relnamespace where n.nspname=:schema and "
- "relname=:name"
- ).bindparams(
- sql.bindparam(
- "name",
- str(table_name),
- type_=sqltypes.Unicode,
- ),
- sql.bindparam(
- "schema",
- str(schema),
- type_=sqltypes.Unicode,
- ),
- )
- )
- return bool(cursor.first())
-
- def has_sequence(self, connection, sequence_name, schema=None):
- if schema is None:
- schema = self.default_schema_name
- cursor = connection.execute(
- sql.text(
- "SELECT relname FROM pg_class c join pg_namespace n on "
- "n.oid=c.relnamespace where relkind='S' and "
- "n.nspname=:schema and relname=:name"
- ).bindparams(
- sql.bindparam(
- "name",
- str(sequence_name),
- type_=sqltypes.Unicode,
- ),
- sql.bindparam(
- "schema",
- str(schema),
- type_=sqltypes.Unicode,
- ),
- )
+ query = query.where(pg_catalog.pg_namespace.c.nspname == schema)
+ return query
+
+ def _pg_class_relkind_condition(self, relkinds, pg_class_table=None):
+ if pg_class_table is None:
+ pg_class_table = pg_catalog.pg_class
+ # uses the any form instead of in otherwise postgresql complaings
+ # that 'IN could not convert type character to "char"'
+ return pg_class_table.c.relkind == sql.any_(_array.array(relkinds))
+
+ @lru_cache()
+ def _has_table_query(self, schema):
+ query = select(pg_catalog.pg_class.c.relname).where(
+ pg_catalog.pg_class.c.relname == bindparam("table_name"),
+ self._pg_class_relkind_condition(
+ pg_catalog.RELKINDS_ALL_TABLE_LIKE
+ ),
+ )
+ return self._pg_class_filter_scope_schema(
+ query, schema, scope=ObjectScope.ANY
)
- return bool(cursor.first())
+ @reflection.cache
+ def has_table(self, connection, table_name, schema=None, **kw):
+ self._ensure_has_table_connection(connection)
+ query = self._has_table_query(schema)
+ return bool(connection.scalar(query, {"table_name": table_name}))
- def has_type(self, connection, type_name, schema=None):
- if schema is not None:
- query = """
- SELECT EXISTS (
- SELECT * FROM pg_catalog.pg_type t, pg_catalog.pg_namespace n
- WHERE t.typnamespace = n.oid
- AND t.typname = :typname
- AND n.nspname = :nspname
- )
- """
- query = sql.text(query)
- else:
- query = """
- SELECT EXISTS (
- SELECT * FROM pg_catalog.pg_type t
- WHERE t.typname = :typname
- AND pg_type_is_visible(t.oid)
- )
- """
- query = sql.text(query)
- query = query.bindparams(
- sql.bindparam("typname", str(type_name), type_=sqltypes.Unicode)
+ @reflection.cache
+ def has_sequence(self, connection, sequence_name, schema=None, **kw):
+ query = select(pg_catalog.pg_class.c.relname).where(
+ pg_catalog.pg_class.c.relkind == "S",
+ pg_catalog.pg_class.c.relname == sequence_name,
)
- if schema is not None:
- query = query.bindparams(
- sql.bindparam("nspname", str(schema), type_=sqltypes.Unicode)
+ query = self._pg_class_filter_scope_schema(
+ query, schema, scope=ObjectScope.ANY
+ )
+ return bool(connection.scalar(query))
+
+ @reflection.cache
+ def has_type(self, connection, type_name, schema=None, **kw):
+ query = (
+ select(pg_catalog.pg_type.c.typname)
+ .join(
+ pg_catalog.pg_namespace,
+ pg_catalog.pg_namespace.c.oid
+ == pg_catalog.pg_type.c.typnamespace,
)
- cursor = connection.execute(query)
- return bool(cursor.scalar())
+ .where(pg_catalog.pg_type.c.typname == type_name)
+ )
+ if schema is None:
+ query = query.where(
+ pg_catalog.pg_type_is_visible(pg_catalog.pg_type.c.oid),
+ # ignore pg_catalog schema
+ pg_catalog.pg_namespace.c.nspname != "pg_catalog",
+ )
+ elif schema != "*":
+ query = query.where(pg_catalog.pg_namespace.c.nspname == schema)
+
+ return bool(connection.scalar(query))
def _get_server_version_info(self, connection):
v = connection.exec_driver_sql("select pg_catalog.version()").scalar()
@reflection.cache
def get_table_oid(self, connection, table_name, schema=None, **kw):
- """Fetch the oid for schema.table_name.
-
- Several reflection methods require the table oid. The idea for using
- this method is that it can be fetched one time and cached for
- subsequent calls.
-
- """
- table_oid = None
- if schema is not None:
- schema_where_clause = "n.nspname = :schema"
- else:
- schema_where_clause = "pg_catalog.pg_table_is_visible(c.oid)"
- query = (
- """
- SELECT c.oid
- FROM pg_catalog.pg_class c
- LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
- WHERE (%s)
- AND c.relname = :table_name AND c.relkind in
- ('r', 'v', 'm', 'f', 'p')
- """
- % schema_where_clause
+ """Fetch the oid for schema.table_name."""
+ query = select(pg_catalog.pg_class.c.oid).where(
+ pg_catalog.pg_class.c.relname == table_name,
+ self._pg_class_relkind_condition(
+ pg_catalog.RELKINDS_ALL_TABLE_LIKE
+ ),
)
- # Since we're binding to unicode, table_name and schema_name must be
- # unicode.
- table_name = str(table_name)
- if schema is not None:
- schema = str(schema)
- s = sql.text(query).bindparams(table_name=sqltypes.Unicode)
- s = s.columns(oid=sqltypes.Integer)
- if schema:
- s = s.bindparams(sql.bindparam("schema", type_=sqltypes.Unicode))
- c = connection.execute(s, dict(table_name=table_name, schema=schema))
- table_oid = c.scalar()
+ query = self._pg_class_filter_scope_schema(
+ query, schema, scope=ObjectScope.ANY
+ )
+ table_oid = connection.scalar(query)
if table_oid is None:
- raise exc.NoSuchTableError(table_name)
+ raise exc.NoSuchTableError(
+ f"{schema}.{table_name}" if schema else table_name
+ )
return table_oid
@reflection.cache
def get_schema_names(self, connection, **kw):
- result = connection.execute(
- sql.text(
- "SELECT nspname FROM pg_namespace "
- "WHERE nspname NOT LIKE 'pg_%' "
- "ORDER BY nspname"
- ).columns(nspname=sqltypes.Unicode)
+ query = (
+ select(pg_catalog.pg_namespace.c.nspname)
+ .where(pg_catalog.pg_namespace.c.nspname.not_like("pg_%"))
+ .order_by(pg_catalog.pg_namespace.c.nspname)
+ )
+ return connection.scalars(query).all()
+
+ def _get_relnames_for_relkinds(self, connection, schema, relkinds, scope):
+ query = select(pg_catalog.pg_class.c.relname).where(
+ self._pg_class_relkind_condition(relkinds)
)
- return [name for name, in result]
+ query = self._pg_class_filter_scope_schema(query, schema, scope=scope)
+ return connection.scalars(query).all()
@reflection.cache
def get_table_names(self, connection, schema=None, **kw):
- result = connection.execute(
- sql.text(
- "SELECT c.relname FROM pg_class c "
- "JOIN pg_namespace n ON n.oid = c.relnamespace "
- "WHERE n.nspname = :schema AND c.relkind in ('r', 'p')"
- ).columns(relname=sqltypes.Unicode),
- dict(
- schema=schema
- if schema is not None
- else self.default_schema_name
- ),
+ return self._get_relnames_for_relkinds(
+ connection,
+ schema,
+ pg_catalog.RELKINDS_TABLE_NO_FOREIGN,
+ scope=ObjectScope.DEFAULT,
+ )
+
+ @reflection.cache
+ def get_temp_table_names(self, connection, **kw):
+ return self._get_relnames_for_relkinds(
+ connection,
+ schema=None,
+ relkinds=pg_catalog.RELKINDS_TABLE_NO_FOREIGN,
+ scope=ObjectScope.TEMPORARY,
)
- return [name for name, in result]
@reflection.cache
def _get_foreign_table_names(self, connection, schema=None, **kw):
- result = connection.execute(
- sql.text(
- "SELECT c.relname FROM pg_class c "
- "JOIN pg_namespace n ON n.oid = c.relnamespace "
- "WHERE n.nspname = :schema AND c.relkind = 'f'"
- ).columns(relname=sqltypes.Unicode),
- dict(
- schema=schema
- if schema is not None
- else self.default_schema_name
- ),
+ return self._get_relnames_for_relkinds(
+ connection, schema, relkinds=("f",), scope=ObjectScope.ANY
)
- return [name for name, in result]
@reflection.cache
- def get_view_names(
- self, connection, schema=None, include=("plain", "materialized"), **kw
- ):
+ def get_view_names(self, connection, schema=None, **kw):
+ return self._get_relnames_for_relkinds(
+ connection,
+ schema,
+ pg_catalog.RELKINDS_VIEW,
+ scope=ObjectScope.DEFAULT,
+ )
- include_kind = {"plain": "v", "materialized": "m"}
- try:
- kinds = [include_kind[i] for i in util.to_list(include)]
- except KeyError:
- raise ValueError(
- "include %r unknown, needs to be a sequence containing "
- "one or both of 'plain' and 'materialized'" % (include,)
- )
- if not kinds:
- raise ValueError(
- "empty include, needs to be a sequence containing "
- "one or both of 'plain' and 'materialized'"
- )
+ @reflection.cache
+ def get_materialized_view_names(self, connection, schema=None, **kw):
+ return self._get_relnames_for_relkinds(
+ connection,
+ schema,
+ pg_catalog.RELKINDS_MAT_VIEW,
+ scope=ObjectScope.DEFAULT,
+ )
- result = connection.execute(
- sql.text(
- "SELECT c.relname FROM pg_class c "
- "JOIN pg_namespace n ON n.oid = c.relnamespace "
- "WHERE n.nspname = :schema AND c.relkind IN (%s)"
- % (", ".join("'%s'" % elem for elem in kinds))
- ).columns(relname=sqltypes.Unicode),
- dict(
- schema=schema
- if schema is not None
- else self.default_schema_name
- ),
+ @reflection.cache
+ def get_temp_view_names(self, connection, schema=None, **kw):
+ return self._get_relnames_for_relkinds(
+ connection,
+ schema,
+ # NOTE: do not include temp materialzied views (that do not
+ # seem to be a thing at least up to version 14)
+ pg_catalog.RELKINDS_VIEW,
+ scope=ObjectScope.TEMPORARY,
)
- return [name for name, in result]
@reflection.cache
def get_sequence_names(self, connection, schema=None, **kw):
- if not schema:
- schema = self.default_schema_name
- cursor = connection.execute(
- sql.text(
- "SELECT relname FROM pg_class c join pg_namespace n on "
- "n.oid=c.relnamespace where relkind='S' and "
- "n.nspname=:schema"
- ).bindparams(
- sql.bindparam(
- "schema",
- str(schema),
- type_=sqltypes.Unicode,
- ),
- )
+ return self._get_relnames_for_relkinds(
+ connection, schema, relkinds=("S",), scope=ObjectScope.ANY
)
- return [row[0] for row in cursor]
@reflection.cache
def get_view_definition(self, connection, view_name, schema=None, **kw):
- view_def = connection.scalar(
- sql.text(
- "SELECT pg_get_viewdef(c.oid) view_def FROM pg_class c "
- "JOIN pg_namespace n ON n.oid = c.relnamespace "
- "WHERE n.nspname = :schema AND c.relname = :view_name "
- "AND c.relkind IN ('v', 'm')"
- ).columns(view_def=sqltypes.Unicode),
- dict(
- schema=schema
- if schema is not None
- else self.default_schema_name,
- view_name=view_name,
- ),
+ query = (
+ select(pg_catalog.pg_get_viewdef(pg_catalog.pg_class.c.oid))
+ .select_from(pg_catalog.pg_class)
+ .where(
+ pg_catalog.pg_class.c.relname == view_name,
+ self._pg_class_relkind_condition(
+ pg_catalog.RELKINDS_VIEW + pg_catalog.RELKINDS_MAT_VIEW
+ ),
+ )
)
- return view_def
+ query = self._pg_class_filter_scope_schema(
+ query, schema, scope=ObjectScope.ANY
+ )
+ res = connection.scalar(query)
+ if res is None:
+ raise exc.NoSuchTableError(
+ f"{schema}.{view_name}" if schema else view_name
+ )
+ else:
+ return res
+
+ def _value_or_raise(self, data, table, schema):
+ try:
+ return dict(data)[(schema, table)]
+ except KeyError:
+ raise exc.NoSuchTableError(
+ f"{schema}.{table}" if schema else table
+ ) from None
+
+ def _prepare_filter_names(self, filter_names):
+ if filter_names:
+ return True, {"filter_names": filter_names}
+ else:
+ return False, {}
+
+ def _kind_to_relkinds(self, kind: ObjectKind) -> tuple[str, ...]:
+ if kind is ObjectKind.ANY:
+ return pg_catalog.RELKINDS_ALL_TABLE_LIKE
+ relkinds = ()
+ if ObjectKind.TABLE in kind:
+ relkinds += pg_catalog.RELKINDS_TABLE
+ if ObjectKind.VIEW in kind:
+ relkinds += pg_catalog.RELKINDS_VIEW
+ if ObjectKind.MATERIALIZED_VIEW in kind:
+ relkinds += pg_catalog.RELKINDS_MAT_VIEW
+ return relkinds
@reflection.cache
def get_columns(self, connection, table_name, schema=None, **kw):
-
- table_oid = self.get_table_oid(
- connection, table_name, schema, info_cache=kw.get("info_cache")
+ data = self.get_multi_columns(
+ connection,
+ schema=schema,
+ filter_names=[table_name],
+ scope=ObjectScope.ANY,
+ kind=ObjectKind.ANY,
+ **kw,
)
+ return self._value_or_raise(data, table_name, schema)
+ @lru_cache()
+ def _columns_query(self, schema, has_filter_names, scope, kind):
+ # NOTE: the query with the default and identity options scalar
+ # subquery is faster than trying to use outer joins for them
generated = (
- "a.attgenerated as generated"
+ pg_catalog.pg_attribute.c.attgenerated.label("generated")
if self.server_version_info >= (12,)
- else "NULL as generated"
+ else sql.null().label("generated")
)
if self.server_version_info >= (10,):
- # a.attidentity != '' is required or it will reflect also
- # serial columns as identity.
- identity = """\
- (SELECT json_build_object(
- 'always', a.attidentity = 'a',
- 'start', s.seqstart,
- 'increment', s.seqincrement,
- 'minvalue', s.seqmin,
- 'maxvalue', s.seqmax,
- 'cache', s.seqcache,
- 'cycle', s.seqcycle)
- FROM pg_catalog.pg_sequence s
- JOIN pg_catalog.pg_class c on s.seqrelid = c."oid"
- WHERE c.relkind = 'S'
- AND a.attidentity != ''
- AND s.seqrelid = pg_catalog.pg_get_serial_sequence(
- a.attrelid::regclass::text, a.attname
- )::regclass::oid
- ) as identity_options\
- """
+ # join lateral performs worse (~2x slower) than a scalar_subquery
+ identity = (
+ select(
+ sql.func.json_build_object(
+ "always",
+ pg_catalog.pg_attribute.c.attidentity == "a",
+ "start",
+ pg_catalog.pg_sequence.c.seqstart,
+ "increment",
+ pg_catalog.pg_sequence.c.seqincrement,
+ "minvalue",
+ pg_catalog.pg_sequence.c.seqmin,
+ "maxvalue",
+ pg_catalog.pg_sequence.c.seqmax,
+ "cache",
+ pg_catalog.pg_sequence.c.seqcache,
+ "cycle",
+ pg_catalog.pg_sequence.c.seqcycle,
+ )
+ )
+ .select_from(pg_catalog.pg_sequence)
+ .where(
+ # attidentity != '' is required or it will reflect also
+ # serial columns as identity.
+ pg_catalog.pg_attribute.c.attidentity != "",
+ pg_catalog.pg_sequence.c.seqrelid
+ == sql.cast(
+ sql.cast(
+ pg_catalog.pg_get_serial_sequence(
+ sql.cast(
+ sql.cast(
+ pg_catalog.pg_attribute.c.attrelid,
+ REGCLASS,
+ ),
+ TEXT,
+ ),
+ pg_catalog.pg_attribute.c.attname,
+ ),
+ REGCLASS,
+ ),
+ OID,
+ ),
+ )
+ .correlate(pg_catalog.pg_attribute)
+ .scalar_subquery()
+ .label("identity_options")
+ )
else:
- identity = "NULL as identity_options"
-
- SQL_COLS = """
- SELECT a.attname,
- pg_catalog.format_type(a.atttypid, a.atttypmod),
- (
- SELECT pg_catalog.pg_get_expr(d.adbin, d.adrelid)
- FROM pg_catalog.pg_attrdef d
- WHERE d.adrelid = a.attrelid AND d.adnum = a.attnum
- AND a.atthasdef
- ) AS DEFAULT,
- a.attnotnull,
- a.attrelid as table_oid,
- pgd.description as comment,
- %s,
- %s
- FROM pg_catalog.pg_attribute a
- LEFT JOIN pg_catalog.pg_description pgd ON (
- pgd.objoid = a.attrelid AND pgd.objsubid = a.attnum)
- WHERE a.attrelid = :table_oid
- AND a.attnum > 0 AND NOT a.attisdropped
- ORDER BY a.attnum
- """ % (
- generated,
- identity,
+ identity = sql.null().label("identity_options")
+
+ # join lateral performs the same as scalar_subquery here
+ default = (
+ select(
+ pg_catalog.pg_get_expr(
+ pg_catalog.pg_attrdef.c.adbin,
+ pg_catalog.pg_attrdef.c.adrelid,
+ )
+ )
+ .select_from(pg_catalog.pg_attrdef)
+ .where(
+ pg_catalog.pg_attrdef.c.adrelid
+ == pg_catalog.pg_attribute.c.attrelid,
+ pg_catalog.pg_attrdef.c.adnum
+ == pg_catalog.pg_attribute.c.attnum,
+ pg_catalog.pg_attribute.c.atthasdef,
+ )
+ .correlate(pg_catalog.pg_attribute)
+ .scalar_subquery()
+ .label("default")
)
- s = (
- sql.text(SQL_COLS)
- .bindparams(sql.bindparam("table_oid", type_=sqltypes.Integer))
- .columns(attname=sqltypes.Unicode, default=sqltypes.Unicode)
+ relkinds = self._kind_to_relkinds(kind)
+ query = (
+ select(
+ pg_catalog.pg_attribute.c.attname.label("name"),
+ pg_catalog.format_type(
+ pg_catalog.pg_attribute.c.atttypid,
+ pg_catalog.pg_attribute.c.atttypmod,
+ ).label("format_type"),
+ default,
+ pg_catalog.pg_attribute.c.attnotnull.label("not_null"),
+ pg_catalog.pg_class.c.relname.label("table_name"),
+ pg_catalog.pg_description.c.description.label("comment"),
+ generated,
+ identity,
+ )
+ .select_from(pg_catalog.pg_class)
+ # NOTE: postgresql support table with no user column, meaning
+ # there is no row with pg_attribute.attnum > 0. use a left outer
+ # join to avoid filtering these tables.
+ .outerjoin(
+ pg_catalog.pg_attribute,
+ sql.and_(
+ pg_catalog.pg_class.c.oid
+ == pg_catalog.pg_attribute.c.attrelid,
+ pg_catalog.pg_attribute.c.attnum > 0,
+ ~pg_catalog.pg_attribute.c.attisdropped,
+ ),
+ )
+ .outerjoin(
+ pg_catalog.pg_description,
+ sql.and_(
+ pg_catalog.pg_description.c.objoid
+ == pg_catalog.pg_attribute.c.attrelid,
+ pg_catalog.pg_description.c.objsubid
+ == pg_catalog.pg_attribute.c.attnum,
+ ),
+ )
+ .where(self._pg_class_relkind_condition(relkinds))
+ .order_by(
+ pg_catalog.pg_class.c.relname, pg_catalog.pg_attribute.c.attnum
+ )
)
- c = connection.execute(s, dict(table_oid=table_oid))
- rows = c.fetchall()
+ query = self._pg_class_filter_scope_schema(query, schema, scope=scope)
+ if has_filter_names:
+ query = query.where(
+ pg_catalog.pg_class.c.relname.in_(bindparam("filter_names"))
+ )
+ return query
+
+ def get_multi_columns(
+ self, connection, schema, filter_names, scope, kind, **kw
+ ):
+ has_filter_names, params = self._prepare_filter_names(filter_names)
+ query = self._columns_query(schema, has_filter_names, scope, kind)
+ rows = connection.execute(query, params).mappings()
# dictionary with (name, ) if default search path or (schema, name)
# as keys
- domains = self._load_domains(connection)
+ domains = self._load_domains(
+ connection, info_cache=kw.get("info_cache")
+ )
# dictionary with (name, ) if default search path or (schema, name)
# as keys
((rec["name"],), rec)
if rec["visible"]
else ((rec["schema"], rec["name"]), rec)
- for rec in self._load_enums(connection, schema="*")
+ for rec in self._load_enums(
+ connection, schema="*", info_cache=kw.get("info_cache")
+ )
)
- # format columns
- columns = []
-
- for (
- name,
- format_type,
- default_,
- notnull,
- table_oid,
- comment,
- generated,
- identity,
- ) in rows:
- column_info = self._get_column_info(
- name,
- format_type,
- default_,
- notnull,
- domains,
- enums,
- schema,
- comment,
- generated,
- identity,
- )
- columns.append(column_info)
- return columns
+ columns = self._get_columns_info(rows, domains, enums, schema)
+
+ return columns.items()
+
+ def _get_columns_info(self, rows, domains, enums, schema):
+ array_type_pattern = re.compile(r"\[\]$")
+ attype_pattern = re.compile(r"\(.*\)")
+ charlen_pattern = re.compile(r"\(([\d,]+)\)")
+ args_pattern = re.compile(r"\((.*)\)")
+ args_split_pattern = re.compile(r"\s*,\s*")
- def _get_column_info(
- self,
- name,
- format_type,
- default,
- notnull,
- domains,
- enums,
- schema,
- comment,
- generated,
- identity,
- ):
def _handle_array_type(attype):
return (
# strip '[]' from integer[], etc.
- re.sub(r"\[\]$", "", attype),
+ array_type_pattern.sub("", attype),
attype.endswith("[]"),
)
- # strip (*) from character varying(5), timestamp(5)
- # with time zone, geometry(POLYGON), etc.
- attype = re.sub(r"\(.*\)", "", format_type)
+ columns = defaultdict(list)
+ for row_dict in rows:
+ # ensure that each table has an entry, even if it has no columns
+ if row_dict["name"] is None:
+ columns[
+ (schema, row_dict["table_name"])
+ ] = ReflectionDefaults.columns()
+ continue
+ table_cols = columns[(schema, row_dict["table_name"])]
- # strip '[]' from integer[], etc. and check if an array
- attype, is_array = _handle_array_type(attype)
+ format_type = row_dict["format_type"]
+ default = row_dict["default"]
+ name = row_dict["name"]
+ generated = row_dict["generated"]
+ identity = row_dict["identity_options"]
- # strip quotes from case sensitive enum or domain names
- enum_or_domain_key = tuple(util.quoted_token_parser(attype))
+ # strip (*) from character varying(5), timestamp(5)
+ # with time zone, geometry(POLYGON), etc.
+ attype = attype_pattern.sub("", format_type)
- nullable = not notnull
+ # strip '[]' from integer[], etc. and check if an array
+ attype, is_array = _handle_array_type(attype)
- charlen = re.search(r"\(([\d,]+)\)", format_type)
- if charlen:
- charlen = charlen.group(1)
- args = re.search(r"\((.*)\)", format_type)
- if args and args.group(1):
- args = tuple(re.split(r"\s*,\s*", args.group(1)))
- else:
- args = ()
- kwargs = {}
+ # strip quotes from case sensitive enum or domain names
+ enum_or_domain_key = tuple(util.quoted_token_parser(attype))
+
+ nullable = not row_dict["not_null"]
- if attype == "numeric":
+ charlen = charlen_pattern.search(format_type)
if charlen:
- prec, scale = charlen.split(",")
- args = (int(prec), int(scale))
+ charlen = charlen.group(1)
+ args = args_pattern.search(format_type)
+ if args and args.group(1):
+ args = tuple(args_split_pattern.split(args.group(1)))
else:
args = ()
- elif attype == "double precision":
- args = (53,)
- elif attype == "integer":
- args = ()
- elif attype in ("timestamp with time zone", "time with time zone"):
- kwargs["timezone"] = True
- if charlen:
- kwargs["precision"] = int(charlen)
- args = ()
- elif attype in (
- "timestamp without time zone",
- "time without time zone",
- "time",
- ):
- kwargs["timezone"] = False
- if charlen:
- kwargs["precision"] = int(charlen)
- args = ()
- elif attype == "bit varying":
- kwargs["varying"] = True
- if charlen:
+ kwargs = {}
+
+ if attype == "numeric":
+ if charlen:
+ prec, scale = charlen.split(",")
+ args = (int(prec), int(scale))
+ else:
+ args = ()
+ elif attype == "double precision":
+ args = (53,)
+ elif attype == "integer":
+ args = ()
+ elif attype in ("timestamp with time zone", "time with time zone"):
+ kwargs["timezone"] = True
+ if charlen:
+ kwargs["precision"] = int(charlen)
+ args = ()
+ elif attype in (
+ "timestamp without time zone",
+ "time without time zone",
+ "time",
+ ):
+ kwargs["timezone"] = False
+ if charlen:
+ kwargs["precision"] = int(charlen)
+ args = ()
+ elif attype == "bit varying":
+ kwargs["varying"] = True
+ if charlen:
+ args = (int(charlen),)
+ else:
+ args = ()
+ elif attype.startswith("interval"):
+ field_match = re.match(r"interval (.+)", attype, re.I)
+ if charlen:
+ kwargs["precision"] = int(charlen)
+ if field_match:
+ kwargs["fields"] = field_match.group(1)
+ attype = "interval"
+ args = ()
+ elif charlen:
args = (int(charlen),)
+
+ while True:
+ # looping here to suit nested domains
+ if attype in self.ischema_names:
+ coltype = self.ischema_names[attype]
+ break
+ elif enum_or_domain_key in enums:
+ enum = enums[enum_or_domain_key]
+ coltype = ENUM
+ kwargs["name"] = enum["name"]
+ if not enum["visible"]:
+ kwargs["schema"] = enum["schema"]
+ args = tuple(enum["labels"])
+ break
+ elif enum_or_domain_key in domains:
+ domain = domains[enum_or_domain_key]
+ attype = domain["attype"]
+ attype, is_array = _handle_array_type(attype)
+ # strip quotes from case sensitive enum or domain names
+ enum_or_domain_key = tuple(
+ util.quoted_token_parser(attype)
+ )
+ # A table can't override a not null on the domain,
+ # but can override nullable
+ nullable = nullable and domain["nullable"]
+ if domain["default"] and not default:
+ # It can, however, override the default
+ # value, but can't set it to null.
+ default = domain["default"]
+ continue
+ else:
+ coltype = None
+ break
+
+ if coltype:
+ coltype = coltype(*args, **kwargs)
+ if is_array:
+ coltype = self.ischema_names["_array"](coltype)
else:
- args = ()
- elif attype.startswith("interval"):
- field_match = re.match(r"interval (.+)", attype, re.I)
- if charlen:
- kwargs["precision"] = int(charlen)
- if field_match:
- kwargs["fields"] = field_match.group(1)
- attype = "interval"
- args = ()
- elif charlen:
- args = (int(charlen),)
-
- while True:
- # looping here to suit nested domains
- if attype in self.ischema_names:
- coltype = self.ischema_names[attype]
- break
- elif enum_or_domain_key in enums:
- enum = enums[enum_or_domain_key]
- coltype = ENUM
- kwargs["name"] = enum["name"]
- if not enum["visible"]:
- kwargs["schema"] = enum["schema"]
- args = tuple(enum["labels"])
- break
- elif enum_or_domain_key in domains:
- domain = domains[enum_or_domain_key]
- attype = domain["attype"]
- attype, is_array = _handle_array_type(attype)
- # strip quotes from case sensitive enum or domain names
- enum_or_domain_key = tuple(util.quoted_token_parser(attype))
- # A table can't override a not null on the domain,
- # but can override nullable
- nullable = nullable and domain["nullable"]
- if domain["default"] and not default:
- # It can, however, override the default
- # value, but can't set it to null.
- default = domain["default"]
- continue
+ util.warn(
+ "Did not recognize type '%s' of column '%s'"
+ % (attype, name)
+ )
+ coltype = sqltypes.NULLTYPE
+
+ # If a zero byte or blank string depending on driver (is also
+ # absent for older PG versions), then not a generated column.
+ # Otherwise, s = stored. (Other values might be added in the
+ # future.)
+ if generated not in (None, "", b"\x00"):
+ computed = dict(
+ sqltext=default, persisted=generated in ("s", b"s")
+ )
+ default = None
else:
- coltype = None
- break
+ computed = None
- if coltype:
- coltype = coltype(*args, **kwargs)
- if is_array:
- coltype = self.ischema_names["_array"](coltype)
- else:
- util.warn(
- "Did not recognize type '%s' of column '%s'" % (attype, name)
+ # adjust the default value
+ autoincrement = False
+ if default is not None:
+ match = re.search(r"""(nextval\(')([^']+)('.*$)""", default)
+ if match is not None:
+ if issubclass(coltype._type_affinity, sqltypes.Integer):
+ autoincrement = True
+ # the default is related to a Sequence
+ if "." not in match.group(2) and schema is not None:
+ # unconditionally quote the schema name. this could
+ # later be enhanced to obey quoting rules /
+ # "quote schema"
+ default = (
+ match.group(1)
+ + ('"%s"' % schema)
+ + "."
+ + match.group(2)
+ + match.group(3)
+ )
+
+ column_info = {
+ "name": name,
+ "type": coltype,
+ "nullable": nullable,
+ "default": default,
+ "autoincrement": autoincrement or identity is not None,
+ "comment": row_dict["comment"],
+ }
+ if computed is not None:
+ column_info["computed"] = computed
+ if identity is not None:
+ column_info["identity"] = identity
+
+ table_cols.append(column_info)
+
+ return columns
+
+ @lru_cache()
+ def _table_oids_query(self, schema, has_filter_names, scope, kind):
+ relkinds = self._kind_to_relkinds(kind)
+ oid_q = select(
+ pg_catalog.pg_class.c.oid, pg_catalog.pg_class.c.relname
+ ).where(self._pg_class_relkind_condition(relkinds))
+ oid_q = self._pg_class_filter_scope_schema(oid_q, schema, scope=scope)
+
+ if has_filter_names:
+ oid_q = oid_q.where(
+ pg_catalog.pg_class.c.relname.in_(bindparam("filter_names"))
)
- coltype = sqltypes.NULLTYPE
-
- # If a zero byte or blank string depending on driver (is also absent
- # for older PG versions), then not a generated column. Otherwise, s =
- # stored. (Other values might be added in the future.)
- if generated not in (None, "", b"\x00"):
- computed = dict(
- sqltext=default, persisted=generated in ("s", b"s")
+ return oid_q
+
+ @reflection.flexi_cache(
+ ("schema", InternalTraversal.dp_string),
+ ("filter_names", InternalTraversal.dp_string_list),
+ ("kind", InternalTraversal.dp_plain_obj),
+ ("scope", InternalTraversal.dp_plain_obj),
+ )
+ def _get_table_oids(
+ self, connection, schema, filter_names, scope, kind, **kw
+ ):
+ has_filter_names, params = self._prepare_filter_names(filter_names)
+ oid_q = self._table_oids_query(schema, has_filter_names, scope, kind)
+ result = connection.execute(oid_q, params)
+ return result.all()
+
+ @util.memoized_property
+ def _constraint_query(self):
+ con_sq = (
+ select(
+ pg_catalog.pg_constraint.c.conrelid,
+ pg_catalog.pg_constraint.c.conname,
+ sql.func.unnest(pg_catalog.pg_constraint.c.conkey).label(
+ "attnum"
+ ),
+ sql.func.generate_subscripts(
+ pg_catalog.pg_constraint.c.conkey, 1
+ ).label("ord"),
)
- default = None
- else:
- computed = None
-
- # adjust the default value
- autoincrement = False
- if default is not None:
- match = re.search(r"""(nextval\(')([^']+)('.*$)""", default)
- if match is not None:
- if issubclass(coltype._type_affinity, sqltypes.Integer):
- autoincrement = True
- # the default is related to a Sequence
- sch = schema
- if "." not in match.group(2) and sch is not None:
- # unconditionally quote the schema name. this could
- # later be enhanced to obey quoting rules /
- # "quote schema"
- default = (
- match.group(1)
- + ('"%s"' % sch)
- + "."
- + match.group(2)
- + match.group(3)
- )
+ .where(
+ pg_catalog.pg_constraint.c.contype == bindparam("contype"),
+ pg_catalog.pg_constraint.c.conrelid.in_(bindparam("oids")),
+ )
+ .subquery("con")
+ )
- column_info = dict(
- name=name,
- type=coltype,
- nullable=nullable,
- default=default,
- autoincrement=autoincrement or identity is not None,
- comment=comment,
+ attr_sq = (
+ select(
+ con_sq.c.conrelid,
+ con_sq.c.conname,
+ pg_catalog.pg_attribute.c.attname,
+ )
+ .select_from(pg_catalog.pg_attribute)
+ .join(
+ con_sq,
+ sql.and_(
+ pg_catalog.pg_attribute.c.attnum == con_sq.c.attnum,
+ pg_catalog.pg_attribute.c.attrelid == con_sq.c.conrelid,
+ ),
+ )
+ .order_by(con_sq.c.conname, con_sq.c.ord)
+ .subquery("attr")
)
- if computed is not None:
- column_info["computed"] = computed
- if identity is not None:
- column_info["identity"] = identity
- return column_info
- @reflection.cache
- def get_pk_constraint(self, connection, table_name, schema=None, **kw):
- table_oid = self.get_table_oid(
- connection, table_name, schema, info_cache=kw.get("info_cache")
+ return (
+ select(
+ attr_sq.c.conrelid,
+ sql.func.array_agg(attr_sq.c.attname).label("cols"),
+ attr_sq.c.conname,
+ )
+ .group_by(attr_sq.c.conrelid, attr_sq.c.conname)
+ .order_by(attr_sq.c.conrelid, attr_sq.c.conname)
)
- if self.server_version_info < (8, 4):
- PK_SQL = """
- SELECT a.attname
- FROM
- pg_class t
- join pg_index ix on t.oid = ix.indrelid
- join pg_attribute a
- on t.oid=a.attrelid AND %s
- WHERE
- t.oid = :table_oid and ix.indisprimary = 't'
- ORDER BY a.attnum
- """ % self._pg_index_any(
- "a.attnum", "ix.indkey"
+ def _reflect_constraint(
+ self, connection, contype, schema, filter_names, scope, kind, **kw
+ ):
+ table_oids = self._get_table_oids(
+ connection, schema, filter_names, scope, kind, **kw
+ )
+ batches = list(table_oids)
+
+ while batches:
+ batch = batches[0:3000]
+ batches[0:3000] = []
+
+ result = connection.execute(
+ self._constraint_query,
+ {"oids": [r[0] for r in batch], "contype": contype},
)
- else:
- # unnest() and generate_subscripts() both introduced in
- # version 8.4
- PK_SQL = """
- SELECT a.attname
- FROM pg_attribute a JOIN (
- SELECT unnest(ix.indkey) attnum,
- generate_subscripts(ix.indkey, 1) ord
- FROM pg_index ix
- WHERE ix.indrelid = :table_oid AND ix.indisprimary
- ) k ON a.attnum=k.attnum
- WHERE a.attrelid = :table_oid
- ORDER BY k.ord
- """
- t = sql.text(PK_SQL).columns(attname=sqltypes.Unicode)
- c = connection.execute(t, dict(table_oid=table_oid))
- cols = [r[0] for r in c.fetchall()]
-
- PK_CONS_SQL = """
- SELECT conname
- FROM pg_catalog.pg_constraint r
- WHERE r.conrelid = :table_oid AND r.contype = 'p'
- ORDER BY 1
- """
- t = sql.text(PK_CONS_SQL).columns(conname=sqltypes.Unicode)
- c = connection.execute(t, dict(table_oid=table_oid))
- name = c.scalar()
+ result_by_oid = defaultdict(list)
+ for oid, cols, constraint_name in result:
+ result_by_oid[oid].append((cols, constraint_name))
+
+ for oid, tablename in batch:
+ for_oid = result_by_oid.get(oid, ())
+ if for_oid:
+ for cols, constraint in for_oid:
+ yield tablename, cols, constraint
+ else:
+ yield tablename, None, None
+
+ @reflection.cache
+ def get_pk_constraint(self, connection, table_name, schema=None, **kw):
+ data = self.get_multi_pk_constraint(
+ connection,
+ schema=schema,
+ filter_names=[table_name],
+ scope=ObjectScope.ANY,
+ kind=ObjectKind.ANY,
+ **kw,
+ )
+ return self._value_or_raise(data, table_name, schema)
+
+ def get_multi_pk_constraint(
+ self, connection, schema, filter_names, scope, kind, **kw
+ ):
+ result = self._reflect_constraint(
+ connection, "p", schema, filter_names, scope, kind, **kw
+ )
- return {"constrained_columns": cols, "name": name}
+ # only a single pk can be present for each table. Return an entry
+ # even if a table has no primary key
+ default = ReflectionDefaults.pk_constraint
+ return (
+ (
+ (schema, table_name),
+ {
+ "constrained_columns": [] if cols is None else cols,
+ "name": pk_name,
+ }
+ if pk_name is not None
+ else default(),
+ )
+ for (table_name, cols, pk_name) in result
+ )
@reflection.cache
def get_foreign_keys(
postgresql_ignore_search_path=False,
**kw,
):
- preparer = self.identifier_preparer
- table_oid = self.get_table_oid(
- connection, table_name, schema, info_cache=kw.get("info_cache")
+ data = self.get_multi_foreign_keys(
+ connection,
+ schema=schema,
+ filter_names=[table_name],
+ postgresql_ignore_search_path=postgresql_ignore_search_path,
+ scope=ObjectScope.ANY,
+ kind=ObjectKind.ANY,
+ **kw,
)
+ return self._value_or_raise(data, table_name, schema)
- FK_SQL = """
- SELECT r.conname,
- pg_catalog.pg_get_constraintdef(r.oid, true) as condef,
- n.nspname as conschema
- FROM pg_catalog.pg_constraint r,
- pg_namespace n,
- pg_class c
-
- WHERE r.conrelid = :table AND
- r.contype = 'f' AND
- c.oid = confrelid AND
- n.oid = c.relnamespace
- ORDER BY 1
- """
- # https://www.postgresql.org/docs/9.0/static/sql-createtable.html
- FK_REGEX = re.compile(
+ @lru_cache()
+ def _foreing_key_query(self, schema, has_filter_names, scope, kind):
+ pg_class_ref = pg_catalog.pg_class.alias("cls_ref")
+ pg_namespace_ref = pg_catalog.pg_namespace.alias("nsp_ref")
+ relkinds = self._kind_to_relkinds(kind)
+ query = (
+ select(
+ pg_catalog.pg_class.c.relname,
+ pg_catalog.pg_constraint.c.conname,
+ sql.case(
+ (
+ pg_catalog.pg_constraint.c.oid.is_not(None),
+ pg_catalog.pg_get_constraintdef(
+ pg_catalog.pg_constraint.c.oid, True
+ ),
+ ),
+ else_=None,
+ ),
+ pg_namespace_ref.c.nspname,
+ )
+ .select_from(pg_catalog.pg_class)
+ .outerjoin(
+ pg_catalog.pg_constraint,
+ sql.and_(
+ pg_catalog.pg_class.c.oid
+ == pg_catalog.pg_constraint.c.conrelid,
+ pg_catalog.pg_constraint.c.contype == "f",
+ ),
+ )
+ .outerjoin(
+ pg_class_ref,
+ pg_class_ref.c.oid == pg_catalog.pg_constraint.c.confrelid,
+ )
+ .outerjoin(
+ pg_namespace_ref,
+ pg_class_ref.c.relnamespace == pg_namespace_ref.c.oid,
+ )
+ .order_by(
+ pg_catalog.pg_class.c.relname,
+ pg_catalog.pg_constraint.c.conname,
+ )
+ .where(self._pg_class_relkind_condition(relkinds))
+ )
+ query = self._pg_class_filter_scope_schema(query, schema, scope)
+ if has_filter_names:
+ query = query.where(
+ pg_catalog.pg_class.c.relname.in_(bindparam("filter_names"))
+ )
+ return query
+
+ @util.memoized_property
+ def _fk_regex_pattern(self):
+ # https://www.postgresql.org/docs/14.0/static/sql-createtable.html
+ return re.compile(
r"FOREIGN KEY \((.*?)\) REFERENCES (?:(.*?)\.)?(.*?)\((.*?)\)"
r"[\s]?(MATCH (FULL|PARTIAL|SIMPLE)+)?"
r"[\s]?(ON UPDATE "
r"[\s]?(INITIALLY (DEFERRED|IMMEDIATE)+)?"
)
- t = sql.text(FK_SQL).columns(
- conname=sqltypes.Unicode, condef=sqltypes.Unicode
- )
- c = connection.execute(t, dict(table=table_oid))
- fkeys = []
- for conname, condef, conschema in c.fetchall():
+ def get_multi_foreign_keys(
+ self,
+ connection,
+ schema,
+ filter_names,
+ scope,
+ kind,
+ postgresql_ignore_search_path=False,
+ **kw,
+ ):
+ preparer = self.identifier_preparer
+
+ has_filter_names, params = self._prepare_filter_names(filter_names)
+ query = self._foreing_key_query(schema, has_filter_names, scope, kind)
+ result = connection.execute(query, params)
+
+ FK_REGEX = self._fk_regex_pattern
+
+ fkeys = defaultdict(list)
+ default = ReflectionDefaults.foreign_keys
+ for table_name, conname, condef, conschema in result:
+ # ensure that each table has an entry, even if it has
+ # no foreign keys
+ if conname is None:
+ fkeys[(schema, table_name)] = default()
+ continue
+ table_fks = fkeys[(schema, table_name)]
m = re.search(FK_REGEX, condef).groups()
(
"referred_columns": referred_columns,
"options": options,
}
- fkeys.append(fkey_d)
- return fkeys
-
- def _pg_index_any(self, col, compare_to):
- if self.server_version_info < (8, 1):
- # https://www.postgresql.org/message-id/10279.1124395722@sss.pgh.pa.us
- # "In CVS tip you could replace this with "attnum = ANY (indkey)".
- # Unfortunately, most array support doesn't work on int2vector in
- # pre-8.1 releases, so I think you're kinda stuck with the above
- # for now.
- # regards, tom lane"
- return "(%s)" % " OR ".join(
- "%s[%d] = %s" % (compare_to, ind, col) for ind in range(0, 10)
- )
- else:
- return "%s = ANY(%s)" % (col, compare_to)
+ table_fks.append(fkey_d)
+ return fkeys.items()
@reflection.cache
- def get_indexes(self, connection, table_name, schema, **kw):
- table_oid = self.get_table_oid(
- connection, table_name, schema, info_cache=kw.get("info_cache")
+ def get_indexes(self, connection, table_name, schema=None, **kw):
+ data = self.get_multi_indexes(
+ connection,
+ schema=schema,
+ filter_names=[table_name],
+ scope=ObjectScope.ANY,
+ kind=ObjectKind.ANY,
+ **kw,
)
+ return self._value_or_raise(data, table_name, schema)
- # cast indkey as varchar since it's an int2vector,
- # returned as a list by some drivers such as pypostgresql
-
- if self.server_version_info < (8, 5):
- IDX_SQL = """
- SELECT
- i.relname as relname,
- ix.indisunique, ix.indexprs, ix.indpred,
- a.attname, a.attnum, NULL, ix.indkey%s,
- %s, %s, am.amname,
- NULL as indnkeyatts
- FROM
- pg_class t
- join pg_index ix on t.oid = ix.indrelid
- join pg_class i on i.oid = ix.indexrelid
- left outer join
- pg_attribute a
- on t.oid = a.attrelid and %s
- left outer join
- pg_am am
- on i.relam = am.oid
- WHERE
- t.relkind IN ('r', 'v', 'f', 'm')
- and t.oid = :table_oid
- and ix.indisprimary = 'f'
- ORDER BY
- t.relname,
- i.relname
- """ % (
- # version 8.3 here was based on observing the
- # cast does not work in PG 8.2.4, does work in 8.3.0.
- # nothing in PG changelogs regarding this.
- "::varchar" if self.server_version_info >= (8, 3) else "",
- "ix.indoption::varchar"
- if self.server_version_info >= (8, 3)
- else "NULL",
- "i.reloptions"
- if self.server_version_info >= (8, 2)
- else "NULL",
- self._pg_index_any("a.attnum", "ix.indkey"),
+ @util.memoized_property
+ def _index_query(self):
+ pg_class_index = pg_catalog.pg_class.alias("cls_idx")
+ # NOTE: repeating oids clause improve query performance
+
+ # subquery to get the columns
+ idx_sq = (
+ select(
+ pg_catalog.pg_index.c.indexrelid,
+ pg_catalog.pg_index.c.indrelid,
+ sql.func.unnest(pg_catalog.pg_index.c.indkey).label("attnum"),
+ sql.func.generate_subscripts(
+ pg_catalog.pg_index.c.indkey, 1
+ ).label("ord"),
)
- else:
- IDX_SQL = """
- SELECT
- i.relname as relname,
- ix.indisunique, ix.indexprs,
- a.attname, a.attnum, c.conrelid, ix.indkey::varchar,
- ix.indoption::varchar, i.reloptions, am.amname,
- pg_get_expr(ix.indpred, ix.indrelid),
- %s as indnkeyatts
- FROM
- pg_class t
- join pg_index ix on t.oid = ix.indrelid
- join pg_class i on i.oid = ix.indexrelid
- left outer join
- pg_attribute a
- on t.oid = a.attrelid and a.attnum = ANY(ix.indkey)
- left outer join
- pg_constraint c
- on (ix.indrelid = c.conrelid and
- ix.indexrelid = c.conindid and
- c.contype in ('p', 'u', 'x'))
- left outer join
- pg_am am
- on i.relam = am.oid
- WHERE
- t.relkind IN ('r', 'v', 'f', 'm', 'p')
- and t.oid = :table_oid
- and ix.indisprimary = 'f'
- ORDER BY
- t.relname,
- i.relname
- """ % (
- "ix.indnkeyatts"
- if self.server_version_info >= (11, 0)
- else "NULL",
+ .where(
+ ~pg_catalog.pg_index.c.indisprimary,
+ pg_catalog.pg_index.c.indrelid.in_(bindparam("oids")),
)
+ .subquery("idx")
+ )
- t = sql.text(IDX_SQL).columns(
- relname=sqltypes.Unicode, attname=sqltypes.Unicode
+ attr_sq = (
+ select(
+ idx_sq.c.indexrelid,
+ idx_sq.c.indrelid,
+ pg_catalog.pg_attribute.c.attname,
+ )
+ .select_from(pg_catalog.pg_attribute)
+ .join(
+ idx_sq,
+ sql.and_(
+ pg_catalog.pg_attribute.c.attnum == idx_sq.c.attnum,
+ pg_catalog.pg_attribute.c.attrelid == idx_sq.c.indrelid,
+ ),
+ )
+ .where(idx_sq.c.indrelid.in_(bindparam("oids")))
+ .order_by(idx_sq.c.indexrelid, idx_sq.c.ord)
+ .subquery("idx_attr")
)
- c = connection.execute(t, dict(table_oid=table_oid))
- indexes = defaultdict(lambda: defaultdict(dict))
+ cols_sq = (
+ select(
+ attr_sq.c.indexrelid,
+ attr_sq.c.indrelid,
+ sql.func.array_agg(attr_sq.c.attname).label("cols"),
+ )
+ .group_by(attr_sq.c.indexrelid, attr_sq.c.indrelid)
+ .subquery("idx_cols")
+ )
- sv_idx_name = None
- for row in c.fetchall():
- (
- idx_name,
- unique,
- expr,
- col,
- col_num,
- conrelid,
- idx_key,
- idx_option,
- options,
- amname,
- filter_definition,
- indnkeyatts,
- ) = row
+ if self.server_version_info >= (11, 0):
+ indnkeyatts = pg_catalog.pg_index.c.indnkeyatts
+ else:
+ indnkeyatts = sql.null().label("indnkeyatts")
- if expr:
- if idx_name != sv_idx_name:
- util.warn(
- "Skipped unsupported reflection of "
- "expression-based index %s" % idx_name
- )
- sv_idx_name = idx_name
- continue
+ query = (
+ select(
+ pg_catalog.pg_index.c.indrelid,
+ pg_class_index.c.relname.label("relname_index"),
+ pg_catalog.pg_index.c.indisunique,
+ pg_catalog.pg_index.c.indexprs,
+ pg_catalog.pg_constraint.c.conrelid.is_not(None).label(
+ "has_constraint"
+ ),
+ pg_catalog.pg_index.c.indoption,
+ pg_class_index.c.reloptions,
+ pg_catalog.pg_am.c.amname,
+ pg_catalog.pg_get_expr(
+ pg_catalog.pg_index.c.indpred,
+ pg_catalog.pg_index.c.indrelid,
+ ).label("filter_definition"),
+ indnkeyatts,
+ cols_sq.c.cols.label("index_cols"),
+ )
+ .select_from(pg_catalog.pg_index)
+ .where(
+ pg_catalog.pg_index.c.indrelid.in_(bindparam("oids")),
+ ~pg_catalog.pg_index.c.indisprimary,
+ )
+ .join(
+ pg_class_index,
+ pg_catalog.pg_index.c.indexrelid == pg_class_index.c.oid,
+ )
+ .join(
+ pg_catalog.pg_am,
+ pg_class_index.c.relam == pg_catalog.pg_am.c.oid,
+ )
+ .outerjoin(
+ cols_sq,
+ pg_catalog.pg_index.c.indexrelid == cols_sq.c.indexrelid,
+ )
+ .outerjoin(
+ pg_catalog.pg_constraint,
+ sql.and_(
+ pg_catalog.pg_index.c.indrelid
+ == pg_catalog.pg_constraint.c.conrelid,
+ pg_catalog.pg_index.c.indexrelid
+ == pg_catalog.pg_constraint.c.conindid,
+ pg_catalog.pg_constraint.c.contype
+ == sql.any_(_array.array(("p", "u", "x"))),
+ ),
+ )
+ .order_by(pg_catalog.pg_index.c.indrelid, pg_class_index.c.relname)
+ )
+ return query
- has_idx = idx_name in indexes
- index = indexes[idx_name]
- if col is not None:
- index["cols"][col_num] = col
- if not has_idx:
- idx_keys = idx_key.split()
- # "The number of key columns in the index, not counting any
- # included columns, which are merely stored and do not
- # participate in the index semantics"
- if indnkeyatts and idx_keys[indnkeyatts:]:
- # this is a "covering index" which has INCLUDE columns
- # as well as regular index columns
- inc_keys = idx_keys[indnkeyatts:]
- idx_keys = idx_keys[:indnkeyatts]
- else:
- inc_keys = []
+ def get_multi_indexes(
+ self, connection, schema, filter_names, scope, kind, **kw
+ ):
- index["key"] = [int(k.strip()) for k in idx_keys]
- index["inc"] = [int(k.strip()) for k in inc_keys]
+ table_oids = self._get_table_oids(
+ connection, schema, filter_names, scope, kind, **kw
+ )
- # (new in pg 8.3)
- # "pg_index.indoption" is list of ints, one per column/expr.
- # int acts as bitmask: 0x01=DESC, 0x02=NULLSFIRST
- sorting = {}
- for col_idx, col_flags in enumerate(
- (idx_option or "").split()
- ):
- col_flags = int(col_flags.strip())
- col_sorting = ()
- # try to set flags only if they differ from PG defaults...
- if col_flags & 0x01:
- col_sorting += ("desc",)
- if not (col_flags & 0x02):
- col_sorting += ("nulls_last",)
+ indexes = defaultdict(list)
+ default = ReflectionDefaults.indexes
+
+ batches = list(table_oids)
+
+ while batches:
+ batch = batches[0:3000]
+ batches[0:3000] = []
+
+ result = connection.execute(
+ self._index_query, {"oids": [r[0] for r in batch]}
+ ).mappings()
+
+ result_by_oid = defaultdict(list)
+ for row_dict in result:
+ result_by_oid[row_dict["indrelid"]].append(row_dict)
+
+ for oid, table_name in batch:
+ if oid not in result_by_oid:
+ # ensure that each table has an entry, even if reflection
+ # is skipped because not supported
+ indexes[(schema, table_name)] = default()
+ continue
+
+ for row in result_by_oid[oid]:
+ index_name = row["relname_index"]
+
+ table_indexes = indexes[(schema, table_name)]
+
+ if row["indexprs"]:
+ tn = (
+ table_name
+ if schema is None
+ else f"{schema}.{table_name}"
+ )
+ util.warn(
+ "Skipped unsupported reflection of "
+ f"expression-based index {index_name} of "
+ f"table {tn}"
+ )
+ continue
+
+ all_cols = row["index_cols"]
+ indnkeyatts = row["indnkeyatts"]
+ # "The number of key columns in the index, not counting any
+ # included columns, which are merely stored and do not
+ # participate in the index semantics"
+ if indnkeyatts and all_cols[indnkeyatts:]:
+ # this is a "covering index" which has INCLUDE columns
+ # as well as regular index columns
+ inc_cols = all_cols[indnkeyatts:]
+ idx_cols = all_cols[:indnkeyatts]
else:
- if col_flags & 0x02:
- col_sorting += ("nulls_first",)
- if col_sorting:
- sorting[col_idx] = col_sorting
- if sorting:
- index["sorting"] = sorting
-
- index["unique"] = unique
- if conrelid is not None:
- index["duplicates_constraint"] = idx_name
- if options:
- index["options"] = dict(
- [option.split("=") for option in options]
- )
-
- # it *might* be nice to include that this is 'btree' in the
- # reflection info. But we don't want an Index object
- # to have a ``postgresql_using`` in it that is just the
- # default, so for the moment leaving this out.
- if amname and amname != "btree":
- index["amname"] = amname
-
- if filter_definition:
- index["postgresql_where"] = filter_definition
+ idx_cols = all_cols
+ inc_cols = []
+
+ index = {
+ "name": index_name,
+ "unique": row["indisunique"],
+ "column_names": idx_cols,
+ }
+
+ sorting = {}
+ for col_index, col_flags in enumerate(row["indoption"]):
+ col_sorting = ()
+ # try to set flags only if they differ from PG
+ # defaults...
+ if col_flags & 0x01:
+ col_sorting += ("desc",)
+ if not (col_flags & 0x02):
+ col_sorting += ("nulls_last",)
+ else:
+ if col_flags & 0x02:
+ col_sorting += ("nulls_first",)
+ if col_sorting:
+ sorting[idx_cols[col_index]] = col_sorting
+ if sorting:
+ index["column_sorting"] = sorting
+ if row["has_constraint"]:
+ index["duplicates_constraint"] = index_name
+
+ dialect_options = {}
+ if row["reloptions"]:
+ dialect_options["postgresql_with"] = dict(
+ [option.split("=") for option in row["reloptions"]]
+ )
+ # it *might* be nice to include that this is 'btree' in the
+ # reflection info. But we don't want an Index object
+ # to have a ``postgresql_using`` in it that is just the
+ # default, so for the moment leaving this out.
+ amname = row["amname"]
+ if amname != "btree":
+ dialect_options["postgresql_using"] = row["amname"]
+ if row["filter_definition"]:
+ dialect_options["postgresql_where"] = row[
+ "filter_definition"
+ ]
+ if self.server_version_info >= (11, 0):
+ # NOTE: this is legacy, this is part of
+ # dialect_options now as of #7382
+ index["include_columns"] = inc_cols
+ dialect_options["postgresql_include"] = inc_cols
+ if dialect_options:
+ index["dialect_options"] = dialect_options
- result = []
- for name, idx in indexes.items():
- entry = {
- "name": name,
- "unique": idx["unique"],
- "column_names": [idx["cols"][i] for i in idx["key"]],
- }
- if self.server_version_info >= (11, 0):
- # NOTE: this is legacy, this is part of dialect_options now
- # as of #7382
- entry["include_columns"] = [idx["cols"][i] for i in idx["inc"]]
- if "duplicates_constraint" in idx:
- entry["duplicates_constraint"] = idx["duplicates_constraint"]
- if "sorting" in idx:
- entry["column_sorting"] = dict(
- (idx["cols"][idx["key"][i]], value)
- for i, value in idx["sorting"].items()
- )
- if "include_columns" in entry:
- entry.setdefault("dialect_options", {})[
- "postgresql_include"
- ] = entry["include_columns"]
- if "options" in idx:
- entry.setdefault("dialect_options", {})[
- "postgresql_with"
- ] = idx["options"]
- if "amname" in idx:
- entry.setdefault("dialect_options", {})[
- "postgresql_using"
- ] = idx["amname"]
- if "postgresql_where" in idx:
- entry.setdefault("dialect_options", {})[
- "postgresql_where"
- ] = idx["postgresql_where"]
- result.append(entry)
- return result
+ table_indexes.append(index)
+ return indexes.items()
@reflection.cache
def get_unique_constraints(
self, connection, table_name, schema=None, **kw
):
- table_oid = self.get_table_oid(
- connection, table_name, schema, info_cache=kw.get("info_cache")
+ data = self.get_multi_unique_constraints(
+ connection,
+ schema=schema,
+ filter_names=[table_name],
+ scope=ObjectScope.ANY,
+ kind=ObjectKind.ANY,
+ **kw,
)
+ return self._value_or_raise(data, table_name, schema)
- UNIQUE_SQL = """
- SELECT
- cons.conname as name,
- cons.conkey as key,
- a.attnum as col_num,
- a.attname as col_name
- FROM
- pg_catalog.pg_constraint cons
- join pg_attribute a
- on cons.conrelid = a.attrelid AND
- a.attnum = ANY(cons.conkey)
- WHERE
- cons.conrelid = :table_oid AND
- cons.contype = 'u'
- """
-
- t = sql.text(UNIQUE_SQL).columns(col_name=sqltypes.Unicode)
- c = connection.execute(t, dict(table_oid=table_oid))
+ def get_multi_unique_constraints(
+ self,
+ connection,
+ schema,
+ filter_names,
+ scope,
+ kind,
+ **kw,
+ ):
+ result = self._reflect_constraint(
+ connection, "u", schema, filter_names, scope, kind, **kw
+ )
- uniques = defaultdict(lambda: defaultdict(dict))
- for row in c.fetchall():
- uc = uniques[row.name]
- uc["key"] = row.key
- uc["cols"][row.col_num] = row.col_name
+ # each table can have multiple unique constraints
+ uniques = defaultdict(list)
+ default = ReflectionDefaults.unique_constraints
+ for (table_name, cols, con_name) in result:
+ # ensure a list is created for each table. leave it empty if
+ # the table has no unique cosntraint
+ if con_name is None:
+ uniques[(schema, table_name)] = default()
+ continue
- return [
- {"name": name, "column_names": [uc["cols"][i] for i in uc["key"]]}
- for name, uc in uniques.items()
- ]
+ uniques[(schema, table_name)].append(
+ {
+ "column_names": cols,
+ "name": con_name,
+ }
+ )
+ return uniques.items()
@reflection.cache
def get_table_comment(self, connection, table_name, schema=None, **kw):
- table_oid = self.get_table_oid(
- connection, table_name, schema, info_cache=kw.get("info_cache")
+ data = self.get_multi_table_comment(
+ connection,
+ schema,
+ [table_name],
+ scope=ObjectScope.ANY,
+ kind=ObjectKind.ANY,
+ **kw,
)
+ return self._value_or_raise(data, table_name, schema)
- COMMENT_SQL = """
- SELECT
- pgd.description as table_comment
- FROM
- pg_catalog.pg_description pgd
- WHERE
- pgd.objsubid = 0 AND
- pgd.objoid = :table_oid
- """
+ @lru_cache()
+ def _comment_query(self, schema, has_filter_names, scope, kind):
+ relkinds = self._kind_to_relkinds(kind)
+ query = (
+ select(
+ pg_catalog.pg_class.c.relname,
+ pg_catalog.pg_description.c.description,
+ )
+ .select_from(pg_catalog.pg_class)
+ .outerjoin(
+ pg_catalog.pg_description,
+ sql.and_(
+ pg_catalog.pg_class.c.oid
+ == pg_catalog.pg_description.c.objoid,
+ pg_catalog.pg_description.c.objsubid == 0,
+ ),
+ )
+ .where(self._pg_class_relkind_condition(relkinds))
+ )
+ query = self._pg_class_filter_scope_schema(query, schema, scope)
+ if has_filter_names:
+ query = query.where(
+ pg_catalog.pg_class.c.relname.in_(bindparam("filter_names"))
+ )
+ return query
- c = connection.execute(
- sql.text(COMMENT_SQL), dict(table_oid=table_oid)
+ def get_multi_table_comment(
+ self, connection, schema, filter_names, scope, kind, **kw
+ ):
+ has_filter_names, params = self._prepare_filter_names(filter_names)
+ query = self._comment_query(schema, has_filter_names, scope, kind)
+ result = connection.execute(query, params)
+
+ default = ReflectionDefaults.table_comment
+ return (
+ (
+ (schema, table),
+ {"text": comment} if comment is not None else default(),
+ )
+ for table, comment in result
)
- return {"text": c.scalar()}
@reflection.cache
def get_check_constraints(self, connection, table_name, schema=None, **kw):
- table_oid = self.get_table_oid(
- connection, table_name, schema, info_cache=kw.get("info_cache")
+ data = self.get_multi_check_constraints(
+ connection,
+ schema,
+ [table_name],
+ scope=ObjectScope.ANY,
+ kind=ObjectKind.ANY,
+ **kw,
)
+ return self._value_or_raise(data, table_name, schema)
- CHECK_SQL = """
- SELECT
- cons.conname as name,
- pg_get_constraintdef(cons.oid) as src
- FROM
- pg_catalog.pg_constraint cons
- WHERE
- cons.conrelid = :table_oid AND
- cons.contype = 'c'
- """
-
- c = connection.execute(sql.text(CHECK_SQL), dict(table_oid=table_oid))
+ @lru_cache()
+ def _check_constraint_query(self, schema, has_filter_names, scope, kind):
+ relkinds = self._kind_to_relkinds(kind)
+ query = (
+ select(
+ pg_catalog.pg_class.c.relname,
+ pg_catalog.pg_constraint.c.conname,
+ sql.case(
+ (
+ pg_catalog.pg_constraint.c.oid.is_not(None),
+ pg_catalog.pg_get_constraintdef(
+ pg_catalog.pg_constraint.c.oid
+ ),
+ ),
+ else_=None,
+ ),
+ )
+ .select_from(pg_catalog.pg_class)
+ .outerjoin(
+ pg_catalog.pg_constraint,
+ sql.and_(
+ pg_catalog.pg_class.c.oid
+ == pg_catalog.pg_constraint.c.conrelid,
+ pg_catalog.pg_constraint.c.contype == "c",
+ ),
+ )
+ .where(self._pg_class_relkind_condition(relkinds))
+ )
+ query = self._pg_class_filter_scope_schema(query, schema, scope)
+ if has_filter_names:
+ query = query.where(
+ pg_catalog.pg_class.c.relname.in_(bindparam("filter_names"))
+ )
+ return query
- ret = []
- for name, src in c:
+ def get_multi_check_constraints(
+ self, connection, schema, filter_names, scope, kind, **kw
+ ):
+ has_filter_names, params = self._prepare_filter_names(filter_names)
+ query = self._check_constraint_query(
+ schema, has_filter_names, scope, kind
+ )
+ result = connection.execute(query, params)
+
+ check_constraints = defaultdict(list)
+ default = ReflectionDefaults.check_constraints
+ for table_name, check_name, src in result:
+ # only two cases for check_name and src: both null or both defined
+ if check_name is None and src is None:
+ check_constraints[(schema, table_name)] = default()
+ continue
# samples:
# "CHECK (((a > 1) AND (a < 5)))"
# "CHECK (((a = 1) OR ((a > 2) AND (a < 5))))"
sqltext = re.compile(
r"^[\s\n]*\((.+)\)[\s\n]*$", flags=re.DOTALL
).sub(r"\1", m.group(1))
- entry = {"name": name, "sqltext": sqltext}
+ entry = {"name": check_name, "sqltext": sqltext}
if m and m.group(2):
entry["dialect_options"] = {"not_valid": True}
- ret.append(entry)
- return ret
-
- def _load_enums(self, connection, schema=None):
- schema = schema or self.default_schema_name
- if not self.supports_native_enum:
- return {}
-
- # Load data types for enums:
- SQL_ENUMS = """
- SELECT t.typname as "name",
- -- no enum defaults in 8.4 at least
- -- t.typdefault as "default",
- pg_catalog.pg_type_is_visible(t.oid) as "visible",
- n.nspname as "schema",
- e.enumlabel as "label"
- FROM pg_catalog.pg_type t
- LEFT JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace
- LEFT JOIN pg_catalog.pg_enum e ON t.oid = e.enumtypid
- WHERE t.typtype = 'e'
- """
+ check_constraints[(schema, table_name)].append(entry)
+ return check_constraints.items()
- if schema != "*":
- SQL_ENUMS += "AND n.nspname = :schema "
+ @lru_cache()
+ def _enum_query(self, schema):
+ lbl_sq = (
+ select(
+ pg_catalog.pg_enum.c.enumtypid, pg_catalog.pg_enum.c.enumlabel
+ )
+ .order_by(
+ pg_catalog.pg_enum.c.enumtypid,
+ pg_catalog.pg_enum.c.enumsortorder,
+ )
+ .subquery("lbl")
+ )
- # e.oid gives us label order within an enum
- SQL_ENUMS += 'ORDER BY "schema", "name", e.oid'
+ lbl_agg_sq = (
+ select(
+ lbl_sq.c.enumtypid,
+ sql.func.array_agg(lbl_sq.c.enumlabel).label("labels"),
+ )
+ .group_by(lbl_sq.c.enumtypid)
+ .subquery("lbl_agg")
+ )
- s = sql.text(SQL_ENUMS).columns(
- attname=sqltypes.Unicode, label=sqltypes.Unicode
+ query = (
+ select(
+ pg_catalog.pg_type.c.typname.label("name"),
+ pg_catalog.pg_type_is_visible(pg_catalog.pg_type.c.oid).label(
+ "visible"
+ ),
+ pg_catalog.pg_namespace.c.nspname.label("schema"),
+ lbl_agg_sq.c.labels.label("labels"),
+ )
+ .join(
+ pg_catalog.pg_namespace,
+ pg_catalog.pg_namespace.c.oid
+ == pg_catalog.pg_type.c.typnamespace,
+ )
+ .outerjoin(
+ lbl_agg_sq, pg_catalog.pg_type.c.oid == lbl_agg_sq.c.enumtypid
+ )
+ .where(pg_catalog.pg_type.c.typtype == "e")
+ .order_by(
+ pg_catalog.pg_namespace.c.nspname, pg_catalog.pg_type.c.typname
+ )
)
- if schema != "*":
- s = s.bindparams(schema=schema)
+ if schema is None:
+ query = query.where(
+ pg_catalog.pg_type_is_visible(pg_catalog.pg_type.c.oid),
+ # ignore pg_catalog schema
+ pg_catalog.pg_namespace.c.nspname != "pg_catalog",
+ )
+ elif schema != "*":
+ query = query.where(pg_catalog.pg_namespace.c.nspname == schema)
+ return query
+
+ @reflection.cache
+ def _load_enums(self, connection, schema=None, **kw):
+ if not self.supports_native_enum:
+ return []
- c = connection.execute(s)
+ result = connection.execute(self._enum_query(schema))
enums = []
- enum_by_name = {}
- for enum in c.fetchall():
- key = (enum.schema, enum.name)
- if key in enum_by_name:
- enum_by_name[key]["labels"].append(enum.label)
- else:
- enum_by_name[key] = enum_rec = {
- "name": enum.name,
- "schema": enum.schema,
- "visible": enum.visible,
- "labels": [],
+ for name, visible, schema, labels in result:
+ enums.append(
+ {
+ "name": name,
+ "schema": schema,
+ "visible": visible,
+ "labels": [] if labels is None else labels,
}
- if enum.label is not None:
- enum_rec["labels"].append(enum.label)
- enums.append(enum_rec)
+ )
return enums
- def _load_domains(self, connection):
- # Load data types for domains:
- SQL_DOMAINS = """
- SELECT t.typname as "name",
- pg_catalog.format_type(t.typbasetype, t.typtypmod) as "attype",
- not t.typnotnull as "nullable",
- t.typdefault as "default",
- pg_catalog.pg_type_is_visible(t.oid) as "visible",
- n.nspname as "schema"
- FROM pg_catalog.pg_type t
- LEFT JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace
- WHERE t.typtype = 'd'
- """
+ @util.memoized_property
+ def _domain_query(self):
+ return (
+ select(
+ pg_catalog.pg_type.c.typname.label("name"),
+ pg_catalog.format_type(
+ pg_catalog.pg_type.c.typbasetype,
+ pg_catalog.pg_type.c.typtypmod,
+ ).label("attype"),
+ (~pg_catalog.pg_type.c.typnotnull).label("nullable"),
+ pg_catalog.pg_type.c.typdefault.label("default"),
+ pg_catalog.pg_type_is_visible(pg_catalog.pg_type.c.oid).label(
+ "visible"
+ ),
+ pg_catalog.pg_namespace.c.nspname.label("schema"),
+ )
+ .join(
+ pg_catalog.pg_namespace,
+ pg_catalog.pg_namespace.c.oid
+ == pg_catalog.pg_type.c.typnamespace,
+ )
+ .where(pg_catalog.pg_type.c.typtype == "d")
+ )
- s = sql.text(SQL_DOMAINS)
- c = connection.execution_options(future_result=True).execute(s)
+ @reflection.cache
+ def _load_domains(self, connection, **kw):
+ # Load data types for domains:
+ result = connection.execute(self._domain_query)
domains = {}
- for domain in c.mappings():
+ for domain in result.mappings():
domain = domain
# strip (30) from character varying(30)
attype = re.search(r"([^\(]+)", domain["attype"]).group(1)
from .json import JSON
from .json import JSONB
from .json import JSONPathType
+from .pg_catalog import _SpaceVector
+from .pg_catalog import OIDVECTOR
from ... import exc
from ... import util
from ...engine import processors
render_bind_cast = True
+class _PGOIDVECTOR(_SpaceVector, OIDVECTOR):
+ pass
+
+
_server_side_id = util.counter()
sqltypes.BigInteger: _PGBigInteger,
sqltypes.Enum: _PGEnum,
sqltypes.ARRAY: _PGARRAY,
+ OIDVECTOR: _PGOIDVECTOR,
},
)
--- /dev/null
+# postgresql/pg_catalog.py
+# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+# mypy: ignore-errors
+
+from .array import ARRAY
+from .types import OID
+from .types import REGCLASS
+from ... import Column
+from ... import func
+from ... import MetaData
+from ... import Table
+from ...types import BigInteger
+from ...types import Boolean
+from ...types import CHAR
+from ...types import Float
+from ...types import Integer
+from ...types import SmallInteger
+from ...types import String
+from ...types import Text
+from ...types import TypeDecorator
+
+
+# types
+class NAME(TypeDecorator):
+ impl = String(64, collation="C")
+ cache_ok = True
+
+
+class PG_NODE_TREE(TypeDecorator):
+ impl = Text(collation="C")
+ cache_ok = True
+
+
+class INT2VECTOR(TypeDecorator):
+ impl = ARRAY(SmallInteger)
+ cache_ok = True
+
+
+class OIDVECTOR(TypeDecorator):
+ impl = ARRAY(OID)
+ cache_ok = True
+
+
+class _SpaceVector:
+ def result_processor(self, dialect, coltype):
+ def process(value):
+ if value is None:
+ return value
+ return [int(p) for p in value.split(" ")]
+
+ return process
+
+
+REGPROC = REGCLASS # seems an alias
+
+# functions
+_pg_cat = func.pg_catalog
+quote_ident = _pg_cat.quote_ident
+pg_table_is_visible = _pg_cat.pg_table_is_visible
+pg_type_is_visible = _pg_cat.pg_type_is_visible
+pg_get_viewdef = _pg_cat.pg_get_viewdef
+pg_get_serial_sequence = _pg_cat.pg_get_serial_sequence
+format_type = _pg_cat.format_type
+pg_get_expr = _pg_cat.pg_get_expr
+pg_get_constraintdef = _pg_cat.pg_get_constraintdef
+
+# constants
+RELKINDS_TABLE_NO_FOREIGN = ("r", "p")
+RELKINDS_TABLE = RELKINDS_TABLE_NO_FOREIGN + ("f",)
+RELKINDS_VIEW = ("v",)
+RELKINDS_MAT_VIEW = ("m",)
+RELKINDS_ALL_TABLE_LIKE = RELKINDS_TABLE + RELKINDS_VIEW + RELKINDS_MAT_VIEW
+
+# tables
+pg_catalog_meta = MetaData()
+
+pg_namespace = Table(
+ "pg_namespace",
+ pg_catalog_meta,
+ Column("oid", OID),
+ Column("nspname", NAME),
+ Column("nspowner", OID),
+ schema="pg_catalog",
+)
+
+pg_class = Table(
+ "pg_class",
+ pg_catalog_meta,
+ Column("oid", OID, info={"server_version": (9, 3)}),
+ Column("relname", NAME),
+ Column("relnamespace", OID),
+ Column("reltype", OID),
+ Column("reloftype", OID),
+ Column("relowner", OID),
+ Column("relam", OID),
+ Column("relfilenode", OID),
+ Column("reltablespace", OID),
+ Column("relpages", Integer),
+ Column("reltuples", Float),
+ Column("relallvisible", Integer, info={"server_version": (9, 2)}),
+ Column("reltoastrelid", OID),
+ Column("relhasindex", Boolean),
+ Column("relisshared", Boolean),
+ Column("relpersistence", CHAR, info={"server_version": (9, 1)}),
+ Column("relkind", CHAR),
+ Column("relnatts", SmallInteger),
+ Column("relchecks", SmallInteger),
+ Column("relhasrules", Boolean),
+ Column("relhastriggers", Boolean),
+ Column("relhassubclass", Boolean),
+ Column("relrowsecurity", Boolean),
+ Column("relforcerowsecurity", Boolean, info={"server_version": (9, 5)}),
+ Column("relispopulated", Boolean, info={"server_version": (9, 3)}),
+ Column("relreplident", CHAR, info={"server_version": (9, 4)}),
+ Column("relispartition", Boolean, info={"server_version": (10,)}),
+ Column("relrewrite", OID, info={"server_version": (11,)}),
+ Column("reloptions", ARRAY(Text)),
+ schema="pg_catalog",
+)
+
+pg_type = Table(
+ "pg_type",
+ pg_catalog_meta,
+ Column("oid", OID, info={"server_version": (9, 3)}),
+ Column("typname", NAME),
+ Column("typnamespace", OID),
+ Column("typowner", OID),
+ Column("typlen", SmallInteger),
+ Column("typbyval", Boolean),
+ Column("typtype", CHAR),
+ Column("typcategory", CHAR),
+ Column("typispreferred", Boolean),
+ Column("typisdefined", Boolean),
+ Column("typdelim", CHAR),
+ Column("typrelid", OID),
+ Column("typelem", OID),
+ Column("typarray", OID),
+ Column("typinput", REGPROC),
+ Column("typoutput", REGPROC),
+ Column("typreceive", REGPROC),
+ Column("typsend", REGPROC),
+ Column("typmodin", REGPROC),
+ Column("typmodout", REGPROC),
+ Column("typanalyze", REGPROC),
+ Column("typalign", CHAR),
+ Column("typstorage", CHAR),
+ Column("typnotnull", Boolean),
+ Column("typbasetype", OID),
+ Column("typtypmod", Integer),
+ Column("typndims", Integer),
+ Column("typcollation", OID, info={"server_version": (9, 1)}),
+ Column("typdefault", Text),
+ schema="pg_catalog",
+)
+
+pg_index = Table(
+ "pg_index",
+ pg_catalog_meta,
+ Column("indexrelid", OID),
+ Column("indrelid", OID),
+ Column("indnatts", SmallInteger),
+ Column("indnkeyatts", SmallInteger, info={"server_version": (11,)}),
+ Column("indisunique", Boolean),
+ Column("indisprimary", Boolean),
+ Column("indisexclusion", Boolean, info={"server_version": (9, 1)}),
+ Column("indimmediate", Boolean),
+ Column("indisclustered", Boolean),
+ Column("indisvalid", Boolean),
+ Column("indcheckxmin", Boolean),
+ Column("indisready", Boolean),
+ Column("indislive", Boolean, info={"server_version": (9, 3)}), # 9.3
+ Column("indisreplident", Boolean),
+ Column("indkey", INT2VECTOR),
+ Column("indcollation", OIDVECTOR, info={"server_version": (9, 1)}), # 9.1
+ Column("indclass", OIDVECTOR),
+ Column("indoption", INT2VECTOR),
+ Column("indexprs", PG_NODE_TREE),
+ Column("indpred", PG_NODE_TREE),
+ schema="pg_catalog",
+)
+
+pg_attribute = Table(
+ "pg_attribute",
+ pg_catalog_meta,
+ Column("attrelid", OID),
+ Column("attname", NAME),
+ Column("atttypid", OID),
+ Column("attstattarget", Integer),
+ Column("attlen", SmallInteger),
+ Column("attnum", SmallInteger),
+ Column("attndims", Integer),
+ Column("attcacheoff", Integer),
+ Column("atttypmod", Integer),
+ Column("attbyval", Boolean),
+ Column("attstorage", CHAR),
+ Column("attalign", CHAR),
+ Column("attnotnull", Boolean),
+ Column("atthasdef", Boolean),
+ Column("atthasmissing", Boolean, info={"server_version": (11,)}),
+ Column("attidentity", CHAR, info={"server_version": (10,)}),
+ Column("attgenerated", CHAR, info={"server_version": (12,)}),
+ Column("attisdropped", Boolean),
+ Column("attislocal", Boolean),
+ Column("attinhcount", Integer),
+ Column("attcollation", OID, info={"server_version": (9, 1)}),
+ schema="pg_catalog",
+)
+
+pg_constraint = Table(
+ "pg_constraint",
+ pg_catalog_meta,
+ Column("oid", OID), # 9.3
+ Column("conname", NAME),
+ Column("connamespace", OID),
+ Column("contype", CHAR),
+ Column("condeferrable", Boolean),
+ Column("condeferred", Boolean),
+ Column("convalidated", Boolean, info={"server_version": (9, 1)}),
+ Column("conrelid", OID),
+ Column("contypid", OID),
+ Column("conindid", OID),
+ Column("conparentid", OID, info={"server_version": (11,)}),
+ Column("confrelid", OID),
+ Column("confupdtype", CHAR),
+ Column("confdeltype", CHAR),
+ Column("confmatchtype", CHAR),
+ Column("conislocal", Boolean),
+ Column("coninhcount", Integer),
+ Column("connoinherit", Boolean, info={"server_version": (9, 2)}),
+ Column("conkey", ARRAY(SmallInteger)),
+ Column("confkey", ARRAY(SmallInteger)),
+ schema="pg_catalog",
+)
+
+pg_sequence = Table(
+ "pg_sequence",
+ pg_catalog_meta,
+ Column("seqrelid", OID),
+ Column("seqtypid", OID),
+ Column("seqstart", BigInteger),
+ Column("seqincrement", BigInteger),
+ Column("seqmax", BigInteger),
+ Column("seqmin", BigInteger),
+ Column("seqcache", BigInteger),
+ Column("seqcycle", Boolean),
+ schema="pg_catalog",
+ info={"server_version": (10,)},
+)
+
+pg_attrdef = Table(
+ "pg_attrdef",
+ pg_catalog_meta,
+ Column("oid", OID, info={"server_version": (9, 3)}),
+ Column("adrelid", OID),
+ Column("adnum", SmallInteger),
+ Column("adbin", PG_NODE_TREE),
+ schema="pg_catalog",
+)
+
+pg_description = Table(
+ "pg_description",
+ pg_catalog_meta,
+ Column("objoid", OID),
+ Column("classoid", OID),
+ Column("objsubid", Integer),
+ Column("description", Text(collation="C")),
+ schema="pg_catalog",
+)
+
+pg_enum = Table(
+ "pg_enum",
+ pg_catalog_meta,
+ Column("oid", OID, info={"server_version": (9, 3)}),
+ Column("enumtypid", OID),
+ Column("enumsortorder", Float(), info={"server_version": (9, 1)}),
+ Column("enumlabel", NAME),
+ schema="pg_catalog",
+)
+
+pg_am = Table(
+ "pg_am",
+ pg_catalog_meta,
+ Column("oid", OID, info={"server_version": (9, 3)}),
+ Column("amname", NAME),
+ Column("amhandler", REGPROC, info={"server_version": (9, 6)}),
+ Column("amtype", CHAR, info={"server_version": (9, 6)}),
+ schema="pg_catalog",
+)
--- /dev/null
+# Copyright (C) 2013-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+# mypy: ignore-errors
+
+import datetime as dt
+from typing import Any
+
+from ... import schema
+from ... import util
+from ...sql import sqltypes
+from ...sql.ddl import InvokeDDLBase
+
+
+_DECIMAL_TYPES = (1231, 1700)
+_FLOAT_TYPES = (700, 701, 1021, 1022)
+_INT_TYPES = (20, 21, 23, 26, 1005, 1007, 1016)
+
+
+class PGUuid(sqltypes.UUID):
+ render_bind_cast = True
+ render_literal_cast = True
+
+
+class BYTEA(sqltypes.LargeBinary[bytes]):
+ __visit_name__ = "BYTEA"
+
+
+class INET(sqltypes.TypeEngine[str]):
+ __visit_name__ = "INET"
+
+
+PGInet = INET
+
+
+class CIDR(sqltypes.TypeEngine[str]):
+ __visit_name__ = "CIDR"
+
+
+PGCidr = CIDR
+
+
+class MACADDR(sqltypes.TypeEngine[str]):
+ __visit_name__ = "MACADDR"
+
+
+PGMacAddr = MACADDR
+
+
+class MONEY(sqltypes.TypeEngine[str]):
+
+ r"""Provide the PostgreSQL MONEY type.
+
+ Depending on driver, result rows using this type may return a
+ string value which includes currency symbols.
+
+ For this reason, it may be preferable to provide conversion to a
+ numerically-based currency datatype using :class:`_types.TypeDecorator`::
+
+ import re
+ import decimal
+ from sqlalchemy import TypeDecorator
+
+ class NumericMoney(TypeDecorator):
+ impl = MONEY
+
+ def process_result_value(self, value: Any, dialect: Any) -> None:
+ if value is not None:
+ # adjust this for the currency and numeric
+ m = re.match(r"\$([\d.]+)", value)
+ if m:
+ value = decimal.Decimal(m.group(1))
+ return value
+
+ Alternatively, the conversion may be applied as a CAST using
+ the :meth:`_types.TypeDecorator.column_expression` method as follows::
+
+ import decimal
+ from sqlalchemy import cast
+ from sqlalchemy import TypeDecorator
+
+ class NumericMoney(TypeDecorator):
+ impl = MONEY
+
+ def column_expression(self, column: Any):
+ return cast(column, Numeric())
+
+ .. versionadded:: 1.2
+
+ """
+
+ __visit_name__ = "MONEY"
+
+
+class OID(sqltypes.TypeEngine[int]):
+
+ """Provide the PostgreSQL OID type.
+
+ .. versionadded:: 0.9.5
+
+ """
+
+ __visit_name__ = "OID"
+
+
+class REGCLASS(sqltypes.TypeEngine[str]):
+
+ """Provide the PostgreSQL REGCLASS type.
+
+ .. versionadded:: 1.2.7
+
+ """
+
+ __visit_name__ = "REGCLASS"
+
+
+class TIMESTAMP(sqltypes.TIMESTAMP):
+ def __init__(self, timezone=False, precision=None):
+ super(TIMESTAMP, self).__init__(timezone=timezone)
+ self.precision = precision
+
+
+class TIME(sqltypes.TIME):
+ def __init__(self, timezone=False, precision=None):
+ super(TIME, self).__init__(timezone=timezone)
+ self.precision = precision
+
+
+class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval):
+
+ """PostgreSQL INTERVAL type."""
+
+ __visit_name__ = "INTERVAL"
+ native = True
+
+ def __init__(self, precision=None, fields=None):
+ """Construct an INTERVAL.
+
+ :param precision: optional integer precision value
+ :param fields: string fields specifier. allows storage of fields
+ to be limited, such as ``"YEAR"``, ``"MONTH"``, ``"DAY TO HOUR"``,
+ etc.
+
+ .. versionadded:: 1.2
+
+ """
+ self.precision = precision
+ self.fields = fields
+
+ @classmethod
+ def adapt_emulated_to_native(cls, interval, **kw):
+ return INTERVAL(precision=interval.second_precision)
+
+ @property
+ def _type_affinity(self):
+ return sqltypes.Interval
+
+ def as_generic(self, allow_nulltype=False):
+ return sqltypes.Interval(native=True, second_precision=self.precision)
+
+ @property
+ def python_type(self):
+ return dt.timedelta
+
+
+PGInterval = INTERVAL
+
+
+class BIT(sqltypes.TypeEngine[int]):
+ __visit_name__ = "BIT"
+
+ def __init__(self, length=None, varying=False):
+ if not varying:
+ # BIT without VARYING defaults to length 1
+ self.length = length or 1
+ else:
+ # but BIT VARYING can be unlimited-length, so no default
+ self.length = length
+ self.varying = varying
+
+
+PGBit = BIT
+
+
+class TSVECTOR(sqltypes.TypeEngine[Any]):
+
+ """The :class:`_postgresql.TSVECTOR` type implements the PostgreSQL
+ text search type TSVECTOR.
+
+ It can be used to do full text queries on natural language
+ documents.
+
+ .. versionadded:: 0.9.0
+
+ .. seealso::
+
+ :ref:`postgresql_match`
+
+ """
+
+ __visit_name__ = "TSVECTOR"
+
+
+class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum):
+
+ """PostgreSQL ENUM type.
+
+ This is a subclass of :class:`_types.Enum` which includes
+ support for PG's ``CREATE TYPE`` and ``DROP TYPE``.
+
+ When the builtin type :class:`_types.Enum` is used and the
+ :paramref:`.Enum.native_enum` flag is left at its default of
+ True, the PostgreSQL backend will use a :class:`_postgresql.ENUM`
+ type as the implementation, so the special create/drop rules
+ will be used.
+
+ The create/drop behavior of ENUM is necessarily intricate, due to the
+ awkward relationship the ENUM type has in relationship to the
+ parent table, in that it may be "owned" by just a single table, or
+ may be shared among many tables.
+
+ When using :class:`_types.Enum` or :class:`_postgresql.ENUM`
+ in an "inline" fashion, the ``CREATE TYPE`` and ``DROP TYPE`` is emitted
+ corresponding to when the :meth:`_schema.Table.create` and
+ :meth:`_schema.Table.drop`
+ methods are called::
+
+ table = Table('sometable', metadata,
+ Column('some_enum', ENUM('a', 'b', 'c', name='myenum'))
+ )
+
+ table.create(engine) # will emit CREATE ENUM and CREATE TABLE
+ table.drop(engine) # will emit DROP TABLE and DROP ENUM
+
+ To use a common enumerated type between multiple tables, the best
+ practice is to declare the :class:`_types.Enum` or
+ :class:`_postgresql.ENUM` independently, and associate it with the
+ :class:`_schema.MetaData` object itself::
+
+ my_enum = ENUM('a', 'b', 'c', name='myenum', metadata=metadata)
+
+ t1 = Table('sometable_one', metadata,
+ Column('some_enum', myenum)
+ )
+
+ t2 = Table('sometable_two', metadata,
+ Column('some_enum', myenum)
+ )
+
+ When this pattern is used, care must still be taken at the level
+ of individual table creates. Emitting CREATE TABLE without also
+ specifying ``checkfirst=True`` will still cause issues::
+
+ t1.create(engine) # will fail: no such type 'myenum'
+
+ If we specify ``checkfirst=True``, the individual table-level create
+ operation will check for the ``ENUM`` and create if not exists::
+
+ # will check if enum exists, and emit CREATE TYPE if not
+ t1.create(engine, checkfirst=True)
+
+ When using a metadata-level ENUM type, the type will always be created
+ and dropped if either the metadata-wide create/drop is called::
+
+ metadata.create_all(engine) # will emit CREATE TYPE
+ metadata.drop_all(engine) # will emit DROP TYPE
+
+ The type can also be created and dropped directly::
+
+ my_enum.create(engine)
+ my_enum.drop(engine)
+
+ .. versionchanged:: 1.0.0 The PostgreSQL :class:`_postgresql.ENUM` type
+ now behaves more strictly with regards to CREATE/DROP. A metadata-level
+ ENUM type will only be created and dropped at the metadata level,
+ not the table level, with the exception of
+ ``table.create(checkfirst=True)``.
+ The ``table.drop()`` call will now emit a DROP TYPE for a table-level
+ enumerated type.
+
+ """
+
+ native_enum = True
+
+ def __init__(self, *enums, **kw):
+ """Construct an :class:`_postgresql.ENUM`.
+
+ Arguments are the same as that of
+ :class:`_types.Enum`, but also including
+ the following parameters.
+
+ :param create_type: Defaults to True.
+ Indicates that ``CREATE TYPE`` should be
+ emitted, after optionally checking for the
+ presence of the type, when the parent
+ table is being created; and additionally
+ that ``DROP TYPE`` is called when the table
+ is dropped. When ``False``, no check
+ will be performed and no ``CREATE TYPE``
+ or ``DROP TYPE`` is emitted, unless
+ :meth:`~.postgresql.ENUM.create`
+ or :meth:`~.postgresql.ENUM.drop`
+ are called directly.
+ Setting to ``False`` is helpful
+ when invoking a creation scheme to a SQL file
+ without access to the actual database -
+ the :meth:`~.postgresql.ENUM.create` and
+ :meth:`~.postgresql.ENUM.drop` methods can
+ be used to emit SQL to a target bind.
+
+ """
+ native_enum = kw.pop("native_enum", None)
+ if native_enum is False:
+ util.warn(
+ "the native_enum flag does not apply to the "
+ "sqlalchemy.dialects.postgresql.ENUM datatype; this type "
+ "always refers to ENUM. Use sqlalchemy.types.Enum for "
+ "non-native enum."
+ )
+ self.create_type = kw.pop("create_type", True)
+ super(ENUM, self).__init__(*enums, **kw)
+
+ @classmethod
+ def adapt_emulated_to_native(cls, impl, **kw):
+ """Produce a PostgreSQL native :class:`_postgresql.ENUM` from plain
+ :class:`.Enum`.
+
+ """
+ kw.setdefault("validate_strings", impl.validate_strings)
+ kw.setdefault("name", impl.name)
+ kw.setdefault("schema", impl.schema)
+ kw.setdefault("inherit_schema", impl.inherit_schema)
+ kw.setdefault("metadata", impl.metadata)
+ kw.setdefault("_create_events", False)
+ kw.setdefault("values_callable", impl.values_callable)
+ kw.setdefault("omit_aliases", impl._omit_aliases)
+ return cls(**kw)
+
+ def create(self, bind=None, checkfirst=True):
+ """Emit ``CREATE TYPE`` for this
+ :class:`_postgresql.ENUM`.
+
+ If the underlying dialect does not support
+ PostgreSQL CREATE TYPE, no action is taken.
+
+ :param bind: a connectable :class:`_engine.Engine`,
+ :class:`_engine.Connection`, or similar object to emit
+ SQL.
+ :param checkfirst: if ``True``, a query against
+ the PG catalog will be first performed to see
+ if the type does not exist already before
+ creating.
+
+ """
+ if not bind.dialect.supports_native_enum:
+ return
+
+ bind._run_ddl_visitor(self.EnumGenerator, self, checkfirst=checkfirst)
+
+ def drop(self, bind=None, checkfirst=True):
+ """Emit ``DROP TYPE`` for this
+ :class:`_postgresql.ENUM`.
+
+ If the underlying dialect does not support
+ PostgreSQL DROP TYPE, no action is taken.
+
+ :param bind: a connectable :class:`_engine.Engine`,
+ :class:`_engine.Connection`, or similar object to emit
+ SQL.
+ :param checkfirst: if ``True``, a query against
+ the PG catalog will be first performed to see
+ if the type actually exists before dropping.
+
+ """
+ if not bind.dialect.supports_native_enum:
+ return
+
+ bind._run_ddl_visitor(self.EnumDropper, self, checkfirst=checkfirst)
+
+ class EnumGenerator(InvokeDDLBase):
+ def __init__(self, dialect, connection, checkfirst=False, **kwargs):
+ super(ENUM.EnumGenerator, self).__init__(connection, **kwargs)
+ self.checkfirst = checkfirst
+
+ def _can_create_enum(self, enum):
+ if not self.checkfirst:
+ return True
+
+ effective_schema = self.connection.schema_for_object(enum)
+
+ return not self.connection.dialect.has_type(
+ self.connection, enum.name, schema=effective_schema
+ )
+
+ def visit_enum(self, enum):
+ if not self._can_create_enum(enum):
+ return
+
+ self.connection.execute(CreateEnumType(enum))
+
+ class EnumDropper(InvokeDDLBase):
+ def __init__(self, dialect, connection, checkfirst=False, **kwargs):
+ super(ENUM.EnumDropper, self).__init__(connection, **kwargs)
+ self.checkfirst = checkfirst
+
+ def _can_drop_enum(self, enum):
+ if not self.checkfirst:
+ return True
+
+ effective_schema = self.connection.schema_for_object(enum)
+
+ return self.connection.dialect.has_type(
+ self.connection, enum.name, schema=effective_schema
+ )
+
+ def visit_enum(self, enum):
+ if not self._can_drop_enum(enum):
+ return
+
+ self.connection.execute(DropEnumType(enum))
+
+ def get_dbapi_type(self, dbapi):
+ """dont return dbapi.STRING for ENUM in PostgreSQL, since that's
+ a different type"""
+
+ return None
+
+ def _check_for_name_in_memos(self, checkfirst, kw):
+ """Look in the 'ddl runner' for 'memos', then
+ note our name in that collection.
+
+ This to ensure a particular named enum is operated
+ upon only once within any kind of create/drop
+ sequence without relying upon "checkfirst".
+
+ """
+ if not self.create_type:
+ return True
+ if "_ddl_runner" in kw:
+ ddl_runner = kw["_ddl_runner"]
+ if "_pg_enums" in ddl_runner.memo:
+ pg_enums = ddl_runner.memo["_pg_enums"]
+ else:
+ pg_enums = ddl_runner.memo["_pg_enums"] = set()
+ present = (self.schema, self.name) in pg_enums
+ pg_enums.add((self.schema, self.name))
+ return present
+ else:
+ return False
+
+ def _on_table_create(self, target, bind, checkfirst=False, **kw):
+ if (
+ checkfirst
+ or (
+ not self.metadata
+ and not kw.get("_is_metadata_operation", False)
+ )
+ ) and not self._check_for_name_in_memos(checkfirst, kw):
+ self.create(bind=bind, checkfirst=checkfirst)
+
+ def _on_table_drop(self, target, bind, checkfirst=False, **kw):
+ if (
+ not self.metadata
+ and not kw.get("_is_metadata_operation", False)
+ and not self._check_for_name_in_memos(checkfirst, kw)
+ ):
+ self.drop(bind=bind, checkfirst=checkfirst)
+
+ def _on_metadata_create(self, target, bind, checkfirst=False, **kw):
+ if not self._check_for_name_in_memos(checkfirst, kw):
+ self.create(bind=bind, checkfirst=checkfirst)
+
+ def _on_metadata_drop(self, target, bind, checkfirst=False, **kw):
+ if not self._check_for_name_in_memos(checkfirst, kw):
+ self.drop(bind=bind, checkfirst=checkfirst)
+
+
+class CreateEnumType(schema._CreateDropBase):
+ __visit_name__ = "create_enum_type"
+
+
+class DropEnumType(schema._CreateDropBase):
+ __visit_name__ = "drop_enum_type"
from ...engine import default
from ...engine import processors
from ...engine import reflection
+from ...engine.reflection import ReflectionDefaults
from ...sql import coercions
from ...sql import ColumnElement
from ...sql import compiler
return [db[1] for db in dl if db[1] != "temp"]
- @reflection.cache
- def get_table_names(self, connection, schema=None, **kw):
+ def _format_schema(self, schema, table_name):
if schema is not None:
qschema = self.identifier_preparer.quote_identifier(schema)
- master = "%s.sqlite_master" % qschema
+ name = f"{qschema}.{table_name}"
else:
- master = "sqlite_master"
- s = ("SELECT name FROM %s " "WHERE type='table' ORDER BY name") % (
- master,
- )
- rs = connection.exec_driver_sql(s)
- return [row[0] for row in rs]
+ name = table_name
+ return name
@reflection.cache
- def get_temp_table_names(self, connection, **kw):
- s = (
- "SELECT name FROM sqlite_temp_master "
- "WHERE type='table' ORDER BY name "
- )
- rs = connection.exec_driver_sql(s)
+ def get_table_names(self, connection, schema=None, **kw):
+ main = self._format_schema(schema, "sqlite_master")
+ s = f"SELECT name FROM {main} WHERE type='table' ORDER BY name"
+ names = connection.exec_driver_sql(s).scalars().all()
+ return names
- return [row[0] for row in rs]
+ @reflection.cache
+ def get_temp_table_names(self, connection, **kw):
+ main = "sqlite_temp_master"
+ s = f"SELECT name FROM {main} WHERE type='table' ORDER BY name"
+ names = connection.exec_driver_sql(s).scalars().all()
+ return names
@reflection.cache
def get_temp_view_names(self, connection, **kw):
"SELECT name FROM sqlite_temp_master "
"WHERE type='view' ORDER BY name "
)
- rs = connection.exec_driver_sql(s)
-
- return [row[0] for row in rs]
+ names = connection.exec_driver_sql(s).scalars().all()
+ return names
- def has_table(self, connection, table_name, schema=None):
+ @reflection.cache
+ def has_table(self, connection, table_name, schema=None, **kw):
self._ensure_has_table_connection(connection)
info = self._get_table_pragma(
@reflection.cache
def get_view_names(self, connection, schema=None, **kw):
- if schema is not None:
- qschema = self.identifier_preparer.quote_identifier(schema)
- master = "%s.sqlite_master" % qschema
- else:
- master = "sqlite_master"
- s = ("SELECT name FROM %s " "WHERE type='view' ORDER BY name") % (
- master,
- )
- rs = connection.exec_driver_sql(s)
-
- return [row[0] for row in rs]
+ main = self._format_schema(schema, "sqlite_master")
+ s = f"SELECT name FROM {main} WHERE type='view' ORDER BY name"
+ names = connection.exec_driver_sql(s).scalars().all()
+ return names
@reflection.cache
def get_view_definition(self, connection, view_name, schema=None, **kw):
if schema is not None:
qschema = self.identifier_preparer.quote_identifier(schema)
- master = "%s.sqlite_master" % qschema
+ master = f"{qschema}.sqlite_master"
s = ("SELECT sql FROM %s WHERE name = ? AND type='view'") % (
master,
)
result = rs.fetchall()
if result:
return result[0].sql
+ else:
+ raise exc.NoSuchTableError(
+ f"{schema}.{view_name}" if schema else view_name
+ )
@reflection.cache
def get_columns(self, connection, table_name, schema=None, **kw):
tablesql,
)
)
- return columns
+ if columns:
+ return columns
+ elif not self.has_table(connection, table_name, schema):
+ raise exc.NoSuchTableError(
+ f"{schema}.{table_name}" if schema else table_name
+ )
+ else:
+ return ReflectionDefaults.columns()
def _get_column_info(
self,
"type": coltype,
"nullable": nullable,
"default": default,
- "autoincrement": "auto",
"primary_key": primary_key,
}
if generated:
constraint_name = result.group(1) if result else None
cols = self.get_columns(connection, table_name, schema, **kw)
+ # consider only pk columns. This also avoids sorting the cached
+ # value returned by get_columns
+ cols = [col for col in cols if col.get("primary_key", 0) > 0]
cols.sort(key=lambda col: col.get("primary_key"))
- pkeys = []
- for col in cols:
- if col["primary_key"]:
- pkeys.append(col["name"])
+ pkeys = [col["name"] for col in cols]
- return {"constrained_columns": pkeys, "name": constraint_name}
+ if pkeys:
+ return {"constrained_columns": pkeys, "name": constraint_name}
+ else:
+ return ReflectionDefaults.pk_constraint()
@reflection.cache
def get_foreign_keys(self, connection, table_name, schema=None, **kw):
# original DDL. The referred columns of the foreign key
# constraint are therefore the primary key of the referred
# table.
- referred_pk = self.get_pk_constraint(
- connection, rtbl, schema=schema, **kw
- )
- # note that if table doesn't exist, we still get back a record,
- # just it has no columns in it
- referred_columns = referred_pk["constrained_columns"]
+ try:
+ referred_pk = self.get_pk_constraint(
+ connection, rtbl, schema=schema, **kw
+ )
+ referred_columns = referred_pk["constrained_columns"]
+ except exc.NoSuchTableError:
+ # ignore not existing parents
+ referred_columns = []
else:
# note we use this list only if this is the first column
# in the constraint. for subsequent columns we ignore the
)
table_data = self._get_table_sql(connection, table_name, schema=schema)
- if table_data is None:
- # system tables, etc.
- return []
def parse_fks():
+ if table_data is None:
+ # system tables, etc.
+ return
FK_PATTERN = (
r"(?:CONSTRAINT (\w+) +)?"
r"FOREIGN KEY *\( *(.+?) *\) +"
# use them as is as it's extremely difficult to parse inline
# constraints
fkeys.extend(keys_by_signature.values())
- return fkeys
+ if fkeys:
+ return fkeys
+ else:
+ return ReflectionDefaults.foreign_keys()
def _find_cols_in_sig(self, sig):
for match in re.finditer(r'(?:"(.+?)")|([a-z0-9_]+)', sig, re.I):
table_data = self._get_table_sql(
connection, table_name, schema=schema, **kw
)
- if not table_data:
- return []
-
unique_constraints = []
def parse_uqs():
+ if table_data is None:
+ return
UNIQUE_PATTERN = r'(?:CONSTRAINT "?(.+?)"? +)?UNIQUE *\((.+?)\)'
INLINE_UNIQUE_PATTERN = (
r'(?:(".+?")|(?:[\[`])?([a-z0-9_]+)(?:[\]`])?) '
unique_constraints.append(parsed_constraint)
# NOTE: auto_index_by_sig might not be empty here,
# the PRIMARY KEY may have an entry.
- return unique_constraints
+ if unique_constraints:
+ return unique_constraints
+ else:
+ return ReflectionDefaults.unique_constraints()
@reflection.cache
def get_check_constraints(self, connection, table_name, schema=None, **kw):
table_data = self._get_table_sql(
connection, table_name, schema=schema, **kw
)
- if not table_data:
- return []
CHECK_PATTERN = r"(?:CONSTRAINT (.+) +)?" r"CHECK *\( *(.+) *\),? *"
check_constraints = []
# necessarily makes assumptions as to how the CREATE TABLE
# was emitted.
- for match in re.finditer(CHECK_PATTERN, table_data, re.I):
+ for match in re.finditer(CHECK_PATTERN, table_data or "", re.I):
name = match.group(1)
if name:
check_constraints.append({"sqltext": match.group(2), "name": name})
- return check_constraints
+ if check_constraints:
+ return check_constraints
+ else:
+ return ReflectionDefaults.check_constraints()
@reflection.cache
def get_indexes(self, connection, table_name, schema=None, **kw):
# loop thru unique indexes to get the column names.
for idx in list(indexes):
pragma_index = self._get_table_pragma(
- connection, "index_info", idx["name"]
+ connection, "index_info", idx["name"], schema=schema
)
for row in pragma_index:
break
else:
idx["column_names"].append(row[2])
- return indexes
+ indexes.sort(key=lambda d: d["name"] or "~") # sort None as last
+ if indexes:
+ return indexes
+ elif not self.has_table(connection, table_name, schema):
+ raise exc.NoSuchTableError(
+ f"{schema}.{table_name}" if schema else table_name
+ )
+ else:
+ return ReflectionDefaults.indexes()
+
+ def _is_sys_table(self, table_name):
+ return table_name in {
+ "sqlite_schema",
+ "sqlite_master",
+ "sqlite_temp_schema",
+ "sqlite_temp_master",
+ }
@reflection.cache
def _get_table_sql(self, connection, table_name, schema=None, **kw):
" (SELECT * FROM %(schema)ssqlite_master UNION ALL "
" SELECT * FROM %(schema)ssqlite_temp_master) "
"WHERE name = ? "
- "AND type = 'table'" % {"schema": schema_expr}
+ "AND type in ('table', 'view')" % {"schema": schema_expr}
)
rs = connection.exec_driver_sql(s, (table_name,))
except exc.DBAPIError:
s = (
"SELECT sql FROM %(schema)ssqlite_master "
"WHERE name = ? "
- "AND type = 'table'" % {"schema": schema_expr}
+ "AND type in ('table', 'view')" % {"schema": schema_expr}
)
rs = connection.exec_driver_sql(s, (table_name,))
- return rs.scalar()
+ value = rs.scalar()
+ if value is None and not self._is_sys_table(table_name):
+ raise exc.NoSuchTableError(f"{schema_expr}{table_name}")
+ return value
def _get_table_pragma(self, connection, pragma, table_name, schema=None):
quote = self.identifier_preparer.quote_identifier
if schema is not None:
- statements = ["PRAGMA %s." % quote(schema)]
+ statements = [f"PRAGMA {quote(schema)}."]
else:
# because PRAGMA looks in all attached databases if no schema
# given, need to specify "main" schema, however since we want
qtable = quote(table_name)
for statement in statements:
- statement = "%s%s(%s)" % (statement, pragma, qtable)
+ statement = f"{statement}{pragma}({qtable})"
cursor = connection.exec_driver_sql(statement)
if not cursor._soft_closed:
# work around SQLite issue whereby cursor.description
from .interfaces import TypeCompiler as TypeCompiler
from .mock import create_mock_engine as create_mock_engine
from .reflection import Inspector as Inspector
+from .reflection import ObjectKind as ObjectKind
+from .reflection import ObjectScope as ObjectScope
from .result import ChunkedIteratorResult as ChunkedIteratorResult
from .result import FrozenResult as FrozenResult
from .result import IteratorResult as IteratorResult
from .interfaces import DBAPICursor
from .interfaces import Dialect
from .interfaces import ExecutionContext
+from .reflection import ObjectKind
+from .reflection import ObjectScope
from .. import event
from .. import exc
from .. import pool
"""
return type_api.adapt_type(typeobj, self.colspecs)
- def has_index(self, connection, table_name, index_name, schema=None):
- if not self.has_table(connection, table_name, schema=schema):
+ def has_index(self, connection, table_name, index_name, schema=None, **kw):
+ if not self.has_table(connection, table_name, schema=schema, **kw):
return False
- for idx in self.get_indexes(connection, table_name, schema=schema):
+ for idx in self.get_indexes(
+ connection, table_name, schema=schema, **kw
+ ):
if idx["name"] == index_name:
return True
else:
return False
+ def has_schema(
+ self, connection: Connection, schema_name: str, **kw: Any
+ ) -> bool:
+ return schema_name in self.get_schema_names(connection, **kw)
+
def validate_identifier(self, ident):
if len(ident) > self.max_identifier_length:
raise exc.IdentifierError(
def get_driver_connection(self, connection):
return connection
+ def _overrides_default(self, method):
+ return (
+ getattr(type(self), method).__code__
+ is not getattr(DefaultDialect, method).__code__
+ )
+
+ def _default_multi_reflect(
+ self,
+ single_tbl_method,
+ connection,
+ kind,
+ schema,
+ filter_names,
+ scope,
+ **kw,
+ ):
+
+ names_fns = []
+ temp_names_fns = []
+ if ObjectKind.TABLE in kind:
+ names_fns.append(self.get_table_names)
+ temp_names_fns.append(self.get_temp_table_names)
+ if ObjectKind.VIEW in kind:
+ names_fns.append(self.get_view_names)
+ temp_names_fns.append(self.get_temp_view_names)
+ if ObjectKind.MATERIALIZED_VIEW in kind:
+ names_fns.append(self.get_materialized_view_names)
+ # no temp materialized view at the moment
+ # temp_names_fns.append(self.get_temp_materialized_view_names)
+
+ unreflectable = kw.pop("unreflectable", {})
+
+ if (
+ filter_names
+ and scope is ObjectScope.ANY
+ and kind is ObjectKind.ANY
+ ):
+ # if names are given and no qualification on type of table
+ # (i.e. the Table(..., autoload) case), take the names as given,
+ # don't run names queries. If a table does not exit
+ # NoSuchTableError is raised and it's skipped
+
+ # this also suits the case for mssql where we can reflect
+ # individual temp tables but there's no temp_names_fn
+ names = filter_names
+ else:
+ names = []
+ name_kw = {"schema": schema, **kw}
+ fns = []
+ if ObjectScope.DEFAULT in scope:
+ fns.extend(names_fns)
+ if ObjectScope.TEMPORARY in scope:
+ fns.extend(temp_names_fns)
+
+ for fn in fns:
+ try:
+ names.extend(fn(connection, **name_kw))
+ except NotImplementedError:
+ pass
+
+ if filter_names:
+ filter_names = set(filter_names)
+
+ # iterate over all the tables/views and call the single table method
+ for table in names:
+ if not filter_names or table in filter_names:
+ key = (schema, table)
+ try:
+ yield (
+ key,
+ single_tbl_method(
+ connection, table, schema=schema, **kw
+ ),
+ )
+ except exc.UnreflectableTableError as err:
+ if key not in unreflectable:
+ unreflectable[key] = err
+ except exc.NoSuchTableError:
+ pass
+
+ def get_multi_table_options(self, connection, **kw):
+ return self._default_multi_reflect(
+ self.get_table_options, connection, **kw
+ )
+
+ def get_multi_columns(self, connection, **kw):
+ return self._default_multi_reflect(self.get_columns, connection, **kw)
+
+ def get_multi_pk_constraint(self, connection, **kw):
+ return self._default_multi_reflect(
+ self.get_pk_constraint, connection, **kw
+ )
+
+ def get_multi_foreign_keys(self, connection, **kw):
+ return self._default_multi_reflect(
+ self.get_foreign_keys, connection, **kw
+ )
+
+ def get_multi_indexes(self, connection, **kw):
+ return self._default_multi_reflect(self.get_indexes, connection, **kw)
+
+ def get_multi_unique_constraints(self, connection, **kw):
+ return self._default_multi_reflect(
+ self.get_unique_constraints, connection, **kw
+ )
+
+ def get_multi_check_constraints(self, connection, **kw):
+ return self._default_multi_reflect(
+ self.get_check_constraints, connection, **kw
+ )
+
+ def get_multi_table_comment(self, connection, **kw):
+ return self._default_multi_reflect(
+ self.get_table_comment, connection, **kw
+ )
+
class StrCompileDialect(DefaultDialect):
from typing import Awaitable
from typing import Callable
from typing import ClassVar
+from typing import Collection
from typing import Dict
+from typing import Iterable
from typing import List
from typing import Mapping
from typing import MutableMapping
nullable: bool
"""column nullability"""
- default: str
+ default: Optional[str]
"""column default expression as a SQL string"""
autoincrement: NotRequired[bool]
comment: NotRequired[Optional[str]]
"""comment for the column, if present"""
- computed: NotRequired[Optional[ReflectedComputed]]
+ computed: NotRequired[ReflectedComputed]
"""indicates this column is computed at insert (possibly update) time by
the database."""
- identity: NotRequired[Optional[ReflectedIdentity]]
+ identity: NotRequired[ReflectedIdentity]
"""indicates this column is an IDENTITY column"""
dialect_options: NotRequired[Dict[str, Any]]
column_names: List[str]
"""column names which comprise the constraint"""
+ duplicates_index: NotRequired[Optional[str]]
+ "Indicates if this unique constraint duplicates an index with this name"
+
dialect_options: NotRequired[Dict[str, Any]]
"""Additional dialect-specific options detected for this reflected
object"""
referred_columns: List[str]
"""referenced column names"""
- dialect_options: NotRequired[Dict[str, Any]]
+ options: NotRequired[Dict[str, Any]]
"""Additional dialect-specific options detected for this reflected
object"""
unique: bool
"""whether or not the index has a unique flag"""
- duplicates_constraint: NotRequired[bool]
- """boolean indicating this index mirrors a unique constraint of the same
- name"""
+ duplicates_constraint: NotRequired[Optional[str]]
+ "Indicates if this index mirrors a unique constraint with this name"
include_columns: NotRequired[List[str]]
"""columns to include in the INCLUDE clause for supporting databases.
.. deprecated:: 2.0
Legacy value, will be replaced with
- ``d["dialect_options"][<dialect name>]["include"]``
+ ``d["dialect_options"]["<dialect name>_include"]``
"""
"""
- text: str
+ text: Optional[str]
"""text of the comment"""
VersionInfoType = Tuple[Union[int, str], ...]
+TableKey = Tuple[Optional[str], str]
class Dialect(EventTarget):
raise NotImplementedError()
- def initialize(self, connection: "Connection") -> None:
+ def initialize(self, connection: Connection) -> None:
"""Called during strategized creation of the dialect with a
connection.
pass
+ if TYPE_CHECKING:
+
+ def _overrides_default(self, method_name: str) -> bool:
+ ...
+
def get_columns(
self,
- connection: "Connection",
+ connection: Connection,
table_name: str,
schema: Optional[str] = None,
**kw: Any,
information as a list of dictionaries
corresponding to the :class:`.ReflectedColumn` dictionary.
+ This is an internal dialect method. Applications should use
+ :meth:`.Inspector.get_columns`.
+ """
+
+ def get_multi_columns(
+ self,
+ connection: Connection,
+ schema: Optional[str] = None,
+ filter_names: Optional[Collection[str]] = None,
+ **kw: Any,
+ ) -> Iterable[Tuple[TableKey, List[ReflectedColumn]]]:
+ """Return information about columns in all tables in the
+ given ``schema``.
+
+ This is an internal dialect method. Applications should use
+ :meth:`.Inspector.get_multi_columns`.
+
+ .. note:: The :class:`_engine.DefaultDialect` provides a default
+ implementation that will call the single table method for
+ each object returned by :meth:`Dialect.get_table_names`,
+ :meth:`Dialect.get_view_names` or
+ :meth:`Dialect.get_materialized_view_names` depending on the
+ provided ``kind``. Dialects that want to support a faster
+ implementation should implement this method.
+
+ .. versionadded:: 2.0
+
"""
raise NotImplementedError()
def get_pk_constraint(
self,
- connection: "Connection",
+ connection: Connection,
table_name: str,
schema: Optional[str] = None,
**kw: Any,
key information as a dictionary corresponding to the
:class:`.ReflectedPrimaryKeyConstraint` dictionary.
+ This is an internal dialect method. Applications should use
+ :meth:`.Inspector.get_pk_constraint`.
+
+ """
+ raise NotImplementedError()
+
+ def get_multi_pk_constraint(
+ self,
+ connection: Connection,
+ schema: Optional[str] = None,
+ filter_names: Optional[Collection[str]] = None,
+ **kw: Any,
+ ) -> Iterable[Tuple[TableKey, ReflectedPrimaryKeyConstraint]]:
+ """Return information about primary key constraints in
+ all tables in the given ``schema``.
+
+ This is an internal dialect method. Applications should use
+ :meth:`.Inspector.get_multi_pk_constraint`.
+
+ .. note:: The :class:`_engine.DefaultDialect` provides a default
+ implementation that will call the single table method for
+ each object returned by :meth:`Dialect.get_table_names`,
+ :meth:`Dialect.get_view_names` or
+ :meth:`Dialect.get_materialized_view_names` depending on the
+ provided ``kind``. Dialects that want to support a faster
+ implementation should implement this method.
+
+ .. versionadded:: 2.0
"""
raise NotImplementedError()
def get_foreign_keys(
self,
- connection: "Connection",
+ connection: Connection,
table_name: str,
schema: Optional[str] = None,
**kw: Any,
key information as a list of dicts corresponding to the
:class:`.ReflectedForeignKeyConstraint` dictionary.
+ This is an internal dialect method. Applications should use
+ :meth:`_engine.Inspector.get_foreign_keys`.
+ """
+
+ raise NotImplementedError()
+
+ def get_multi_foreign_keys(
+ self,
+ connection: Connection,
+ schema: Optional[str] = None,
+ filter_names: Optional[Collection[str]] = None,
+ **kw: Any,
+ ) -> Iterable[Tuple[TableKey, List[ReflectedForeignKeyConstraint]]]:
+ """Return information about foreign_keys in all tables
+ in the given ``schema``.
+
+ This is an internal dialect method. Applications should use
+ :meth:`_engine.Inspector.get_multi_foreign_keys`.
+
+ .. note:: The :class:`_engine.DefaultDialect` provides a default
+ implementation that will call the single table method for
+ each object returned by :meth:`Dialect.get_table_names`,
+ :meth:`Dialect.get_view_names` or
+ :meth:`Dialect.get_materialized_view_names` depending on the
+ provided ``kind``. Dialects that want to support a faster
+ implementation should implement this method.
+
+ .. versionadded:: 2.0
+
"""
raise NotImplementedError()
def get_table_names(
- self, connection: "Connection", schema: Optional[str] = None, **kw: Any
+ self, connection: Connection, schema: Optional[str] = None, **kw: Any
) -> List[str]:
- """Return a list of table names for ``schema``."""
+ """Return a list of table names for ``schema``.
+
+ This is an internal dialect method. Applications should use
+ :meth:`_engine.Inspector.get_table_names`.
+
+ """
raise NotImplementedError()
def get_temp_table_names(
- self, connection: "Connection", schema: Optional[str] = None, **kw: Any
+ self, connection: Connection, schema: Optional[str] = None, **kw: Any
) -> List[str]:
"""Return a list of temporary table names on the given connection,
if supported by the underlying backend.
+ This is an internal dialect method. Applications should use
+ :meth:`_engine.Inspector.get_temp_table_names`.
+
"""
raise NotImplementedError()
def get_view_names(
- self, connection: "Connection", schema: Optional[str] = None, **kw: Any
+ self, connection: Connection, schema: Optional[str] = None, **kw: Any
+ ) -> List[str]:
+ """Return a list of all non-materialized view names available in the
+ database.
+
+ This is an internal dialect method. Applications should use
+ :meth:`_engine.Inspector.get_view_names`.
+
+ :param schema: schema name to query, if not the default schema.
+
+ """
+
+ raise NotImplementedError()
+
+ def get_materialized_view_names(
+ self, connection: Connection, schema: Optional[str] = None, **kw: Any
) -> List[str]:
- """Return a list of all view names available in the database.
+ """Return a list of all materialized view names available in the
+ database.
+
+ This is an internal dialect method. Applications should use
+ :meth:`_engine.Inspector.get_materialized_view_names`.
:param schema: schema name to query, if not the default schema.
+
+ .. versionadded:: 2.0
+
"""
raise NotImplementedError()
def get_sequence_names(
- self, connection: "Connection", schema: Optional[str] = None, **kw: Any
+ self, connection: Connection, schema: Optional[str] = None, **kw: Any
) -> List[str]:
"""Return a list of all sequence names available in the database.
+ This is an internal dialect method. Applications should use
+ :meth:`_engine.Inspector.get_sequence_names`.
+
:param schema: schema name to query, if not the default schema.
.. versionadded:: 1.4
raise NotImplementedError()
def get_temp_view_names(
- self, connection: "Connection", schema: Optional[str] = None, **kw: Any
+ self, connection: Connection, schema: Optional[str] = None, **kw: Any
) -> List[str]:
"""Return a list of temporary view names on the given connection,
if supported by the underlying backend.
+ This is an internal dialect method. Applications should use
+ :meth:`_engine.Inspector.get_temp_view_names`.
+
"""
raise NotImplementedError()
+ def get_schema_names(self, connection: Connection, **kw: Any) -> List[str]:
+ """Return a list of all schema names available in the database.
+
+ This is an internal dialect method. Applications should use
+ :meth:`_engine.Inspector.get_schema_names`.
+ """
+ raise NotImplementedError()
+
def get_view_definition(
self,
- connection: "Connection",
+ connection: Connection,
view_name: str,
schema: Optional[str] = None,
**kw: Any,
) -> str:
- """Return view definition.
+ """Return plain or materialized view definition.
+
+ This is an internal dialect method. Applications should use
+ :meth:`_engine.Inspector.get_view_definition`.
Given a :class:`_engine.Connection`, a string
- `view_name`, and an optional string ``schema``, return the view
+ ``view_name``, and an optional string ``schema``, return the view
definition.
"""
def get_indexes(
self,
- connection: "Connection",
+ connection: Connection,
table_name: str,
schema: Optional[str] = None,
**kw: Any,
information as a list of dictionaries corresponding to the
:class:`.ReflectedIndex` dictionary.
+ This is an internal dialect method. Applications should use
+ :meth:`.Inspector.get_indexes`.
+ """
+
+ raise NotImplementedError()
+
+ def get_multi_indexes(
+ self,
+ connection: Connection,
+ schema: Optional[str] = None,
+ filter_names: Optional[Collection[str]] = None,
+ **kw: Any,
+ ) -> Iterable[Tuple[TableKey, List[ReflectedIndex]]]:
+ """Return information about indexes in in all tables
+ in the given ``schema``.
+
+ This is an internal dialect method. Applications should use
+ :meth:`.Inspector.get_multi_indexes`.
+
+ .. note:: The :class:`_engine.DefaultDialect` provides a default
+ implementation that will call the single table method for
+ each object returned by :meth:`Dialect.get_table_names`,
+ :meth:`Dialect.get_view_names` or
+ :meth:`Dialect.get_materialized_view_names` depending on the
+ provided ``kind``. Dialects that want to support a faster
+ implementation should implement this method.
+
+ .. versionadded:: 2.0
+
"""
raise NotImplementedError()
def get_unique_constraints(
self,
- connection: "Connection",
+ connection: Connection,
table_name: str,
schema: Optional[str] = None,
**kw: Any,
unique constraint information as a list of dicts corresponding
to the :class:`.ReflectedUniqueConstraint` dictionary.
+ This is an internal dialect method. Applications should use
+ :meth:`.Inspector.get_unique_constraints`.
+ """
+
+ raise NotImplementedError()
+
+ def get_multi_unique_constraints(
+ self,
+ connection: Connection,
+ schema: Optional[str] = None,
+ filter_names: Optional[Collection[str]] = None,
+ **kw: Any,
+ ) -> Iterable[Tuple[TableKey, List[ReflectedUniqueConstraint]]]:
+ """Return information about unique constraints in all tables
+ in the given ``schema``.
+
+ This is an internal dialect method. Applications should use
+ :meth:`.Inspector.get_multi_unique_constraints`.
+
+ .. note:: The :class:`_engine.DefaultDialect` provides a default
+ implementation that will call the single table method for
+ each object returned by :meth:`Dialect.get_table_names`,
+ :meth:`Dialect.get_view_names` or
+ :meth:`Dialect.get_materialized_view_names` depending on the
+ provided ``kind``. Dialects that want to support a faster
+ implementation should implement this method.
+
+ .. versionadded:: 2.0
+
"""
raise NotImplementedError()
def get_check_constraints(
self,
- connection: "Connection",
+ connection: Connection,
table_name: str,
schema: Optional[str] = None,
**kw: Any,
check constraint information as a list of dicts corresponding
to the :class:`.ReflectedCheckConstraint` dictionary.
+ This is an internal dialect method. Applications should use
+ :meth:`.Inspector.get_check_constraints`.
+
+ .. versionadded:: 1.1.0
+
+ """
+
+ raise NotImplementedError()
+
+ def get_multi_check_constraints(
+ self,
+ connection: Connection,
+ schema: Optional[str] = None,
+ filter_names: Optional[Collection[str]] = None,
+ **kw: Any,
+ ) -> Iterable[Tuple[TableKey, List[ReflectedCheckConstraint]]]:
+ """Return information about check constraints in all tables
+ in the given ``schema``.
+
+ This is an internal dialect method. Applications should use
+ :meth:`.Inspector.get_multi_check_constraints`.
+
+ .. note:: The :class:`_engine.DefaultDialect` provides a default
+ implementation that will call the single table method for
+ each object returned by :meth:`Dialect.get_table_names`,
+ :meth:`Dialect.get_view_names` or
+ :meth:`Dialect.get_materialized_view_names` depending on the
+ provided ``kind``. Dialects that want to support a faster
+ implementation should implement this method.
+
+ .. versionadded:: 2.0
+
"""
raise NotImplementedError()
def get_table_options(
self,
- connection: "Connection",
+ connection: Connection,
table_name: str,
schema: Optional[str] = None,
**kw: Any,
- ) -> Optional[Dict[str, Any]]:
- r"""Return the "options" for the table identified by ``table_name``
- as a dictionary.
+ ) -> Dict[str, Any]:
+ """Return a dictionary of options specified when ``table_name``
+ was created.
+ This is an internal dialect method. Applications should use
+ :meth:`_engine.Inspector.get_table_options`.
"""
- return None
+ raise NotImplementedError()
+
+ def get_multi_table_options(
+ self,
+ connection: Connection,
+ schema: Optional[str] = None,
+ filter_names: Optional[Collection[str]] = None,
+ **kw: Any,
+ ) -> Iterable[Tuple[TableKey, Dict[str, Any]]]:
+ """Return a dictionary of options specified when the tables in the
+ given schema were created.
+
+ This is an internal dialect method. Applications should use
+ :meth:`_engine.Inspector.get_multi_table_options`.
+
+ .. note:: The :class:`_engine.DefaultDialect` provides a default
+ implementation that will call the single table method for
+ each object returned by :meth:`Dialect.get_table_names`,
+ :meth:`Dialect.get_view_names` or
+ :meth:`Dialect.get_materialized_view_names` depending on the
+ provided ``kind``. Dialects that want to support a faster
+ implementation should implement this method.
+
+ .. versionadded:: 2.0
+
+ """
+ raise NotImplementedError()
def get_table_comment(
self,
- connection: "Connection",
+ connection: Connection,
table_name: str,
schema: Optional[str] = None,
**kw: Any,
table comment information as a dictionary corresponding to the
:class:`.ReflectedTableComment` dictionary.
+ This is an internal dialect method. Applications should use
+ :meth:`.Inspector.get_table_comment`.
:raise: ``NotImplementedError`` for dialects that don't support
comments.
raise NotImplementedError()
+ def get_multi_table_comment(
+ self,
+ connection: Connection,
+ schema: Optional[str] = None,
+ filter_names: Optional[Collection[str]] = None,
+ **kw: Any,
+ ) -> Iterable[Tuple[TableKey, ReflectedTableComment]]:
+ """Return information about the table comment in all tables
+ in the given ``schema``.
+
+ This is an internal dialect method. Applications should use
+ :meth:`_engine.Inspector.get_multi_table_comment`.
+
+ .. note:: The :class:`_engine.DefaultDialect` provides a default
+ implementation that will call the single table method for
+ each object returned by :meth:`Dialect.get_table_names`,
+ :meth:`Dialect.get_view_names` or
+ :meth:`Dialect.get_materialized_view_names` depending on the
+ provided ``kind``. Dialects that want to support a faster
+ implementation should implement this method.
+
+ .. versionadded:: 2.0
+
+ """
+
+ raise NotImplementedError()
+
def normalize_name(self, name: str) -> str:
"""convert the given name to lowercase if it is detected as
case insensitive.
def has_table(
self,
- connection: "Connection",
+ connection: Connection,
table_name: str,
schema: Optional[str] = None,
**kw: Any,
def has_index(
self,
- connection: "Connection",
+ connection: Connection,
table_name: str,
index_name: str,
schema: Optional[str] = None,
+ **kw: Any,
) -> bool:
"""Check the existence of a particular index name in the database.
Given a :class:`_engine.Connection` object, a string
- ``table_name`` and string index name, return True if an index of the
- given name on the given table exists, false otherwise.
+ ``table_name`` and string index name, return ``True`` if an index of
+ the given name on the given table exists, ``False`` otherwise.
The :class:`.DefaultDialect` implements this in terms of the
:meth:`.Dialect.has_table` and :meth:`.Dialect.get_indexes` methods,
however dialects can implement a more performant version.
+ This is an internal dialect method. Applications should use
+ :meth:`_engine.Inspector.has_index`.
.. versionadded:: 1.4
def has_sequence(
self,
- connection: "Connection",
+ connection: Connection,
sequence_name: str,
schema: Optional[str] = None,
**kw: Any,
"""Check the existence of a particular sequence in the database.
Given a :class:`_engine.Connection` object and a string
- `sequence_name`, return True if the given sequence exists in
- the database, False otherwise.
+ `sequence_name`, return ``True`` if the given sequence exists in
+ the database, ``False`` otherwise.
+
+ This is an internal dialect method. Applications should use
+ :meth:`_engine.Inspector.has_sequence`.
+ """
+
+ raise NotImplementedError()
+
+ def has_schema(
+ self, connection: Connection, schema_name: str, **kw: Any
+ ) -> bool:
+ """Check the existence of a particular schema name in the database.
+
+ Given a :class:`_engine.Connection` object, a string
+ ``schema_name``, return ``True`` if a schema of the
+ given exists, ``False`` otherwise.
+
+ The :class:`.DefaultDialect` implements this by checking
+ the presence of ``schema_name`` among the schemas returned by
+ :meth:`.Dialect.get_schema_names`,
+ however dialects can implement a more performant version.
+
+ This is an internal dialect method. Applications should use
+ :meth:`_engine.Inspector.has_schema`.
+
+ .. versionadded:: 2.0
+
"""
raise NotImplementedError()
- def _get_server_version_info(self, connection: "Connection") -> Any:
+ def _get_server_version_info(self, connection: Connection) -> Any:
"""Retrieve the server version info from the given connection.
This is used by the default implementation to populate the
raise NotImplementedError()
- def _get_default_schema_name(self, connection: "Connection") -> str:
+ def _get_default_schema_name(self, connection: Connection) -> str:
"""Return the string name of the currently selected schema from
the given connection.
raise NotImplementedError()
- def do_savepoint(self, connection: "Connection", name: str) -> None:
+ def do_savepoint(self, connection: Connection, name: str) -> None:
"""Create a savepoint with the given name.
:param connection: a :class:`_engine.Connection`.
raise NotImplementedError()
def do_rollback_to_savepoint(
- self, connection: "Connection", name: str
+ self, connection: Connection, name: str
) -> None:
"""Rollback a connection to the named savepoint.
raise NotImplementedError()
- def do_release_savepoint(
- self, connection: "Connection", name: str
- ) -> None:
+ def do_release_savepoint(self, connection: Connection, name: str) -> None:
"""Release the named savepoint on a connection.
:param connection: a :class:`_engine.Connection`.
raise NotImplementedError()
- def do_begin_twophase(self, connection: "Connection", xid: Any) -> None:
+ def do_begin_twophase(self, connection: Connection, xid: Any) -> None:
"""Begin a two phase transaction on the given connection.
:param connection: a :class:`_engine.Connection`.
raise NotImplementedError()
- def do_prepare_twophase(self, connection: "Connection", xid: Any) -> None:
+ def do_prepare_twophase(self, connection: Connection, xid: Any) -> None:
"""Prepare a two phase transaction on the given connection.
:param connection: a :class:`_engine.Connection`.
def do_rollback_twophase(
self,
- connection: "Connection",
+ connection: Connection,
xid: Any,
is_prepared: bool = True,
recover: bool = False,
def do_commit_twophase(
self,
- connection: "Connection",
+ connection: Connection,
xid: Any,
is_prepared: bool = True,
recover: bool = False,
raise NotImplementedError()
- def do_recover_twophase(self, connection: "Connection") -> List[Any]:
+ def do_recover_twophase(self, connection: Connection) -> List[Any]:
"""Recover list of uncommitted prepared two phase transaction
identifiers on the given connection.
from __future__ import annotations
import contextlib
+from dataclasses import dataclass
+from enum import auto
+from enum import Flag
+from enum import unique
+from typing import Any
+from typing import Callable
+from typing import Collection
+from typing import Dict
+from typing import Generator
+from typing import Iterable
from typing import List
from typing import Optional
+from typing import Sequence
+from typing import Set
+from typing import Tuple
+from typing import TYPE_CHECKING
+from typing import TypeVar
+from typing import Union
from .base import Connection
from .base import Engine
-from .interfaces import ReflectedColumn
from .. import exc
from .. import inspection
from .. import sql
from .. import util
from ..sql import operators
from ..sql import schema as sa_schema
+from ..sql.cache_key import _ad_hoc_cache_key_from_args
+from ..sql.elements import TextClause
from ..sql.type_api import TypeEngine
+from ..sql.visitors import InternalTraversal
from ..util import topological
+from ..util.typing import final
+
+if TYPE_CHECKING:
+ from .interfaces import Dialect
+ from .interfaces import ReflectedCheckConstraint
+ from .interfaces import ReflectedColumn
+ from .interfaces import ReflectedForeignKeyConstraint
+ from .interfaces import ReflectedIndex
+ from .interfaces import ReflectedPrimaryKeyConstraint
+ from .interfaces import ReflectedTableComment
+ from .interfaces import ReflectedUniqueConstraint
+ from .interfaces import TableKey
+
+_R = TypeVar("_R")
@util.decorator
-def cache(fn, self, con, *args, **kw):
+def cache(
+ fn: Callable[..., _R],
+ self: Dialect,
+ con: Connection,
+ *args: Any,
+ **kw: Any,
+) -> _R:
info_cache = kw.get("info_cache", None)
if info_cache is None:
return fn(self, con, *args, **kw)
+ exclude = {"info_cache", "unreflectable"}
key = (
fn.__name__,
tuple(a for a in args if isinstance(a, str)),
- tuple((k, v) for k, v in kw.items() if k != "info_cache"),
+ tuple((k, v) for k, v in kw.items() if k not in exclude),
)
- ret = info_cache.get(key)
+ ret: _R = info_cache.get(key)
if ret is None:
ret = fn(self, con, *args, **kw)
info_cache[key] = ret
return ret
+def flexi_cache(
+ *traverse_args: Tuple[str, InternalTraversal]
+) -> Callable[[Callable[..., _R]], Callable[..., _R]]:
+ @util.decorator
+ def go(
+ fn: Callable[..., _R],
+ self: Dialect,
+ con: Connection,
+ *args: Any,
+ **kw: Any,
+ ) -> _R:
+ info_cache = kw.get("info_cache", None)
+ if info_cache is None:
+ return fn(self, con, *args, **kw)
+ key = _ad_hoc_cache_key_from_args((fn.__name__,), traverse_args, args)
+ ret: _R = info_cache.get(key)
+ if ret is None:
+ ret = fn(self, con, *args, **kw)
+ info_cache[key] = ret
+ return ret
+
+ return go
+
+
+@unique
+class ObjectKind(Flag):
+ """Enumerator that indicates which kind of object to return when calling
+ the ``get_multi`` methods.
+
+ This is a Flag enum, so custom combinations can be passed. For example,
+ to reflect tables and plain views ``ObjectKind.TABLE | ObjectKind.VIEW``
+ may be used.
+
+ .. note::
+ Not all dialect may support all kind of object. If a dialect does
+ not support a particular object an empty dict is returned.
+ In case a dialect supports an object, but the requested method
+ is not applicable for the specified kind the default value
+ will be returned for each reflected object. For example reflecting
+ check constraints of view return a dict with all the views with
+ empty lists as values.
+ """
+
+ TABLE = auto()
+ "Reflect table objects"
+ VIEW = auto()
+ "Reflect plain view objects"
+ MATERIALIZED_VIEW = auto()
+ "Reflect materialized view object"
+
+ ANY_VIEW = VIEW | MATERIALIZED_VIEW
+ "Reflect any kind of view objects"
+ ANY = TABLE | VIEW | MATERIALIZED_VIEW
+ "Reflect all type of objects"
+
+
+@unique
+class ObjectScope(Flag):
+ """Enumerator that indicates which scope to use when calling
+ the ``get_multi`` methods.
+ """
+
+ DEFAULT = auto()
+ "Include default scope"
+ TEMPORARY = auto()
+ "Include only temp scope"
+ ANY = DEFAULT | TEMPORARY
+ "Include both default and temp scope"
+
+
@inspection._self_inspects
class Inspector(inspection.Inspectable["Inspector"]):
"""Performs database schema inspection.
"""
+ bind: Union[Engine, Connection]
+ engine: Engine
+ _op_context_requires_connect: bool
+ dialect: Dialect
+ info_cache: Dict[Any, Any]
+
@util.deprecated(
"1.4",
"The __init__() method on :class:`_reflection.Inspector` "
"in order to "
"acquire an :class:`_reflection.Inspector`.",
)
- def __init__(self, bind):
+ def __init__(self, bind: Union[Engine, Connection]):
"""Initialize a new :class:`_reflection.Inspector`.
:param bind: a :class:`~sqlalchemy.engine.Connection`,
:meth:`_reflection.Inspector.from_engine`
"""
- return self._init_legacy(bind)
+ self._init_legacy(bind)
@classmethod
- def _construct(cls, init, bind):
+ def _construct(
+ cls, init: Callable[..., Any], bind: Union[Engine, Connection]
+ ) -> Inspector:
if hasattr(bind.dialect, "inspector"):
- cls = bind.dialect.inspector
+ cls = bind.dialect.inspector # type: ignore[attr-defined]
self = cls.__new__(cls)
init(self, bind)
return self
- def _init_legacy(self, bind):
+ def _init_legacy(self, bind: Union[Engine, Connection]) -> None:
if hasattr(bind, "exec_driver_sql"):
- self._init_connection(bind)
+ self._init_connection(bind) # type: ignore[arg-type]
else:
- self._init_engine(bind)
+ self._init_engine(bind) # type: ignore[arg-type]
- def _init_engine(self, engine):
+ def _init_engine(self, engine: Engine) -> None:
self.bind = self.engine = engine
engine.connect().close()
self._op_context_requires_connect = True
self.dialect = self.engine.dialect
self.info_cache = {}
- def _init_connection(self, connection):
+ def _init_connection(self, connection: Connection) -> None:
self.bind = connection
self.engine = connection.engine
self._op_context_requires_connect = False
self.dialect = self.engine.dialect
self.info_cache = {}
+ def clear_cache(self) -> None:
+ """reset the cache for this :class:`.Inspector`.
+
+ Inspection methods that have data cached will emit SQL queries
+ when next called to get new data.
+
+ .. versionadded:: 2.0
+
+ """
+ self.info_cache.clear()
+
@classmethod
@util.deprecated(
"1.4",
"in order to "
"acquire an :class:`_reflection.Inspector`.",
)
- def from_engine(cls, bind):
+ def from_engine(cls, bind: Engine) -> Inspector:
"""Construct a new dialect-specific Inspector object from the given
engine or connection.
return cls._construct(cls._init_legacy, bind)
@inspection._inspects(Engine)
- def _engine_insp(bind):
+ def _engine_insp(bind: Engine) -> Inspector: # type: ignore[misc]
return Inspector._construct(Inspector._init_engine, bind)
@inspection._inspects(Connection)
- def _connection_insp(bind):
+ def _connection_insp(bind: Connection) -> Inspector: # type: ignore[misc]
return Inspector._construct(Inspector._init_connection, bind)
@contextlib.contextmanager
- def _operation_context(self):
+ def _operation_context(self) -> Generator[Connection, None, None]:
"""Return a context that optimizes for multiple operations on a single
transaction.
:class:`_engine.Connection`.
"""
+ conn: Connection
if self._op_context_requires_connect:
- conn = self.bind.connect()
+ conn = self.bind.connect() # type: ignore[union-attr]
else:
- conn = self.bind
+ conn = self.bind # type: ignore[assignment]
try:
yield conn
finally:
conn.close()
@contextlib.contextmanager
- def _inspection_context(self):
+ def _inspection_context(self) -> Generator[Inspector, None, None]:
"""Return an :class:`_reflection.Inspector`
from this one that will run all
operations on a single connection.
yield sub_insp
@property
- def default_schema_name(self):
+ def default_schema_name(self) -> Optional[str]:
"""Return the default schema name presented by the dialect
for the current engine's database user.
"""
return self.dialect.default_schema_name
- def get_schema_names(self):
- """Return all schema names."""
+ def get_schema_names(self, **kw: Any) -> List[str]:
+ r"""Return all schema names.
- if hasattr(self.dialect, "get_schema_names"):
- with self._operation_context() as conn:
- return self.dialect.get_schema_names(
- conn, info_cache=self.info_cache
- )
- return []
+ :param \**kw: Additional keyword argument to pass to the dialect
+ specific implementation. See the documentation of the dialect
+ in use for more information.
+ """
- def get_table_names(self, schema=None):
- """Return all table names in referred to within a particular schema.
+ with self._operation_context() as conn:
+ return self.dialect.get_schema_names(
+ conn, info_cache=self.info_cache, **kw
+ )
+
+ def get_table_names(
+ self, schema: Optional[str] = None, **kw: Any
+ ) -> List[str]:
+ r"""Return all table names within a particular schema.
The names are expected to be real tables only, not views.
Views are instead returned using the
- :meth:`_reflection.Inspector.get_view_names`
- method.
-
+ :meth:`_reflection.Inspector.get_view_names` and/or
+ :meth:`_reflection.Inspector.get_materialized_view_names`
+ methods.
:param schema: Schema name. If ``schema`` is left at ``None``, the
database's default schema is
used, else the named schema is searched. If the database does not
support named schemas, behavior is undefined if ``schema`` is not
passed as ``None``. For special quoting, use :class:`.quoted_name`.
+ :param \**kw: Additional keyword argument to pass to the dialect
+ specific implementation. See the documentation of the dialect
+ in use for more information.
.. seealso::
with self._operation_context() as conn:
return self.dialect.get_table_names(
- conn, schema, info_cache=self.info_cache
+ conn, schema, info_cache=self.info_cache, **kw
)
- def has_table(self, table_name, schema=None):
- """Return True if the backend has a table or view of the given name.
+ def has_table(
+ self, table_name: str, schema: Optional[str] = None, **kw: Any
+ ) -> bool:
+ r"""Return True if the backend has a table or view of the given name.
:param table_name: name of the table to check
:param schema: schema name to query, if not the default schema.
+ :param \**kw: Additional keyword argument to pass to the dialect
+ specific implementation. See the documentation of the dialect
+ in use for more information.
.. versionadded:: 1.4 - the :meth:`.Inspector.has_table` method
replaces the :meth:`_engine.Engine.has_table` method.
- .. versionchanged:: 2.0:: The method checks also for views.
+ .. versionchanged:: 2.0:: The method checks also for any type of
+ views (plain or materialized).
In previous version this behaviour was dialect specific. New
dialect suite tests were added to ensure all dialect conform with
this behaviour.
"""
- # TODO: info_cache?
with self._operation_context() as conn:
- return self.dialect.has_table(conn, table_name, schema)
+ return self.dialect.has_table(
+ conn, table_name, schema, info_cache=self.info_cache, **kw
+ )
- def has_sequence(self, sequence_name, schema=None):
- """Return True if the backend has a table of the given name.
+ def has_sequence(
+ self, sequence_name: str, schema: Optional[str] = None, **kw: Any
+ ) -> bool:
+ r"""Return True if the backend has a sequence with the given name.
- :param sequence_name: name of the table to check
+ :param sequence_name: name of the sequence to check
:param schema: schema name to query, if not the default schema.
+ :param \**kw: Additional keyword argument to pass to the dialect
+ specific implementation. See the documentation of the dialect
+ in use for more information.
.. versionadded:: 1.4
"""
- # TODO: info_cache?
with self._operation_context() as conn:
- return self.dialect.has_sequence(conn, sequence_name, schema)
+ return self.dialect.has_sequence(
+ conn, sequence_name, schema, info_cache=self.info_cache, **kw
+ )
+
+ def has_index(
+ self,
+ table_name: str,
+ index_name: str,
+ schema: Optional[str] = None,
+ **kw: Any,
+ ) -> bool:
+ r"""Check the existence of a particular index name in the database.
+
+ :param table_name: the name of the table the index belongs to
+ :param index_name: the name of the index to check
+ :param schema: schema name to query, if not the default schema.
+ :param \**kw: Additional keyword argument to pass to the dialect
+ specific implementation. See the documentation of the dialect
+ in use for more information.
+
+ .. versionadded:: 2.0
+
+ """
+ with self._operation_context() as conn:
+ return self.dialect.has_index(
+ conn,
+ table_name,
+ index_name,
+ schema,
+ info_cache=self.info_cache,
+ **kw,
+ )
+
+ def has_schema(self, schema_name: str, **kw: Any) -> bool:
+ r"""Return True if the backend has a schema with the given name.
+
+ :param schema_name: name of the schema to check
+ :param \**kw: Additional keyword argument to pass to the dialect
+ specific implementation. See the documentation of the dialect
+ in use for more information.
+
+ .. versionadded:: 2.0
+
+ """
+ with self._operation_context() as conn:
+ return self.dialect.has_schema(
+ conn, schema_name, info_cache=self.info_cache, **kw
+ )
- def get_sorted_table_and_fkc_names(self, schema=None):
- """Return dependency-sorted table and foreign key constraint names in
+ def get_sorted_table_and_fkc_names(
+ self,
+ schema: Optional[str] = None,
+ **kw: Any,
+ ) -> List[Tuple[Optional[str], List[Tuple[str, Optional[str]]]]]:
+ r"""Return dependency-sorted table and foreign key constraint names in
referred to within a particular schema.
This will yield 2-tuples of
.. versionadded:: 1.0.-
+ :param schema: schema name to query, if not the default schema.
+ :param \**kw: Additional keyword argument to pass to the dialect
+ specific implementation. See the documentation of the dialect
+ in use for more information.
+
.. seealso::
:meth:`_reflection.Inspector.get_table_names`
with an already-given :class:`_schema.MetaData`.
"""
- with self._operation_context() as conn:
- tnames = self.dialect.get_table_names(
- conn, schema, info_cache=self.info_cache
+
+ return [
+ (
+ table_key[1] if table_key else None,
+ [(tname, fks) for (_, tname), fks in fk_collection],
)
+ for (
+ table_key,
+ fk_collection,
+ ) in self.sort_tables_on_foreign_key_dependency(
+ consider_schemas=(schema,)
+ )
+ ]
- tuples = set()
- remaining_fkcs = set()
+ def sort_tables_on_foreign_key_dependency(
+ self,
+ consider_schemas: Collection[Optional[str]] = (None,),
+ **kw: Any,
+ ) -> List[
+ Tuple[
+ Optional[Tuple[Optional[str], str]],
+ List[Tuple[Tuple[Optional[str], str], Optional[str]]],
+ ]
+ ]:
+ r"""Return dependency-sorted table and foreign key constraint names
+ referred to within multiple schemas.
+
+ This method may be compared to
+ :meth:`.Inspector.get_sorted_table_and_fkc_names`, which
+ works on one schema at a time; here, the method is a generalization
+ that will consider multiple schemas at once including that it will
+ resolve for cross-schema foreign keys.
+
+ .. versionadded:: 2.0
- fknames_for_table = {}
- for tname in tnames:
- fkeys = self.get_foreign_keys(tname, schema)
- fknames_for_table[tname] = set([fk["name"] for fk in fkeys])
- for fkey in fkeys:
- if tname != fkey["referred_table"]:
- tuples.add((fkey["referred_table"], tname))
+ """
+ SchemaTab = Tuple[Optional[str], str]
+
+ tuples: Set[Tuple[SchemaTab, SchemaTab]] = set()
+ remaining_fkcs: Set[Tuple[SchemaTab, Optional[str]]] = set()
+ fknames_for_table: Dict[SchemaTab, Set[Optional[str]]] = {}
+ tnames: List[SchemaTab] = []
+
+ for schname in consider_schemas:
+ schema_fkeys = self.get_multi_foreign_keys(schname, **kw)
+ tnames.extend(schema_fkeys)
+ for (_, tname), fkeys in schema_fkeys.items():
+ fknames_for_table[(schname, tname)] = set(
+ [fk["name"] for fk in fkeys]
+ )
+ for fkey in fkeys:
+ if (
+ tname != fkey["referred_table"]
+ or schname != fkey["referred_schema"]
+ ):
+ tuples.add(
+ (
+ (
+ fkey["referred_schema"],
+ fkey["referred_table"],
+ ),
+ (schname, tname),
+ )
+ )
try:
candidate_sort = list(topological.sort(tuples, tnames))
except exc.CircularDependencyError as err:
+ edge: Tuple[SchemaTab, SchemaTab]
for edge in err.edges:
tuples.remove(edge)
remaining_fkcs.update(
)
candidate_sort = list(topological.sort(tuples, tnames))
- return [
- (tname, fknames_for_table[tname].difference(remaining_fkcs))
- for tname in candidate_sort
- ] + [(None, list(remaining_fkcs))]
+ ret: List[
+ Tuple[Optional[SchemaTab], List[Tuple[SchemaTab, Optional[str]]]]
+ ]
+ ret = [
+ (
+ (schname, tname),
+ [
+ ((schname, tname), fk)
+ for fk in fknames_for_table[(schname, tname)].difference(
+ name for _, name in remaining_fkcs
+ )
+ ],
+ )
+ for (schname, tname) in candidate_sort
+ ]
+ return ret + [(None, list(remaining_fkcs))]
- def get_temp_table_names(self):
- """Return a list of temporary table names for the current bind.
+ def get_temp_table_names(self, **kw: Any) -> List[str]:
+ r"""Return a list of temporary table names for the current bind.
This method is unsupported by most dialects; currently
- only SQLite implements it.
+ only Oracle, PostgreSQL and SQLite implements it.
+
+ :param \**kw: Additional keyword argument to pass to the dialect
+ specific implementation. See the documentation of the dialect
+ in use for more information.
.. versionadded:: 1.0.0
with self._operation_context() as conn:
return self.dialect.get_temp_table_names(
- conn, info_cache=self.info_cache
+ conn, info_cache=self.info_cache, **kw
)
- def get_temp_view_names(self):
- """Return a list of temporary view names for the current bind.
+ def get_temp_view_names(self, **kw: Any) -> List[str]:
+ r"""Return a list of temporary view names for the current bind.
This method is unsupported by most dialects; currently
- only SQLite implements it.
+ only PostgreSQL and SQLite implements it.
+
+ :param \**kw: Additional keyword argument to pass to the dialect
+ specific implementation. See the documentation of the dialect
+ in use for more information.
.. versionadded:: 1.0.0
"""
with self._operation_context() as conn:
return self.dialect.get_temp_view_names(
- conn, info_cache=self.info_cache
+ conn, info_cache=self.info_cache, **kw
)
- def get_table_options(self, table_name, schema=None, **kw):
- """Return a dictionary of options specified when the table of the
+ def get_table_options(
+ self, table_name: str, schema: Optional[str] = None, **kw: Any
+ ) -> Dict[str, Any]:
+ r"""Return a dictionary of options specified when the table of the
given name was created.
- This currently includes some options that apply to MySQL tables.
+ This currently includes some options that apply to MySQL and Oracle
+ tables.
:param table_name: string name of the table. For special quoting,
use :class:`.quoted_name`.
of the database connection. For special quoting,
use :class:`.quoted_name`.
+ :param \**kw: Additional keyword argument to pass to the dialect
+ specific implementation. See the documentation of the dialect
+ in use for more information.
+
+ :return: a dict with the table options. The returned keys depend on the
+ dialect in use. Each one is prefixed with the dialect name.
+
"""
- if hasattr(self.dialect, "get_table_options"):
- with self._operation_context() as conn:
- return self.dialect.get_table_options(
- conn, table_name, schema, info_cache=self.info_cache, **kw
- )
- return {}
+ with self._operation_context() as conn:
+ return self.dialect.get_table_options(
+ conn, table_name, schema, info_cache=self.info_cache, **kw
+ )
+
+ def get_multi_table_options(
+ self,
+ schema: Optional[str] = None,
+ filter_names: Optional[Sequence[str]] = None,
+ kind: ObjectKind = ObjectKind.TABLE,
+ scope: ObjectScope = ObjectScope.DEFAULT,
+ **kw: Any,
+ ) -> Dict[TableKey, Dict[str, Any]]:
+ r"""Return a dictionary of options specified when the tables in the
+ given schema were created.
+
+ The tables can be filtered by passing the names to use to
+ ``filter_names``.
+
+ This currently includes some options that apply to MySQL and Oracle
+ tables.
+
+ :param schema: string schema name; if omitted, uses the default schema
+ of the database connection. For special quoting,
+ use :class:`.quoted_name`.
+
+ :param filter_names: optionally return information only for the
+ objects listed here.
+
+ :param kind: a :class:`.ObjectKind` that specifies the type of objects
+ to reflect. Defaults to ``ObjectKind.TABLE``.
+
+ :param scope: a :class:`.ObjectScope` that specifies if options of
+ default, temporary or any tables should be reflected.
+ Defaults to ``ObjectScope.DEFAULT``.
+
+ :param \**kw: Additional keyword argument to pass to the dialect
+ specific implementation. See the documentation of the dialect
+ in use for more information.
+
+ :return: a dictionary where the keys are two-tuple schema,table-name
+ and the values are dictionaries with the table options.
+ The returned keys in each dict depend on the
+ dialect in use. Each one is prefixed with the dialect name.
+ The schema is ``None`` if no schema is provided.
+
+ .. versionadded:: 2.0
+ """
+ with self._operation_context() as conn:
+ res = self.dialect.get_multi_table_options(
+ conn,
+ schema=schema,
+ filter_names=filter_names,
+ kind=kind,
+ scope=scope,
+ info_cache=self.info_cache,
+ **kw,
+ )
+ return dict(res)
- def get_view_names(self, schema=None):
- """Return all view names in `schema`.
+ def get_view_names(
+ self, schema: Optional[str] = None, **kw: Any
+ ) -> List[str]:
+ r"""Return all non-materialized view names in `schema`.
:param schema: Optional, retrieve names from a non-default schema.
For special quoting, use :class:`.quoted_name`.
+ :param \**kw: Additional keyword argument to pass to the dialect
+ specific implementation. See the documentation of the dialect
+ in use for more information.
+
+
+ .. versionchanged:: 2.0 For those dialects that previously included
+ the names of materialized views in this list (currently PostgreSQL),
+ this method no longer returns the names of materialized views.
+ the :meth:`.Inspector.get_materialized_view_names` method should
+ be used instead.
+
+ .. seealso::
+
+ :meth:`.Inspector.get_materialized_view_names`
"""
with self._operation_context() as conn:
return self.dialect.get_view_names(
- conn, schema, info_cache=self.info_cache
+ conn, schema, info_cache=self.info_cache, **kw
+ )
+
+ def get_materialized_view_names(
+ self, schema: Optional[str] = None, **kw: Any
+ ) -> List[str]:
+ r"""Return all materialized view names in `schema`.
+
+ :param schema: Optional, retrieve names from a non-default schema.
+ For special quoting, use :class:`.quoted_name`.
+ :param \**kw: Additional keyword argument to pass to the dialect
+ specific implementation. See the documentation of the dialect
+ in use for more information.
+
+ .. versionadded:: 2.0
+
+ .. seealso::
+
+ :meth:`.Inspector.get_view_names`
+
+ """
+
+ with self._operation_context() as conn:
+ return self.dialect.get_materialized_view_names(
+ conn, schema, info_cache=self.info_cache, **kw
)
- def get_sequence_names(self, schema=None):
- """Return all sequence names in `schema`.
+ def get_sequence_names(
+ self, schema: Optional[str] = None, **kw: Any
+ ) -> List[str]:
+ r"""Return all sequence names in `schema`.
:param schema: Optional, retrieve names from a non-default schema.
For special quoting, use :class:`.quoted_name`.
+ :param \**kw: Additional keyword argument to pass to the dialect
+ specific implementation. See the documentation of the dialect
+ in use for more information.
"""
with self._operation_context() as conn:
return self.dialect.get_sequence_names(
- conn, schema, info_cache=self.info_cache
+ conn, schema, info_cache=self.info_cache, **kw
)
- def get_view_definition(self, view_name, schema=None):
- """Return definition for `view_name`.
+ def get_view_definition(
+ self, view_name: str, schema: Optional[str] = None, **kw: Any
+ ) -> str:
+ r"""Return definition for the plain or materialized view called
+ ``view_name``.
+ :param view_name: Name of the view.
:param schema: Optional, retrieve names from a non-default schema.
For special quoting, use :class:`.quoted_name`.
+ :param \**kw: Additional keyword argument to pass to the dialect
+ specific implementation. See the documentation of the dialect
+ in use for more information.
"""
with self._operation_context() as conn:
return self.dialect.get_view_definition(
- conn, view_name, schema, info_cache=self.info_cache
+ conn, view_name, schema, info_cache=self.info_cache, **kw
)
def get_columns(
- self, table_name: str, schema: Optional[str] = None, **kw
+ self, table_name: str, schema: Optional[str] = None, **kw: Any
) -> List[ReflectedColumn]:
- """Return information about columns in `table_name`.
+ r"""Return information about columns in ``table_name``.
- Given a string `table_name` and an optional string `schema`, return
- column information as a list of dicts with these keys:
+ Given a string ``table_name`` and an optional string ``schema``,
+ return column information as a list of dicts with these keys:
* ``name`` - the column's name
of the database connection. For special quoting,
use :class:`.quoted_name`.
+ :param \**kw: Additional keyword argument to pass to the dialect
+ specific implementation. See the documentation of the dialect
+ in use for more information.
+
:return: list of dictionaries, each representing the definition of
a database column.
col_defs = self.dialect.get_columns(
conn, table_name, schema, info_cache=self.info_cache, **kw
)
- for col_def in col_defs:
- # make this easy and only return instances for coltype
- coltype = col_def["type"]
- if not isinstance(coltype, TypeEngine):
- col_def["type"] = coltype()
+ if col_defs:
+ self._instantiate_types([col_defs])
return col_defs
- def get_pk_constraint(self, table_name, schema=None, **kw):
- """Return information about primary key constraint on `table_name`.
+ def _instantiate_types(
+ self, data: Iterable[List[ReflectedColumn]]
+ ) -> None:
+ # make this easy and only return instances for coltype
+ for col_defs in data:
+ for col_def in col_defs:
+ coltype = col_def["type"]
+ if not isinstance(coltype, TypeEngine):
+ col_def["type"] = coltype()
+
+ def get_multi_columns(
+ self,
+ schema: Optional[str] = None,
+ filter_names: Optional[Sequence[str]] = None,
+ kind: ObjectKind = ObjectKind.TABLE,
+ scope: ObjectScope = ObjectScope.DEFAULT,
+ **kw: Any,
+ ) -> Dict[TableKey, List[ReflectedColumn]]:
+ r"""Return information about columns in all objects in the given schema.
+
+ The objects can be filtered by passing the names to use to
+ ``filter_names``.
+
+ The column information is as described in
+ :meth:`Inspector.get_columns`.
+
+ :param schema: string schema name; if omitted, uses the default schema
+ of the database connection. For special quoting,
+ use :class:`.quoted_name`.
+
+ :param filter_names: optionally return information only for the
+ objects listed here.
+
+ :param kind: a :class:`.ObjectKind` that specifies the type of objects
+ to reflect. Defaults to ``ObjectKind.TABLE``.
+
+ :param scope: a :class:`.ObjectScope` that specifies if columns of
+ default, temporary or any tables should be reflected.
+ Defaults to ``ObjectScope.DEFAULT``.
+
+ :param \**kw: Additional keyword argument to pass to the dialect
+ specific implementation. See the documentation of the dialect
+ in use for more information.
+
+ :return: a dictionary where the keys are two-tuple schema,table-name
+ and the values are list of dictionaries, each representing the
+ definition of a database column.
+ The schema is ``None`` if no schema is provided.
+
+ .. versionadded:: 2.0
+ """
+
+ with self._operation_context() as conn:
+ table_col_defs = dict(
+ self.dialect.get_multi_columns(
+ conn,
+ schema=schema,
+ filter_names=filter_names,
+ kind=kind,
+ scope=scope,
+ info_cache=self.info_cache,
+ **kw,
+ )
+ )
+ self._instantiate_types(table_col_defs.values())
+ return table_col_defs
+
+ def get_pk_constraint(
+ self, table_name: str, schema: Optional[str] = None, **kw: Any
+ ) -> ReflectedPrimaryKeyConstraint:
+ r"""Return information about primary key constraint in ``table_name``.
- Given a string `table_name`, and an optional string `schema`, return
+ Given a string ``table_name``, and an optional string `schema`, return
primary key information as a dictionary with these keys:
* ``constrained_columns`` -
of the database connection. For special quoting,
use :class:`.quoted_name`.
+ :param \**kw: Additional keyword argument to pass to the dialect
+ specific implementation. See the documentation of the dialect
+ in use for more information.
+
+ :return: a dictionary representing the definition of
+ a primary key constraint.
+
"""
with self._operation_context() as conn:
return self.dialect.get_pk_constraint(
conn, table_name, schema, info_cache=self.info_cache, **kw
)
- def get_foreign_keys(self, table_name, schema=None, **kw):
- """Return information about foreign_keys in `table_name`.
+ def get_multi_pk_constraint(
+ self,
+ schema: Optional[str] = None,
+ filter_names: Optional[Sequence[str]] = None,
+ kind: ObjectKind = ObjectKind.TABLE,
+ scope: ObjectScope = ObjectScope.DEFAULT,
+ **kw: Any,
+ ) -> Dict[TableKey, ReflectedPrimaryKeyConstraint]:
+ r"""Return information about primary key constraints in
+ all tables in the given schema.
+
+ The tables can be filtered by passing the names to use to
+ ``filter_names``.
+
+ The primary key information is as described in
+ :meth:`Inspector.get_pk_constraint`.
+
+ :param schema: string schema name; if omitted, uses the default schema
+ of the database connection. For special quoting,
+ use :class:`.quoted_name`.
+
+ :param filter_names: optionally return information only for the
+ objects listed here.
+
+ :param kind: a :class:`.ObjectKind` that specifies the type of objects
+ to reflect. Defaults to ``ObjectKind.TABLE``.
+
+ :param scope: a :class:`.ObjectScope` that specifies if primary keys of
+ default, temporary or any tables should be reflected.
+ Defaults to ``ObjectScope.DEFAULT``.
+
+ :param \**kw: Additional keyword argument to pass to the dialect
+ specific implementation. See the documentation of the dialect
+ in use for more information.
+
+ :return: a dictionary where the keys are two-tuple schema,table-name
+ and the values are dictionaries, each representing the
+ definition of a primary key constraint.
+ The schema is ``None`` if no schema is provided.
+
+ .. versionadded:: 2.0
+ """
+ with self._operation_context() as conn:
+ return dict(
+ self.dialect.get_multi_pk_constraint(
+ conn,
+ schema=schema,
+ filter_names=filter_names,
+ kind=kind,
+ scope=scope,
+ info_cache=self.info_cache,
+ **kw,
+ )
+ )
+
+ def get_foreign_keys(
+ self, table_name: str, schema: Optional[str] = None, **kw: Any
+ ) -> List[ReflectedForeignKeyConstraint]:
+ r"""Return information about foreign_keys in ``table_name``.
- Given a string `table_name`, and an optional string `schema`, return
+ Given a string ``table_name``, and an optional string `schema`, return
foreign key information as a list of dicts with these keys:
* ``constrained_columns`` -
of the database connection. For special quoting,
use :class:`.quoted_name`.
+ :param \**kw: Additional keyword argument to pass to the dialect
+ specific implementation. See the documentation of the dialect
+ in use for more information.
+
+ :return: a list of dictionaries, each representing the
+ a foreign key definition.
+
"""
with self._operation_context() as conn:
conn, table_name, schema, info_cache=self.info_cache, **kw
)
- def get_indexes(self, table_name, schema=None, **kw):
- """Return information about indexes in `table_name`.
+ def get_multi_foreign_keys(
+ self,
+ schema: Optional[str] = None,
+ filter_names: Optional[Sequence[str]] = None,
+ kind: ObjectKind = ObjectKind.TABLE,
+ scope: ObjectScope = ObjectScope.DEFAULT,
+ **kw: Any,
+ ) -> Dict[TableKey, List[ReflectedForeignKeyConstraint]]:
+ r"""Return information about foreign_keys in all tables
+ in the given schema.
+
+ The tables can be filtered by passing the names to use to
+ ``filter_names``.
+
+ The foreign key informations as described in
+ :meth:`Inspector.get_foreign_keys`.
+
+ :param schema: string schema name; if omitted, uses the default schema
+ of the database connection. For special quoting,
+ use :class:`.quoted_name`.
+
+ :param filter_names: optionally return information only for the
+ objects listed here.
+
+ :param kind: a :class:`.ObjectKind` that specifies the type of objects
+ to reflect. Defaults to ``ObjectKind.TABLE``.
+
+ :param scope: a :class:`.ObjectScope` that specifies if foreign keys of
+ default, temporary or any tables should be reflected.
+ Defaults to ``ObjectScope.DEFAULT``.
+
+ :param \**kw: Additional keyword argument to pass to the dialect
+ specific implementation. See the documentation of the dialect
+ in use for more information.
+
+ :return: a dictionary where the keys are two-tuple schema,table-name
+ and the values are list of dictionaries, each representing
+ a foreign key definition.
+ The schema is ``None`` if no schema is provided.
+
+ .. versionadded:: 2.0
+ """
+
+ with self._operation_context() as conn:
+ return dict(
+ self.dialect.get_multi_foreign_keys(
+ conn,
+ schema=schema,
+ filter_names=filter_names,
+ kind=kind,
+ scope=scope,
+ info_cache=self.info_cache,
+ **kw,
+ )
+ )
- Given a string `table_name` and an optional string `schema`, return
+ def get_indexes(
+ self, table_name: str, schema: Optional[str] = None, **kw: Any
+ ) -> List[ReflectedIndex]:
+ r"""Return information about indexes in ``table_name``.
+
+ Given a string ``table_name`` and an optional string `schema`, return
index information as a list of dicts with these keys:
* ``name`` -
of the database connection. For special quoting,
use :class:`.quoted_name`.
+ :param \**kw: Additional keyword argument to pass to the dialect
+ specific implementation. See the documentation of the dialect
+ in use for more information.
+
+ :return: a list of dictionaries, each representing the
+ definition of an index.
+
"""
with self._operation_context() as conn:
conn, table_name, schema, info_cache=self.info_cache, **kw
)
- def get_unique_constraints(self, table_name, schema=None, **kw):
- """Return information about unique constraints in `table_name`.
+ def get_multi_indexes(
+ self,
+ schema: Optional[str] = None,
+ filter_names: Optional[Sequence[str]] = None,
+ kind: ObjectKind = ObjectKind.TABLE,
+ scope: ObjectScope = ObjectScope.DEFAULT,
+ **kw: Any,
+ ) -> Dict[TableKey, List[ReflectedIndex]]:
+ r"""Return information about indexes in in all objects
+ in the given schema.
+
+ The objects can be filtered by passing the names to use to
+ ``filter_names``.
+
+ The foreign key information is as described in
+ :meth:`Inspector.get_foreign_keys`.
+
+ The indexes information as described in
+ :meth:`Inspector.get_indexes`.
+
+ :param schema: string schema name; if omitted, uses the default schema
+ of the database connection. For special quoting,
+ use :class:`.quoted_name`.
+
+ :param filter_names: optionally return information only for the
+ objects listed here.
+
+ :param kind: a :class:`.ObjectKind` that specifies the type of objects
+ to reflect. Defaults to ``ObjectKind.TABLE``.
+
+ :param scope: a :class:`.ObjectScope` that specifies if indexes of
+ default, temporary or any tables should be reflected.
+ Defaults to ``ObjectScope.DEFAULT``.
+
+ :param \**kw: Additional keyword argument to pass to the dialect
+ specific implementation. See the documentation of the dialect
+ in use for more information.
+
+ :return: a dictionary where the keys are two-tuple schema,table-name
+ and the values are list of dictionaries, each representing the
+ definition of an index.
+ The schema is ``None`` if no schema is provided.
+
+ .. versionadded:: 2.0
+ """
+
+ with self._operation_context() as conn:
+ return dict(
+ self.dialect.get_multi_indexes(
+ conn,
+ schema=schema,
+ filter_names=filter_names,
+ kind=kind,
+ scope=scope,
+ info_cache=self.info_cache,
+ **kw,
+ )
+ )
+
+ def get_unique_constraints(
+ self, table_name: str, schema: Optional[str] = None, **kw: Any
+ ) -> List[ReflectedUniqueConstraint]:
+ r"""Return information about unique constraints in ``table_name``.
- Given a string `table_name` and an optional string `schema`, return
+ Given a string ``table_name`` and an optional string `schema`, return
unique constraint information as a list of dicts with these keys:
* ``name`` -
of the database connection. For special quoting,
use :class:`.quoted_name`.
+ :param \**kw: Additional keyword argument to pass to the dialect
+ specific implementation. See the documentation of the dialect
+ in use for more information.
+
+ :return: a list of dictionaries, each representing the
+ definition of an unique constraint.
+
"""
with self._operation_context() as conn:
conn, table_name, schema, info_cache=self.info_cache, **kw
)
- def get_table_comment(self, table_name, schema=None, **kw):
- """Return information about the table comment for ``table_name``.
+ def get_multi_unique_constraints(
+ self,
+ schema: Optional[str] = None,
+ filter_names: Optional[Sequence[str]] = None,
+ kind: ObjectKind = ObjectKind.TABLE,
+ scope: ObjectScope = ObjectScope.DEFAULT,
+ **kw: Any,
+ ) -> Dict[TableKey, List[ReflectedUniqueConstraint]]:
+ r"""Return information about unique constraints in all tables
+ in the given schema.
+
+ The tables can be filtered by passing the names to use to
+ ``filter_names``.
+
+ The unique constraint information is as described in
+ :meth:`Inspector.get_unique_constraints`.
+
+ :param schema: string schema name; if omitted, uses the default schema
+ of the database connection. For special quoting,
+ use :class:`.quoted_name`.
+
+ :param filter_names: optionally return information only for the
+ objects listed here.
+
+ :param kind: a :class:`.ObjectKind` that specifies the type of objects
+ to reflect. Defaults to ``ObjectKind.TABLE``.
+
+ :param scope: a :class:`.ObjectScope` that specifies if constraints of
+ default, temporary or any tables should be reflected.
+ Defaults to ``ObjectScope.DEFAULT``.
+
+ :param \**kw: Additional keyword argument to pass to the dialect
+ specific implementation. See the documentation of the dialect
+ in use for more information.
+
+ :return: a dictionary where the keys are two-tuple schema,table-name
+ and the values are list of dictionaries, each representing the
+ definition of an unique constraint.
+ The schema is ``None`` if no schema is provided.
+
+ .. versionadded:: 2.0
+ """
+
+ with self._operation_context() as conn:
+ return dict(
+ self.dialect.get_multi_unique_constraints(
+ conn,
+ schema=schema,
+ filter_names=filter_names,
+ kind=kind,
+ scope=scope,
+ info_cache=self.info_cache,
+ **kw,
+ )
+ )
+
+ def get_table_comment(
+ self, table_name: str, schema: Optional[str] = None, **kw: Any
+ ) -> ReflectedTableComment:
+ r"""Return information about the table comment for ``table_name``.
Given a string ``table_name`` and an optional string ``schema``,
return table comment information as a dictionary with these keys:
Raises ``NotImplementedError`` for a dialect that does not support
comments.
- .. versionadded:: 1.2
+ :param table_name: string name of the table. For special quoting,
+ use :class:`.quoted_name`.
+
+ :param schema: string schema name; if omitted, uses the default schema
+ of the database connection. For special quoting,
+ use :class:`.quoted_name`.
+
+ :param \**kw: Additional keyword argument to pass to the dialect
+ specific implementation. See the documentation of the dialect
+ in use for more information.
+
+ :return: a dictionary, with the table comment.
+ .. versionadded:: 1.2
"""
with self._operation_context() as conn:
conn, table_name, schema, info_cache=self.info_cache, **kw
)
- def get_check_constraints(self, table_name, schema=None, **kw):
- """Return information about check constraints in `table_name`.
+ def get_multi_table_comment(
+ self,
+ schema: Optional[str] = None,
+ filter_names: Optional[Sequence[str]] = None,
+ kind: ObjectKind = ObjectKind.TABLE,
+ scope: ObjectScope = ObjectScope.DEFAULT,
+ **kw: Any,
+ ) -> Dict[TableKey, ReflectedTableComment]:
+ r"""Return information about the table comment in all objects
+ in the given schema.
+
+ The objects can be filtered by passing the names to use to
+ ``filter_names``.
+
+ The comment information is as described in
+ :meth:`Inspector.get_table_comment`.
+
+ Raises ``NotImplementedError`` for a dialect that does not support
+ comments.
+
+ :param schema: string schema name; if omitted, uses the default schema
+ of the database connection. For special quoting,
+ use :class:`.quoted_name`.
+
+ :param filter_names: optionally return information only for the
+ objects listed here.
+
+ :param kind: a :class:`.ObjectKind` that specifies the type of objects
+ to reflect. Defaults to ``ObjectKind.TABLE``.
+
+ :param scope: a :class:`.ObjectScope` that specifies if comments of
+ default, temporary or any tables should be reflected.
+ Defaults to ``ObjectScope.DEFAULT``.
+
+ :param \**kw: Additional keyword argument to pass to the dialect
+ specific implementation. See the documentation of the dialect
+ in use for more information.
+
+ :return: a dictionary where the keys are two-tuple schema,table-name
+ and the values are dictionaries, representing the
+ table comments.
+ The schema is ``None`` if no schema is provided.
+
+ .. versionadded:: 2.0
+ """
+
+ with self._operation_context() as conn:
+ return dict(
+ self.dialect.get_multi_table_comment(
+ conn,
+ schema=schema,
+ filter_names=filter_names,
+ kind=kind,
+ scope=scope,
+ info_cache=self.info_cache,
+ **kw,
+ )
+ )
+
+ def get_check_constraints(
+ self, table_name: str, schema: Optional[str] = None, **kw: Any
+ ) -> List[ReflectedCheckConstraint]:
+ r"""Return information about check constraints in ``table_name``.
- Given a string `table_name` and an optional string `schema`, return
+ Given a string ``table_name`` and an optional string `schema`, return
check constraint information as a list of dicts with these keys:
* ``name`` -
of the database connection. For special quoting,
use :class:`.quoted_name`.
+ :param \**kw: Additional keyword argument to pass to the dialect
+ specific implementation. See the documentation of the dialect
+ in use for more information.
+
+ :return: a list of dictionaries, each representing the
+ definition of a check constraints.
+
.. versionadded:: 1.1.0
"""
conn, table_name, schema, info_cache=self.info_cache, **kw
)
+ def get_multi_check_constraints(
+ self,
+ schema: Optional[str] = None,
+ filter_names: Optional[Sequence[str]] = None,
+ kind: ObjectKind = ObjectKind.TABLE,
+ scope: ObjectScope = ObjectScope.DEFAULT,
+ **kw: Any,
+ ) -> Dict[TableKey, List[ReflectedCheckConstraint]]:
+ r"""Return information about check constraints in all tables
+ in the given schema.
+
+ The tables can be filtered by passing the names to use to
+ ``filter_names``.
+
+ The check constraint information is as described in
+ :meth:`Inspector.get_check_constraints`.
+
+ :param schema: string schema name; if omitted, uses the default schema
+ of the database connection. For special quoting,
+ use :class:`.quoted_name`.
+
+ :param filter_names: optionally return information only for the
+ objects listed here.
+
+ :param kind: a :class:`.ObjectKind` that specifies the type of objects
+ to reflect. Defaults to ``ObjectKind.TABLE``.
+
+ :param scope: a :class:`.ObjectScope` that specifies if constraints of
+ default, temporary or any tables should be reflected.
+ Defaults to ``ObjectScope.DEFAULT``.
+
+ :param \**kw: Additional keyword argument to pass to the dialect
+ specific implementation. See the documentation of the dialect
+ in use for more information.
+
+ :return: a dictionary where the keys are two-tuple schema,table-name
+ and the values are list of dictionaries, each representing the
+ definition of a check constraints.
+ The schema is ``None`` if no schema is provided.
+
+ .. versionadded:: 2.0
+ """
+
+ with self._operation_context() as conn:
+ return dict(
+ self.dialect.get_multi_check_constraints(
+ conn,
+ schema=schema,
+ filter_names=filter_names,
+ kind=kind,
+ scope=scope,
+ info_cache=self.info_cache,
+ **kw,
+ )
+ )
+
def reflect_table(
self,
- table,
- include_columns,
- exclude_columns=(),
- resolve_fks=True,
- _extend_on=None,
- ):
+ table: sa_schema.Table,
+ include_columns: Optional[Collection[str]],
+ exclude_columns: Collection[str] = (),
+ resolve_fks: bool = True,
+ _extend_on: Optional[Set[sa_schema.Table]] = None,
+ _reflect_info: Optional[_ReflectionInfo] = None,
+ ) -> None:
"""Given a :class:`_schema.Table` object, load its internal
constructs based on introspection.
if k in table.dialect_kwargs
)
+ table_key = (schema, table_name)
+ if _reflect_info is None or table_key not in _reflect_info.columns:
+ _reflect_info = self._get_reflection_info(
+ schema,
+ filter_names=[table_name],
+ kind=ObjectKind.ANY,
+ scope=ObjectScope.ANY,
+ _reflect_info=_reflect_info,
+ **table.dialect_kwargs,
+ )
+ if table_key in _reflect_info.unreflectable:
+ raise _reflect_info.unreflectable[table_key]
+
+ if table_key not in _reflect_info.columns:
+ raise exc.NoSuchTableError(table_name)
+
# reflect table options, like mysql_engine
- tbl_opts = self.get_table_options(
- table_name, schema, **table.dialect_kwargs
- )
- if tbl_opts:
- # add additional kwargs to the Table if the dialect
- # returned them
- table._validate_dialect_kwargs(tbl_opts)
+ if _reflect_info.table_options:
+ tbl_opts = _reflect_info.table_options.get(table_key)
+ if tbl_opts:
+ # add additional kwargs to the Table if the dialect
+ # returned them
+ table._validate_dialect_kwargs(tbl_opts)
found_table = False
- cols_by_orig_name = {}
+ cols_by_orig_name: Dict[str, sa_schema.Column[Any]] = {}
- for col_d in self.get_columns(
- table_name, schema, **table.dialect_kwargs
- ):
+ for col_d in _reflect_info.columns[table_key]:
found_table = True
self._reflect_column(
raise exc.NoSuchTableError(table_name)
self._reflect_pk(
- table_name, schema, table, cols_by_orig_name, exclude_columns
+ _reflect_info, table_key, table, cols_by_orig_name, exclude_columns
)
self._reflect_fk(
- table_name,
- schema,
+ _reflect_info,
+ table_key,
table,
cols_by_orig_name,
include_columns,
)
self._reflect_indexes(
- table_name,
- schema,
+ _reflect_info,
+ table_key,
table,
cols_by_orig_name,
include_columns,
)
self._reflect_unique_constraints(
- table_name,
- schema,
+ _reflect_info,
+ table_key,
table,
cols_by_orig_name,
include_columns,
)
self._reflect_check_constraints(
- table_name,
- schema,
+ _reflect_info,
+ table_key,
table,
cols_by_orig_name,
include_columns,
)
self._reflect_table_comment(
- table_name, schema, table, reflection_options
+ _reflect_info,
+ table_key,
+ table,
+ reflection_options,
)
def _reflect_column(
- self, table, col_d, include_columns, exclude_columns, cols_by_orig_name
- ):
+ self,
+ table: sa_schema.Table,
+ col_d: ReflectedColumn,
+ include_columns: Optional[Collection[str]],
+ exclude_columns: Collection[str],
+ cols_by_orig_name: Dict[str, sa_schema.Column[Any]],
+ ) -> None:
orig_name = col_d["name"]
table.metadata.dispatch.column_reflect(self, table, col_d)
- table.dispatch.column_reflect(self, table, col_d)
+ table.dispatch.column_reflect( # type: ignore[attr-defined]
+ self, table, col_d
+ )
# fetch name again as column_reflect is allowed to
# change it
coltype = col_d["type"]
col_kw = dict(
- (k, col_d[k])
+ (k, col_d[k]) # type: ignore[literal-required]
for k in [
"nullable",
"autoincrement",
col_kw.update(col_d["dialect_options"])
colargs = []
+ default: Any
if col_d.get("default") is not None:
- default = col_d["default"]
- if isinstance(default, sql.elements.TextClause):
- default = sa_schema.DefaultClause(default, _reflected=True)
- elif not isinstance(default, sa_schema.FetchedValue):
+ default_text = col_d["default"]
+ assert default_text is not None
+ if isinstance(default_text, TextClause):
default = sa_schema.DefaultClause(
- sql.text(col_d["default"]), _reflected=True
+ default_text, _reflected=True
)
-
+ elif not isinstance(default_text, sa_schema.FetchedValue):
+ default = sa_schema.DefaultClause(
+ sql.text(default_text), _reflected=True
+ )
+ else:
+ default = default_text
colargs.append(default)
if "computed" in col_d:
colargs.append(computed)
if "identity" in col_d:
- computed = sa_schema.Identity(**col_d["identity"])
- colargs.append(computed)
-
- if "sequence" in col_d:
- self._reflect_col_sequence(col_d, colargs)
+ identity = sa_schema.Identity(**col_d["identity"])
+ colargs.append(identity)
cols_by_orig_name[orig_name] = col = sa_schema.Column(
name, coltype, *colargs, **col_kw
col.primary_key = True
table.append_column(col, replace_existing=True)
- def _reflect_col_sequence(self, col_d, colargs):
- if "sequence" in col_d:
- # TODO: mssql is using this.
- seq = col_d["sequence"]
- sequence = sa_schema.Sequence(seq["name"], 1, 1)
- if "start" in seq:
- sequence.start = seq["start"]
- if "increment" in seq:
- sequence.increment = seq["increment"]
- colargs.append(sequence)
-
def _reflect_pk(
- self, table_name, schema, table, cols_by_orig_name, exclude_columns
- ):
- pk_cons = self.get_pk_constraint(
- table_name, schema, **table.dialect_kwargs
- )
+ self,
+ _reflect_info: _ReflectionInfo,
+ table_key: TableKey,
+ table: sa_schema.Table,
+ cols_by_orig_name: Dict[str, sa_schema.Column[Any]],
+ exclude_columns: Collection[str],
+ ) -> None:
+ pk_cons = _reflect_info.pk_constraint.get(table_key)
if pk_cons:
pk_cols = [
cols_by_orig_name[pk]
def _reflect_fk(
self,
- table_name,
- schema,
- table,
- cols_by_orig_name,
- include_columns,
- exclude_columns,
- resolve_fks,
- _extend_on,
- reflection_options,
- ):
- fkeys = self.get_foreign_keys(
- table_name, schema, **table.dialect_kwargs
- )
+ _reflect_info: _ReflectionInfo,
+ table_key: TableKey,
+ table: sa_schema.Table,
+ cols_by_orig_name: Dict[str, sa_schema.Column[Any]],
+ include_columns: Optional[Collection[str]],
+ exclude_columns: Collection[str],
+ resolve_fks: bool,
+ _extend_on: Optional[Set[sa_schema.Table]],
+ reflection_options: Dict[str, Any],
+ ) -> None:
+ fkeys = _reflect_info.foreign_keys.get(table_key, [])
for fkey_d in fkeys:
conname = fkey_d["name"]
# look for columns by orig name in cols_by_orig_name,
schema=referred_schema,
autoload_with=self.bind,
_extend_on=_extend_on,
+ _reflect_info=_reflect_info,
**reflection_options,
)
for column in referred_columns:
autoload_with=self.bind,
schema=sa_schema.BLANK_SCHEMA,
_extend_on=_extend_on,
+ _reflect_info=_reflect_info,
**reflection_options,
)
for column in referred_columns:
def _reflect_indexes(
self,
- table_name,
- schema,
- table,
- cols_by_orig_name,
- include_columns,
- exclude_columns,
- reflection_options,
- ):
+ _reflect_info: _ReflectionInfo,
+ table_key: TableKey,
+ table: sa_schema.Table,
+ cols_by_orig_name: Dict[str, sa_schema.Column[Any]],
+ include_columns: Optional[Collection[str]],
+ exclude_columns: Collection[str],
+ reflection_options: Dict[str, Any],
+ ) -> None:
# Indexes
- indexes = self.get_indexes(table_name, schema)
+ indexes = _reflect_info.indexes.get(table_key, [])
for index_d in indexes:
name = index_d["name"]
columns = index_d["column_names"]
continue
# look for columns by orig name in cols_by_orig_name,
# but support columns that are in-Python only as fallback
+ idx_col: Any
idx_cols = []
for c in columns:
try:
except KeyError:
util.warn(
"%s key '%s' was not located in "
- "columns for table '%s'" % (flavor, c, table_name)
+ "columns for table '%s'" % (flavor, c, table.name)
)
continue
c_sorting = column_sorting.get(c, ())
def _reflect_unique_constraints(
self,
- table_name,
- schema,
- table,
- cols_by_orig_name,
- include_columns,
- exclude_columns,
- reflection_options,
- ):
-
+ _reflect_info: _ReflectionInfo,
+ table_key: TableKey,
+ table: sa_schema.Table,
+ cols_by_orig_name: Dict[str, sa_schema.Column[Any]],
+ include_columns: Optional[Collection[str]],
+ exclude_columns: Collection[str],
+ reflection_options: Dict[str, Any],
+ ) -> None:
+ constraints = _reflect_info.unique_constraints.get(table_key, [])
# Unique Constraints
- try:
- constraints = self.get_unique_constraints(table_name, schema)
- except NotImplementedError:
- # optional dialect feature
- return
-
for const_d in constraints:
conname = const_d["name"]
columns = const_d["column_names"]
except KeyError:
util.warn(
"unique constraint key '%s' was not located in "
- "columns for table '%s'" % (c, table_name)
+ "columns for table '%s'" % (c, table.name)
)
else:
constrained_cols.append(constrained_col)
def _reflect_check_constraints(
self,
- table_name,
- schema,
- table,
- cols_by_orig_name,
- include_columns,
- exclude_columns,
- reflection_options,
- ):
- try:
- constraints = self.get_check_constraints(table_name, schema)
- except NotImplementedError:
- # optional dialect feature
- return
-
+ _reflect_info: _ReflectionInfo,
+ table_key: TableKey,
+ table: sa_schema.Table,
+ cols_by_orig_name: Dict[str, sa_schema.Column[Any]],
+ include_columns: Optional[Collection[str]],
+ exclude_columns: Collection[str],
+ reflection_options: Dict[str, Any],
+ ) -> None:
+ constraints = _reflect_info.check_constraints.get(table_key, [])
for const_d in constraints:
table.append_constraint(sa_schema.CheckConstraint(**const_d))
def _reflect_table_comment(
- self, table_name, schema, table, reflection_options
- ):
- try:
- comment_dict = self.get_table_comment(table_name, schema)
- except NotImplementedError:
- return
+ self,
+ _reflect_info: _ReflectionInfo,
+ table_key: TableKey,
+ table: sa_schema.Table,
+ reflection_options: Dict[str, Any],
+ ) -> None:
+ comment_dict = _reflect_info.table_comment.get(table_key)
+ if comment_dict:
+ table.comment = comment_dict["text"]
+
+ def _get_reflection_info(
+ self,
+ schema: Optional[str] = None,
+ filter_names: Optional[Collection[str]] = None,
+ available: Optional[Collection[str]] = None,
+ _reflect_info: Optional[_ReflectionInfo] = None,
+ **kw: Any,
+ ) -> _ReflectionInfo:
+ kw["schema"] = schema
+
+ if filter_names and available and len(filter_names) > 100:
+ fraction = len(filter_names) / len(available)
+ else:
+ fraction = None
+
+ unreflectable: Dict[TableKey, exc.UnreflectableTableError]
+ kw["unreflectable"] = unreflectable = {}
+
+ has_result: bool = True
+
+ def run(
+ meth: Any,
+ *,
+ optional: bool = False,
+ check_filter_names_from_meth: bool = False,
+ ) -> Any:
+ nonlocal has_result
+ # simple heuristic to improve reflection performance if a
+ # dialect implements multi_reflection:
+ # if more than 50% of the tables in the db are in filter_names
+ # load all the tables, since it's most likely faster to avoid
+ # a filter on that many tables.
+ if (
+ fraction is None
+ or fraction <= 0.5
+ or not self.dialect._overrides_default(meth.__name__)
+ ):
+ _fn = filter_names
+ else:
+ _fn = None
+ try:
+ if has_result:
+ res = meth(filter_names=_fn, **kw)
+ if check_filter_names_from_meth and not res:
+ # method returned no result data.
+ # skip any future call methods
+ has_result = False
+ else:
+ res = {}
+ except NotImplementedError:
+ if not optional:
+ raise
+ res = {}
+ return res
+
+ info = _ReflectionInfo(
+ columns=run(
+ self.get_multi_columns, check_filter_names_from_meth=True
+ ),
+ pk_constraint=run(self.get_multi_pk_constraint),
+ foreign_keys=run(self.get_multi_foreign_keys),
+ indexes=run(self.get_multi_indexes),
+ unique_constraints=run(
+ self.get_multi_unique_constraints, optional=True
+ ),
+ table_comment=run(self.get_multi_table_comment, optional=True),
+ check_constraints=run(
+ self.get_multi_check_constraints, optional=True
+ ),
+ table_options=run(self.get_multi_table_options, optional=True),
+ unreflectable=unreflectable,
+ )
+ if _reflect_info:
+ _reflect_info.update(info)
+ return _reflect_info
else:
- table.comment = comment_dict.get("text", None)
+ return info
+
+
+@final
+class ReflectionDefaults:
+ """provides blank default values for reflection methods."""
+
+ @classmethod
+ def columns(cls) -> List[ReflectedColumn]:
+ return []
+
+ @classmethod
+ def pk_constraint(cls) -> ReflectedPrimaryKeyConstraint:
+ return { # type: ignore # pep-655 not supported
+ "name": None,
+ "constrained_columns": [],
+ }
+
+ @classmethod
+ def foreign_keys(cls) -> List[ReflectedForeignKeyConstraint]:
+ return []
+
+ @classmethod
+ def indexes(cls) -> List[ReflectedIndex]:
+ return []
+
+ @classmethod
+ def unique_constraints(cls) -> List[ReflectedUniqueConstraint]:
+ return []
+
+ @classmethod
+ def check_constraints(cls) -> List[ReflectedCheckConstraint]:
+ return []
+
+ @classmethod
+ def table_options(cls) -> Dict[str, Any]:
+ return {}
+
+ @classmethod
+ def table_comment(cls) -> ReflectedTableComment:
+ return {"text": None}
+
+
+@dataclass
+class _ReflectionInfo:
+ columns: Dict[TableKey, List[ReflectedColumn]]
+ pk_constraint: Dict[TableKey, Optional[ReflectedPrimaryKeyConstraint]]
+ foreign_keys: Dict[TableKey, List[ReflectedForeignKeyConstraint]]
+ indexes: Dict[TableKey, List[ReflectedIndex]]
+ # optionals
+ unique_constraints: Dict[TableKey, List[ReflectedUniqueConstraint]]
+ table_comment: Dict[TableKey, Optional[ReflectedTableComment]]
+ check_constraints: Dict[TableKey, List[ReflectedCheckConstraint]]
+ table_options: Dict[TableKey, Dict[str, Any]]
+ unreflectable: Dict[TableKey, exc.UnreflectableTableError]
+
+ def update(self, other: _ReflectionInfo) -> None:
+ for k, v in self.__dict__.items():
+ ov = getattr(other, k)
+ if ov is not None:
+ if v is None:
+ setattr(self, k, ov)
+ else:
+ v.update(ov)
util.portable_instancemethod(self._kw_reg_for_dialect_cls)
)
- def _validate_dialect_kwargs(self, kwargs: Any) -> None:
+ def _validate_dialect_kwargs(self, kwargs: Dict[str, Any]) -> None:
# validate remaining kwargs that they all specify DB prefixes
if not kwargs:
import typing
from typing import Any
from typing import Dict
+from typing import Iterable
from typing import Iterator
from typing import List
from typing import MutableMapping
return target_element.params(translate)
+def _ad_hoc_cache_key_from_args(
+ tokens: Tuple[Any, ...],
+ traverse_args: Iterable[Tuple[str, InternalTraversal]],
+ args: Iterable[Any],
+) -> Tuple[Any, ...]:
+ """a quick cache key generator used by reflection.flexi_cache."""
+ bindparams: List[BindParameter[Any]] = []
+
+ _anon_map = anon_map()
+
+ tup = tokens
+
+ for (attrname, sym), arg in zip(traverse_args, args):
+ key = sym.name
+ visit_key = key.replace("dp_", "visit_")
+
+ if arg is None:
+ tup += (attrname, None)
+ continue
+
+ meth = getattr(_cache_key_traversal_visitor, visit_key)
+ if meth is CACHE_IN_PLACE:
+ tup += (attrname, arg)
+ elif meth in (
+ CALL_GEN_CACHE_KEY,
+ STATIC_CACHE_KEY,
+ ANON_NAME,
+ PROPAGATE_ATTRS,
+ ):
+ raise NotImplementedError(
+ f"Haven't implemented symbol {meth} for ad-hoc key from args"
+ )
+ else:
+ tup += meth(attrname, arg, None, _anon_map, bindparams)
+ return tup
+
+
class _CacheKeyTraversal(HasTraversalDispatch):
# very common elements are inlined into the main _get_cache_key() method
# to produce a dramatic savings in Python function call overhead
from typing import Any
from typing import Callable
from typing import cast
+from typing import Collection
from typing import Dict
from typing import Iterable
from typing import Iterator
from ..engine.interfaces import _ExecuteOptionsParameter
from ..engine.interfaces import ExecutionContext
from ..engine.mock import MockConnection
+ from ..engine.reflection import _ReflectionInfo
from ..sql.selectable import FromClause
_T = TypeVar("_T", bound="Any")
keep_existing: bool = False,
extend_existing: bool = False,
resolve_fks: bool = True,
- include_columns: Optional[Iterable[str]] = None,
+ include_columns: Optional[Collection[str]] = None,
implicit_returning: bool = True,
comment: Optional[str] = None,
info: Optional[Dict[Any, Any]] = None,
self.fullname = self.name
self.implicit_returning = implicit_returning
+ _reflect_info = kw.pop("_reflect_info", None)
self.comment = comment
autoload_with,
include_columns,
_extend_on=_extend_on,
+ _reflect_info=_reflect_info,
resolve_fks=resolve_fks,
)
self,
metadata: MetaData,
autoload_with: Union[Engine, Connection],
- include_columns: Optional[Iterable[str]],
- exclude_columns: Iterable[str] = (),
+ include_columns: Optional[Collection[str]],
+ exclude_columns: Collection[str] = (),
resolve_fks: bool = True,
_extend_on: Optional[Set[Table]] = None,
+ _reflect_info: _ReflectionInfo | None = None,
) -> None:
insp = inspection.inspect(autoload_with)
with insp._inspection_context() as conn_insp:
exclude_columns,
resolve_fks,
_extend_on=_extend_on,
+ _reflect_info=_reflect_info,
)
@property
autoload_replace = kwargs.pop("autoload_replace", True)
schema = kwargs.pop("schema", None)
_extend_on = kwargs.pop("_extend_on", None)
+ _reflect_info = kwargs.pop("_reflect_info", None)
# these arguments are only used with _init()
kwargs.pop("extend_existing", False)
kwargs.pop("keep_existing", False)
exclude_columns,
resolve_fks,
_extend_on=_extend_on,
+ _reflect_info=_reflect_info,
)
self._extra_kwargs(**kwargs)
nominvalue: Optional[bool] = None,
nomaxvalue: Optional[bool] = None,
cycle: Optional[bool] = None,
- cache: Optional[bool] = None,
+ cache: Optional[int] = None,
order: Optional[bool] = None,
) -> None:
"""Construct a :class:`.IdentityOptions` object.
sorted(self.tables.values(), key=lambda t: t.key) # type: ignore
)
+ @util.preload_module("sqlalchemy.engine.reflection")
def reflect(
self,
bind: Union[Engine, Connection],
is used, if any.
:param views:
- If True, also reflect views.
+ If True, also reflect views (materialized and plain).
:param only:
Optional. Load only a sub-set of available named tables. May be
"""
with inspection.inspect(bind)._inspection_context() as insp:
- reflect_opts = {
+ reflect_opts: Any = {
"autoload_with": insp,
"extend_existing": extend_existing,
"autoload_replace": autoload_replace,
if schema is not None:
reflect_opts["schema"] = schema
+ kind = util.preloaded.engine_reflection.ObjectKind.TABLE
available: util.OrderedSet[str] = util.OrderedSet(
insp.get_table_names(schema)
)
if views:
+ kind = util.preloaded.engine_reflection.ObjectKind.ANY
available.update(insp.get_view_names(schema))
+ try:
+ available.update(insp.get_materialized_view_names(schema))
+ except NotImplementedError:
+ pass
if schema is not None:
available_w_schema: util.OrderedSet[str] = util.OrderedSet(
- ["%s.%s" % (schema, name) for name in available]
+ [f"{schema}.{name}" for name in available]
)
else:
available_w_schema = available
for name in only
if extend_existing or name not in current
]
+ # pass the available tables so the inspector can
+ # choose to ignore the filter_names
+ _reflect_info = insp._get_reflection_info(
+ schema=schema,
+ filter_names=load,
+ available=available,
+ kind=kind,
+ scope=util.preloaded.engine_reflection.ObjectScope.ANY,
+ **dialect_kwargs,
+ )
+ reflect_opts["_reflect_info"] = _reflect_info
for name in load:
try:
nominvalue: Optional[bool] = None,
nomaxvalue: Optional[bool] = None,
cycle: Optional[bool] = None,
- cache: Optional[bool] = None,
+ cache: Optional[int] = None,
order: Optional[bool] = None,
) -> None:
"""Construct a GENERATED { ALWAYS | BY DEFAULT } AS IDENTITY DDL
class ComparesTables:
- def assert_tables_equal(self, table, reflected_table, strict_types=False):
+ def assert_tables_equal(
+ self,
+ table,
+ reflected_table,
+ strict_types=False,
+ strict_constraints=True,
+ ):
assert len(table.c) == len(reflected_table.c)
for c, reflected_c in zip(table.c, reflected_table.c):
eq_(c.name, reflected_c.name)
assert reflected_c is reflected_table.c[c.name]
- eq_(c.primary_key, reflected_c.primary_key)
- eq_(c.nullable, reflected_c.nullable)
+
+ if strict_constraints:
+ eq_(c.primary_key, reflected_c.primary_key)
+ eq_(c.nullable, reflected_c.nullable)
if strict_types:
msg = "Type '%s' doesn't correspond to type '%s'"
if isinstance(c.type, sqltypes.String):
eq_(c.type.length, reflected_c.type.length)
- eq_(
- {f.column.name for f in c.foreign_keys},
- {f.column.name for f in reflected_c.foreign_keys},
- )
+ if strict_constraints:
+ eq_(
+ {f.column.name for f in c.foreign_keys},
+ {f.column.name for f in reflected_c.foreign_keys},
+ )
if c.server_default:
assert isinstance(
reflected_c.server_default, schema.FetchedValue
)
- assert len(table.primary_key) == len(reflected_table.primary_key)
- for c in table.primary_key:
- assert reflected_table.primary_key.columns[c.name] is not None
+ if strict_constraints:
+ assert len(table.primary_key) == len(reflected_table.primary_key)
+ for c in table.primary_key:
+ assert reflected_table.primary_key.columns[c.name] is not None
def assert_types_base(self, c1, c2):
assert c1.type._compare_type_affinity(
fn._sa_parametrize.append((argnames, pytest_params))
return fn
else:
+ _fn_argnames = inspect.getfullargspec(fn).args[1:]
if argnames is None:
- _argnames = inspect.getfullargspec(fn).args[1:]
+ _argnames = _fn_argnames
else:
_argnames = re.split(r", *", argnames)
if has_exclusions:
- _argnames += ["_exclusions"]
+ existing_exl = sum(
+ 1 for n in _fn_argnames if n.startswith("_exclusions")
+ )
+ current_exclusion_name = f"_exclusions_{existing_exl}"
+ _argnames += [current_exclusion_name]
@_pytest_fn_decorator
def check_exclusions(fn, *args, **kw):
if _exclusions:
exlu = exclusions.compound().add(*_exclusions)
fn = exlu(fn)
- return fn(*args[0:-1], **kw)
-
- def process_metadata(spec):
- spec.args.append("_exclusions")
+ return fn(*args[:-1], **kw)
fn = check_exclusions(
- fn, add_positional_parameters=("_exclusions",)
+ fn, add_positional_parameters=(current_exclusion_name,)
)
return pytest.mark.parametrize(_argnames, pytest_params)(fn)
drop_all_schema_objects_pre_tables(cfg, eng)
+ drop_views(cfg, eng)
+
+ if config.requirements.materialized_views.enabled:
+ drop_materialized_views(cfg, eng)
+
inspector = inspect(eng)
+
+ consider_schemas = (None,)
+ if config.requirements.schemas.enabled_for_config(cfg):
+ consider_schemas += (cfg.test_schema, cfg.test_schema_2)
+ util.drop_all_tables(eng, inspector, consider_schemas=consider_schemas)
+
+ drop_all_schema_objects_post_tables(cfg, eng)
+
+ if config.requirements.sequences.enabled_for_config(cfg):
+ with eng.begin() as conn:
+ for seq in inspector.get_sequence_names():
+ conn.execute(ddl.DropSequence(schema.Sequence(seq)))
+ if config.requirements.schemas.enabled_for_config(cfg):
+ for schema_name in [cfg.test_schema, cfg.test_schema_2]:
+ for seq in inspector.get_sequence_names(
+ schema=schema_name
+ ):
+ conn.execute(
+ ddl.DropSequence(
+ schema.Sequence(seq, schema=schema_name)
+ )
+ )
+
+
+def drop_views(cfg, eng):
+ inspector = inspect(eng)
+
try:
view_names = inspector.get_view_names()
except NotImplementedError:
if config.requirements.schemas.enabled_for_config(cfg):
try:
- view_names = inspector.get_view_names(schema="test_schema")
+ view_names = inspector.get_view_names(schema=cfg.test_schema)
except NotImplementedError:
pass
else:
schema.Table(
vname,
schema.MetaData(),
- schema="test_schema",
+ schema=cfg.test_schema,
)
)
)
- util.drop_all_tables(eng, inspector)
- if config.requirements.schemas.enabled_for_config(cfg):
- util.drop_all_tables(eng, inspector, schema=cfg.test_schema)
- util.drop_all_tables(eng, inspector, schema=cfg.test_schema_2)
- drop_all_schema_objects_post_tables(cfg, eng)
+def drop_materialized_views(cfg, eng):
+ inspector = inspect(eng)
- if config.requirements.sequences.enabled_for_config(cfg):
+ mview_names = inspector.get_materialized_view_names()
+
+ with eng.begin() as conn:
+ for vname in mview_names:
+ conn.exec_driver_sql(f"DROP MATERIALIZED VIEW {vname}")
+
+ if config.requirements.schemas.enabled_for_config(cfg):
+ mview_names = inspector.get_materialized_view_names(
+ schema=cfg.test_schema
+ )
with eng.begin() as conn:
- for seq in inspector.get_sequence_names():
- conn.execute(ddl.DropSequence(schema.Sequence(seq)))
- if config.requirements.schemas.enabled_for_config(cfg):
- for schema_name in [cfg.test_schema, cfg.test_schema_2]:
- for seq in inspector.get_sequence_names(
- schema=schema_name
- ):
- conn.execute(
- ddl.DropSequence(
- schema.Sequence(seq, schema=schema_name)
- )
- )
+ for vname in mview_names:
+ conn.exec_driver_sql(
+ f"DROP MATERIALIZED VIEW {cfg.test_schema}.{vname}"
+ )
@register.init
return exclusions.open()
+ @property
+ def foreign_keys_reflect_as_index(self):
+ """Target database creates an index that's reflected for
+ foreign keys."""
+
+ return exclusions.closed()
+
+ @property
+ def unique_index_reflect_as_unique_constraints(self):
+ """Target database reflects unique indexes as unique constrains."""
+
+ return exclusions.closed()
+
+ @property
+ def unique_constraints_reflect_as_index(self):
+ """Target database reflects unique constraints as indexes."""
+
+ return exclusions.closed()
+
@property
def table_value_constructor(self):
"""Database / dialect supports a query like::
def schema_reflection(self):
return self.schemas
+ @property
+ def schema_create_delete(self):
+ """target database supports schema create and dropped with
+ 'CREATE SCHEMA' and 'DROP SCHEMA'"""
+ return exclusions.closed()
+
@property
def primary_key_constraint_reflection(self):
return exclusions.open()
"""target database supports CREATE INDEX with per-column ASC/DESC."""
return exclusions.open()
+ @property
+ def reflect_indexes_with_ascdesc(self):
+ """target database supports reflecting INDEX with per-column
+ ASC/DESC."""
+ return exclusions.open()
+
@property
def indexes_with_expressions(self):
"""target database supports CREATE INDEX against SQL expressions."""
def json_deserializer_binary(self):
"indicates if the json_deserializer function is called with bytes"
return exclusions.closed()
+
+ @property
+ def reflect_table_options(self):
+ """Target database must support reflecting table_options."""
+ return exclusions.closed()
+
+ @property
+ def materialized_views(self):
+ """Target database must support MATERIALIZED VIEWs."""
+ return exclusions.closed()
+
+ @property
+ def materialized_views_reflect_pk(self):
+ """Target database reflect MATERIALIZED VIEWs pks."""
+ return exclusions.closed()
table_options = {}
-def Table(*args, **kw):
+def Table(*args, **kw) -> schema.Table:
"""A schema.Table wrapper/hook for dialect-specific tweaks."""
test_opts = {k: kw.pop(k) for k in list(kw) if k.startswith("test_")}
return self.target._type_affinity is not other._type_affinity
+class eq_compile_type:
+ """similar to eq_type_affinity but uses compile"""
+
+ def __init__(self, target):
+ self.target = target
+
+ def __eq__(self, other):
+ return self.target == other.compile()
+
+ def __ne__(self, other):
+ return self.target != other.compile()
+
+
class eq_clause_element:
"""Helper to compare SQL structures based on compare()"""
from .. import config
from .. import engines
from .. import eq_
+from .. import expect_raises
+from .. import expect_raises_message
from .. import expect_warnings
from .. import fixtures
from .. import is_
from ... import String
from ... import testing
from ... import types as sql_types
+from ...engine import Inspector
+from ...engine import ObjectKind
+from ...engine import ObjectScope
+from ...exc import NoSuchTableError
+from ...exc import UnreflectableTableError
from ...schema import DDL
from ...schema import Index
from ...sql.elements import quoted_name
from ...sql.schema import BLANK_SCHEMA
+from ...testing import ComparesTables
from ...testing import is_false
from ...testing import is_true
+from ...testing import mock
metadata, users = None, None
is_false(config.db.dialect.has_table(conn, "test_table_s"))
is_false(config.db.dialect.has_table(conn, "nonexistent_table"))
+ def test_has_table_cache(self, metadata):
+ insp = inspect(config.db)
+ is_true(insp.has_table("test_table"))
+ nt = Table("new_table", metadata, Column("col", Integer))
+ is_false(insp.has_table("new_table"))
+ nt.create(config.db)
+ try:
+ is_false(insp.has_table("new_table"))
+ insp.clear_cache()
+ is_true(insp.has_table("new_table"))
+ finally:
+ nt.drop(config.db)
+
@testing.requires.schemas
def test_has_table_schema(self):
with config.db.begin() as conn:
metadata,
Column("id", Integer, primary_key=True),
Column("data", String(50)),
+ Column("data2", String(50)),
)
Index("my_idx", tt.c.data)
)
Index("my_idx_s", tt.c.data)
- def test_has_index(self):
- with config.db.begin() as conn:
- assert config.db.dialect.has_index(conn, "test_table", "my_idx")
- assert not config.db.dialect.has_index(
- conn, "test_table", "my_idx_s"
- )
- assert not config.db.dialect.has_index(
- conn, "nonexistent_table", "my_idx"
- )
- assert not config.db.dialect.has_index(
- conn, "test_table", "nonexistent_idx"
- )
+ kind = testing.combinations("dialect", "inspector", argnames="kind")
+
+ def _has_index(self, kind, conn):
+ if kind == "dialect":
+ return lambda *a, **k: config.db.dialect.has_index(conn, *a, **k)
+ else:
+ return inspect(conn).has_index
+
+ @kind
+ def test_has_index(self, kind, connection, metadata):
+ meth = self._has_index(kind, connection)
+ assert meth("test_table", "my_idx")
+ assert not meth("test_table", "my_idx_s")
+ assert not meth("nonexistent_table", "my_idx")
+ assert not meth("test_table", "nonexistent_idx")
+
+ assert not meth("test_table", "my_idx_2")
+ assert not meth("test_table_2", "my_idx_3")
+ idx = Index("my_idx_2", self.tables.test_table.c.data2)
+ tbl = Table(
+ "test_table_2",
+ metadata,
+ Column("foo", Integer),
+ Index("my_idx_3", "foo"),
+ )
+ idx.create(connection)
+ tbl.create(connection)
+ try:
+ if kind == "inspector":
+ assert not meth("test_table", "my_idx_2")
+ assert not meth("test_table_2", "my_idx_3")
+ meth.__self__.clear_cache()
+ assert meth("test_table", "my_idx_2") is True
+ assert meth("test_table_2", "my_idx_3") is True
+ finally:
+ tbl.drop(connection)
+ idx.drop(connection)
@testing.requires.schemas
- def test_has_index_schema(self):
- with config.db.begin() as conn:
- assert config.db.dialect.has_index(
- conn, "test_table", "my_idx_s", schema=config.test_schema
- )
- assert not config.db.dialect.has_index(
- conn, "test_table", "my_idx", schema=config.test_schema
- )
- assert not config.db.dialect.has_index(
- conn,
- "nonexistent_table",
- "my_idx_s",
- schema=config.test_schema,
- )
- assert not config.db.dialect.has_index(
- conn,
- "test_table",
- "nonexistent_idx_s",
- schema=config.test_schema,
- )
+ @kind
+ def test_has_index_schema(self, kind, connection):
+ meth = self._has_index(kind, connection)
+ assert meth("test_table", "my_idx_s", schema=config.test_schema)
+ assert not meth("test_table", "my_idx", schema=config.test_schema)
+ assert not meth(
+ "nonexistent_table", "my_idx_s", schema=config.test_schema
+ )
+ assert not meth(
+ "test_table", "nonexistent_idx_s", schema=config.test_schema
+ )
class QuotedNameArgumentTest(fixtures.TablesTest):
def test_get_table_options(self, name):
insp = inspect(config.db)
- insp.get_table_options(name)
+ if testing.requires.reflect_table_options.enabled:
+ res = insp.get_table_options(name)
+ is_true(isinstance(res, dict))
+ else:
+ with expect_raises(NotImplementedError):
+ res = insp.get_table_options(name)
@quote_fixtures
@testing.requires.view_column_reflection
assert insp.get_check_constraints(name)
-class ComponentReflectionTest(fixtures.TablesTest):
+def _multi_combination(fn):
+ schema = testing.combinations(
+ None,
+ (
+ lambda: config.test_schema,
+ testing.requires.schemas,
+ ),
+ argnames="schema",
+ )
+ scope = testing.combinations(
+ ObjectScope.DEFAULT,
+ ObjectScope.TEMPORARY,
+ ObjectScope.ANY,
+ argnames="scope",
+ )
+ kind = testing.combinations(
+ ObjectKind.TABLE,
+ ObjectKind.VIEW,
+ ObjectKind.MATERIALIZED_VIEW,
+ ObjectKind.ANY,
+ ObjectKind.ANY_VIEW,
+ ObjectKind.TABLE | ObjectKind.VIEW,
+ ObjectKind.TABLE | ObjectKind.MATERIALIZED_VIEW,
+ argnames="kind",
+ )
+ filter_names = testing.combinations(True, False, argnames="use_filter")
+
+ return schema(scope(kind(filter_names(fn))))
+
+
+class ComponentReflectionTest(ComparesTables, fixtures.TablesTest):
run_inserts = run_deletes = None
__backend__ = True
"%susers.user_id" % schema_prefix, name="user_id_fk"
),
),
+ sa.CheckConstraint("test2 > 0", name="test2_gt_zero"),
schema=schema,
test_needs_fk=True,
)
Column("user_id", sa.INT, primary_key=True),
Column("test1", sa.CHAR(5), nullable=False),
Column("test2", sa.Float(), nullable=False),
+ Column("parent_user_id", sa.Integer),
+ sa.CheckConstraint("test2 > 0", name="test2_gt_zero"),
schema=schema,
test_needs_fk=True,
)
Column(
"address_id",
sa.Integer,
- sa.ForeignKey("%semail_addresses.address_id" % schema_prefix),
+ sa.ForeignKey(
+ "%semail_addresses.address_id" % schema_prefix,
+ name="email_add_id_fg",
+ ),
+ ),
+ Column("data", sa.String(30), unique=True),
+ sa.CheckConstraint(
+ "address_id > 0 AND address_id < 1000",
+ name="address_id_gt_zero",
+ ),
+ sa.UniqueConstraint(
+ "address_id", "dingaling_id", name="zz_dingalings_multiple"
),
- Column("data", sa.String(30)),
schema=schema,
test_needs_fk=True,
)
Column(
"remote_user_id", sa.Integer, sa.ForeignKey(users.c.user_id)
),
- Column("email_address", sa.String(20)),
+ Column("email_address", sa.String(20), index=True),
sa.PrimaryKeyConstraint("address_id", name="email_ad_pk"),
schema=schema,
test_needs_fk=True,
schema=schema,
comment=r"""the test % ' " \ table comment""",
)
+ Table(
+ "no_constraints",
+ metadata,
+ Column("data", sa.String(20)),
+ schema=schema,
+ )
if testing.requires.cross_schema_fk_reflection.enabled:
if schema is None:
)
if testing.requires.index_reflection.enabled:
- cls.define_index(metadata, users)
+ Index("users_t_idx", users.c.test1, users.c.test2, unique=True)
+ Index(
+ "users_all_idx", users.c.user_id, users.c.test2, users.c.test1
+ )
if not schema:
# test_needs_fk is at the moment to force MySQL InnoDB
test_needs_fk=True,
)
- if testing.requires.indexes_with_ascdesc.enabled:
+ if (
+ testing.requires.indexes_with_ascdesc.enabled
+ and testing.requires.reflect_indexes_with_ascdesc.enabled
+ ):
Index("noncol_idx_nopk", noncol_idx_test_nopk.c.q.desc())
Index("noncol_idx_pk", noncol_idx_test_pk.c.q.desc())
if not schema and testing.requires.temp_table_reflection.enabled:
cls.define_temp_tables(metadata)
+ @classmethod
+ def temp_table_name(cls):
+ return get_temp_table_name(
+ config, config.db, f"user_tmp_{config.ident}"
+ )
+
@classmethod
def define_temp_tables(cls, metadata):
kw = temp_table_keyword_args(config, config.db)
- table_name = get_temp_table_name(
- config, config.db, "user_tmp_%s" % config.ident
- )
+ table_name = cls.temp_table_name()
user_tmp = Table(
table_name,
metadata,
# unique constraints created against temp tables in different
# databases.
# https://www.arbinada.com/en/node/1645
- sa.UniqueConstraint("name", name="user_tmp_uq_%s" % config.ident),
+ sa.UniqueConstraint("name", name=f"user_tmp_uq_{config.ident}"),
sa.Index("user_tmp_ix", "foo"),
**kw,
)
)
event.listen(user_tmp, "before_drop", DDL("drop view user_tmp_v"))
- @classmethod
- def define_index(cls, metadata, users):
- Index("users_t_idx", users.c.test1, users.c.test2)
- Index("users_all_idx", users.c.user_id, users.c.test2, users.c.test1)
-
@classmethod
def define_views(cls, metadata, schema):
- for table_name in ("users", "email_addresses"):
+ if testing.requires.materialized_views.enabled:
+ materialized = {"dingalings"}
+ else:
+ materialized = set()
+ for table_name in ("users", "email_addresses", "dingalings"):
fullname = table_name
if schema:
- fullname = "%s.%s" % (schema, table_name)
+ fullname = f"{schema}.{table_name}"
view_name = fullname + "_v"
- query = "CREATE VIEW %s AS SELECT * FROM %s" % (
- view_name,
- fullname,
+ prefix = "MATERIALIZED " if table_name in materialized else ""
+ query = (
+ f"CREATE {prefix}VIEW {view_name} AS SELECT * FROM {fullname}"
)
event.listen(metadata, "after_create", DDL(query))
+ if table_name in materialized:
+ index_name = "mat_index"
+ if schema and testing.against("oracle"):
+ index_name = f"{schema}.{index_name}"
+ idx = f"CREATE INDEX {index_name} ON {view_name}(data)"
+ event.listen(metadata, "after_create", DDL(idx))
event.listen(
- metadata, "before_drop", DDL("DROP VIEW %s" % view_name)
+ metadata, "before_drop", DDL(f"DROP {prefix}VIEW {view_name}")
+ )
+
+ def _resolve_kind(self, kind, tables, views, materialized):
+ res = {}
+ if ObjectKind.TABLE in kind:
+ res.update(tables)
+ if ObjectKind.VIEW in kind:
+ res.update(views)
+ if ObjectKind.MATERIALIZED_VIEW in kind:
+ res.update(materialized)
+ return res
+
+ def _resolve_views(self, views, materialized):
+ if not testing.requires.view_column_reflection.enabled:
+ materialized.clear()
+ views.clear()
+ elif not testing.requires.materialized_views.enabled:
+ views.update(materialized)
+ materialized.clear()
+
+ def _resolve_names(self, schema, scope, filter_names, values):
+ scope_filter = lambda _: True # noqa: E731
+ if scope is ObjectScope.DEFAULT:
+ scope_filter = lambda k: "tmp" not in k[1] # noqa: E731
+ if scope is ObjectScope.TEMPORARY:
+ scope_filter = lambda k: "tmp" in k[1] # noqa: E731
+
+ removed = {
+ None: {"remote_table", "remote_table_2"},
+ testing.config.test_schema: {
+ "local_table",
+ "noncol_idx_test_nopk",
+ "noncol_idx_test_pk",
+ "user_tmp_v",
+ self.temp_table_name(),
+ },
+ }
+ if not testing.requires.cross_schema_fk_reflection.enabled:
+ removed[None].add("local_table")
+ removed[testing.config.test_schema].update(
+ ["remote_table", "remote_table_2"]
+ )
+ if not testing.requires.index_reflection.enabled:
+ removed[None].update(
+ ["noncol_idx_test_nopk", "noncol_idx_test_pk"]
)
+ if (
+ not testing.requires.temp_table_reflection.enabled
+ or not testing.requires.temp_table_names.enabled
+ ):
+ removed[None].update(["user_tmp_v", self.temp_table_name()])
+ if not testing.requires.temporary_views.enabled:
+ removed[None].update(["user_tmp_v"])
+
+ res = {
+ k: v
+ for k, v in values.items()
+ if scope_filter(k)
+ and k[1] not in removed[schema]
+ and (not filter_names or k[1] in filter_names)
+ }
+ return res
+
+ def exp_options(
+ self,
+ schema=None,
+ scope=ObjectScope.ANY,
+ kind=ObjectKind.ANY,
+ filter_names=None,
+ ):
+ materialized = {(schema, "dingalings_v"): mock.ANY}
+ views = {
+ (schema, "email_addresses_v"): mock.ANY,
+ (schema, "users_v"): mock.ANY,
+ (schema, "user_tmp_v"): mock.ANY,
+ }
+ self._resolve_views(views, materialized)
+ tables = {
+ (schema, "users"): mock.ANY,
+ (schema, "dingalings"): mock.ANY,
+ (schema, "email_addresses"): mock.ANY,
+ (schema, "comment_test"): mock.ANY,
+ (schema, "no_constraints"): mock.ANY,
+ (schema, "local_table"): mock.ANY,
+ (schema, "remote_table"): mock.ANY,
+ (schema, "remote_table_2"): mock.ANY,
+ (schema, "noncol_idx_test_nopk"): mock.ANY,
+ (schema, "noncol_idx_test_pk"): mock.ANY,
+ (schema, self.temp_table_name()): mock.ANY,
+ }
+ res = self._resolve_kind(kind, tables, views, materialized)
+ res = self._resolve_names(schema, scope, filter_names, res)
+ return res
+
+ def exp_comments(
+ self,
+ schema=None,
+ scope=ObjectScope.ANY,
+ kind=ObjectKind.ANY,
+ filter_names=None,
+ ):
+ empty = {"text": None}
+ materialized = {(schema, "dingalings_v"): empty}
+ views = {
+ (schema, "email_addresses_v"): empty,
+ (schema, "users_v"): empty,
+ (schema, "user_tmp_v"): empty,
+ }
+ self._resolve_views(views, materialized)
+ tables = {
+ (schema, "users"): empty,
+ (schema, "dingalings"): empty,
+ (schema, "email_addresses"): empty,
+ (schema, "comment_test"): {
+ "text": r"""the test % ' " \ table comment"""
+ },
+ (schema, "no_constraints"): empty,
+ (schema, "local_table"): empty,
+ (schema, "remote_table"): empty,
+ (schema, "remote_table_2"): empty,
+ (schema, "noncol_idx_test_nopk"): empty,
+ (schema, "noncol_idx_test_pk"): empty,
+ (schema, self.temp_table_name()): empty,
+ }
+ res = self._resolve_kind(kind, tables, views, materialized)
+ res = self._resolve_names(schema, scope, filter_names, res)
+ return res
+
+ def exp_columns(
+ self,
+ schema=None,
+ scope=ObjectScope.ANY,
+ kind=ObjectKind.ANY,
+ filter_names=None,
+ ):
+ def col(
+ name, auto=False, default=mock.ANY, comment=None, nullable=True
+ ):
+ res = {
+ "name": name,
+ "autoincrement": auto,
+ "type": mock.ANY,
+ "default": default,
+ "comment": comment,
+ "nullable": nullable,
+ }
+ if auto == "omit":
+ res.pop("autoincrement")
+ return res
+
+ def pk(name, **kw):
+ kw = {"auto": True, "default": mock.ANY, "nullable": False, **kw}
+ return col(name, **kw)
+
+ materialized = {
+ (schema, "dingalings_v"): [
+ col("dingaling_id", auto="omit", nullable=mock.ANY),
+ col("address_id"),
+ col("data"),
+ ]
+ }
+ views = {
+ (schema, "email_addresses_v"): [
+ col("address_id", auto="omit", nullable=mock.ANY),
+ col("remote_user_id"),
+ col("email_address"),
+ ],
+ (schema, "users_v"): [
+ col("user_id", auto="omit", nullable=mock.ANY),
+ col("test1", nullable=mock.ANY),
+ col("test2", nullable=mock.ANY),
+ col("parent_user_id"),
+ ],
+ (schema, "user_tmp_v"): [
+ col("id", auto="omit", nullable=mock.ANY),
+ col("name"),
+ col("foo"),
+ ],
+ }
+ self._resolve_views(views, materialized)
+ tables = {
+ (schema, "users"): [
+ pk("user_id"),
+ col("test1", nullable=False),
+ col("test2", nullable=False),
+ col("parent_user_id"),
+ ],
+ (schema, "dingalings"): [
+ pk("dingaling_id"),
+ col("address_id"),
+ col("data"),
+ ],
+ (schema, "email_addresses"): [
+ pk("address_id"),
+ col("remote_user_id"),
+ col("email_address"),
+ ],
+ (schema, "comment_test"): [
+ pk("id", comment="id comment"),
+ col("data", comment="data % comment"),
+ col(
+ "d2",
+ comment=r"""Comment types type speedily ' " \ '' Fun!""",
+ ),
+ ],
+ (schema, "no_constraints"): [col("data")],
+ (schema, "local_table"): [pk("id"), col("data"), col("remote_id")],
+ (schema, "remote_table"): [pk("id"), col("local_id"), col("data")],
+ (schema, "remote_table_2"): [pk("id"), col("data")],
+ (schema, "noncol_idx_test_nopk"): [col("q")],
+ (schema, "noncol_idx_test_pk"): [pk("id"), col("q")],
+ (schema, self.temp_table_name()): [
+ pk("id"),
+ col("name"),
+ col("foo"),
+ ],
+ }
+ res = self._resolve_kind(kind, tables, views, materialized)
+ res = self._resolve_names(schema, scope, filter_names, res)
+ return res
+
+ @property
+ def _required_column_keys(self):
+ return {"name", "type", "nullable", "default"}
+
+ def exp_pks(
+ self,
+ schema=None,
+ scope=ObjectScope.ANY,
+ kind=ObjectKind.ANY,
+ filter_names=None,
+ ):
+ def pk(*cols, name=mock.ANY):
+ return {"constrained_columns": list(cols), "name": name}
+
+ empty = pk(name=None)
+ if testing.requires.materialized_views_reflect_pk.enabled:
+ materialized = {(schema, "dingalings_v"): pk("dingaling_id")}
+ else:
+ materialized = {(schema, "dingalings_v"): empty}
+ views = {
+ (schema, "email_addresses_v"): empty,
+ (schema, "users_v"): empty,
+ (schema, "user_tmp_v"): empty,
+ }
+ self._resolve_views(views, materialized)
+ tables = {
+ (schema, "users"): pk("user_id"),
+ (schema, "dingalings"): pk("dingaling_id"),
+ (schema, "email_addresses"): pk("address_id", name="email_ad_pk"),
+ (schema, "comment_test"): pk("id"),
+ (schema, "no_constraints"): empty,
+ (schema, "local_table"): pk("id"),
+ (schema, "remote_table"): pk("id"),
+ (schema, "remote_table_2"): pk("id"),
+ (schema, "noncol_idx_test_nopk"): empty,
+ (schema, "noncol_idx_test_pk"): pk("id"),
+ (schema, self.temp_table_name()): pk("id"),
+ }
+ if not testing.requires.reflects_pk_names.enabled:
+ for val in tables.values():
+ if val["name"] is not None:
+ val["name"] = mock.ANY
+ res = self._resolve_kind(kind, tables, views, materialized)
+ res = self._resolve_names(schema, scope, filter_names, res)
+ return res
+
+ @property
+ def _required_pk_keys(self):
+ return {"name", "constrained_columns"}
+
+ def exp_fks(
+ self,
+ schema=None,
+ scope=ObjectScope.ANY,
+ kind=ObjectKind.ANY,
+ filter_names=None,
+ ):
+ class tt:
+ def __eq__(self, other):
+ return (
+ other is None
+ or config.db.dialect.default_schema_name == other
+ )
+
+ def fk(cols, ref_col, ref_table, ref_schema=schema, name=mock.ANY):
+ return {
+ "constrained_columns": cols,
+ "referred_columns": ref_col,
+ "name": name,
+ "options": mock.ANY,
+ "referred_schema": ref_schema
+ if ref_schema is not None
+ else tt(),
+ "referred_table": ref_table,
+ }
+
+ materialized = {(schema, "dingalings_v"): []}
+ views = {
+ (schema, "email_addresses_v"): [],
+ (schema, "users_v"): [],
+ (schema, "user_tmp_v"): [],
+ }
+ self._resolve_views(views, materialized)
+ tables = {
+ (schema, "users"): [
+ fk(["parent_user_id"], ["user_id"], "users", name="user_id_fk")
+ ],
+ (schema, "dingalings"): [
+ fk(
+ ["address_id"],
+ ["address_id"],
+ "email_addresses",
+ name="email_add_id_fg",
+ )
+ ],
+ (schema, "email_addresses"): [
+ fk(["remote_user_id"], ["user_id"], "users")
+ ],
+ (schema, "comment_test"): [],
+ (schema, "no_constraints"): [],
+ (schema, "local_table"): [
+ fk(
+ ["remote_id"],
+ ["id"],
+ "remote_table_2",
+ ref_schema=config.test_schema,
+ )
+ ],
+ (schema, "remote_table"): [
+ fk(["local_id"], ["id"], "local_table", ref_schema=None)
+ ],
+ (schema, "remote_table_2"): [],
+ (schema, "noncol_idx_test_nopk"): [],
+ (schema, "noncol_idx_test_pk"): [],
+ (schema, self.temp_table_name()): [],
+ }
+ if not testing.requires.self_referential_foreign_keys.enabled:
+ tables[(schema, "users")].clear()
+ if not testing.requires.named_constraints.enabled:
+ for vals in tables.values():
+ for val in vals:
+ if val["name"] is not mock.ANY:
+ val["name"] = mock.ANY
+
+ res = self._resolve_kind(kind, tables, views, materialized)
+ res = self._resolve_names(schema, scope, filter_names, res)
+ return res
+
+ @property
+ def _required_fk_keys(self):
+ return {
+ "name",
+ "constrained_columns",
+ "referred_schema",
+ "referred_table",
+ "referred_columns",
+ }
+
+ def exp_indexes(
+ self,
+ schema=None,
+ scope=ObjectScope.ANY,
+ kind=ObjectKind.ANY,
+ filter_names=None,
+ ):
+ def idx(
+ *cols,
+ name,
+ unique=False,
+ column_sorting=None,
+ duplicates=False,
+ fk=False,
+ ):
+ fk_req = testing.requires.foreign_keys_reflect_as_index
+ dup_req = testing.requires.unique_constraints_reflect_as_index
+ if (fk and not fk_req.enabled) or (
+ duplicates and not dup_req.enabled
+ ):
+ return ()
+ res = {
+ "unique": unique,
+ "column_names": list(cols),
+ "name": name,
+ "dialect_options": mock.ANY,
+ "include_columns": [],
+ }
+ if column_sorting:
+ res["column_sorting"] = {"q": ("desc",)}
+ if duplicates:
+ res["duplicates_constraint"] = name
+ return [res]
+
+ materialized = {(schema, "dingalings_v"): []}
+ views = {
+ (schema, "email_addresses_v"): [],
+ (schema, "users_v"): [],
+ (schema, "user_tmp_v"): [],
+ }
+ self._resolve_views(views, materialized)
+ if materialized:
+ materialized[(schema, "dingalings_v")].extend(
+ idx("data", name="mat_index")
+ )
+ tables = {
+ (schema, "users"): [
+ *idx("parent_user_id", name="user_id_fk", fk=True),
+ *idx("user_id", "test2", "test1", name="users_all_idx"),
+ *idx("test1", "test2", name="users_t_idx", unique=True),
+ ],
+ (schema, "dingalings"): [
+ *idx("data", name=mock.ANY, unique=True, duplicates=True),
+ *idx(
+ "address_id",
+ "dingaling_id",
+ name="zz_dingalings_multiple",
+ unique=True,
+ duplicates=True,
+ ),
+ ],
+ (schema, "email_addresses"): [
+ *idx("email_address", name=mock.ANY),
+ *idx("remote_user_id", name=mock.ANY, fk=True),
+ ],
+ (schema, "comment_test"): [],
+ (schema, "no_constraints"): [],
+ (schema, "local_table"): [
+ *idx("remote_id", name=mock.ANY, fk=True)
+ ],
+ (schema, "remote_table"): [
+ *idx("local_id", name=mock.ANY, fk=True)
+ ],
+ (schema, "remote_table_2"): [],
+ (schema, "noncol_idx_test_nopk"): [
+ *idx(
+ "q",
+ name="noncol_idx_nopk",
+ column_sorting={"q": ("desc",)},
+ )
+ ],
+ (schema, "noncol_idx_test_pk"): [
+ *idx(
+ "q", name="noncol_idx_pk", column_sorting={"q": ("desc",)}
+ )
+ ],
+ (schema, self.temp_table_name()): [
+ *idx("foo", name="user_tmp_ix"),
+ *idx(
+ "name",
+ name=f"user_tmp_uq_{config.ident}",
+ duplicates=True,
+ unique=True,
+ ),
+ ],
+ }
+ if (
+ not testing.requires.indexes_with_ascdesc.enabled
+ or not testing.requires.reflect_indexes_with_ascdesc.enabled
+ ):
+ tables[(schema, "noncol_idx_test_nopk")].clear()
+ tables[(schema, "noncol_idx_test_pk")].clear()
+ res = self._resolve_kind(kind, tables, views, materialized)
+ res = self._resolve_names(schema, scope, filter_names, res)
+ return res
+
+ @property
+ def _required_index_keys(self):
+ return {"name", "column_names", "unique"}
+
+ def exp_ucs(
+ self,
+ schema=None,
+ scope=ObjectScope.ANY,
+ kind=ObjectKind.ANY,
+ filter_names=None,
+ all_=False,
+ ):
+ def uc(*cols, name, duplicates_index=None, is_index=False):
+ req = testing.requires.unique_index_reflect_as_unique_constraints
+ if is_index and not req.enabled:
+ return ()
+ res = {
+ "column_names": list(cols),
+ "name": name,
+ }
+ if duplicates_index:
+ res["duplicates_index"] = duplicates_index
+ return [res]
+
+ materialized = {(schema, "dingalings_v"): []}
+ views = {
+ (schema, "email_addresses_v"): [],
+ (schema, "users_v"): [],
+ (schema, "user_tmp_v"): [],
+ }
+ self._resolve_views(views, materialized)
+ tables = {
+ (schema, "users"): [
+ *uc(
+ "test1",
+ "test2",
+ name="users_t_idx",
+ duplicates_index="users_t_idx",
+ is_index=True,
+ )
+ ],
+ (schema, "dingalings"): [
+ *uc("data", name=mock.ANY, duplicates_index=mock.ANY),
+ *uc(
+ "address_id",
+ "dingaling_id",
+ name="zz_dingalings_multiple",
+ duplicates_index="zz_dingalings_multiple",
+ ),
+ ],
+ (schema, "email_addresses"): [],
+ (schema, "comment_test"): [],
+ (schema, "no_constraints"): [],
+ (schema, "local_table"): [],
+ (schema, "remote_table"): [],
+ (schema, "remote_table_2"): [],
+ (schema, "noncol_idx_test_nopk"): [],
+ (schema, "noncol_idx_test_pk"): [],
+ (schema, self.temp_table_name()): [
+ *uc("name", name=f"user_tmp_uq_{config.ident}")
+ ],
+ }
+ if all_:
+ return {**materialized, **views, **tables}
+ else:
+ res = self._resolve_kind(kind, tables, views, materialized)
+ res = self._resolve_names(schema, scope, filter_names, res)
+ return res
+
+ @property
+ def _required_unique_cst_keys(self):
+ return {"name", "column_names"}
+
+ def exp_ccs(
+ self,
+ schema=None,
+ scope=ObjectScope.ANY,
+ kind=ObjectKind.ANY,
+ filter_names=None,
+ ):
+ class tt(str):
+ def __eq__(self, other):
+ res = (
+ other.lower()
+ .replace("(", "")
+ .replace(")", "")
+ .replace("`", "")
+ )
+ return self in res
+
+ def cc(text, name):
+ return {"sqltext": tt(text), "name": name}
+
+ # print({1: "test2 > (0)::double precision"} == {1: tt("test2 > 0")})
+ # assert 0
+ materialized = {(schema, "dingalings_v"): []}
+ views = {
+ (schema, "email_addresses_v"): [],
+ (schema, "users_v"): [],
+ (schema, "user_tmp_v"): [],
+ }
+ self._resolve_views(views, materialized)
+ tables = {
+ (schema, "users"): [cc("test2 > 0", "test2_gt_zero")],
+ (schema, "dingalings"): [
+ cc(
+ "address_id > 0 and address_id < 1000",
+ name="address_id_gt_zero",
+ ),
+ ],
+ (schema, "email_addresses"): [],
+ (schema, "comment_test"): [],
+ (schema, "no_constraints"): [],
+ (schema, "local_table"): [],
+ (schema, "remote_table"): [],
+ (schema, "remote_table_2"): [],
+ (schema, "noncol_idx_test_nopk"): [],
+ (schema, "noncol_idx_test_pk"): [],
+ (schema, self.temp_table_name()): [],
+ }
+ res = self._resolve_kind(kind, tables, views, materialized)
+ res = self._resolve_names(schema, scope, filter_names, res)
+ return res
+
+ @property
+ def _required_cc_keys(self):
+ return {"name", "sqltext"}
@testing.requires.schema_reflection
- def test_get_schema_names(self):
- insp = inspect(self.bind)
+ def test_get_schema_names(self, connection):
+ insp = inspect(connection)
- self.assert_(testing.config.test_schema in insp.get_schema_names())
+ is_true(testing.config.test_schema in insp.get_schema_names())
+
+ @testing.requires.schema_reflection
+ def test_has_schema(self, connection):
+ insp = inspect(connection)
+
+ is_true(insp.has_schema(testing.config.test_schema))
+ is_false(insp.has_schema("sa_fake_schema_foo"))
@testing.requires.schema_reflection
def test_get_schema_names_w_translate_map(self, connection):
)
insp = inspect(connection)
- self.assert_(testing.config.test_schema in insp.get_schema_names())
+ is_true(testing.config.test_schema in insp.get_schema_names())
+
+ @testing.requires.schema_reflection
+ def test_has_schema_w_translate_map(self, connection):
+ connection = connection.execution_options(
+ schema_translate_map={
+ "foo": "bar",
+ BLANK_SCHEMA: testing.config.test_schema,
+ }
+ )
+ insp = inspect(connection)
+
+ is_true(insp.has_schema(testing.config.test_schema))
+ is_false(insp.has_schema("sa_fake_schema_foo"))
+
+ @testing.requires.schema_reflection
+ @testing.requires.schema_create_delete
+ def test_schema_cache(self, connection):
+ insp = inspect(connection)
+
+ is_false("foo_bar" in insp.get_schema_names())
+ is_false(insp.has_schema("foo_bar"))
+ connection.execute(DDL("CREATE SCHEMA foo_bar"))
+ try:
+ is_false("foo_bar" in insp.get_schema_names())
+ is_false(insp.has_schema("foo_bar"))
+ insp.clear_cache()
+ is_true("foo_bar" in insp.get_schema_names())
+ is_true(insp.has_schema("foo_bar"))
+ finally:
+ connection.execute(DDL("DROP SCHEMA foo_bar"))
@testing.requires.schema_reflection
def test_dialect_initialize(self):
assert hasattr(engine.dialect, "default_schema_name")
@testing.requires.schema_reflection
- def test_get_default_schema_name(self):
- insp = inspect(self.bind)
- eq_(insp.default_schema_name, self.bind.dialect.default_schema_name)
+ def test_get_default_schema_name(self, connection):
+ insp = inspect(connection)
+ eq_(insp.default_schema_name, connection.dialect.default_schema_name)
- @testing.requires.foreign_key_constraint_reflection
@testing.combinations(
- (None, True, False, False),
- (None, True, False, True, testing.requires.schemas),
- ("foreign_key", True, False, False),
- (None, False, True, False),
- (None, False, True, True, testing.requires.schemas),
- (None, True, True, False),
- (None, True, True, True, testing.requires.schemas),
- argnames="order_by,include_plain,include_views,use_schema",
+ None,
+ ("foreign_key", testing.requires.foreign_key_constraint_reflection),
+ argnames="order_by",
)
- def test_get_table_names(
- self, connection, order_by, include_plain, include_views, use_schema
- ):
+ @testing.combinations(
+ (True, testing.requires.schemas), False, argnames="use_schema"
+ )
+ def test_get_table_names(self, connection, order_by, use_schema):
if use_schema:
schema = config.test_schema
else:
schema = None
- _ignore_tables = [
+ _ignore_tables = {
"comment_test",
"noncol_idx_test_pk",
"noncol_idx_test_nopk",
"local_table",
"remote_table",
"remote_table_2",
- ]
+ "no_constraints",
+ }
insp = inspect(connection)
- if include_views:
- table_names = insp.get_view_names(schema)
- table_names.sort()
- answer = ["email_addresses_v", "users_v"]
- eq_(sorted(table_names), answer)
+ if order_by:
+ tables = [
+ rec[0]
+ for rec in insp.get_sorted_table_and_fkc_names(schema)
+ if rec[0]
+ ]
+ else:
+ tables = insp.get_table_names(schema)
+ table_names = [t for t in tables if t not in _ignore_tables]
- if include_plain:
- if order_by:
- tables = [
- rec[0]
- for rec in insp.get_sorted_table_and_fkc_names(schema)
- if rec[0]
- ]
- else:
- tables = insp.get_table_names(schema)
- table_names = [t for t in tables if t not in _ignore_tables]
+ if order_by == "foreign_key":
+ answer = ["users", "email_addresses", "dingalings"]
+ eq_(table_names, answer)
+ else:
+ answer = ["dingalings", "email_addresses", "users"]
+ eq_(sorted(table_names), answer)
- if order_by == "foreign_key":
- answer = ["users", "email_addresses", "dingalings"]
- eq_(table_names, answer)
- else:
- answer = ["dingalings", "email_addresses", "users"]
- eq_(sorted(table_names), answer)
+ @testing.combinations(
+ (True, testing.requires.schemas), False, argnames="use_schema"
+ )
+ def test_get_view_names(self, connection, use_schema):
+ insp = inspect(connection)
+ if use_schema:
+ schema = config.test_schema
+ else:
+ schema = None
+ table_names = insp.get_view_names(schema)
+ if testing.requires.materialized_views.enabled:
+ eq_(sorted(table_names), ["email_addresses_v", "users_v"])
+ eq_(insp.get_materialized_view_names(schema), ["dingalings_v"])
+ else:
+ answer = ["dingalings_v", "email_addresses_v", "users_v"]
+ eq_(sorted(table_names), answer)
@testing.requires.temp_table_names
- def test_get_temp_table_names(self):
- insp = inspect(self.bind)
+ def test_get_temp_table_names(self, connection):
+ insp = inspect(connection)
temp_table_names = insp.get_temp_table_names()
- eq_(sorted(temp_table_names), ["user_tmp_%s" % config.ident])
+ eq_(sorted(temp_table_names), [f"user_tmp_{config.ident}"])
@testing.requires.view_reflection
- @testing.requires.temp_table_names
@testing.requires.temporary_views
- def test_get_temp_view_names(self):
- insp = inspect(self.bind)
+ def test_get_temp_view_names(self, connection):
+ insp = inspect(connection)
temp_table_names = insp.get_temp_view_names()
eq_(sorted(temp_table_names), ["user_tmp_v"])
@testing.requires.comment_reflection
- def test_get_comments(self):
- self._test_get_comments()
+ def test_get_comments(self, connection):
+ self._test_get_comments(connection)
@testing.requires.comment_reflection
@testing.requires.schemas
- def test_get_comments_with_schema(self):
- self._test_get_comments(testing.config.test_schema)
-
- def _test_get_comments(self, schema=None):
- insp = inspect(self.bind)
+ def test_get_comments_with_schema(self, connection):
+ self._test_get_comments(connection, testing.config.test_schema)
+ def _test_get_comments(self, connection, schema=None):
+ insp = inspect(connection)
+ exp = self.exp_comments(schema=schema)
eq_(
insp.get_table_comment("comment_test", schema=schema),
- {"text": r"""the test % ' " \ table comment"""},
+ exp[(schema, "comment_test")],
)
- eq_(insp.get_table_comment("users", schema=schema), {"text": None})
+ eq_(
+ insp.get_table_comment("users", schema=schema),
+ exp[(schema, "users")],
+ )
eq_(
- [
- {"name": rec["name"], "comment": rec["comment"]}
- for rec in insp.get_columns("comment_test", schema=schema)
- ],
- [
- {"comment": "id comment", "name": "id"},
- {"comment": "data % comment", "name": "data"},
- {
- "comment": (
- r"""Comment types type speedily ' " \ '' Fun!"""
- ),
- "name": "d2",
- },
- ],
+ insp.get_table_comment("comment_test", schema=schema),
+ exp[(schema, "comment_test")],
+ )
+
+ no_cst = self.tables.no_constraints.name
+ eq_(
+ insp.get_table_comment(no_cst, schema=schema),
+ exp[(schema, no_cst)],
)
@testing.combinations(
users, addresses = (self.tables.users, self.tables.email_addresses)
if use_views:
- table_names = ["users_v", "email_addresses_v"]
+ table_names = ["users_v", "email_addresses_v", "dingalings_v"]
else:
table_names = ["users", "email_addresses"]
for table_name, table in zip(table_names, (users, addresses)):
schema_name = schema
cols = insp.get_columns(table_name, schema=schema_name)
- self.assert_(len(cols) > 0, len(cols))
+ is_true(len(cols) > 0, len(cols))
# should be in order
# assert that the desired type and return type share
# a base within one of the generic types.
- self.assert_(
+ is_true(
len(
set(ctype.__mro__)
.intersection(ctype_def.__mro__)
if not col.primary_key:
assert cols[i]["default"] is None
+ # The case of a table with no column
+ # is tested below in TableNoColumnsTest
+
@testing.requires.temp_table_reflection
- def test_get_temp_table_columns(self):
- table_name = get_temp_table_name(
- config, self.bind, "user_tmp_%s" % config.ident
+ def test_reflect_table_temp_table(self, connection):
+
+ table_name = self.temp_table_name()
+ user_tmp = self.tables[table_name]
+
+ reflected_user_tmp = Table(
+ table_name, MetaData(), autoload_with=connection
)
+ self.assert_tables_equal(
+ user_tmp, reflected_user_tmp, strict_constraints=False
+ )
+
+ @testing.requires.temp_table_reflection
+ def test_get_temp_table_columns(self, connection):
+ table_name = self.temp_table_name()
user_tmp = self.tables[table_name]
- insp = inspect(self.bind)
+ insp = inspect(connection)
cols = insp.get_columns(table_name)
- self.assert_(len(cols) > 0, len(cols))
+ is_true(len(cols) > 0, len(cols))
for i, col in enumerate(user_tmp.columns):
eq_(col.name, cols[i]["name"])
@testing.requires.temp_table_reflection
@testing.requires.view_column_reflection
@testing.requires.temporary_views
- def test_get_temp_view_columns(self):
- insp = inspect(self.bind)
+ def test_get_temp_view_columns(self, connection):
+ insp = inspect(connection)
cols = insp.get_columns("user_tmp_v")
eq_([col["name"] for col in cols], ["id", "name", "foo"])
users, addresses = self.tables.users, self.tables.email_addresses
insp = inspect(connection)
+ exp = self.exp_pks(schema=schema)
users_cons = insp.get_pk_constraint(users.name, schema=schema)
- users_pkeys = users_cons["constrained_columns"]
- eq_(users_pkeys, ["user_id"])
+ self._check_list(
+ [users_cons], [exp[(schema, users.name)]], self._required_pk_keys
+ )
addr_cons = insp.get_pk_constraint(addresses.name, schema=schema)
- addr_pkeys = addr_cons["constrained_columns"]
- eq_(addr_pkeys, ["address_id"])
+ exp_cols = exp[(schema, addresses.name)]["constrained_columns"]
+ eq_(addr_cons["constrained_columns"], exp_cols)
with testing.requires.reflects_pk_names.fail_if():
eq_(addr_cons["name"], "email_ad_pk")
+ no_cst = self.tables.no_constraints.name
+ self._check_list(
+ [insp.get_pk_constraint(no_cst, schema=schema)],
+ [exp[(schema, no_cst)]],
+ self._required_pk_keys,
+ )
+
@testing.combinations(
(False,), (True, testing.requires.schemas), argnames="use_schema"
)
eq_(fkey1["referred_schema"], expected_schema)
eq_(fkey1["referred_table"], users.name)
eq_(fkey1["referred_columns"], ["user_id"])
- if testing.requires.self_referential_foreign_keys.enabled:
- eq_(fkey1["constrained_columns"], ["parent_user_id"])
+ eq_(fkey1["constrained_columns"], ["parent_user_id"])
# addresses
addr_fkeys = insp.get_foreign_keys(addresses.name, schema=schema)
fkey1 = addr_fkeys[0]
with testing.requires.implicitly_named_constraints.fail_if():
- self.assert_(fkey1["name"] is not None)
+ is_true(fkey1["name"] is not None)
eq_(fkey1["referred_schema"], expected_schema)
eq_(fkey1["referred_table"], users.name)
eq_(fkey1["referred_columns"], ["user_id"])
eq_(fkey1["constrained_columns"], ["remote_user_id"])
+ no_cst = self.tables.no_constraints.name
+ eq_(insp.get_foreign_keys(no_cst, schema=schema), [])
+
@testing.requires.cross_schema_fk_reflection
@testing.requires.schemas
- def test_get_inter_schema_foreign_keys(self):
+ def test_get_inter_schema_foreign_keys(self, connection):
local_table, remote_table, remote_table_2 = self.tables(
- "%s.local_table" % self.bind.dialect.default_schema_name,
+ "%s.local_table" % connection.dialect.default_schema_name,
"%s.remote_table" % testing.config.test_schema,
"%s.remote_table_2" % testing.config.test_schema,
)
- insp = inspect(self.bind)
+ insp = inspect(connection)
local_fkeys = insp.get_foreign_keys(local_table.name)
eq_(len(local_fkeys), 1)
fkey2 = remote_fkeys[0]
- assert fkey2["referred_schema"] in (
- None,
- self.bind.dialect.default_schema_name,
+ is_true(
+ fkey2["referred_schema"]
+ in (
+ None,
+ connection.dialect.default_schema_name,
+ )
)
eq_(fkey2["referred_table"], local_table.name)
eq_(fkey2["referred_columns"], ["id"])
eq_(fkey2["constrained_columns"], ["local_id"])
- def _assert_insp_indexes(self, indexes, expected_indexes):
- index_names = [d["name"] for d in indexes]
- for e_index in expected_indexes:
- assert e_index["name"] in index_names
- index = indexes[index_names.index(e_index["name"])]
- for key in e_index:
- eq_(e_index[key], index[key])
-
@testing.combinations(
(False,), (True, testing.requires.schemas), argnames="use_schema"
)
+ @testing.requires.index_reflection
def test_get_indexes(self, connection, use_schema):
if use_schema:
# The database may decide to create indexes for foreign keys, etc.
# so there may be more indexes than expected.
- insp = inspect(self.bind)
+ insp = inspect(connection)
indexes = insp.get_indexes("users", schema=schema)
- expected_indexes = [
- {
- "unique": False,
- "column_names": ["test1", "test2"],
- "name": "users_t_idx",
- },
- {
- "unique": False,
- "column_names": ["user_id", "test2", "test1"],
- "name": "users_all_idx",
- },
- ]
- self._assert_insp_indexes(indexes, expected_indexes)
+ exp = self.exp_indexes(schema=schema)
+ self._check_list(
+ indexes, exp[(schema, "users")], self._required_index_keys
+ )
+
+ no_cst = self.tables.no_constraints.name
+ self._check_list(
+ insp.get_indexes(no_cst, schema=schema),
+ exp[(schema, no_cst)],
+ self._required_index_keys,
+ )
@testing.combinations(
("noncol_idx_test_nopk", "noncol_idx_nopk"),
)
@testing.requires.index_reflection
@testing.requires.indexes_with_ascdesc
+ @testing.requires.reflect_indexes_with_ascdesc
def test_get_noncol_index(self, connection, tname, ixname):
insp = inspect(connection)
indexes = insp.get_indexes(tname)
-
# reflecting an index that has "x DESC" in it as the column.
# the DB may or may not give us "x", but make sure we get the index
# back, it has a name, it's connected to the table.
- expected_indexes = [{"unique": False, "name": ixname}]
- self._assert_insp_indexes(indexes, expected_indexes)
+ expected_indexes = self.exp_indexes()[(None, tname)]
+ self._check_list(indexes, expected_indexes, self._required_index_keys)
t = Table(tname, MetaData(), autoload_with=connection)
eq_(len(t.indexes), 1)
@testing.requires.temp_table_reflection
@testing.requires.unique_constraint_reflection
- def test_get_temp_table_unique_constraints(self):
- insp = inspect(self.bind)
- reflected = insp.get_unique_constraints("user_tmp_%s" % config.ident)
- for refl in reflected:
- # Different dialects handle duplicate index and constraints
- # differently, so ignore this flag
- refl.pop("duplicates_index", None)
- eq_(
- reflected,
- [
- {
- "column_names": ["name"],
- "name": "user_tmp_uq_%s" % config.ident,
- }
- ],
- )
+ def test_get_temp_table_unique_constraints(self, connection):
+ insp = inspect(connection)
+ name = self.temp_table_name()
+ reflected = insp.get_unique_constraints(name)
+ exp = self.exp_ucs(all_=True)[(None, name)]
+ self._check_list(reflected, exp, self._required_index_keys)
@testing.requires.temp_table_reflect_indexes
- def test_get_temp_table_indexes(self):
- insp = inspect(self.bind)
- table_name = get_temp_table_name(
- config, config.db, "user_tmp_%s" % config.ident
- )
+ def test_get_temp_table_indexes(self, connection):
+ insp = inspect(connection)
+ table_name = self.temp_table_name()
indexes = insp.get_indexes(table_name)
for ind in indexes:
ind.pop("dialect_options", None)
)
table.create(connection)
- inspector = inspect(connection)
+ insp = inspect(connection)
reflected = sorted(
- inspector.get_unique_constraints("testtbl", schema=schema),
+ insp.get_unique_constraints("testtbl", schema=schema),
key=operator.itemgetter("name"),
)
eq_(names_that_duplicate_index, idx_names)
eq_(uq_names, set())
+ no_cst = self.tables.no_constraints.name
+ eq_(insp.get_unique_constraints(no_cst, schema=schema), [])
+
@testing.requires.view_reflection
@testing.combinations(
(False,), (True, testing.requires.schemas), argnames="use_schema"
schema = config.test_schema
else:
schema = None
- view_name1 = "users_v"
- view_name2 = "email_addresses_v"
insp = inspect(connection)
- v1 = insp.get_view_definition(view_name1, schema=schema)
- self.assert_(v1)
- v2 = insp.get_view_definition(view_name2, schema=schema)
- self.assert_(v2)
+ for view in ["users_v", "email_addresses_v", "dingalings_v"]:
+ v = insp.get_view_definition(view, schema=schema)
+ is_true(bool(v))
- # why is this here if it's PG specific ?
- @testing.combinations(
- ("users", False),
- ("users", True, testing.requires.schemas),
- argnames="table_name,use_schema",
- )
- @testing.only_on("postgresql", "PG specific feature")
- def test_get_table_oid(self, connection, table_name, use_schema):
- if use_schema:
- schema = config.test_schema
- else:
- schema = None
+ @testing.requires.view_reflection
+ def test_get_view_definition_does_not_exist(self, connection):
insp = inspect(connection)
- oid = insp.get_table_oid(table_name, schema)
- self.assert_(isinstance(oid, int))
+ with expect_raises(NoSuchTableError):
+ insp.get_view_definition("view_does_not_exist")
+ with expect_raises(NoSuchTableError):
+ insp.get_view_definition("users") # a table
@testing.requires.table_reflection
- def test_autoincrement_col(self):
+ def test_autoincrement_col(self, connection):
"""test that 'autoincrement' is reflected according to sqla's policy.
Don't mark this test as unsupported for any backend !
"""
- insp = inspect(self.bind)
+ insp = inspect(connection)
for tname, cname in [
("users", "user_id"),
id_ = {c["name"]: c for c in cols}[cname]
assert id_.get("autoincrement", True)
+ @testing.combinations(
+ (True, testing.requires.schemas), (False,), argnames="use_schema"
+ )
+ def test_get_table_options(self, use_schema):
+ insp = inspect(config.db)
+ schema = config.test_schema if use_schema else None
+
+ if testing.requires.reflect_table_options.enabled:
+ res = insp.get_table_options("users", schema=schema)
+ is_true(isinstance(res, dict))
+ # NOTE: can't really create a table with no option
+ res = insp.get_table_options("no_constraints", schema=schema)
+ is_true(isinstance(res, dict))
+ else:
+ with expect_raises(NotImplementedError):
+ res = insp.get_table_options("users", schema=schema)
+
+ @testing.combinations((True, testing.requires.schemas), False)
+ def test_multi_get_table_options(self, use_schema):
+ insp = inspect(config.db)
+ if testing.requires.reflect_table_options.enabled:
+ schema = config.test_schema if use_schema else None
+ res = insp.get_multi_table_options(schema=schema)
+
+ exp = {
+ (schema, table): insp.get_table_options(table, schema=schema)
+ for table in insp.get_table_names(schema=schema)
+ }
+ eq_(res, exp)
+ else:
+ with expect_raises(NotImplementedError):
+ res = insp.get_multi_table_options()
+
+ @testing.fixture
+ def get_multi_exp(self, connection):
+ def provide_fixture(
+ schema, scope, kind, use_filter, single_reflect_fn, exp_method
+ ):
+ insp = inspect(connection)
+ # call the reflection function at least once to avoid
+ # "Unexpected success" errors if the result is actually empty
+ # and NotImplementedError is not raised
+ single_reflect_fn(insp, "email_addresses")
+ kw = {"scope": scope, "kind": kind}
+ if schema:
+ schema = schema()
+
+ filter_names = []
+
+ if ObjectKind.TABLE in kind:
+ filter_names.extend(
+ ["comment_test", "users", "does-not-exist"]
+ )
+ if ObjectKind.VIEW in kind:
+ filter_names.extend(["email_addresses_v", "does-not-exist"])
+ if ObjectKind.MATERIALIZED_VIEW in kind:
+ filter_names.extend(["dingalings_v", "does-not-exist"])
+
+ if schema:
+ kw["schema"] = schema
+ if use_filter:
+ kw["filter_names"] = filter_names
+
+ exp = exp_method(
+ schema=schema,
+ scope=scope,
+ kind=kind,
+ filter_names=kw.get("filter_names"),
+ )
+ kws = [kw]
+ if scope == ObjectScope.DEFAULT:
+ nkw = kw.copy()
+ nkw.pop("scope")
+ kws.append(nkw)
+ if kind == ObjectKind.TABLE:
+ nkw = kw.copy()
+ nkw.pop("kind")
+ kws.append(nkw)
+
+ return inspect(connection), kws, exp
+
+ return provide_fixture
+
+ @testing.requires.reflect_table_options
+ @_multi_combination
+ def test_multi_get_table_options_tables(
+ self, get_multi_exp, schema, scope, kind, use_filter
+ ):
+ insp, kws, exp = get_multi_exp(
+ schema,
+ scope,
+ kind,
+ use_filter,
+ Inspector.get_table_options,
+ self.exp_options,
+ )
+ for kw in kws:
+ insp.clear_cache()
+ result = insp.get_multi_table_options(**kw)
+ eq_(result, exp)
+
+ @testing.requires.comment_reflection
+ @_multi_combination
+ def test_get_multi_table_comment(
+ self, get_multi_exp, schema, scope, kind, use_filter
+ ):
+ insp, kws, exp = get_multi_exp(
+ schema,
+ scope,
+ kind,
+ use_filter,
+ Inspector.get_table_comment,
+ self.exp_comments,
+ )
+ for kw in kws:
+ insp.clear_cache()
+ eq_(insp.get_multi_table_comment(**kw), exp)
+
+ def _check_list(self, result, exp, req_keys=None, msg=None):
+ if req_keys is None:
+ eq_(result, exp, msg)
+ else:
+ eq_(len(result), len(exp), msg)
+ for r, e in zip(result, exp):
+ for k in set(r) | set(e):
+ if k in req_keys or (k in r and k in e):
+ eq_(r[k], e[k], f"{msg} - {k} - {r}")
+
+ def _check_table_dict(self, result, exp, req_keys=None, make_lists=False):
+ eq_(set(result.keys()), set(exp.keys()))
+ for k in result:
+ r, e = result[k], exp[k]
+ if make_lists:
+ r, e = [r], [e]
+ self._check_list(r, e, req_keys, k)
+
+ @_multi_combination
+ def test_get_multi_columns(
+ self, get_multi_exp, schema, scope, kind, use_filter
+ ):
+ insp, kws, exp = get_multi_exp(
+ schema,
+ scope,
+ kind,
+ use_filter,
+ Inspector.get_columns,
+ self.exp_columns,
+ )
+
+ for kw in kws:
+ insp.clear_cache()
+ result = insp.get_multi_columns(**kw)
+ self._check_table_dict(result, exp, self._required_column_keys)
+
+ @testing.requires.primary_key_constraint_reflection
+ @_multi_combination
+ def test_get_multi_pk_constraint(
+ self, get_multi_exp, schema, scope, kind, use_filter
+ ):
+ insp, kws, exp = get_multi_exp(
+ schema,
+ scope,
+ kind,
+ use_filter,
+ Inspector.get_pk_constraint,
+ self.exp_pks,
+ )
+ for kw in kws:
+ insp.clear_cache()
+ result = insp.get_multi_pk_constraint(**kw)
+ self._check_table_dict(
+ result, exp, self._required_pk_keys, make_lists=True
+ )
+
+ def _adjust_sort(self, result, expected, key):
+ if not testing.requires.implicitly_named_constraints.enabled:
+ for obj in [result, expected]:
+ for val in obj.values():
+ if len(val) > 1 and any(
+ v.get("name") in (None, mock.ANY) for v in val
+ ):
+ val.sort(key=key)
+
+ @testing.requires.foreign_key_constraint_reflection
+ @_multi_combination
+ def test_get_multi_foreign_keys(
+ self, get_multi_exp, schema, scope, kind, use_filter
+ ):
+ insp, kws, exp = get_multi_exp(
+ schema,
+ scope,
+ kind,
+ use_filter,
+ Inspector.get_foreign_keys,
+ self.exp_fks,
+ )
+ for kw in kws:
+ insp.clear_cache()
+ result = insp.get_multi_foreign_keys(**kw)
+ self._adjust_sort(
+ result, exp, lambda d: tuple(d["constrained_columns"])
+ )
+ self._check_table_dict(result, exp, self._required_fk_keys)
+
+ @testing.requires.index_reflection
+ @_multi_combination
+ def test_get_multi_indexes(
+ self, get_multi_exp, schema, scope, kind, use_filter
+ ):
+ insp, kws, exp = get_multi_exp(
+ schema,
+ scope,
+ kind,
+ use_filter,
+ Inspector.get_indexes,
+ self.exp_indexes,
+ )
+ for kw in kws:
+ insp.clear_cache()
+ result = insp.get_multi_indexes(**kw)
+ self._check_table_dict(result, exp, self._required_index_keys)
+
+ @testing.requires.unique_constraint_reflection
+ @_multi_combination
+ def test_get_multi_unique_constraints(
+ self, get_multi_exp, schema, scope, kind, use_filter
+ ):
+ insp, kws, exp = get_multi_exp(
+ schema,
+ scope,
+ kind,
+ use_filter,
+ Inspector.get_unique_constraints,
+ self.exp_ucs,
+ )
+ for kw in kws:
+ insp.clear_cache()
+ result = insp.get_multi_unique_constraints(**kw)
+ self._adjust_sort(result, exp, lambda d: tuple(d["column_names"]))
+ self._check_table_dict(result, exp, self._required_unique_cst_keys)
+
+ @testing.requires.check_constraint_reflection
+ @_multi_combination
+ def test_get_multi_check_constraints(
+ self, get_multi_exp, schema, scope, kind, use_filter
+ ):
+ insp, kws, exp = get_multi_exp(
+ schema,
+ scope,
+ kind,
+ use_filter,
+ Inspector.get_check_constraints,
+ self.exp_ccs,
+ )
+ for kw in kws:
+ insp.clear_cache()
+ result = insp.get_multi_check_constraints(**kw)
+ self._adjust_sort(result, exp, lambda d: tuple(d["sqltext"]))
+ self._check_table_dict(result, exp, self._required_cc_keys)
+
+ @testing.combinations(
+ ("get_table_options", testing.requires.reflect_table_options),
+ "get_columns",
+ (
+ "get_pk_constraint",
+ testing.requires.primary_key_constraint_reflection,
+ ),
+ (
+ "get_foreign_keys",
+ testing.requires.foreign_key_constraint_reflection,
+ ),
+ ("get_indexes", testing.requires.index_reflection),
+ (
+ "get_unique_constraints",
+ testing.requires.unique_constraint_reflection,
+ ),
+ (
+ "get_check_constraints",
+ testing.requires.check_constraint_reflection,
+ ),
+ ("get_table_comment", testing.requires.comment_reflection),
+ argnames="method",
+ )
+ def test_not_existing_table(self, method, connection):
+ insp = inspect(connection)
+ meth = getattr(insp, method)
+ with expect_raises(NoSuchTableError):
+ meth("table_does_not_exists")
+
+ def test_unreflectable(self, connection):
+ mc = Inspector.get_multi_columns
+
+ def patched(*a, **k):
+ ur = k.setdefault("unreflectable", {})
+ ur[(None, "some_table")] = UnreflectableTableError("err")
+ return mc(*a, **k)
+
+ with mock.patch.object(Inspector, "get_multi_columns", patched):
+ with expect_raises_message(UnreflectableTableError, "err"):
+ inspect(connection).reflect_table(
+ Table("some_table", MetaData()), None
+ )
+
+ @testing.combinations(True, False, argnames="use_schema")
+ @testing.combinations(
+ (True, testing.requires.views), False, argnames="views"
+ )
+ def test_metadata(self, connection, use_schema, views):
+ m = MetaData()
+ schema = config.test_schema if use_schema else None
+ m.reflect(connection, schema=schema, views=views, resolve_fks=False)
+
+ insp = inspect(connection)
+ tables = insp.get_table_names(schema)
+ if views:
+ tables += insp.get_view_names(schema)
+ try:
+ tables += insp.get_materialized_view_names(schema)
+ except NotImplementedError:
+ pass
+ if schema:
+ tables = [f"{schema}.{t}" for t in tables]
+ eq_(sorted(m.tables), sorted(tables))
+
class TableNoColumnsTest(fixtures.TestBase):
__requires__ = ("reflect_tables_no_columns",)
@testing.fixture
def view_no_columns(self, connection, metadata):
- Table("empty", metadata)
- metadata.create_all(connection)
-
Table("empty", metadata)
event.listen(
metadata,
)
metadata.create_all(connection)
- @testing.requires.reflect_tables_no_columns
def test_reflect_table_no_columns(self, connection, table_no_columns):
t2 = Table("empty", MetaData(), autoload_with=connection)
eq_(list(t2.c), [])
- @testing.requires.reflect_tables_no_columns
def test_get_columns_table_no_columns(self, connection, table_no_columns):
- eq_(inspect(connection).get_columns("empty"), [])
+ insp = inspect(connection)
+ eq_(insp.get_columns("empty"), [])
+ multi = insp.get_multi_columns()
+ eq_(multi, {(None, "empty"): []})
- @testing.requires.reflect_tables_no_columns
def test_reflect_incl_table_no_columns(self, connection, table_no_columns):
m = MetaData()
m.reflect(connection)
assert set(m.tables).intersection(["empty"])
@testing.requires.views
- @testing.requires.reflect_tables_no_columns
def test_reflect_view_no_columns(self, connection, view_no_columns):
t2 = Table("empty_v", MetaData(), autoload_with=connection)
eq_(list(t2.c), [])
@testing.requires.views
- @testing.requires.reflect_tables_no_columns
def test_get_columns_view_no_columns(self, connection, view_no_columns):
- eq_(inspect(connection).get_columns("empty_v"), [])
+ insp = inspect(connection)
+ eq_(insp.get_columns("empty_v"), [])
+ multi = insp.get_multi_columns(kind=ObjectKind.VIEW)
+ eq_(multi, {(None, "empty_v"): []})
class ComponentReflectionTestExtra(fixtures.TestBase):
),
schema=schema,
)
+ Table(
+ "no_constraints",
+ metadata,
+ Column("data", sa.String(20)),
+ schema=schema,
+ )
metadata.create_all(connection)
- inspector = inspect(connection)
+ insp = inspect(connection)
reflected = sorted(
- inspector.get_check_constraints("sa_cc", schema=schema),
+ insp.get_check_constraints("sa_cc", schema=schema),
key=operator.itemgetter("name"),
)
{"name": "cc1", "sqltext": "a > 1 and a < 5"},
],
)
+ no_cst = "no_constraints"
+ eq_(insp.get_check_constraints(no_cst, schema=schema), [])
@testing.requires.indexes_with_expressions
def test_reflect_expression_based_indexes(self, metadata, connection):
if col["name"] == "normal":
is_false("identity" in col)
elif col["name"] == "id1":
- is_true(col["autoincrement"] in (True, "auto"))
+ if "autoincrement" in col:
+ is_true(col["autoincrement"])
eq_(col["default"], None)
is_true("identity" in col)
self.check(
approx=True,
)
elif col["name"] == "id2":
- is_true(col["autoincrement"] in (True, "auto"))
+ if "autoincrement" in col:
+ is_true(col["autoincrement"])
eq_(col["default"], None)
is_true("identity" in col)
self.check(
if col["name"] == "normal":
is_false("identity" in col)
elif col["name"] == "id1":
- is_true(col["autoincrement"] in (True, "auto"))
+ if "autoincrement" in col:
+ is_true(col["autoincrement"])
eq_(col["default"], None)
is_true("identity" in col)
self.check(
)
@testing.requires.primary_key_constraint_reflection
- def test_pk_column_order(self):
+ def test_pk_column_order(self, connection):
# test for issue #5661
- insp = inspect(self.bind)
+ insp = inspect(connection)
primary_key = insp.get_pk_constraint(self.tables.tb1.name)
eq_(primary_key.get("constrained_columns"), ["name", "id", "attr"])
@testing.requires.foreign_key_constraint_reflection
- def test_fk_column_order(self):
+ def test_fk_column_order(self, connection):
# test for issue #5661
- insp = inspect(self.bind)
+ insp = inspect(connection)
foreign_keys = insp.get_foreign_keys(self.tables.tb2.name)
eq_(len(foreign_keys), 1)
fkey1 = foreign_keys[0]
)
def test_has_sequence(self, connection):
- eq_(
- inspect(connection).has_sequence("user_id_seq"),
- True,
- )
+ eq_(inspect(connection).has_sequence("user_id_seq"), True)
+
+ def test_has_sequence_cache(self, connection, metadata):
+ insp = inspect(connection)
+ eq_(insp.has_sequence("user_id_seq"), True)
+ ss = Sequence("new_seq", metadata=metadata)
+ eq_(insp.has_sequence("new_seq"), False)
+ ss.create(connection)
+ try:
+ eq_(insp.has_sequence("new_seq"), False)
+ insp.clear_cache()
+ eq_(insp.has_sequence("new_seq"), True)
+ finally:
+ ss.drop(connection)
def test_has_sequence_other_object(self, connection):
- eq_(
- inspect(connection).has_sequence("user_id_table"),
- False,
- )
+ eq_(inspect(connection).has_sequence("user_id_table"), False)
@testing.requires.schemas
def test_has_sequence_schema(self, connection):
)
def test_has_sequence_neg(self, connection):
- eq_(
- inspect(connection).has_sequence("some_sequence"),
- False,
- )
+ eq_(inspect(connection).has_sequence("some_sequence"), False)
@testing.requires.schemas
def test_has_sequence_schemas_neg(self, connection):
@testing.requires.schemas
def test_has_sequence_remote_not_in_default(self, connection):
- eq_(
- inspect(connection).has_sequence("schema_seq"),
- False,
- )
+ eq_(inspect(connection).has_sequence("schema_seq"), False)
def test_get_sequence_names(self, connection):
exp = {"other_seq", "user_id_seq"}
go(engine_or_connection)
-def drop_all_tables(engine, inspector, schema=None, include_names=None):
+def drop_all_tables(
+ engine,
+ inspector,
+ schema=None,
+ consider_schemas=(None,),
+ include_names=None,
+):
if include_names is not None:
include_names = set(include_names)
+ if schema is not None:
+ assert consider_schemas == (
+ None,
+ ), "consider_schemas and schema are mutually exclusive"
+ consider_schemas = (schema,)
+
with engine.begin() as conn:
- for tname, fkcs in reversed(
- inspector.get_sorted_table_and_fkc_names(schema=schema)
+ for table_key, fkcs in reversed(
+ inspector.sort_tables_on_foreign_key_dependency(
+ consider_schemas=consider_schemas
+ )
):
- if tname:
- if include_names is not None and tname not in include_names:
+ if table_key:
+ if (
+ include_names is not None
+ and table_key[1] not in include_names
+ ):
continue
conn.execute(
- DropTable(Table(tname, MetaData(), schema=schema))
+ DropTable(
+ Table(table_key[1], MetaData(), schema=table_key[0])
+ )
)
elif fkcs:
if not engine.dialect.supports_alter:
continue
- for tname, fkc in fkcs:
+ for t_key, fkc in fkcs:
if (
include_names is not None
- and tname not in include_names
+ and t_key[1] not in include_names
):
continue
tb = Table(
- tname,
+ t_key[1],
MetaData(),
Column("x", Integer),
Column("y", Integer),
- schema=schema,
+ schema=t_key[0],
)
conn.execute(
DropConstraint(
from __future__ import annotations
from typing import Any
+from typing import Collection
from typing import DefaultDict
from typing import Iterable
from typing import Iterator
def sort_as_subsets(
- tuples: Iterable[Tuple[_T, _T]], allitems: Iterable[_T]
+ tuples: Collection[Tuple[_T, _T]], allitems: Collection[_T]
) -> Iterator[Sequence[_T]]:
edges: DefaultDict[_T, Set[_T]] = util.defaultdict(set)
def sort(
- tuples: Iterable[Tuple[_T, _T]],
- allitems: Iterable[_T],
+ tuples: Collection[Tuple[_T, _T]],
+ allitems: Collection[_T],
deterministic_order: bool = True,
) -> Iterator[_T]:
"""sort the given list of items by dependency.
def find_cycles(
- tuples: Iterable[Tuple[_T, _T]],
- allitems: Iterable[_T],
+ tuples: Iterable[Tuple[_T, _T]], allitems: Iterable[_T]
) -> Set[_T]:
# adapted from:
# https://neopythonic.blogspot.com/2009/01/detecting-cycles-in-directed-graph.html
from typing import Protocol as Protocol
from typing import TypedDict as TypedDict
from typing import Final as Final
+ from typing import final as final
else:
from typing_extensions import Literal as Literal # noqa: F401
from typing_extensions import Protocol as Protocol # noqa: F401
from typing_extensions import TypedDict as TypedDict # noqa: F401
from typing_extensions import Final as Final # noqa: F401
+ from typing_extensions import final as final # noqa: F401
typing_get_args = get_args
typing_get_origin = get_origin
warn_unused_ignores = false
strict = true
-[[tool.mypy.overrides]]
-
-#####################################################################
-# interim list of modules that need some level of type checking to
-# pass
-module = [
-
- "sqlalchemy.engine.reflection",
-
-]
-
-ignore_errors = true
-warn_unused_ignores = false
# create public database link test_link connect to scott identified by tiger
# using 'xe';
oracle_db_link = test_link
+# create public database link test_link2 connect to test_schema identified by tiger
+# using 'xe';
+oracle_db_link2 = test_link2
# host name of a postgres database that has the postgres_fdw extension.
# to create this run:
mariadb_connector = mariadb+mariadbconnector://scott:tiger@127.0.0.1:3306/test
mssql = mssql+pyodbc://scott:tiger^5HHH@mssql2017:1433/test?driver=ODBC+Driver+13+for+SQL+Server
mssql_pymssql = mssql+pymssql://scott:tiger@ms_2008
-docker_mssql = mssql+pymssql://scott:tiger^5HHH@127.0.0.1:1433/test
+docker_mssql = mssql+pyodbc://scott:tiger^5HHH@127.0.0.1:1433/test?driver=ODBC+Driver+17+for+SQL+Server
oracle = oracle+cx_oracle://scott:tiger@oracle18c/xe
cxoracle = oracle+cx_oracle://scott:tiger@oracle18c/xe
-oracle_oracledb = oracle+oracledb://scott:tiger@oracle18c/xe
oracledb = oracle+oracledb://scott:tiger@oracle18c/xe
+docker_oracle = oracle+cx_oracle://scott:tiger@127.0.0.1:1521/?service_name=XEPDB1
\ No newline at end of file
from sqlalchemy.dialects.mysql import base as mysql
from sqlalchemy.dialects.mysql import reflection as _reflection
from sqlalchemy.schema import CreateIndex
-from sqlalchemy.testing import assert_raises_message
from sqlalchemy.testing import AssertsCompiledSQL
from sqlalchemy.testing import eq_
+from sqlalchemy.testing import expect_raises_message
from sqlalchemy.testing import expect_warnings
from sqlalchemy.testing import fixtures
from sqlalchemy.testing import is_
)
def test_skip_not_describable(self, metadata, connection):
+ """This test is the only one that test the _default_multi_reflect
+ behaviour with UnreflectableTableError
+ """
+
@event.listens_for(metadata, "before_drop")
def cleanup(*arg, **kw):
with testing.db.begin() as conn:
m.reflect(views=True, bind=conn)
eq_(m.tables["test_t2"].name, "test_t2")
- assert_raises_message(
- exc.UnreflectableTableError,
- "references invalid table",
- Table,
- "test_v",
- MetaData(),
- autoload_with=conn,
- )
+ with expect_raises_message(
+ exc.UnreflectableTableError, "references invalid table"
+ ):
+ Table("test_v", MetaData(), autoload_with=conn)
@testing.exclude("mysql", "<", (5, 0, 0), "no information_schema support")
def test_system_views(self):
from sqlalchemy.dialects.oracle.base import DOUBLE_PRECISION
from sqlalchemy.dialects.oracle.base import NUMBER
from sqlalchemy.dialects.oracle.base import REAL
+from sqlalchemy.engine import ObjectKind
from sqlalchemy.testing import assert_warns
from sqlalchemy.testing import AssertsCompiledSQL
+from sqlalchemy.testing import config
from sqlalchemy.testing import eq_
+from sqlalchemy.testing import expect_raises
from sqlalchemy.testing import fixtures
from sqlalchemy.testing import is_
from sqlalchemy.testing import is_true
from sqlalchemy.testing.engines import testing_engine
from sqlalchemy.testing.schema import Column
+from sqlalchemy.testing.schema import eq_compile_type
from sqlalchemy.testing.schema import Table
__backend__ = True
def setup_test(self):
+
with testing.db.begin() as conn:
conn.exec_driver_sql("create table my_table (id integer)")
conn.exec_driver_sql(
set(["my_table", "foo_table"]),
)
+ def test_reflect_system_table(self):
+ meta = MetaData()
+ t = Table("foo_table", meta, autoload_with=testing.db)
+ assert t.columns.keys() == ["id"]
+
+ t = Table("my_temp_table", meta, autoload_with=testing.db)
+ assert t.columns.keys() == ["id"]
+
class DontReflectIOTTest(fixtures.TestBase):
"""test that index overflow tables aren't included in
tbl = Table("test_compress", m2, autoload_with=connection)
assert tbl.dialect_options["oracle"]["compress"] == "OLTP"
+ def test_reflect_hidden_column(self):
+ with testing.db.begin() as conn:
+ conn.exec_driver_sql(
+ "CREATE TABLE my_table(id integer, hide integer INVISIBLE)"
+ )
+
+ try:
+ insp = inspect(conn)
+ cols = insp.get_columns("my_table")
+ assert len(cols) == 1
+ assert cols[0]["name"] == "id"
+ finally:
+ conn.exec_driver_sql("DROP TABLE my_table")
+
+
+class ViewReflectionTest(fixtures.TestBase):
+ __only_on__ = "oracle"
+ __backend__ = True
+
+ @classmethod
+ def setup_test_class(cls):
+ sql = """
+ CREATE TABLE tbl (
+ id INTEGER PRIMARY KEY,
+ data INTEGER
+ );
+
+ CREATE VIEW tbl_plain_v AS
+ SELECT id, data FROM tbl WHERE id > 100;
+
+ -- comments on plain views are created with "comment on table"
+ -- because why not..
+ COMMENT ON TABLE tbl_plain_v IS 'view comment';
+
+ CREATE MATERIALIZED VIEW tbl_v AS
+ SELECT id, data FROM tbl WHERE id > 42;
+
+ COMMENT ON MATERIALIZED VIEW tbl_v IS 'my mat view comment';
+
+ CREATE MATERIALIZED VIEW tbl_v2 AS
+ SELECT id, data FROM tbl WHERE id < 42;
+
+ COMMENT ON MATERIALIZED VIEW tbl_v2 IS 'my other mat view comment';
+
+ CREATE SYNONYM view_syn FOR tbl_plain_v;
+ CREATE SYNONYM %(test_schema)s.ts_v_s FOR tbl_plain_v;
+
+ CREATE VIEW %(test_schema)s.schema_view AS
+ SELECT 1 AS value FROM dual;
+
+ COMMENT ON TABLE %(test_schema)s.schema_view IS 'schema view comment';
+ CREATE SYNONYM syn_schema_view FOR %(test_schema)s.schema_view;
+ """
+ if testing.requires.oracle_test_dblink.enabled:
+ cls.dblink = config.file_config.get(
+ "sqla_testing", "oracle_db_link"
+ )
+ sql += """
+ CREATE SYNONYM syn_link FOR tbl_plain_v@%(link)s;
+ """ % {
+ "link": cls.dblink
+ }
+ with testing.db.begin() as conn:
+ for stmt in (
+ sql % {"test_schema": testing.config.test_schema}
+ ).split(";"):
+ if stmt.strip():
+ conn.exec_driver_sql(stmt)
+
+ @classmethod
+ def teardown_test_class(cls):
+ sql = """
+ DROP MATERIALIZED VIEW tbl_v;
+ DROP MATERIALIZED VIEW tbl_v2;
+ DROP VIEW tbl_plain_v;
+ DROP TABLE tbl;
+ DROP VIEW %(test_schema)s.schema_view;
+ DROP SYNONYM view_syn;
+ DROP SYNONYM %(test_schema)s.ts_v_s;
+ DROP SYNONYM syn_schema_view;
+ """
+ if testing.requires.oracle_test_dblink.enabled:
+ sql += """
+ DROP SYNONYM syn_link;
+ """
+ with testing.db.begin() as conn:
+ for stmt in (
+ sql % {"test_schema": testing.config.test_schema}
+ ).split(";"):
+ if stmt.strip():
+ conn.exec_driver_sql(stmt)
+
+ def test_get_names(self, connection):
+ insp = inspect(connection)
+ eq_(insp.get_table_names(), ["tbl"])
+ eq_(insp.get_view_names(), ["tbl_plain_v"])
+ eq_(insp.get_materialized_view_names(), ["tbl_v", "tbl_v2"])
+ eq_(
+ insp.get_view_names(schema=testing.config.test_schema),
+ ["schema_view"],
+ )
+
+ def test_get_table_comment_on_view(self, connection):
+ insp = inspect(connection)
+ eq_(insp.get_table_comment("tbl_v"), {"text": "my mat view comment"})
+ eq_(insp.get_table_comment("tbl_plain_v"), {"text": "view comment"})
+
+ def test_get_multi_view_comment(self, connection):
+ insp = inspect(connection)
+ plain = {(None, "tbl_plain_v"): {"text": "view comment"}}
+ mat = {
+ (None, "tbl_v"): {"text": "my mat view comment"},
+ (None, "tbl_v2"): {"text": "my other mat view comment"},
+ }
+ eq_(insp.get_multi_table_comment(kind=ObjectKind.VIEW), plain)
+ eq_(
+ insp.get_multi_table_comment(kind=ObjectKind.MATERIALIZED_VIEW),
+ mat,
+ )
+ eq_(
+ insp.get_multi_table_comment(kind=ObjectKind.ANY_VIEW),
+ {**plain, **mat},
+ )
+ ts = testing.config.test_schema
+ eq_(
+ insp.get_multi_table_comment(kind=ObjectKind.ANY_VIEW, schema=ts),
+ {(ts, "schema_view"): {"text": "schema view comment"}},
+ )
+ eq_(insp.get_multi_table_comment(), {(None, "tbl"): {"text": None}})
+
+ def test_get_table_comment_synonym(self, connection):
+ insp = inspect(connection)
+ eq_(
+ insp.get_table_comment("view_syn", oracle_resolve_synonyms=True),
+ {"text": "view comment"},
+ )
+ eq_(
+ insp.get_table_comment(
+ "syn_schema_view", oracle_resolve_synonyms=True
+ ),
+ {"text": "schema view comment"},
+ )
+ eq_(
+ insp.get_table_comment(
+ "ts_v_s",
+ oracle_resolve_synonyms=True,
+ schema=testing.config.test_schema,
+ ),
+ {"text": "view comment"},
+ )
+
+ def test_get_multi_view_comment_synonym(self, connection):
+ insp = inspect(connection)
+ exp = {
+ (None, "view_syn"): {"text": "view comment"},
+ (None, "syn_schema_view"): {"text": "schema view comment"},
+ }
+ if testing.requires.oracle_test_dblink.enabled:
+ exp[(None, "syn_link")] = {"text": "view comment"}
+ eq_(
+ insp.get_multi_table_comment(
+ oracle_resolve_synonyms=True, kind=ObjectKind.ANY_VIEW
+ ),
+ exp,
+ )
+ ts = testing.config.test_schema
+ eq_(
+ insp.get_multi_table_comment(
+ oracle_resolve_synonyms=True,
+ schema=ts,
+ kind=ObjectKind.ANY_VIEW,
+ ),
+ {(ts, "ts_v_s"): {"text": "view comment"}},
+ )
+
+ def test_get_view_definition(self, connection):
+ insp = inspect(connection)
+ eq_(
+ insp.get_view_definition("tbl_plain_v"),
+ "SELECT id, data FROM tbl WHERE id > 100",
+ )
+ eq_(
+ insp.get_view_definition("tbl_v"),
+ "SELECT id, data FROM tbl WHERE id > 42",
+ )
+ with expect_raises(exc.NoSuchTableError):
+ eq_(insp.get_view_definition("view_syn"), None)
+ eq_(
+ insp.get_view_definition("view_syn", oracle_resolve_synonyms=True),
+ "SELECT id, data FROM tbl WHERE id > 100",
+ )
+ eq_(
+ insp.get_view_definition(
+ "syn_schema_view", oracle_resolve_synonyms=True
+ ),
+ "SELECT 1 AS value FROM dual",
+ )
+ eq_(
+ insp.get_view_definition(
+ "ts_v_s",
+ oracle_resolve_synonyms=True,
+ schema=testing.config.test_schema,
+ ),
+ "SELECT id, data FROM tbl WHERE id > 100",
+ )
+
+ @testing.requires.oracle_test_dblink
+ def test_get_view_definition_dblink(self, connection):
+ insp = inspect(connection)
+ eq_(
+ insp.get_view_definition("syn_link", oracle_resolve_synonyms=True),
+ "SELECT id, data FROM tbl WHERE id > 100",
+ )
+ eq_(
+ insp.get_view_definition("tbl_plain_v", dblink=self.dblink),
+ "SELECT id, data FROM tbl WHERE id > 100",
+ )
+ eq_(
+ insp.get_view_definition("tbl_v", dblink=self.dblink),
+ "SELECT id, data FROM tbl WHERE id > 42",
+ )
+
class RoundTripIndexTest(fixtures.TestBase):
__only_on__ = "oracle"
@classmethod
def setup_test_class(cls):
- from sqlalchemy.testing import config
-
cls.dblink = config.file_config.get("sqla_testing", "oracle_db_link")
# note that the synonym here is still not totally functional
exp = common.copy()
exp["order"] = True
eq_(col["identity"], exp)
+
+
+class AdditionalReflectionTests(fixtures.TestBase):
+ __only_on__ = "oracle"
+ __backend__ = True
+
+ @classmethod
+ def setup_test_class(cls):
+ # currently assuming full DBA privs for the user.
+ # don't really know how else to go here unless
+ # we connect as the other user.
+
+ sql = """
+CREATE TABLE %(schema)sparent(
+ id INTEGER,
+ data VARCHAR2(50),
+ CONSTRAINT parent_pk_%(schema_id)s PRIMARY KEY (id)
+);
+CREATE TABLE %(schema)smy_table(
+ id INTEGER,
+ name VARCHAR2(125),
+ related INTEGER,
+ data%(schema_id)s NUMBER NOT NULL,
+ CONSTRAINT my_table_pk_%(schema_id)s PRIMARY KEY (id),
+ CONSTRAINT my_table_fk_%(schema_id)s FOREIGN KEY(related)
+ REFERENCES %(schema)sparent(id),
+ CONSTRAINT my_table_check_%(schema_id)s CHECK (data%(schema_id)s > 42),
+ CONSTRAINT data_unique%(schema_id)s UNIQUE (data%(schema_id)s)
+);
+CREATE INDEX my_table_index_%(schema_id)s on %(schema)smy_table (id, name);
+COMMENT ON TABLE %(schema)smy_table IS 'my table comment %(schema_id)s';
+COMMENT ON COLUMN %(schema)smy_table.name IS
+'my table.name comment %(schema_id)s';
+"""
+
+ with testing.db.begin() as conn:
+ for schema in ("", testing.config.test_schema):
+ dd = {
+ "schema": f"{schema}." if schema else "",
+ "schema_id": "sch" if schema else "",
+ }
+ for stmt in (sql % dd).split(";"):
+ if stmt.strip():
+ conn.exec_driver_sql(stmt)
+
+ @classmethod
+ def teardown_test_class(cls):
+ sql = """
+drop table %(schema)smy_table;
+drop table %(schema)sparent;
+"""
+ with testing.db.begin() as conn:
+ for schema in ("", testing.config.test_schema):
+ dd = {"schema": f"{schema}." if schema else ""}
+ for stmt in (sql % dd).split(";"):
+ if stmt.strip():
+ try:
+ conn.exec_driver_sql(stmt)
+ except:
+ pass
+
+ def setup_test(self):
+ self.dblink = config.file_config.get("sqla_testing", "oracle_db_link")
+ self.dblink2 = config.file_config.get(
+ "sqla_testing", "oracle_db_link2"
+ )
+ self.columns = {}
+ self.indexes = {}
+ self.primary_keys = {}
+ self.comments = {}
+ self.uniques = {}
+ self.checks = {}
+ self.foreign_keys = {}
+ self.options = {}
+ self.allDicts = [
+ self.columns,
+ self.indexes,
+ self.primary_keys,
+ self.comments,
+ self.uniques,
+ self.checks,
+ self.foreign_keys,
+ self.options,
+ ]
+ for schema in (None, testing.config.test_schema):
+ suffix = "sch" if schema else ""
+
+ self.columns[schema] = {
+ (schema, "my_table"): [
+ {
+ "name": "id",
+ "nullable": False,
+ "type": eq_compile_type("INTEGER"),
+ "default": None,
+ "comment": None,
+ },
+ {
+ "name": "name",
+ "nullable": True,
+ "type": eq_compile_type("VARCHAR(125)"),
+ "default": None,
+ "comment": f"my table.name comment {suffix}",
+ },
+ {
+ "name": "related",
+ "nullable": True,
+ "type": eq_compile_type("INTEGER"),
+ "default": None,
+ "comment": None,
+ },
+ {
+ "name": f"data{suffix}",
+ "nullable": False,
+ "type": eq_compile_type("NUMBER"),
+ "default": None,
+ "comment": None,
+ },
+ ],
+ (schema, "parent"): [
+ {
+ "name": "id",
+ "nullable": False,
+ "type": eq_compile_type("INTEGER"),
+ "default": None,
+ "comment": None,
+ },
+ {
+ "name": "data",
+ "nullable": True,
+ "type": eq_compile_type("VARCHAR(50)"),
+ "default": None,
+ "comment": None,
+ },
+ ],
+ }
+ self.indexes[schema] = {
+ (schema, "my_table"): [
+ {
+ "name": f"data_unique{suffix}",
+ "column_names": [f"data{suffix}"],
+ "dialect_options": {},
+ "unique": True,
+ },
+ {
+ "name": f"my_table_index_{suffix}",
+ "column_names": ["id", "name"],
+ "dialect_options": {},
+ "unique": False,
+ },
+ ],
+ (schema, "parent"): [],
+ }
+ self.primary_keys[schema] = {
+ (schema, "my_table"): {
+ "name": f"my_table_pk_{suffix}",
+ "constrained_columns": ["id"],
+ },
+ (schema, "parent"): {
+ "name": f"parent_pk_{suffix}",
+ "constrained_columns": ["id"],
+ },
+ }
+ self.comments[schema] = {
+ (schema, "my_table"): {"text": f"my table comment {suffix}"},
+ (schema, "parent"): {"text": None},
+ }
+ self.foreign_keys[schema] = {
+ (schema, "my_table"): [
+ {
+ "name": f"my_table_fk_{suffix}",
+ "constrained_columns": ["related"],
+ "referred_schema": schema,
+ "referred_table": "parent",
+ "referred_columns": ["id"],
+ "options": {},
+ }
+ ],
+ (schema, "parent"): [],
+ }
+ self.checks[schema] = {
+ (schema, "my_table"): [
+ {
+ "name": f"my_table_check_{suffix}",
+ "sqltext": f"data{suffix} > 42",
+ }
+ ],
+ (schema, "parent"): [],
+ }
+ self.uniques[schema] = {
+ (schema, "my_table"): [
+ {
+ "name": f"data_unique{suffix}",
+ "column_names": [f"data{suffix}"],
+ "duplicates_index": f"data_unique{suffix}",
+ }
+ ],
+ (schema, "parent"): [],
+ }
+ self.options[schema] = {
+ (schema, "my_table"): {},
+ (schema, "parent"): {},
+ }
+
+ def test_tables(self, connection):
+ insp = inspect(connection)
+
+ eq_(sorted(insp.get_table_names()), ["my_table", "parent"])
+
+ def _check_reflection(self, conn, schema, res_schema=False, **kw):
+ if res_schema is False:
+ res_schema = schema
+ insp = inspect(conn)
+ eq_(
+ insp.get_multi_columns(schema=schema, **kw),
+ self.columns[res_schema],
+ )
+ eq_(
+ insp.get_multi_indexes(schema=schema, **kw),
+ self.indexes[res_schema],
+ )
+ eq_(
+ insp.get_multi_pk_constraint(schema=schema, **kw),
+ self.primary_keys[res_schema],
+ )
+ eq_(
+ insp.get_multi_table_comment(schema=schema, **kw),
+ self.comments[res_schema],
+ )
+ eq_(
+ insp.get_multi_foreign_keys(schema=schema, **kw),
+ self.foreign_keys[res_schema],
+ )
+ eq_(
+ insp.get_multi_check_constraints(schema=schema, **kw),
+ self.checks[res_schema],
+ )
+ eq_(
+ insp.get_multi_unique_constraints(schema=schema, **kw),
+ self.uniques[res_schema],
+ )
+ eq_(
+ insp.get_multi_table_options(schema=schema, **kw),
+ self.options[res_schema],
+ )
+
+ @testing.combinations(True, False, argnames="schema")
+ def test_schema_translate_map(self, connection, schema):
+ schema = testing.config.test_schema if schema else None
+ c = connection.execution_options(
+ schema_translate_map={
+ None: "foo",
+ testing.config.test_schema: "bar",
+ }
+ )
+ self._check_reflection(c, schema)
+
+ @testing.requires.oracle_test_dblink
+ def test_db_link(self, connection):
+ self._check_reflection(connection, schema=None, dblink=self.dblink)
+ self._check_reflection(
+ connection,
+ schema=testing.config.test_schema,
+ dblink=self.dblink,
+ )
+
+ def test_no_synonyms(self, connection):
+ # oracle_resolve_synonyms is ignored if there are no matching synonym
+ self._check_reflection(
+ connection, schema=None, oracle_resolve_synonyms=True
+ )
+ connection.exec_driver_sql("CREATE SYNONYM tmp FOR parent")
+ for dict_ in self.allDicts:
+ dict_["tmp"] = {(None, "parent"): dict_[None][(None, "parent")]}
+ try:
+ self._check_reflection(
+ connection,
+ schema=None,
+ res_schema="tmp",
+ oracle_resolve_synonyms=True,
+ filter_names=["parent"],
+ )
+ finally:
+ connection.exec_driver_sql("DROP SYNONYM tmp")
+
+ @testing.requires.oracle_test_dblink
+ @testing.requires.oracle_test_dblink2
+ def test_multi_dblink_synonyms(self, connection):
+ # oracle_resolve_synonyms handles multiple dblink at once
+ connection.exec_driver_sql(
+ f"CREATE SYNONYM s1 FOR my_table@{self.dblink}"
+ )
+ connection.exec_driver_sql(
+ f"CREATE SYNONYM s2 FOR {testing.config.test_schema}."
+ f"my_table@{self.dblink2}"
+ )
+ connection.exec_driver_sql("CREATE SYNONYM s3 FOR parent")
+ for dict_ in self.allDicts:
+ dict_["tmp"] = {
+ (None, "s1"): dict_[None][(None, "my_table")],
+ (None, "s2"): dict_[testing.config.test_schema][
+ (testing.config.test_schema, "my_table")
+ ],
+ (None, "s3"): dict_[None][(None, "parent")],
+ }
+ fk = self.foreign_keys["tmp"][(None, "s1")][0]
+ fk["referred_table"] = "s3"
+ try:
+ self._check_reflection(
+ connection,
+ schema=None,
+ res_schema="tmp",
+ oracle_resolve_synonyms=True,
+ )
+ finally:
+ connection.exec_driver_sql("DROP SYNONYM s1")
+ connection.exec_driver_sql("DROP SYNONYM s2")
+ connection.exec_driver_sql("DROP SYNONYM s3")
)
@testing.combinations(
- ((8, 1), False, False),
- ((8, 1), None, False),
- ((11, 5), True, False),
- ((11, 5), False, True),
+ (True, False),
+ (False, True),
)
- def test_backslash_escapes_detection(
- self, version, explicit_setting, expected
- ):
+ def test_backslash_escapes_detection(self, explicit_setting, expected):
engine = engines.testing_engine()
- def _server_version(conn):
- return version
-
if explicit_setting is not None:
@event.listens_for(engine, "connect", insert=True)
)
dbapi_connection.commit()
- with mock.patch.object(
- engine.dialect, "_get_server_version_info", _server_version
- ):
- with engine.connect():
- eq_(engine.dialect._backslash_escapes, expected)
+ with engine.connect():
+ eq_(engine.dialect._backslash_escapes, expected)
def test_dbapi_autocommit_attribute(self):
"""all the supported DBAPIs have an .autocommit attribute. make
from sqlalchemy.dialects.postgresql import ExcludeConstraint
from sqlalchemy.dialects.postgresql import INTEGER
from sqlalchemy.dialects.postgresql import INTERVAL
+from sqlalchemy.dialects.postgresql import pg_catalog
from sqlalchemy.dialects.postgresql import TSRANGE
+from sqlalchemy.engine import ObjectKind
+from sqlalchemy.engine import ObjectScope
from sqlalchemy.schema import CreateIndex
from sqlalchemy.sql.schema import CheckConstraint
from sqlalchemy.testing import AssertsCompiledSQL
from sqlalchemy.testing import fixtures
from sqlalchemy.testing import mock
-from sqlalchemy.testing.assertions import assert_raises
from sqlalchemy.testing.assertions import assert_warns
from sqlalchemy.testing.assertions import AssertsExecutionResults
from sqlalchemy.testing.assertions import eq_
+from sqlalchemy.testing.assertions import expect_raises
from sqlalchemy.testing.assertions import is_
+from sqlalchemy.testing.assertions import is_false
from sqlalchemy.testing.assertions import is_true
connection.execute(target.insert(), {"id": 89, "data": "d1"})
materialized_view = sa.DDL(
- "CREATE MATERIALIZED VIEW test_mview AS " "SELECT * FROM testtable"
+ "CREATE MATERIALIZED VIEW test_mview AS SELECT * FROM testtable"
)
plain_view = sa.DDL(
- "CREATE VIEW test_regview AS " "SELECT * FROM testtable"
+ "CREATE VIEW test_regview AS SELECT data FROM testtable"
)
sa.event.listen(testtable, "after_create", plain_view)
sa.event.listen(testtable, "after_create", materialized_view)
+ sa.event.listen(
+ testtable,
+ "after_create",
+ sa.DDL("COMMENT ON VIEW test_regview IS 'regular view comment'"),
+ )
+ sa.event.listen(
+ testtable,
+ "after_create",
+ sa.DDL(
+ "COMMENT ON MATERIALIZED VIEW test_mview "
+ "IS 'materialized view comment'"
+ ),
+ )
+ sa.event.listen(
+ testtable,
+ "after_create",
+ sa.DDL("CREATE INDEX mat_index ON test_mview(data DESC)"),
+ )
+
sa.event.listen(
testtable,
"before_drop",
testtable, "before_drop", sa.DDL("DROP VIEW test_regview")
)
+ def test_has_type(self, connection):
+ insp = inspect(connection)
+ is_true(insp.has_type("test_mview"))
+ is_true(insp.has_type("test_regview"))
+ is_true(insp.has_type("testtable"))
+
def test_mview_is_reflected(self, connection):
metadata = MetaData()
table = Table("test_mview", metadata, autoload_with=connection)
def test_get_view_names(self, inspect_fixture):
insp, conn = inspect_fixture
- eq_(set(insp.get_view_names()), set(["test_regview", "test_mview"]))
+ eq_(set(insp.get_view_names()), set(["test_regview"]))
- def test_get_view_names_plain(self, connection):
+ def test_get_materialized_view_names(self, inspect_fixture):
+ insp, conn = inspect_fixture
+ eq_(set(insp.get_materialized_view_names()), set(["test_mview"]))
+
+ def test_get_view_names_reflection_cache_ok(self, connection):
insp = inspect(connection)
+ eq_(set(insp.get_view_names()), set(["test_regview"]))
eq_(
- set(insp.get_view_names(include=("plain",))), set(["test_regview"])
+ set(insp.get_materialized_view_names()),
+ set(["test_mview"]),
+ )
+ eq_(
+ set(insp.get_view_names()).union(
+ insp.get_materialized_view_names()
+ ),
+ set(["test_regview", "test_mview"]),
)
- def test_get_view_names_plain_string(self, connection):
+ def test_get_view_definition(self, connection):
insp = inspect(connection)
- eq_(set(insp.get_view_names(include="plain")), set(["test_regview"]))
- def test_get_view_names_materialized(self, connection):
- insp = inspect(connection)
+ def normalize(definition):
+ return re.sub(r"[\n\t ]+", " ", definition.strip())
+
eq_(
- set(insp.get_view_names(include=("materialized",))),
- set(["test_mview"]),
+ normalize(insp.get_view_definition("test_mview")),
+ "SELECT testtable.id, testtable.data FROM testtable;",
+ )
+ eq_(
+ normalize(insp.get_view_definition("test_regview")),
+ "SELECT testtable.data FROM testtable;",
)
- def test_get_view_names_reflection_cache_ok(self, connection):
+ def test_get_view_comment(self, connection):
insp = inspect(connection)
eq_(
- set(insp.get_view_names(include=("plain",))), set(["test_regview"])
+ insp.get_table_comment("test_regview"),
+ {"text": "regular view comment"},
)
eq_(
- set(insp.get_view_names(include=("materialized",))),
- set(["test_mview"]),
+ insp.get_table_comment("test_mview"),
+ {"text": "materialized view comment"},
)
- eq_(set(insp.get_view_names()), set(["test_regview", "test_mview"]))
- def test_get_view_names_empty(self, connection):
+ def test_get_multi_view_comment(self, connection):
insp = inspect(connection)
- assert_raises(ValueError, insp.get_view_names, include=())
+ eq_(
+ insp.get_multi_table_comment(),
+ {(None, "testtable"): {"text": None}},
+ )
+ plain = {(None, "test_regview"): {"text": "regular view comment"}}
+ mat = {(None, "test_mview"): {"text": "materialized view comment"}}
+ eq_(insp.get_multi_table_comment(kind=ObjectKind.VIEW), plain)
+ eq_(
+ insp.get_multi_table_comment(kind=ObjectKind.MATERIALIZED_VIEW),
+ mat,
+ )
+ eq_(
+ insp.get_multi_table_comment(kind=ObjectKind.ANY_VIEW),
+ {**plain, **mat},
+ )
+ eq_(
+ insp.get_multi_table_comment(
+ kind=ObjectKind.ANY_VIEW, scope=ObjectScope.TEMPORARY
+ ),
+ {},
+ )
- def test_get_view_definition(self, connection):
+ def test_get_multi_view_indexes(self, connection):
insp = inspect(connection)
+ eq_(insp.get_multi_indexes(), {(None, "testtable"): []})
+
+ exp = {
+ "name": "mat_index",
+ "unique": False,
+ "column_names": ["data"],
+ "column_sorting": {"data": ("desc",)},
+ }
+ if connection.dialect.server_version_info >= (11, 0):
+ exp["include_columns"] = []
+ exp["dialect_options"] = {"postgresql_include": []}
+ plain = {(None, "test_regview"): []}
+ mat = {(None, "test_mview"): [exp]}
+ eq_(insp.get_multi_indexes(kind=ObjectKind.VIEW), plain)
+ eq_(insp.get_multi_indexes(kind=ObjectKind.MATERIALIZED_VIEW), mat)
+ eq_(insp.get_multi_indexes(kind=ObjectKind.ANY_VIEW), {**plain, **mat})
eq_(
- re.sub(
- r"[\n\t ]+",
- " ",
- insp.get_view_definition("test_mview").strip(),
+ insp.get_multi_indexes(
+ kind=ObjectKind.ANY_VIEW, scope=ObjectScope.TEMPORARY
),
- "SELECT testtable.id, testtable.data FROM testtable;",
+ {},
)
go,
[
"Skipped unsupported reflection of "
- "expression-based index idx1",
+ "expression-based index idx1 of table party",
"Skipped unsupported reflection of "
- "expression-based index idx3",
+ "expression-based index idx3 of table party",
],
)
metadata.create_all(connection)
- ind = connection.dialect.get_indexes(connection, t1, None)
+ ind = connection.dialect.get_indexes(connection, t1.name, None)
partial_definitions = []
for ix in ind:
}
],
)
+ is_true(inspector.has_type("mood", "test_schema"))
+ is_true(inspector.has_type("mood", "*"))
+ is_false(inspector.has_type("mood"))
def test_inspect_enums(self, metadata, inspect_fixture):
enum_type = postgresql.ENUM(
"cat", "dog", "rat", name="pet", metadata=metadata
)
+ enum_type.create(conn)
+ conn.commit()
- with conn.begin():
- enum_type.create(conn)
-
- eq_(
- inspector.get_enums(),
- [
- {
- "visible": True,
- "labels": ["cat", "dog", "rat"],
- "name": "pet",
- "schema": "public",
- }
- ],
- )
-
- def test_get_table_oid(self, metadata, inspect_fixture):
-
- inspector, conn = inspect_fixture
+ res = [
+ {
+ "visible": True,
+ "labels": ["cat", "dog", "rat"],
+ "name": "pet",
+ "schema": "public",
+ }
+ ]
+ eq_(inspector.get_enums(), res)
+ is_true(inspector.has_type("pet", "*"))
+ is_true(inspector.has_type("pet"))
+ is_false(inspector.has_type("pet", "test_schema"))
+
+ enum_type.drop(conn)
+ conn.commit()
+ eq_(inspector.get_enums(), res)
+ is_true(inspector.has_type("pet"))
+ inspector.clear_cache()
+ eq_(inspector.get_enums(), [])
+ is_false(inspector.has_type("pet"))
+
+ def test_get_table_oid(self, metadata, connection):
+ Table("t1", metadata, Column("col", Integer))
+ Table("t1", metadata, Column("col", Integer), schema="test_schema")
+ metadata.create_all(connection)
+ insp = inspect(connection)
+ oid = insp.get_table_oid("t1")
+ oid_schema = insp.get_table_oid("t1", schema="test_schema")
+ is_true(isinstance(oid, int))
+ is_true(isinstance(oid_schema, int))
+ is_true(oid != oid_schema)
- with conn.begin():
- Table("some_table", metadata, Column("q", Integer)).create(conn)
+ with expect_raises(exc.NoSuchTableError):
+ insp.get_table_oid("does_not_exist")
- assert inspector.get_table_oid("some_table") is not None
+ metadata.tables["t1"].drop(connection)
+ eq_(insp.get_table_oid("t1"), oid)
+ insp.clear_cache()
+ with expect_raises(exc.NoSuchTableError):
+ insp.get_table_oid("t1")
def test_inspect_enums_case_sensitive(self, metadata, connection):
sa.event.listen(
)
def test_reflect_check_warning(self):
- rows = [("some name", "NOTCHECK foobar")]
+ rows = [("foo", "some name", "NOTCHECK foobar")]
conn = mock.Mock(
execute=lambda *arg, **kw: mock.MagicMock(
fetchall=lambda: rows, __iter__=lambda self: iter(rows)
)
)
- with mock.patch.object(
- testing.db.dialect, "get_table_oid", lambda *arg, **kw: 1
+ with testing.expect_warnings(
+ "Could not parse CHECK constraint text: 'NOTCHECK foobar'"
):
- with testing.expect_warnings(
- "Could not parse CHECK constraint text: 'NOTCHECK foobar'"
- ):
- testing.db.dialect.get_check_constraints(conn, "foo")
+ testing.db.dialect.get_check_constraints(conn, "foo")
def test_reflect_extra_newlines(self):
rows = [
- ("some name", "CHECK (\n(a \nIS\n NOT\n\n NULL\n)\n)"),
- ("some other name", "CHECK ((b\nIS\nNOT\nNULL))"),
- ("some CRLF name", "CHECK ((c\r\n\r\nIS\r\nNOT\r\nNULL))"),
- ("some name", "CHECK (c != 'hi\nim a name\n')"),
+ ("foo", "some name", "CHECK (\n(a \nIS\n NOT\n\n NULL\n)\n)"),
+ ("foo", "some other name", "CHECK ((b\nIS\nNOT\nNULL))"),
+ ("foo", "some CRLF name", "CHECK ((c\r\n\r\nIS\r\nNOT\r\nNULL))"),
+ ("foo", "some name", "CHECK (c != 'hi\nim a name\n')"),
]
conn = mock.Mock(
execute=lambda *arg, **kw: mock.MagicMock(
fetchall=lambda: rows, __iter__=lambda self: iter(rows)
)
)
- with mock.patch.object(
- testing.db.dialect, "get_table_oid", lambda *arg, **kw: 1
- ):
- check_constraints = testing.db.dialect.get_check_constraints(
- conn, "foo"
- )
- eq_(
- check_constraints,
- [
- {
- "name": "some name",
- "sqltext": "a \nIS\n NOT\n\n NULL\n",
- },
- {"name": "some other name", "sqltext": "b\nIS\nNOT\nNULL"},
- {
- "name": "some CRLF name",
- "sqltext": "c\r\n\r\nIS\r\nNOT\r\nNULL",
- },
- {"name": "some name", "sqltext": "c != 'hi\nim a name\n'"},
- ],
- )
+ check_constraints = testing.db.dialect.get_check_constraints(
+ conn, "foo"
+ )
+ eq_(
+ check_constraints,
+ [
+ {
+ "name": "some name",
+ "sqltext": "a \nIS\n NOT\n\n NULL\n",
+ },
+ {"name": "some other name", "sqltext": "b\nIS\nNOT\nNULL"},
+ {
+ "name": "some CRLF name",
+ "sqltext": "c\r\n\r\nIS\r\nNOT\r\nNULL",
+ },
+ {"name": "some name", "sqltext": "c != 'hi\nim a name\n'"},
+ ],
+ )
def test_reflect_with_not_valid_check_constraint(self):
- rows = [("some name", "CHECK ((a IS NOT NULL)) NOT VALID")]
+ rows = [("foo", "some name", "CHECK ((a IS NOT NULL)) NOT VALID")]
conn = mock.Mock(
execute=lambda *arg, **kw: mock.MagicMock(
fetchall=lambda: rows, __iter__=lambda self: iter(rows)
)
)
- with mock.patch.object(
- testing.db.dialect, "get_table_oid", lambda *arg, **kw: 1
- ):
- check_constraints = testing.db.dialect.get_check_constraints(
- conn, "foo"
+ check_constraints = testing.db.dialect.get_check_constraints(
+ conn, "foo"
+ )
+ eq_(
+ check_constraints,
+ [
+ {
+ "name": "some name",
+ "sqltext": "a IS NOT NULL",
+ "dialect_options": {"not_valid": True},
+ }
+ ],
+ )
+
+ def _apply_stm(self, connection, use_map):
+ if use_map:
+ return connection.execution_options(
+ schema_translate_map={
+ None: "foo",
+ testing.config.test_schema: "bar",
+ }
)
- eq_(
- check_constraints,
- [
- {
- "name": "some name",
- "sqltext": "a IS NOT NULL",
- "dialect_options": {"not_valid": True},
- }
- ],
+ else:
+ return connection
+
+ @testing.combinations(True, False, argnames="use_map")
+ @testing.combinations(True, False, argnames="schema")
+ def test_schema_translate_map(self, metadata, connection, use_map, schema):
+ schema = testing.config.test_schema if schema else None
+ Table(
+ "foo",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("a", Integer, index=True),
+ Column(
+ "b",
+ ForeignKey(f"{schema}.foo.id" if schema else "foo.id"),
+ unique=True,
+ ),
+ CheckConstraint("a>10", name="foo_check"),
+ comment="comm",
+ schema=schema,
+ )
+ metadata.create_all(connection)
+ if use_map:
+ connection = connection.execution_options(
+ schema_translate_map={
+ None: "foo",
+ testing.config.test_schema: "bar",
+ }
)
+ insp = inspect(connection)
+ eq_(
+ [c["name"] for c in insp.get_columns("foo", schema=schema)],
+ ["id", "a", "b"],
+ )
+ eq_(
+ [
+ i["column_names"]
+ for i in insp.get_indexes("foo", schema=schema)
+ ],
+ [["b"], ["a"]],
+ )
+ eq_(
+ insp.get_pk_constraint("foo", schema=schema)[
+ "constrained_columns"
+ ],
+ ["id"],
+ )
+ eq_(insp.get_table_comment("foo", schema=schema), {"text": "comm"})
+ eq_(
+ [
+ f["constrained_columns"]
+ for f in insp.get_foreign_keys("foo", schema=schema)
+ ],
+ [["b"]],
+ )
+ eq_(
+ [
+ c["name"]
+ for c in insp.get_check_constraints("foo", schema=schema)
+ ],
+ ["foo_check"],
+ )
+ eq_(
+ [
+ u["column_names"]
+ for u in insp.get_unique_constraints("foo", schema=schema)
+ ],
+ [["b"]],
+ )
class CustomTypeReflectionTest(fixtures.TestBase):
("my_custom_type(ARG1)", ("ARG1", None)),
("my_custom_type(ARG1, ARG2)", ("ARG1", "ARG2")),
]:
- column_info = dialect._get_column_info(
- "colname", sch, None, False, {}, {}, "public", None, "", None
+ row_dict = {
+ "name": "colname",
+ "table_name": "tblname",
+ "format_type": sch,
+ "default": None,
+ "not_null": False,
+ "comment": None,
+ "generated": "",
+ "identity_options": None,
+ }
+ column_info = dialect._get_columns_info(
+ [row_dict], {}, {}, "public"
)
+ assert ("public", "tblname") in column_info
+ column_info = column_info[("public", "tblname")]
+ assert len(column_info) == 1
+ column_info = column_info[0]
assert isinstance(column_info["type"], self.CustomType)
eq_(column_info["type"].arg1, args[0])
eq_(column_info["type"].arg2, args[1])
exp = default.copy()
exp.update(maxvalue=2**15 - 1)
eq_(col["identity"], exp)
+
+
+class TestReflectDifficultColTypes(fixtures.TablesTest):
+ __only_on__ = "postgresql"
+ __backend__ = True
+
+ def define_tables(metadata):
+ Table(
+ "sample_table",
+ metadata,
+ Column("c1", Integer, primary_key=True),
+ Column("c2", Integer, unique=True),
+ Column("c3", Integer),
+ Index("sample_table_index", "c2", "c3"),
+ )
+
+ def check_int_list(self, row, key):
+ value = row[key]
+ is_true(isinstance(value, list))
+ is_true(len(value) > 0)
+ is_true(all(isinstance(v, int) for v in value))
+
+ def test_pg_index(self, connection):
+ insp = inspect(connection)
+
+ pgc_oid = insp.get_table_oid("sample_table")
+ cols = [
+ col
+ for col in pg_catalog.pg_index.c
+ if testing.db.dialect.server_version_info
+ >= col.info.get("server_version", (0,))
+ ]
+
+ stmt = sa.select(*cols).filter_by(indrelid=pgc_oid)
+ rows = connection.execute(stmt).mappings().all()
+ is_true(len(rows) > 0)
+ cols = [
+ col
+ for col in ["indkey", "indoption", "indclass", "indcollation"]
+ if testing.db.dialect.server_version_info
+ >= pg_catalog.pg_index.c[col].info.get("server_version", (0,))
+ ]
+ for row in rows:
+ for col in cols:
+ self.check_int_list(row, col)
+
+ def test_pg_constraint(self, connection):
+ insp = inspect(connection)
+
+ pgc_oid = insp.get_table_oid("sample_table")
+ cols = [
+ col
+ for col in pg_catalog.pg_constraint.c
+ if testing.db.dialect.server_version_info
+ >= col.info.get("server_version", (0,))
+ ]
+ stmt = sa.select(*cols).filter_by(conrelid=pgc_oid)
+ rows = connection.execute(stmt).mappings().all()
+ is_true(len(rows) > 0)
+ for row in rows:
+ self.check_int_list(row, "conkey")
asserter.assert_(
# check for table
RegexSQL(
- "select relname from pg_class c join pg_namespace.*",
+ "SELECT pg_catalog.pg_class.relname FROM pg_catalog."
+ "pg_class JOIN pg_catalog.pg_namespace.*",
dialect="postgresql",
),
# check for enum, just once
- RegexSQL(r".*SELECT EXISTS ", dialect="postgresql"),
+ RegexSQL(
+ r"SELECT pg_catalog.pg_type.typname .* WHERE "
+ "pg_catalog.pg_type.typname = ",
+ dialect="postgresql",
+ ),
RegexSQL("CREATE TYPE myenum AS ENUM .*", dialect="postgresql"),
RegexSQL(r"CREATE TABLE t .*", dialect="postgresql"),
)
asserter.assert_(
RegexSQL(
- "select relname from pg_class c join pg_namespace.*",
+ "SELECT pg_catalog.pg_class.relname FROM pg_catalog."
+ "pg_class JOIN pg_catalog.pg_namespace.*",
dialect="postgresql",
),
RegexSQL("DROP TABLE t", dialect="postgresql"),
- RegexSQL(r".*SELECT EXISTS ", dialect="postgresql"),
+ RegexSQL(
+ r"SELECT pg_catalog.pg_type.typname .* WHERE "
+ "pg_catalog.pg_type.typname = ",
+ dialect="postgresql",
+ ),
RegexSQL("DROP TYPE myenum", dialect="postgresql"),
)
connection, "fourfivesixtype"
)
- def test_no_support(self, testing_engine):
- def server_version_info(self):
- return (8, 2)
-
- e = testing_engine()
- dialect = e.dialect
- dialect._get_server_version_info = server_version_info
-
- assert dialect.supports_native_enum
- e.connect()
- assert not dialect.supports_native_enum
-
- # initialize is called again on new pool
- e.dispose()
- e.connect()
- assert not dialect.supports_native_enum
-
def test_reflection(self, metadata, connection):
etype = Enum(
"four", "five", "six", name="fourfivesixtype", metadata=metadata
from sqlalchemy.testing import config
from sqlalchemy.testing import engines
from sqlalchemy.testing import eq_
+from sqlalchemy.testing import expect_raises
from sqlalchemy.testing import expect_warnings
from sqlalchemy.testing import fixtures
from sqlalchemy.testing import is_
["foo", "bar"],
)
- eq_(
- [
- d["name"]
- for d in insp.get_columns("nonexistent", schema="test_schema")
- ],
- [],
- )
- eq_(
- [
- d["name"]
- for d in insp.get_columns("another_created", schema=None)
- ],
- [],
- )
- eq_(
- [
- d["name"]
- for d in insp.get_columns("local_only", schema="test_schema")
- ],
- [],
- )
+ with expect_raises(exc.NoSuchTableError):
+ insp.get_columns("nonexistent", schema="test_schema")
+
+ with expect_raises(exc.NoSuchTableError):
+ insp.get_columns("another_created", schema=None)
+
+ with expect_raises(exc.NoSuchTableError):
+ insp.get_columns("local_only", schema="test_schema")
+
eq_([d["name"] for d in insp.get_columns("local_only")], ["q", "p"])
def test_table_names_present(self):
import sqlalchemy as sa
from sqlalchemy import Computed
+from sqlalchemy import Connection
from sqlalchemy import DefaultClause
from sqlalchemy import event
from sqlalchemy import FetchedValue
from sqlalchemy import String
from sqlalchemy import testing
from sqlalchemy import UniqueConstraint
+from sqlalchemy.engine import Inspector
from sqlalchemy.testing import assert_raises
from sqlalchemy.testing import assert_raises_message
from sqlalchemy.testing import AssertsCompiledSQL
m2 = MetaData()
t2 = Table("x", m2, autoload_with=connection)
- ck = [
+ cks = [
const
for const in t2.constraints
if isinstance(const, sa.CheckConstraint)
- ][0]
-
+ ]
+ eq_(len(cks), 1)
+ ck = cks[0]
eq_regex(ck.sqltext.text, r"[\(`]*q[\)`]* > 10")
eq_(ck.name, "ck1")
sa.Index("x_ix", t.c.a, t.c.b)
metadata.create_all(connection)
- def mock_get_columns(self, connection, table_name, **kw):
- return [{"name": "b", "type": Integer, "primary_key": False}]
+ gri = Inspector._get_reflection_info
+
+ def mock_gri(self, *a, **kw):
+ res = gri(self, *a, **kw)
+ res.columns[(None, "x")] = [
+ col for col in res.columns[(None, "x")] if col["name"] == "b"
+ ]
+ return res
with testing.mock.patch.object(
- connection.dialect, "get_columns", mock_get_columns
+ Inspector, "_get_reflection_info", mock_gri
):
m = MetaData()
with testing.expect_warnings(
eq_(ua, ["users", "email_addresses"])
eq_(oi, ["orders", "items"])
- def test_checkfirst(self, connection):
+ def test_checkfirst(self, connection: Connection) -> None:
insp = inspect(connection)
+
users = self.tables.users
is_false(insp.has_table("users"))
users.create(connection)
+ insp.clear_cache()
is_true(insp.has_table("users"))
users.create(connection, checkfirst=True)
users.drop(connection)
users.drop(connection, checkfirst=True)
+ insp.clear_cache()
is_false(insp.has_table("users"))
users.create(connection, checkfirst=True)
users.drop(connection)
- def test_createdrop(self, connection):
+ def test_createdrop(self, connection: Connection) -> None:
insp = inspect(connection)
metadata = self.tables_test_metadata
+ assert metadata is not None
metadata.create_all(connection)
is_true(insp.has_table("items"))
is_true(insp.has_table("email_addresses"))
metadata.create_all(connection)
+ insp.clear_cache()
is_true(insp.has_table("items"))
metadata.drop_all(connection)
+ insp.clear_cache()
is_false(insp.has_table("items"))
is_false(insp.has_table("email_addresses"))
metadata.drop_all(connection)
+ insp.clear_cache()
is_false(insp.has_table("items"))
- def test_tablenames(self, connection):
+ def test_has_table_and_table_names(self, connection):
+ """establish that has_table and get_table_names are consistent w/
+ each other with regard to caching
+
+ """
metadata = self.tables_test_metadata
metadata.create_all(bind=connection)
insp = inspect(connection)
# ensure all tables we created are in the list.
is_true(set(insp.get_table_names()).issuperset(metadata.tables))
+ assert insp.has_table("items")
+ assert "items" in insp.get_table_names()
+
+ self.tables.items.drop(connection)
+
+ # cached
+ assert insp.has_table("items")
+ assert "items" in insp.get_table_names()
+
+ insp = inspect(connection)
+ assert not insp.has_table("items")
+ assert "items" not in insp.get_table_names()
+
class SchemaManipulationTest(fixtures.TestBase):
__backend__ = True
__backend__ = True
@testing.requires.schemas
- @testing.requires.cross_schema_fk_reflection
def test_has_schema(self):
- if not hasattr(testing.db.dialect, "has_schema"):
- testing.config.skip_test(
- "dialect %s doesn't have a has_schema method"
- % testing.db.dialect.name
- )
with testing.db.connect() as conn:
eq_(
testing.db.dialect.has_schema(
--- /dev/null
+from argparse import ArgumentDefaultsHelpFormatter
+from argparse import ArgumentParser
+from collections import defaultdict
+from contextlib import contextmanager
+from functools import wraps
+from pprint import pprint
+import random
+import time
+
+import sqlalchemy as sa
+from sqlalchemy.engine import Inspector
+
+types = (sa.Integer, sa.BigInteger, sa.String(200), sa.DateTime)
+USE_CONNECTION = False
+
+
+def generate_table(meta: sa.MetaData, min_cols, max_cols, dialect_name):
+ col_number = random.randint(min_cols, max_cols)
+ table_num = len(meta.tables)
+ add_identity = random.random() > 0.90
+ identity = sa.Identity(
+ always=random.randint(0, 1),
+ start=random.randint(1, 100),
+ increment=random.randint(1, 7),
+ )
+ is_mssql = dialect_name == "mssql"
+ cols = []
+ for i in range(col_number - (0 if is_mssql else add_identity)):
+ args = []
+ if random.random() < 0.95 or table_num == 0:
+ if is_mssql and add_identity and i == 0:
+ args.append(sa.Integer)
+ args.append(identity)
+ else:
+ args.append(random.choice(types))
+ else:
+ args.append(
+ sa.ForeignKey(f"table_{table_num-1}.table_{table_num-1}_col_1")
+ )
+ cols.append(
+ sa.Column(
+ f"table_{table_num}_col_{i+1}",
+ *args,
+ primary_key=i == 0,
+ comment=f"primary key of table_{table_num}"
+ if i == 0
+ else None,
+ index=random.random() > 0.9 and i > 0,
+ unique=random.random() > 0.95 and i > 0,
+ )
+ )
+ if add_identity and not is_mssql:
+ cols.append(
+ sa.Column(
+ f"table_{table_num}_col_{col_number}",
+ sa.Integer,
+ identity,
+ )
+ )
+ args = ()
+ if table_num % 3 == 0:
+ # mysql can't do check constraint on PK col
+ args = (sa.CheckConstraint(cols[1].is_not(None)),)
+ return sa.Table(
+ f"table_{table_num}",
+ meta,
+ *cols,
+ *args,
+ comment=f"comment for table_{table_num}" if table_num % 2 else None,
+ )
+
+
+def generate_meta(schema_name, table_number, min_cols, max_cols, dialect_name):
+ meta = sa.MetaData(schema=schema_name)
+ log = defaultdict(int)
+ for _ in range(table_number):
+ t = generate_table(meta, min_cols, max_cols, dialect_name)
+ log["tables"] += 1
+ log["columns"] += len(t.columns)
+ log["index"] += len(t.indexes)
+ log["check_con"] += len(
+ [c for c in t.constraints if isinstance(c, sa.CheckConstraint)]
+ )
+ log["foreign_keys_con"] += len(
+ [
+ c
+ for c in t.constraints
+ if isinstance(c, sa.ForeignKeyConstraint)
+ ]
+ )
+ log["unique_con"] += len(
+ [c for c in t.constraints if isinstance(c, sa.UniqueConstraint)]
+ )
+ log["identity"] += len([c for c in t.columns if c.identity])
+
+ print("Meta info", dict(log))
+ return meta
+
+
+def log(fn):
+ @wraps(fn)
+ def wrap(*a, **kw):
+ print("Running ", fn.__name__, "...", flush=True, end="")
+ try:
+ r = fn(*a, **kw)
+ except NotImplementedError:
+ print(" [not implemented]", flush=True)
+ r = None
+ else:
+ print("... done", flush=True)
+ return r
+
+ return wrap
+
+
+tests = {}
+
+
+def define_test(fn):
+ name: str = fn.__name__
+ if name.startswith("reflect_"):
+ name = name[8:]
+ tests[name] = wfn = log(fn)
+ return wfn
+
+
+@log
+def create_tables(engine, meta):
+ tables = list(meta.tables.values())
+ for i in range(0, len(tables), 500):
+ meta.create_all(engine, tables[i : i + 500])
+
+
+@log
+def drop_tables(engine, meta, schema_name, table_names: list):
+ tables = list(meta.tables.values())[::-1]
+ for i in range(0, len(tables), 500):
+ meta.drop_all(engine, tables[i : i + 500])
+
+ remaining = sa.inspect(engine).get_table_names(schema=schema_name)
+ suffix = ""
+ if engine.dialect.name.startswith("postgres"):
+ suffix = "CASCADE"
+
+ remaining = sorted(
+ remaining, key=lambda tn: int(tn.partition("_")[2]), reverse=True
+ )
+ with engine.connect() as conn:
+ for i, tn in enumerate(remaining):
+ if engine.dialect.requires_name_normalize:
+ name = engine.dialect.denormalize_name(tn)
+ else:
+ name = tn
+ if schema_name:
+ conn.execute(
+ sa.schema.DDL(
+ f'DROP TABLE {schema_name}."{name}" {suffix}'
+ )
+ )
+ else:
+ conn.execute(sa.schema.DDL(f'DROP TABLE "{name}" {suffix}'))
+ if i % 500 == 0:
+ conn.commit()
+ conn.commit()
+
+
+@log
+def reflect_tables(engine, schema_name):
+ ref_meta = sa.MetaData(schema=schema_name)
+ ref_meta.reflect(engine)
+
+
+def verify_dict(multi, single, str_compare=False):
+ if single is None or multi is None:
+ return
+ if single != multi:
+ keys = set(single) | set(multi)
+ diff = []
+ for key in sorted(keys):
+ se, me = single.get(key), multi.get(key)
+ if str(se) != str(me) if str_compare else se != me:
+ diff.append((key, single.get(key), multi.get(key)))
+ if diff:
+ print("\nfound different result:")
+ pprint(diff)
+
+
+def _single_test(
+ singe_fn_name,
+ multi_fn_name,
+ engine,
+ schema_name,
+ table_names,
+ timing,
+ mode,
+):
+ single = None
+ if "single" in mode:
+ singe_fn = getattr(Inspector, singe_fn_name)
+
+ def go(bind):
+ insp = sa.inspect(bind)
+ single = {}
+ with timing(singe_fn.__name__):
+ for t in table_names:
+ single[(schema_name, t)] = singe_fn(
+ insp, t, schema=schema_name
+ )
+ return single
+
+ if USE_CONNECTION:
+ with engine.connect() as c:
+ single = go(c)
+ else:
+ single = go(engine)
+
+ multi = None
+ if "multi" in mode:
+ insp = sa.inspect(engine)
+ multi_fn = getattr(Inspector, multi_fn_name)
+ with timing(multi_fn.__name__):
+ multi = multi_fn(insp, schema=schema_name)
+ return (multi, single)
+
+
+@define_test
+def reflect_columns(
+ engine, schema_name, table_names, timing, mode, ignore_diff
+):
+ multi, single = _single_test(
+ "get_columns",
+ "get_multi_columns",
+ engine,
+ schema_name,
+ table_names,
+ timing,
+ mode,
+ )
+ if not ignore_diff:
+ verify_dict(multi, single, str_compare=True)
+
+
+@define_test
+def reflect_table_options(
+ engine, schema_name, table_names, timing, mode, ignore_diff
+):
+ multi, single = _single_test(
+ "get_table_options",
+ "get_multi_table_options",
+ engine,
+ schema_name,
+ table_names,
+ timing,
+ mode,
+ )
+ if not ignore_diff:
+ verify_dict(multi, single)
+
+
+@define_test
+def reflect_pk(engine, schema_name, table_names, timing, mode, ignore_diff):
+ multi, single = _single_test(
+ "get_pk_constraint",
+ "get_multi_pk_constraint",
+ engine,
+ schema_name,
+ table_names,
+ timing,
+ mode,
+ )
+ if not ignore_diff:
+ verify_dict(multi, single)
+
+
+@define_test
+def reflect_comment(
+ engine, schema_name, table_names, timing, mode, ignore_diff
+):
+ multi, single = _single_test(
+ "get_table_comment",
+ "get_multi_table_comment",
+ engine,
+ schema_name,
+ table_names,
+ timing,
+ mode,
+ )
+ if not ignore_diff:
+ verify_dict(multi, single)
+
+
+@define_test
+def reflect_whole_tables(
+ engine, schema_name, table_names, timing, mode, ignore_diff
+):
+ single = None
+ meta = sa.MetaData(schema=schema_name)
+
+ if "single" in mode:
+
+ def go(bind):
+ single = {}
+ with timing("Table_autoload_with"):
+ for name in table_names:
+ single[(None, name)] = sa.Table(
+ name, meta, autoload_with=bind
+ )
+ return single
+
+ if USE_CONNECTION:
+ with engine.connect() as c:
+ single = go(c)
+ else:
+ single = go(engine)
+
+ multi_meta = sa.MetaData(schema=schema_name)
+ if "multi" in mode:
+ with timing("MetaData_reflect"):
+ multi_meta.reflect(engine, only=table_names)
+ return (multi_meta, single)
+
+
+@define_test
+def reflect_check_constraints(
+ engine, schema_name, table_names, timing, mode, ignore_diff
+):
+ multi, single = _single_test(
+ "get_check_constraints",
+ "get_multi_check_constraints",
+ engine,
+ schema_name,
+ table_names,
+ timing,
+ mode,
+ )
+ if not ignore_diff:
+ verify_dict(multi, single)
+
+
+@define_test
+def reflect_indexes(
+ engine, schema_name, table_names, timing, mode, ignore_diff
+):
+ multi, single = _single_test(
+ "get_indexes",
+ "get_multi_indexes",
+ engine,
+ schema_name,
+ table_names,
+ timing,
+ mode,
+ )
+ if not ignore_diff:
+ verify_dict(multi, single)
+
+
+@define_test
+def reflect_foreign_keys(
+ engine, schema_name, table_names, timing, mode, ignore_diff
+):
+ multi, single = _single_test(
+ "get_foreign_keys",
+ "get_multi_foreign_keys",
+ engine,
+ schema_name,
+ table_names,
+ timing,
+ mode,
+ )
+ if not ignore_diff:
+ verify_dict(multi, single)
+
+
+@define_test
+def reflect_unique_constraints(
+ engine, schema_name, table_names, timing, mode, ignore_diff
+):
+ multi, single = _single_test(
+ "get_unique_constraints",
+ "get_multi_unique_constraints",
+ engine,
+ schema_name,
+ table_names,
+ timing,
+ mode,
+ )
+ if not ignore_diff:
+ verify_dict(multi, single)
+
+
+def _apply_events(engine):
+ queries = defaultdict(list)
+
+ now = 0
+
+ @sa.event.listens_for(engine, "before_cursor_execute")
+ def before_cursor_execute(
+ conn, cursor, statement, parameters, context, executemany
+ ):
+
+ nonlocal now
+ now = time.time()
+
+ @sa.event.listens_for(engine, "after_cursor_execute")
+ def after_cursor_execute(
+ conn, cursor, statement, parameters, context, executemany
+ ):
+ total = time.time() - now
+
+ if context and context.compiled:
+ statement_str = context.compiled.string
+ else:
+ statement_str = statement
+ queries[statement_str].append(total)
+
+ return queries
+
+
+def _print_query_stats(queries):
+ number_of_queries = sum(
+ len(query_times) for query_times in queries.values()
+ )
+ print("-" * 50)
+ q_list = list(queries.items())
+ q_list.sort(key=lambda rec: -sum(rec[1]))
+ total = sum([sum(t) for _, t in q_list])
+ print(f"total number of queries: {number_of_queries}. Total time {total}")
+ print("-" * 50)
+
+ for stmt, times in q_list:
+ total_t = sum(times)
+ max_t = max(times)
+ min_t = min(times)
+ avg_t = total_t / len(times)
+ times.sort()
+ median_t = times[len(times) // 2]
+
+ print(
+ f"Query times: {total_t=}, {max_t=}, {min_t=}, {avg_t=}, "
+ f"{median_t=} Number of calls: {len(times)}"
+ )
+ print(stmt.strip(), "\n")
+
+
+def main(db, schema_name, table_number, min_cols, max_cols, args):
+ timing = timer()
+ if args.pool_class:
+ engine = sa.create_engine(
+ db, echo=args.echo, poolclass=getattr(sa.pool, args.pool_class)
+ )
+ else:
+ engine = sa.create_engine(db, echo=args.echo)
+
+ if engine.name == "oracle":
+ # clear out oracle caches so that we get the real-world time the
+ # queries would normally take for scripts that aren't run repeatedly
+ with engine.connect() as conn:
+ # https://stackoverflow.com/questions/2147456/how-to-clear-all-cached-items-in-oracle
+ conn.exec_driver_sql("alter system flush buffer_cache")
+ conn.exec_driver_sql("alter system flush shared_pool")
+ if not args.no_create:
+ print(
+ f"Generating {table_number} using engine {engine} in "
+ f"schema {schema_name or 'default'}",
+ )
+ meta = sa.MetaData()
+ table_names = []
+ stats = {}
+ try:
+ if not args.no_create:
+ with timing("populate-meta"):
+ meta = generate_meta(
+ schema_name, table_number, min_cols, max_cols, engine.name
+ )
+ with timing("create-tables"):
+ create_tables(engine, meta)
+
+ with timing("get_table_names"):
+ with engine.connect() as conn:
+ table_names = engine.dialect.get_table_names(
+ conn, schema=schema_name
+ )
+ print(
+ f"Reflected table number {len(table_names)} in "
+ f"schema {schema_name or 'default'}"
+ )
+ mode = {"single", "multi"}
+ if args.multi_only:
+ mode.discard("single")
+ if args.single_only:
+ mode.discard("multi")
+
+ if args.sqlstats:
+ print("starting stats for subsequent tests")
+ stats = _apply_events(engine)
+ for test_name, test_fn in tests.items():
+ if test_name in args.test or "all" in args.test:
+ test_fn(
+ engine,
+ schema_name,
+ table_names,
+ timing,
+ mode,
+ args.ignore_diff,
+ )
+
+ if args.reflect:
+ with timing("reflect-tables"):
+ reflect_tables(engine, schema_name)
+ finally:
+ # copy stats to new dict
+ if args.sqlstats:
+ stats = dict(stats)
+ try:
+ if not args.no_drop:
+ with timing("drop-tables"):
+ drop_tables(engine, meta, schema_name, table_names)
+ finally:
+ pprint(timing.timing, sort_dicts=False)
+ if args.sqlstats:
+ _print_query_stats(stats)
+
+
+def timer():
+ timing = {}
+
+ @contextmanager
+ def track_time(name):
+ s = time.time()
+ yield
+ timing[name] = time.time() - s
+
+ track_time.timing = timing
+ return track_time
+
+
+if __name__ == "__main__":
+ parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
+ parser.add_argument(
+ "--db", help="Database url", default="sqlite:///many-table.db"
+ )
+ parser.add_argument(
+ "--schema-name",
+ help="optional schema name",
+ type=str,
+ default=None,
+ )
+ parser.add_argument(
+ "--table-number",
+ help="Number of table to generate.",
+ type=int,
+ default=250,
+ )
+ parser.add_argument(
+ "--min-cols",
+ help="Min number of column per table.",
+ type=int,
+ default=15,
+ )
+ parser.add_argument(
+ "--max-cols",
+ help="Max number of column per table.",
+ type=int,
+ default=250,
+ )
+ parser.add_argument(
+ "--no-create", help="Do not run create tables", action="store_true"
+ )
+ parser.add_argument(
+ "--no-drop", help="Do not run drop tables", action="store_true"
+ )
+ parser.add_argument("--reflect", help="Run reflect", action="store_true")
+ parser.add_argument(
+ "--test",
+ help="Run these tests. 'all' runs all tests",
+ nargs="+",
+ choices=tuple(tests) + ("all", "none"),
+ default=["all"],
+ )
+ parser.add_argument(
+ "--sqlstats",
+ help="count and time individual queries",
+ action="store_true",
+ )
+ parser.add_argument(
+ "--multi-only", help="Only run multi table tests", action="store_true"
+ )
+ parser.add_argument(
+ "--single-only",
+ help="Only run single table tests",
+ action="store_true",
+ )
+ parser.add_argument(
+ "--echo", action="store_true", help="Enable echo on the engine"
+ )
+ parser.add_argument(
+ "--ignore-diff",
+ action="store_true",
+ help="Ignores differences in the single/multi reflections",
+ )
+ parser.add_argument(
+ "--single-inspect-conn",
+ action="store_true",
+ help="Uses inspect on a connection instead of on the engine when "
+ "using single reflections. Mainly for sqlite.",
+ )
+ parser.add_argument("--pool-class", help="The pool class to use")
+
+ args = parser.parse_args()
+ min_cols = args.min_cols
+ max_cols = args.max_cols
+ USE_CONNECTION = args.single_inspect_conn
+ assert min_cols <= max_cols and min_cols >= 1
+ assert not (args.multi_only and args.single_only)
+ main(
+ args.db, args.schema_name, args.table_number, min_cols, max_cols, args
+ )
return skip_if(no_support("sqlite", "not supported by database"))
+ @property
+ def foreign_keys_reflect_as_index(self):
+ return only_on(["mysql", "mariadb"])
+
+ @property
+ def unique_index_reflect_as_unique_constraints(self):
+ return only_on(["mysql", "mariadb"])
+
+ @property
+ def unique_constraints_reflect_as_index(self):
+ return only_on(["mysql", "mariadb", "oracle", "postgresql", "mssql"])
+
@property
def foreign_key_constraint_name_reflection(self):
return fails_if(
and not self._mariadb_105(config)
)
+ @property
+ def reflect_indexes_with_ascdesc(self):
+ return fails_if(["oracle"])
+
@property
def table_ddl_if_exists(self):
"""target platform supports IF NOT EXISTS / IF EXISTS for tables."""
return exclusions.open()
+ @property
+ def schema_create_delete(self):
+ """target database supports schema create and dropped with
+ 'CREATE SCHEMA' and 'DROP SCHEMA'"""
+ return exclusions.skip_if(["sqlite", "oracle"])
+
@property
def cross_schema_fk_reflection(self):
"""target system must support reflection of inter-schema foreign
@property
def check_constraint_reflection(self):
- return fails_on_everything_except(
- "postgresql",
- "sqlite",
- "oracle",
- self._mysql_and_check_constraints_exist,
+ return only_on(
+ [
+ "postgresql",
+ "sqlite",
+ "oracle",
+ self._mysql_and_check_constraints_exist,
+ ]
)
@property
def temp_table_names(self):
"""target dialect supports listing of temporary table names"""
- return only_on(["sqlite", "oracle"]) + skip_if(self._sqlite_file_db)
+ return only_on(["sqlite", "oracle", "postgresql"]) + skip_if(
+ self._sqlite_file_db
+ )
@property
def temporary_views(self):
@property
def views(self):
"""Target database must support VIEWs."""
-
- return skip_if("drizzle", "no VIEW support")
+ return exclusions.open()
@property
def empty_strings_varchar(self):
)
)
+ def _has_oracle_test_dblink(self, key):
+ def check(config):
+ assert config.db.dialect.name == "oracle"
+ name = config.file_config.get("sqla_testing", key)
+ if not name:
+ return False
+ with config.db.connect() as conn:
+ links = config.db.dialect._list_dblinks(conn)
+ return config.db.dialect.normalize_name(name) in links
+
+ return only_on(["oracle"]) + only_if(
+ check,
+ f"{key} option not specified in config or dblink not found in db",
+ )
+
@property
def oracle_test_dblink(self):
- return skip_if(
- lambda config: not config.file_config.has_option(
- "sqla_testing", "oracle_db_link"
- ),
- "oracle_db_link option not specified in config",
- )
+ return self._has_oracle_test_dblink("oracle_db_link")
+
+ @property
+ def oracle_test_dblink2(self):
+ return self._has_oracle_test_dblink("oracle_db_link2")
@property
def postgresql_test_dblink(self):
return only_on(["mssql"]) + only_if(check)
+ @property
+ def reflect_table_options(self):
+ return only_on(["mysql", "mariadb", "oracle"])
+
+ @property
+ def materialized_views(self):
+ """Target database must support MATERIALIZED VIEWs."""
+ return only_on(["postgresql", "oracle"])
+
+ @property
+ def materialized_views_reflect_pk(self):
+ return only_on(["oracle"])
+
@property
def uuid_data_type(self):
"""Return databases that support the UUID datatype."""