import logging
from typing import Any
from typing import cast
+from typing import Collection
from typing import Dict
from typing import Mapping
from typing import Optional
from ...util import sqla_compat
if TYPE_CHECKING:
+ from sqlalchemy.engine.interfaces import ReflectedForeignKeyConstraint
+ from sqlalchemy.engine.interfaces import ReflectedIndex
+ from sqlalchemy.engine.interfaces import ReflectedUniqueConstraint
from sqlalchemy.sql.elements import quoted_name
from sqlalchemy.sql.elements import TextClause
from sqlalchemy.sql.schema import Column
from ...operations.ops import ModifyTableOps
from ...runtime.plugins import Plugin
-
_C = TypeVar("_C", bound=Union[UniqueConstraint, ForeignKeyConstraint, Index])
metadata_unique_constraints = set()
metadata_indexes = set()
- conn_uniques = conn_indexes = frozenset() # type:ignore[var-annotated]
+ conn_uniques: Collection[UniqueConstraint] = frozenset()
+ conn_indexes: Collection[Index] = frozenset()
supports_unique_constraints = False
unique_constraints_duplicate_unique_indexes = False
if conn_table is not None:
+ conn_uniques_reflected: Collection[ReflectedUniqueConstraint] = (
+ frozenset()
+ )
+ conn_indexes_reflected: Collection[ReflectedIndex] = frozenset()
+
# 1b. ... and from connection, if the table exists
try:
- conn_uniques = _InspectorConv(inspector).get_unique_constraints(
- tname, schema=schema
- )
+ conn_uniques_reflected = _InspectorConv(
+ inspector
+ ).get_unique_constraints(tname, schema=schema)
supports_unique_constraints = True
except NotImplementedError:
# not being present
pass
else:
- conn_uniques = [ # type:ignore[assignment]
+ conn_uniques_reflected = [
uq
- for uq in conn_uniques
+ for uq in conn_uniques_reflected
if autogen_context.run_name_filters(
uq["name"],
"unique_constraint",
{"table_name": tname, "schema_name": schema},
)
]
- for uq in conn_uniques:
+ for uq in conn_uniques_reflected:
if uq.get("duplicates_index"):
unique_constraints_duplicate_unique_indexes = True
try:
- conn_indexes = _InspectorConv(inspector).get_indexes(
+ conn_indexes_reflected = _InspectorConv(inspector).get_indexes(
tname, schema=schema
)
except NotImplementedError:
pass
else:
- conn_indexes = [ # type:ignore[assignment]
+ conn_indexes_reflected = [
ix
- for ix in conn_indexes
+ for ix in conn_indexes_reflected
if autogen_context.run_name_filters(
ix["name"],
"index",
# into schema objects
if is_drop_table:
# for DROP TABLE uniques are inline, don't need them
- conn_uniques = set() # type:ignore[assignment]
+ conn_uniques = set()
else:
- conn_uniques = { # type:ignore[assignment]
+ conn_uniques = {
_make_unique_constraint(impl, uq_def, conn_table)
- for uq_def in conn_uniques
+ for uq_def in conn_uniques_reflected
}
- conn_indexes = { # type:ignore[assignment]
+ conn_indexes = {
index
for index in (
- _make_index(impl, ix, conn_table) for ix in conn_indexes
+ _make_index(impl, ix, conn_table)
+ for ix in conn_indexes_reflected
)
if index is not None
}
def _make_index(
- impl: DefaultImpl, params: Dict[str, Any], conn_table: Table
+ impl: DefaultImpl, params: ReflectedIndex, conn_table: Table
) -> Optional[Index]:
exprs: list[Union[Column[Any], TextClause]] = []
sorting = params.get("column_sorting")
def _make_unique_constraint(
- impl: DefaultImpl, params: Dict[str, Any], conn_table: Table
+ impl: DefaultImpl, params: ReflectedUniqueConstraint, conn_table: Table
) -> UniqueConstraint:
uq = sa_schema.UniqueConstraint(
*[conn_table.c[cname] for cname in params["column_names"]],
def _make_foreign_key(
- params: Dict[str, Any], conn_table: Table
+ params: ReflectedForeignKeyConstraint, conn_table: Table
) -> ForeignKeyConstraint:
tname = params["referred_table"]
if params["referred_schema"]:
version_table = autogen_context.migration_context.version_table
for schema_name in schemas:
- tables = set(inspector.get_table_names(schema=schema_name))
+ tables = available = set(inspector.get_table_names(schema=schema_name))
if schema_name == version_table_schema:
tables = tables.difference(
[autogen_context.migration_context.version_table]
)
- conn_table_names.update(
- (schema_name, tname)
+ tablenames = [
+ tname
for tname in tables
if autogen_context.run_name_filters(
tname, "table", {"schema_name": schema_name}
)
- )
+ ]
+
+ conn_table_names.update((schema_name, tname) for tname in tablenames)
+
+ inspector = autogen_context.inspector
+ insp = _InspectorConv(inspector)
+ insp.pre_cache_tables(schema_name, tablenames, available)
metadata_table_names = OrderedSet(
[(table.schema, table.name) for table in autogen_context.sorted_tables]
removal_metadata = sa_schema.MetaData()
for s, tname in conn_table_names.difference(metadata_table_names):
name = sa_schema._get_table_key(tname, s)
+
+ # a name might be present already if a previous reflection pulled
+ # this table in via foreign key constraint
exists = name in removal_metadata.tables
t = sa_schema.Table(tname, removal_metadata, schema=s)
(inspector),
# fmt: on
)
- _InspectorConv(inspector).reflect_table(t, include_columns=None)
+ _InspectorConv(inspector).reflect_table(t)
if autogen_context.run_object_filters(t, tname, "table", True, None):
modify_table_ops = ops.ModifyTableOps(tname, [], schema=s)
for s, tname in existing_tables:
name = sa_schema._get_table_key(tname, s)
exists = name in existing_metadata.tables
+
+ # a name might be present already if a previous reflection pulled
+ # this table in via foreign key constraint
t = sa_schema.Table(tname, existing_metadata, schema=s)
if not exists:
event.listen(
_compat_autogen_column_reflect(inspector),
# fmt: on
)
- _InspectorConv(inspector).reflect_table(t, include_columns=None)
+ _InspectorConv(inspector).reflect_table(t)
conn_column_info[(s, tname)] = t
def setup(plugin: Plugin) -> None:
+
plugin.add_autogenerate_comparator(
_autogen_for_tables,
"schema",
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
# mypy: no-warn-return-any, allow-any-generics
+from __future__ import annotations
+
+from typing import Any
+from typing import cast
+from typing import Collection
+from typing import TYPE_CHECKING
from sqlalchemy.sql.elements import conv
+from typing_extensions import Self
+
+from ...util import sqla_compat
+
+if TYPE_CHECKING:
+ from sqlalchemy import Table
+ from sqlalchemy.engine import Inspector
+ from sqlalchemy.engine.interfaces import ReflectedForeignKeyConstraint
+ from sqlalchemy.engine.interfaces import ReflectedIndex
+ from sqlalchemy.engine.interfaces import ReflectedUniqueConstraint
+ from sqlalchemy.engine.reflection import _ReflectionInfo
+
+_INSP_KEYS = (
+ "columns",
+ "pk_constraint",
+ "foreign_keys",
+ "indexes",
+ "unique_constraints",
+ "table_comment",
+ "check_constraints",
+ "table_options",
+)
+_CONSTRAINT_INSP_KEYS = (
+ "pk_constraint",
+ "foreign_keys",
+ "indexes",
+ "unique_constraints",
+ "check_constraints",
+)
class _InspectorConv:
__slots__ = ("inspector",)
- def __init__(self, inspector):
+ def __new__(cls, inspector: Inspector) -> Self:
+ obj: Any
+ if sqla_compat.sqla_2:
+ obj = object.__new__(_SQLA2InspectorConv)
+ _SQLA2InspectorConv.__init__(obj, inspector)
+ else:
+ obj = object.__new__(_LegacyInspectorConv)
+ _LegacyInspectorConv.__init__(obj, inspector)
+ return cast(Self, obj)
+
+ def __init__(self, inspector: Inspector):
self.inspector = inspector
+ def pre_cache_tables(
+ self,
+ schema: str | None,
+ tablenames: list[str],
+ all_available_tablenames: Collection[str],
+ ) -> None:
+ pass
+
+ def get_unique_constraints(
+ self, tname: str, schema: str | None
+ ) -> list[ReflectedUniqueConstraint]:
+ raise NotImplementedError()
+
+ def get_indexes(
+ self, tname: str, schema: str | None
+ ) -> list[ReflectedIndex]:
+ raise NotImplementedError()
+
+ def get_foreign_keys(
+ self, tname: str, schema: str | None
+ ) -> list[ReflectedForeignKeyConstraint]:
+ raise NotImplementedError()
+
+ def reflect_table(self, table: Table) -> None:
+ raise NotImplementedError()
+
+
+class _LegacyInspectorConv(_InspectorConv):
+
def _apply_reflectinfo_conv(self, consts):
if not consts:
return consts
const.name = conv(const.name)
return consts
- def get_indexes(self, *args, **kw):
+ def get_indexes(
+ self, tname: str, schema: str | None
+ ) -> list[ReflectedIndex]:
return self._apply_reflectinfo_conv(
- self.inspector.get_indexes(*args, **kw)
+ self.inspector.get_indexes(tname, schema=schema)
)
- def get_unique_constraints(self, *args, **kw):
+ def get_unique_constraints(
+ self, tname: str, schema: str | None
+ ) -> list[ReflectedUniqueConstraint]:
return self._apply_reflectinfo_conv(
- self.inspector.get_unique_constraints(*args, **kw)
+ self.inspector.get_unique_constraints(tname, schema=schema)
)
- def get_foreign_keys(self, *args, **kw):
+ def get_foreign_keys(
+ self, tname: str, schema: str | None
+ ) -> list[ReflectedForeignKeyConstraint]:
return self._apply_reflectinfo_conv(
- self.inspector.get_foreign_keys(*args, **kw)
+ self.inspector.get_foreign_keys(tname, schema=schema)
)
- def reflect_table(self, table, *, include_columns):
- self.inspector.reflect_table(table, include_columns=include_columns)
+ def reflect_table(self, table: Table) -> None:
+ self.inspector.reflect_table(table, include_columns=None)
- # I had a cool version of this using _ReflectInfo, however that doesn't
- # work in 1.4 and it's not public API in 2.x. Then this is just a two
- # liner. So there's no competition...
self._apply_constraint_conv(table.constraints)
self._apply_constraint_conv(table.indexes)
+
+
+class _SQLA2InspectorConv(_InspectorConv):
+
+ def _pre_cache(
+ self,
+ schema: str | None,
+ tablenames: list[str],
+ all_available_tablenames: Collection[str],
+ info_key: str,
+ inspector_method: Any,
+ ) -> None:
+
+ if info_key in self.inspector.info_cache:
+ return
+
+ # heuristic vendored from SQLAlchemy 2.0
+ # if more than 50% of the tables in the db are in filter_names load all
+ # the tables, since it's most likely faster to avoid a filter on that
+ # many tables. also if a dialect doesnt have a "multi" method then
+ # return the filter names
+ if tablenames and all_available_tablenames and len(tablenames) > 100:
+ fraction = len(tablenames) / len(all_available_tablenames)
+ else:
+ fraction = None
+
+ if (
+ fraction is None
+ or fraction <= 0.5
+ or not self.inspector.dialect._overrides_default(
+ inspector_method.__name__
+ )
+ ):
+ optimized_filter_names = tablenames
+ else:
+ optimized_filter_names = None
+
+ try:
+ elements = inspector_method(
+ schema=schema, filter_names=optimized_filter_names
+ )
+ except NotImplementedError:
+ self.inspector.info_cache[info_key] = NotImplementedError
+ else:
+ self.inspector.info_cache[info_key] = elements
+
+ def _return_from_cache(
+ self,
+ tname: str,
+ schema: str | None,
+ info_key: str,
+ inspector_method: Any,
+ apply_constraint_conv: bool = False,
+ optional=True,
+ ) -> Any:
+ not_in_cache = object()
+
+ if info_key in self.inspector.info_cache:
+ cache = self.inspector.info_cache[info_key]
+ if cache is NotImplementedError:
+ if optional:
+ return {}
+ else:
+ # maintain NotImplementedError as alembic compare
+ # uses these to determine classes of construct that it
+ # should not compare to DB elements
+ raise NotImplementedError()
+
+ individual = cache.get((schema, tname), not_in_cache)
+
+ if individual is not not_in_cache:
+ if apply_constraint_conv and individual is not None:
+ return self._apply_reflectinfo_conv(individual)
+ else:
+ return individual
+
+ try:
+ data = inspector_method(tname, schema=schema)
+ except NotImplementedError:
+ if optional:
+ return {}
+ else:
+ raise
+
+ if apply_constraint_conv:
+ return self._apply_reflectinfo_conv(data)
+ else:
+ return data
+
+ def get_unique_constraints(
+ self, tname: str, schema: str | None
+ ) -> list[ReflectedUniqueConstraint]:
+ return self._return_from_cache(
+ tname,
+ schema,
+ "alembic_unique_constraints",
+ self.inspector.get_unique_constraints,
+ apply_constraint_conv=True,
+ optional=False,
+ )
+
+ def get_indexes(
+ self, tname: str, schema: str | None
+ ) -> list[ReflectedIndex]:
+ return self._return_from_cache(
+ tname,
+ schema,
+ "alembic_indexes",
+ self.inspector.get_indexes,
+ apply_constraint_conv=True,
+ optional=False,
+ )
+
+ def get_foreign_keys(
+ self, tname: str, schema: str | None
+ ) -> list[ReflectedForeignKeyConstraint]:
+ return self._return_from_cache(
+ tname,
+ schema,
+ "alembic_foreign_keys",
+ self.inspector.get_foreign_keys,
+ apply_constraint_conv=True,
+ )
+
+ def _apply_reflectinfo_conv(self, consts):
+ if not consts:
+ return consts
+ for const in consts if not isinstance(consts, dict) else [consts]:
+ if const["name"] is not None and not isinstance(
+ const["name"], conv
+ ):
+ const["name"] = conv(const["name"])
+ return consts
+
+ def pre_cache_tables(
+ self,
+ schema: str | None,
+ tablenames: list[str],
+ all_available_tablenames: Collection[str],
+ ) -> None:
+ for key in _INSP_KEYS:
+ keyname = f"alembic_{key}"
+ meth = getattr(self.inspector, f"get_multi_{key}")
+
+ self._pre_cache(
+ schema,
+ tablenames,
+ all_available_tablenames,
+ keyname,
+ meth,
+ )
+
+ def _make_reflection_info(
+ self, tname: str, schema: str | None
+ ) -> _ReflectionInfo:
+ from sqlalchemy.engine.reflection import _ReflectionInfo
+
+ table_key = (schema, tname)
+
+ return _ReflectionInfo(
+ unreflectable={},
+ **{
+ key: {
+ table_key: self._return_from_cache(
+ tname,
+ schema,
+ f"alembic_{key}",
+ getattr(self.inspector, f"get_{key}"),
+ apply_constraint_conv=(key in _CONSTRAINT_INSP_KEYS),
+ )
+ }
+ for key in _INSP_KEYS
+ },
+ )
+
+ def reflect_table(self, table: Table) -> None:
+ ri = self._make_reflection_info(table.name, table.schema)
+
+ self.inspector.reflect_table(
+ table,
+ include_columns=None,
+ resolve_fks=False,
+ _reflect_info=ri,
+ )
from sqlalchemy.engine import Connection
from sqlalchemy.engine import Dialect
from sqlalchemy.engine.cursor import CursorResult
+ from sqlalchemy.engine.interfaces import ReflectedForeignKeyConstraint
+ from sqlalchemy.engine.interfaces import ReflectedIndex
+ from sqlalchemy.engine.interfaces import ReflectedPrimaryKeyConstraint
+ from sqlalchemy.engine.interfaces import ReflectedUniqueConstraint
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.sql import ClauseElement
from sqlalchemy.sql import Executable
from ..operations.batch import ApplyBatchImpl
from ..operations.batch import BatchOperationsImpl
+ _ReflectedConstraint = (
+ ReflectedForeignKeyConstraint
+ | ReflectedPrimaryKeyConstraint
+ | ReflectedIndex
+ | ReflectedUniqueConstraint
+ )
log = logging.getLogger(__name__)
metadata_indexes.discard(idx)
def adjust_reflected_dialect_options(
- self, reflected_object: Dict[str, Any], kind: str
+ self, reflected_object: _ReflectedConstraint, kind: str
) -> Dict[str, Any]:
- return reflected_object.get("dialect_options", {})
+ return reflected_object.get("dialect_options", {}) # type: ignore[return-value] # noqa: E501
class Params(NamedTuple):
from sqlalchemy.sql.type_api import TypeEngine
from .base import _ServerDefault
+ from .impl import _ReflectedConstraint
class MSSQLImpl(DefaultImpl):
return diff, ignored, is_alter
def adjust_reflected_dialect_options(
- self, reflected_object: Dict[str, Any], kind: str
+ self, reflected_object: _ReflectedConstraint, kind: str
) -> Dict[str, Any]:
options: Dict[str, Any]
- options = reflected_object.get("dialect_options", {}).copy()
+ options = reflected_object.get("dialect_options", {}).copy() # type: ignore[attr-defined] # noqa: E501
if not options.get("mssql_include"):
options.pop("mssql_include", None)
if not options.get("mssql_clustered"):
from sqlalchemy.sql.type_api import TypeEngine
from .base import _ServerDefault
+ from .impl import _ReflectedConstraint
from ..autogenerate.api import AutogenContext
from ..autogenerate.render import _f_name
from ..runtime.migration import MigrationContext
-
log = logging.getLogger(__name__)
return ComparisonResult.Equal()
def adjust_reflected_dialect_options(
- self, reflected_options: Dict[str, Any], kind: str
+ self, reflected_object: _ReflectedConstraint, kind: str
) -> Dict[str, Any]:
options: Dict[str, Any]
- options = reflected_options.get("dialect_options", {}).copy()
+ options = reflected_object.get("dialect_options", {}).copy() # type: ignore[attr-defined] # noqa: E501
if not options.get("postgresql_include"):
options.pop("postgresql_include", None)
return options
--- /dev/null
+.. change::
+ :tags: feature, autogenerate
+ :tickets: 1771
+
+ Autogenerate reflection sweeps now use the "bulk" inspector methods
+ introduced in SQLAlchemy 2.0, which for selected dialects including
+ PostgreSQL and Oracle use batched queries to reflect whole collections of
+ tables using O(1) queries rather than O(N).
["id"],
name=expected_name,
)
+
+
+class AutogenInspectorCacheTest(AutogenFixtureTest, TestBase):
+ """test for the new inspector caching added for #1771."""
+
+ __only_on__ = ("sqlite", "postgresql", "oracle")
+ __requires__ = ("sqlalchemy_2",)
+
+ @testing.fixture
+ def instrument_inspector_conv(self, connection):
+ from sqlalchemy import event
+
+ shared_mock = mock.MagicMock()
+
+ def track_before_cursor_execute(
+ conn, cursor, statement, parameters, context, executemany
+ ):
+ shared_mock.sql_call(statement)
+
+ event.listen(
+ connection,
+ "before_cursor_execute",
+ track_before_cursor_execute,
+ )
+
+ yield shared_mock
+
+ event.remove(
+ connection,
+ "before_cursor_execute",
+ track_before_cursor_execute,
+ )
+
+ @testing.fixture
+ def models(self, metadata, connection):
+ m1 = metadata
+ m2 = MetaData()
+ from sqlalchemy import Identity
+
+ Table(
+ "parent",
+ m1,
+ Column("id", Integer, Identity(), primary_key=True),
+ Column("name", String(50)),
+ Column("x", String(50)),
+ Column("y", String(50)),
+ )
+
+ Table(
+ "child",
+ m1,
+ Column("id", Integer, Identity(), primary_key=True),
+ Column("name", String(50)),
+ Column("parent_id", Integer),
+ (ForeignKeyConstraint(["parent_id"], ["parent.id"])),
+ )
+
+ Table(
+ "t1",
+ m1,
+ Column("id", Integer, Identity(), primary_key=True),
+ Column("name", String(50)),
+ )
+
+ Table(
+ "t2",
+ m1,
+ Column("id", Integer, Identity(), primary_key=True),
+ Column("name", String(50)),
+ )
+
+ Table(
+ "t3",
+ m1,
+ Column("id", Integer, Identity(), primary_key=True),
+ Column("name", String(50)),
+ )
+
+ m1.create_all(connection)
+
+ ctx_opts = {
+ "compare_type": True,
+ "compare_server_default": True,
+ "target_metadata": m2,
+ "upgrade_token": "upgrades",
+ "downgrade_token": "downgrades",
+ "alembic_module_prefix": "op.",
+ "sqlalchemy_module_prefix": "sa.",
+ }
+ context = MigrationContext.configure(
+ connection=connection, opts=ctx_opts
+ )
+
+ autogen_context = api.AutogenContext(context, m2)
+ uo = ops.UpgradeOps(ops=[])
+
+ yield autogen_context, uo
+
+ @testing.fixture(params=[True, False])
+ def disable_pre_cache(self, request):
+ from alembic.autogenerate.compare.util import _SQLA2InspectorConv
+
+ patcher = mock.patch.object(_SQLA2InspectorConv, "pre_cache_tables")
+
+ param = request.param
+
+ if param:
+ patcher.start()
+
+ yield param
+
+ if param:
+ patcher.stop()
+
+ class Approx:
+ def __init__(self, value):
+ self.value = value
+
+ def __eq__(self, other):
+ return abs(other - self.value) < 5
+
+ def __repr__(self):
+ return f"Approximately({self.value})"
+
+ expected = {
+ # not any savings for SQLite which does query-per-table no matter what.
+ # but we can at least see that pre-caching does all the SQL up front
+ "sqlite": {
+ "pre_cache": Approx(60),
+ "no_pre_cache": Approx(95),
+ },
+ # for PG and Oracle which use multi-queries, big savings
+ "postgresql": {"pre_cache": Approx(12), "no_pre_cache": Approx(57)},
+ "oracle": {"pre_cache": Approx(12), "no_pre_cache": Approx(32)},
+ }
+
+ def test_run_compare(
+ self, connection, models, instrument_inspector_conv, disable_pre_cache
+ ):
+ autogen_context, uo = models
+
+ autogenerate._produce_net_changes(autogen_context, uo)
+
+ sql_calls = len(instrument_inspector_conv.mock_calls)
+
+ eq_(
+ sql_calls,
+ self.expected[connection.dialect.name][
+ "pre_cache" if not disable_pre_cache else "no_pre_cache"
+ ],
+ )