]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
new pre-caching architecture for autogenerate
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 22 Dec 2025 15:40:06 +0000 (10:40 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 8 Jan 2026 22:50:49 +0000 (17:50 -0500)
Autogenerate reflection sweeps now use the "bulk" inspector methods
introduced in SQLAlchemy 2.0, which for selected dialects including
PostgreSQL and Oracle use batched queries to reflect whole collections of
tables using O(1) queries rather than O(N).

This is the original proposed version that uses the Inspector
entirely with its public API, with the exception of reflect_table()
which makes a _ReflectionInfo on a per-table basis.  Other than
that, no private API assumptions are made.

If SQLAlchemy needed to add new fields to _ReflectionInfo, it just
needs to make sure they have default functions (which all the current fields
should have anyway, since there is even a ReflectionDefaults
constant that already provides these!)

This version is the one that does not imply any particular
changes in SQLAlchemy and does not have any sqla_compat logic,
so that we may have alembic using the new performance enhancements
allowing for SQLAlchemy to potentially improve its API for a later
release.

Other than that, typing of reflection functions is improved.

Fixes: #1771
Change-Id: I7b9a75fa81cefc377fdb1a22fc1cfc3da1765769

alembic/autogenerate/compare/constraints.py
alembic/autogenerate/compare/tables.py
alembic/autogenerate/compare/util.py
alembic/ddl/impl.py
alembic/ddl/mssql.py
alembic/ddl/postgresql.py
docs/build/unreleased/1771.rst [new file with mode: 0644]
tests/test_autogen_diffs.py

index 0b524b975ee994efa9535c1091ca31fdf574ff84..90934adfd19459c064375fbfdd6e391e71bd5edd 100644 (file)
@@ -5,6 +5,7 @@ from __future__ import annotations
 import logging
 from typing import Any
 from typing import cast
+from typing import Collection
 from typing import Dict
 from typing import Mapping
 from typing import Optional
@@ -28,6 +29,9 @@ from ...util import PriorityDispatchResult
 from ...util import sqla_compat
 
 if TYPE_CHECKING:
+    from sqlalchemy.engine.interfaces import ReflectedForeignKeyConstraint
+    from sqlalchemy.engine.interfaces import ReflectedIndex
+    from sqlalchemy.engine.interfaces import ReflectedUniqueConstraint
     from sqlalchemy.sql.elements import quoted_name
     from sqlalchemy.sql.elements import TextClause
     from sqlalchemy.sql.schema import Column
@@ -40,7 +44,6 @@ if TYPE_CHECKING:
     from ...operations.ops import ModifyTableOps
     from ...runtime.plugins import Plugin
 
-
 _C = TypeVar("_C", bound=Union[UniqueConstraint, ForeignKeyConstraint, Index])
 
 
@@ -72,18 +75,24 @@ def _compare_indexes_and_uniques(
         metadata_unique_constraints = set()
         metadata_indexes = set()
 
-    conn_uniques = conn_indexes = frozenset()  # type:ignore[var-annotated]
+    conn_uniques: Collection[UniqueConstraint] = frozenset()
+    conn_indexes: Collection[Index] = frozenset()
 
     supports_unique_constraints = False
 
     unique_constraints_duplicate_unique_indexes = False
 
     if conn_table is not None:
+        conn_uniques_reflected: Collection[ReflectedUniqueConstraint] = (
+            frozenset()
+        )
+        conn_indexes_reflected: Collection[ReflectedIndex] = frozenset()
+
         # 1b. ... and from connection, if the table exists
         try:
-            conn_uniques = _InspectorConv(inspector).get_unique_constraints(
-                tname, schema=schema
-            )
+            conn_uniques_reflected = _InspectorConv(
+                inspector
+            ).get_unique_constraints(tname, schema=schema)
 
             supports_unique_constraints = True
         except NotImplementedError:
@@ -94,28 +103,28 @@ def _compare_indexes_and_uniques(
             # not being present
             pass
         else:
-            conn_uniques = [  # type:ignore[assignment]
+            conn_uniques_reflected = [
                 uq
-                for uq in conn_uniques
+                for uq in conn_uniques_reflected
                 if autogen_context.run_name_filters(
                     uq["name"],
                     "unique_constraint",
                     {"table_name": tname, "schema_name": schema},
                 )
             ]
-            for uq in conn_uniques:
+            for uq in conn_uniques_reflected:
                 if uq.get("duplicates_index"):
                     unique_constraints_duplicate_unique_indexes = True
         try:
-            conn_indexes = _InspectorConv(inspector).get_indexes(
+            conn_indexes_reflected = _InspectorConv(inspector).get_indexes(
                 tname, schema=schema
             )
         except NotImplementedError:
             pass
         else:
-            conn_indexes = [  # type:ignore[assignment]
+            conn_indexes_reflected = [
                 ix
-                for ix in conn_indexes
+                for ix in conn_indexes_reflected
                 if autogen_context.run_name_filters(
                     ix["name"],
                     "index",
@@ -127,17 +136,18 @@ def _compare_indexes_and_uniques(
         # into schema objects
         if is_drop_table:
             # for DROP TABLE uniques are inline, don't need them
-            conn_uniques = set()  # type:ignore[assignment]
+            conn_uniques = set()
         else:
-            conn_uniques = {  # type:ignore[assignment]
+            conn_uniques = {
                 _make_unique_constraint(impl, uq_def, conn_table)
-                for uq_def in conn_uniques
+                for uq_def in conn_uniques_reflected
             }
 
-        conn_indexes = {  # type:ignore[assignment]
+        conn_indexes = {
             index
             for index in (
-                _make_index(impl, ix, conn_table) for ix in conn_indexes
+                _make_index(impl, ix, conn_table)
+                for ix in conn_indexes_reflected
             )
             if index is not None
         }
@@ -507,7 +517,7 @@ _IndexColumnSortingOps: Mapping[str, Any] = util.immutabledict(
 
 
 def _make_index(
-    impl: DefaultImpl, params: Dict[str, Any], conn_table: Table
+    impl: DefaultImpl, params: ReflectedIndex, conn_table: Table
 ) -> Optional[Index]:
     exprs: list[Union[Column[Any], TextClause]] = []
     sorting = params.get("column_sorting")
@@ -539,7 +549,7 @@ def _make_index(
 
 
 def _make_unique_constraint(
-    impl: DefaultImpl, params: Dict[str, Any], conn_table: Table
+    impl: DefaultImpl, params: ReflectedUniqueConstraint, conn_table: Table
 ) -> UniqueConstraint:
     uq = sa_schema.UniqueConstraint(
         *[conn_table.c[cname] for cname in params["column_names"]],
@@ -553,7 +563,7 @@ def _make_unique_constraint(
 
 
 def _make_foreign_key(
-    params: Dict[str, Any], conn_table: Table
+    params: ReflectedForeignKeyConstraint, conn_table: Table
 ) -> ForeignKeyConstraint:
     tname = params["referred_table"]
     if params["referred_schema"]:
index 0847ff5e12c84dd00f517fab24bf707e8f45c5ac..31eddc6b5993cfab09335453d98dfef45a75b5a5 100644 (file)
@@ -48,19 +48,25 @@ def _autogen_for_tables(
     version_table = autogen_context.migration_context.version_table
 
     for schema_name in schemas:
-        tables = set(inspector.get_table_names(schema=schema_name))
+        tables = available = set(inspector.get_table_names(schema=schema_name))
         if schema_name == version_table_schema:
             tables = tables.difference(
                 [autogen_context.migration_context.version_table]
             )
 
-        conn_table_names.update(
-            (schema_name, tname)
+        tablenames = [
+            tname
             for tname in tables
             if autogen_context.run_name_filters(
                 tname, "table", {"schema_name": schema_name}
             )
-        )
+        ]
+
+        conn_table_names.update((schema_name, tname) for tname in tablenames)
+
+        inspector = autogen_context.inspector
+        insp = _InspectorConv(inspector)
+        insp.pre_cache_tables(schema_name, tablenames, available)
 
     metadata_table_names = OrderedSet(
         [(table.schema, table.name) for table in autogen_context.sorted_tables]
@@ -139,6 +145,9 @@ def _compare_tables(
     removal_metadata = sa_schema.MetaData()
     for s, tname in conn_table_names.difference(metadata_table_names):
         name = sa_schema._get_table_key(tname, s)
+
+        # a name might be present already if a previous reflection pulled
+        # this table in via foreign key constraint
         exists = name in removal_metadata.tables
         t = sa_schema.Table(tname, removal_metadata, schema=s)
 
@@ -152,7 +161,7 @@ def _compare_tables(
                 (inspector),
                 # fmt: on
             )
-            _InspectorConv(inspector).reflect_table(t, include_columns=None)
+            _InspectorConv(inspector).reflect_table(t)
         if autogen_context.run_object_filters(t, tname, "table", True, None):
             modify_table_ops = ops.ModifyTableOps(tname, [], schema=s)
 
@@ -172,6 +181,9 @@ def _compare_tables(
     for s, tname in existing_tables:
         name = sa_schema._get_table_key(tname, s)
         exists = name in existing_metadata.tables
+
+        # a name might be present already if a previous reflection pulled
+        # this table in via foreign key constraint
         t = sa_schema.Table(tname, existing_metadata, schema=s)
         if not exists:
             event.listen(
@@ -182,7 +194,7 @@ def _compare_tables(
                 _compat_autogen_column_reflect(inspector),
                 # fmt: on
             )
-            _InspectorConv(inspector).reflect_table(t, include_columns=None)
+            _InspectorConv(inspector).reflect_table(t)
 
         conn_column_info[(s, tname)] = t
 
@@ -296,6 +308,7 @@ def _compare_columns(
 
 
 def setup(plugin: Plugin) -> None:
+
     plugin.add_autogenerate_comparator(
         _autogen_for_tables,
         "schema",
index 199d8280e00d1d07af397eba396bf31ad0806712..41829c0e0b846b5d88f68c2db01199d560858c94 100644 (file)
@@ -1,15 +1,89 @@
 # mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
 # mypy: no-warn-return-any, allow-any-generics
+from __future__ import annotations
+
+from typing import Any
+from typing import cast
+from typing import Collection
+from typing import TYPE_CHECKING
 
 from sqlalchemy.sql.elements import conv
+from typing_extensions import Self
+
+from ...util import sqla_compat
+
+if TYPE_CHECKING:
+    from sqlalchemy import Table
+    from sqlalchemy.engine import Inspector
+    from sqlalchemy.engine.interfaces import ReflectedForeignKeyConstraint
+    from sqlalchemy.engine.interfaces import ReflectedIndex
+    from sqlalchemy.engine.interfaces import ReflectedUniqueConstraint
+    from sqlalchemy.engine.reflection import _ReflectionInfo
+
+_INSP_KEYS = (
+    "columns",
+    "pk_constraint",
+    "foreign_keys",
+    "indexes",
+    "unique_constraints",
+    "table_comment",
+    "check_constraints",
+    "table_options",
+)
+_CONSTRAINT_INSP_KEYS = (
+    "pk_constraint",
+    "foreign_keys",
+    "indexes",
+    "unique_constraints",
+    "check_constraints",
+)
 
 
 class _InspectorConv:
     __slots__ = ("inspector",)
 
-    def __init__(self, inspector):
+    def __new__(cls, inspector: Inspector) -> Self:
+        obj: Any
+        if sqla_compat.sqla_2:
+            obj = object.__new__(_SQLA2InspectorConv)
+            _SQLA2InspectorConv.__init__(obj, inspector)
+        else:
+            obj = object.__new__(_LegacyInspectorConv)
+            _LegacyInspectorConv.__init__(obj, inspector)
+        return cast(Self, obj)
+
+    def __init__(self, inspector: Inspector):
         self.inspector = inspector
 
+    def pre_cache_tables(
+        self,
+        schema: str | None,
+        tablenames: list[str],
+        all_available_tablenames: Collection[str],
+    ) -> None:
+        pass
+
+    def get_unique_constraints(
+        self, tname: str, schema: str | None
+    ) -> list[ReflectedUniqueConstraint]:
+        raise NotImplementedError()
+
+    def get_indexes(
+        self, tname: str, schema: str | None
+    ) -> list[ReflectedIndex]:
+        raise NotImplementedError()
+
+    def get_foreign_keys(
+        self, tname: str, schema: str | None
+    ) -> list[ReflectedForeignKeyConstraint]:
+        raise NotImplementedError()
+
+    def reflect_table(self, table: Table) -> None:
+        raise NotImplementedError()
+
+
+class _LegacyInspectorConv(_InspectorConv):
+
     def _apply_reflectinfo_conv(self, consts):
         if not consts:
             return consts
@@ -28,26 +102,213 @@ class _InspectorConv:
                 const.name = conv(const.name)
         return consts
 
-    def get_indexes(self, *args, **kw):
+    def get_indexes(
+        self, tname: str, schema: str | None
+    ) -> list[ReflectedIndex]:
         return self._apply_reflectinfo_conv(
-            self.inspector.get_indexes(*args, **kw)
+            self.inspector.get_indexes(tname, schema=schema)
         )
 
-    def get_unique_constraints(self, *args, **kw):
+    def get_unique_constraints(
+        self, tname: str, schema: str | None
+    ) -> list[ReflectedUniqueConstraint]:
         return self._apply_reflectinfo_conv(
-            self.inspector.get_unique_constraints(*args, **kw)
+            self.inspector.get_unique_constraints(tname, schema=schema)
         )
 
-    def get_foreign_keys(self, *args, **kw):
+    def get_foreign_keys(
+        self, tname: str, schema: str | None
+    ) -> list[ReflectedForeignKeyConstraint]:
         return self._apply_reflectinfo_conv(
-            self.inspector.get_foreign_keys(*args, **kw)
+            self.inspector.get_foreign_keys(tname, schema=schema)
         )
 
-    def reflect_table(self, table, *, include_columns):
-        self.inspector.reflect_table(table, include_columns=include_columns)
+    def reflect_table(self, table: Table) -> None:
+        self.inspector.reflect_table(table, include_columns=None)
 
-        # I had a cool version of this using _ReflectInfo, however that doesn't
-        # work in 1.4 and it's not public API in 2.x.  Then this is just a two
-        # liner.  So there's no competition...
         self._apply_constraint_conv(table.constraints)
         self._apply_constraint_conv(table.indexes)
+
+
+class _SQLA2InspectorConv(_InspectorConv):
+
+    def _pre_cache(
+        self,
+        schema: str | None,
+        tablenames: list[str],
+        all_available_tablenames: Collection[str],
+        info_key: str,
+        inspector_method: Any,
+    ) -> None:
+
+        if info_key in self.inspector.info_cache:
+            return
+
+        # heuristic vendored from SQLAlchemy 2.0
+        # if more than 50% of the tables in the db are in filter_names load all
+        # the tables, since it's most likely faster to avoid a filter on that
+        # many tables. also if a dialect doesnt have a "multi" method then
+        # return the filter names
+        if tablenames and all_available_tablenames and len(tablenames) > 100:
+            fraction = len(tablenames) / len(all_available_tablenames)
+        else:
+            fraction = None
+
+        if (
+            fraction is None
+            or fraction <= 0.5
+            or not self.inspector.dialect._overrides_default(
+                inspector_method.__name__
+            )
+        ):
+            optimized_filter_names = tablenames
+        else:
+            optimized_filter_names = None
+
+        try:
+            elements = inspector_method(
+                schema=schema, filter_names=optimized_filter_names
+            )
+        except NotImplementedError:
+            self.inspector.info_cache[info_key] = NotImplementedError
+        else:
+            self.inspector.info_cache[info_key] = elements
+
+    def _return_from_cache(
+        self,
+        tname: str,
+        schema: str | None,
+        info_key: str,
+        inspector_method: Any,
+        apply_constraint_conv: bool = False,
+        optional=True,
+    ) -> Any:
+        not_in_cache = object()
+
+        if info_key in self.inspector.info_cache:
+            cache = self.inspector.info_cache[info_key]
+            if cache is NotImplementedError:
+                if optional:
+                    return {}
+                else:
+                    # maintain NotImplementedError as alembic compare
+                    # uses these to determine classes of construct that it
+                    # should not compare to DB elements
+                    raise NotImplementedError()
+
+            individual = cache.get((schema, tname), not_in_cache)
+
+            if individual is not not_in_cache:
+                if apply_constraint_conv and individual is not None:
+                    return self._apply_reflectinfo_conv(individual)
+                else:
+                    return individual
+
+        try:
+            data = inspector_method(tname, schema=schema)
+        except NotImplementedError:
+            if optional:
+                return {}
+            else:
+                raise
+
+        if apply_constraint_conv:
+            return self._apply_reflectinfo_conv(data)
+        else:
+            return data
+
+    def get_unique_constraints(
+        self, tname: str, schema: str | None
+    ) -> list[ReflectedUniqueConstraint]:
+        return self._return_from_cache(
+            tname,
+            schema,
+            "alembic_unique_constraints",
+            self.inspector.get_unique_constraints,
+            apply_constraint_conv=True,
+            optional=False,
+        )
+
+    def get_indexes(
+        self, tname: str, schema: str | None
+    ) -> list[ReflectedIndex]:
+        return self._return_from_cache(
+            tname,
+            schema,
+            "alembic_indexes",
+            self.inspector.get_indexes,
+            apply_constraint_conv=True,
+            optional=False,
+        )
+
+    def get_foreign_keys(
+        self, tname: str, schema: str | None
+    ) -> list[ReflectedForeignKeyConstraint]:
+        return self._return_from_cache(
+            tname,
+            schema,
+            "alembic_foreign_keys",
+            self.inspector.get_foreign_keys,
+            apply_constraint_conv=True,
+        )
+
+    def _apply_reflectinfo_conv(self, consts):
+        if not consts:
+            return consts
+        for const in consts if not isinstance(consts, dict) else [consts]:
+            if const["name"] is not None and not isinstance(
+                const["name"], conv
+            ):
+                const["name"] = conv(const["name"])
+        return consts
+
+    def pre_cache_tables(
+        self,
+        schema: str | None,
+        tablenames: list[str],
+        all_available_tablenames: Collection[str],
+    ) -> None:
+        for key in _INSP_KEYS:
+            keyname = f"alembic_{key}"
+            meth = getattr(self.inspector, f"get_multi_{key}")
+
+            self._pre_cache(
+                schema,
+                tablenames,
+                all_available_tablenames,
+                keyname,
+                meth,
+            )
+
+    def _make_reflection_info(
+        self, tname: str, schema: str | None
+    ) -> _ReflectionInfo:
+        from sqlalchemy.engine.reflection import _ReflectionInfo
+
+        table_key = (schema, tname)
+
+        return _ReflectionInfo(
+            unreflectable={},
+            **{
+                key: {
+                    table_key: self._return_from_cache(
+                        tname,
+                        schema,
+                        f"alembic_{key}",
+                        getattr(self.inspector, f"get_{key}"),
+                        apply_constraint_conv=(key in _CONSTRAINT_INSP_KEYS),
+                    )
+                }
+                for key in _INSP_KEYS
+            },
+        )
+
+    def reflect_table(self, table: Table) -> None:
+        ri = self._make_reflection_info(table.name, table.schema)
+
+        self.inspector.reflect_table(
+            table,
+            include_columns=None,
+            resolve_fks=False,
+            _reflect_info=ri,
+        )
index 4e3f29ae7d29b29d19e56b844f23c2e7a76503b6..00dd7d86e1b458aac9458b225061b500554a42e7 100644 (file)
@@ -43,6 +43,10 @@ if TYPE_CHECKING:
     from sqlalchemy.engine import Connection
     from sqlalchemy.engine import Dialect
     from sqlalchemy.engine.cursor import CursorResult
+    from sqlalchemy.engine.interfaces import ReflectedForeignKeyConstraint
+    from sqlalchemy.engine.interfaces import ReflectedIndex
+    from sqlalchemy.engine.interfaces import ReflectedPrimaryKeyConstraint
+    from sqlalchemy.engine.interfaces import ReflectedUniqueConstraint
     from sqlalchemy.engine.reflection import Inspector
     from sqlalchemy.sql import ClauseElement
     from sqlalchemy.sql import Executable
@@ -59,6 +63,12 @@ if TYPE_CHECKING:
     from ..operations.batch import ApplyBatchImpl
     from ..operations.batch import BatchOperationsImpl
 
+    _ReflectedConstraint = (
+        ReflectedForeignKeyConstraint
+        | ReflectedPrimaryKeyConstraint
+        | ReflectedIndex
+        | ReflectedUniqueConstraint
+    )
 log = logging.getLogger(__name__)
 
 
@@ -843,9 +853,9 @@ class DefaultImpl(metaclass=ImplMeta):
                 metadata_indexes.discard(idx)
 
     def adjust_reflected_dialect_options(
-        self, reflected_object: Dict[str, Any], kind: str
+        self, reflected_object: _ReflectedConstraint, kind: str
     ) -> Dict[str, Any]:
-        return reflected_object.get("dialect_options", {})
+        return reflected_object.get("dialect_options", {})  # type: ignore[return-value]   # noqa: E501
 
 
 class Params(NamedTuple):
index a236e41f2dffedbeb5200767345111466662bf51..22bd0e4b0b45d213f6cde77f40986bda0f3442de 100644 (file)
@@ -47,6 +47,7 @@ if TYPE_CHECKING:
     from sqlalchemy.sql.type_api import TypeEngine
 
     from .base import _ServerDefault
+    from .impl import _ReflectedConstraint
 
 
 class MSSQLImpl(DefaultImpl):
@@ -282,10 +283,10 @@ class MSSQLImpl(DefaultImpl):
         return diff, ignored, is_alter
 
     def adjust_reflected_dialect_options(
-        self, reflected_object: Dict[str, Any], kind: str
+        self, reflected_object: _ReflectedConstraint, kind: str
     ) -> Dict[str, Any]:
         options: Dict[str, Any]
-        options = reflected_object.get("dialect_options", {}).copy()
+        options = reflected_object.get("dialect_options", {}).copy()  # type: ignore[attr-defined]  # noqa: E501
         if not options.get("mssql_include"):
             options.pop("mssql_include", None)
         if not options.get("mssql_clustered"):
index 18f95e4aa80acc07de76e04081b065095a579993..d55664bb75e814dbea56659d97fd407728839e5e 100644 (file)
@@ -71,11 +71,11 @@ if TYPE_CHECKING:
     from sqlalchemy.sql.type_api import TypeEngine
 
     from .base import _ServerDefault
+    from .impl import _ReflectedConstraint
     from ..autogenerate.api import AutogenContext
     from ..autogenerate.render import _f_name
     from ..runtime.migration import MigrationContext
 
-
 log = logging.getLogger(__name__)
 
 
@@ -421,10 +421,10 @@ class PostgresqlImpl(DefaultImpl):
         return ComparisonResult.Equal()
 
     def adjust_reflected_dialect_options(
-        self, reflected_options: Dict[str, Any], kind: str
+        self, reflected_object: _ReflectedConstraint, kind: str
     ) -> Dict[str, Any]:
         options: Dict[str, Any]
-        options = reflected_options.get("dialect_options", {}).copy()
+        options = reflected_object.get("dialect_options", {}).copy()  # type: ignore[attr-defined]  # noqa: E501
         if not options.get("postgresql_include"):
             options.pop("postgresql_include", None)
         return options
diff --git a/docs/build/unreleased/1771.rst b/docs/build/unreleased/1771.rst
new file mode 100644 (file)
index 0000000..829dac9
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: feature, autogenerate
+    :tickets: 1771
+
+    Autogenerate reflection sweeps now use the "bulk" inspector methods
+    introduced in SQLAlchemy 2.0, which for selected dialects including
+    PostgreSQL and Oracle use batched queries to reflect whole collections of
+    tables using O(1) queries rather than O(N).
index c7dd8a3cd93561485a8dc3c09194bf73de127451..5b66fb2308c56d3f0e3ca863e135387766fd1133 100644 (file)
@@ -2068,3 +2068,154 @@ class AutogenFKTest(AutogenFixtureTest, TestBase):
             ["id"],
             name=expected_name,
         )
+
+
+class AutogenInspectorCacheTest(AutogenFixtureTest, TestBase):
+    """test for the new inspector caching added for #1771."""
+
+    __only_on__ = ("sqlite", "postgresql", "oracle")
+    __requires__ = ("sqlalchemy_2",)
+
+    @testing.fixture
+    def instrument_inspector_conv(self, connection):
+        from sqlalchemy import event
+
+        shared_mock = mock.MagicMock()
+
+        def track_before_cursor_execute(
+            conn, cursor, statement, parameters, context, executemany
+        ):
+            shared_mock.sql_call(statement)
+
+        event.listen(
+            connection,
+            "before_cursor_execute",
+            track_before_cursor_execute,
+        )
+
+        yield shared_mock
+
+        event.remove(
+            connection,
+            "before_cursor_execute",
+            track_before_cursor_execute,
+        )
+
+    @testing.fixture
+    def models(self, metadata, connection):
+        m1 = metadata
+        m2 = MetaData()
+        from sqlalchemy import Identity
+
+        Table(
+            "parent",
+            m1,
+            Column("id", Integer, Identity(), primary_key=True),
+            Column("name", String(50)),
+            Column("x", String(50)),
+            Column("y", String(50)),
+        )
+
+        Table(
+            "child",
+            m1,
+            Column("id", Integer, Identity(), primary_key=True),
+            Column("name", String(50)),
+            Column("parent_id", Integer),
+            (ForeignKeyConstraint(["parent_id"], ["parent.id"])),
+        )
+
+        Table(
+            "t1",
+            m1,
+            Column("id", Integer, Identity(), primary_key=True),
+            Column("name", String(50)),
+        )
+
+        Table(
+            "t2",
+            m1,
+            Column("id", Integer, Identity(), primary_key=True),
+            Column("name", String(50)),
+        )
+
+        Table(
+            "t3",
+            m1,
+            Column("id", Integer, Identity(), primary_key=True),
+            Column("name", String(50)),
+        )
+
+        m1.create_all(connection)
+
+        ctx_opts = {
+            "compare_type": True,
+            "compare_server_default": True,
+            "target_metadata": m2,
+            "upgrade_token": "upgrades",
+            "downgrade_token": "downgrades",
+            "alembic_module_prefix": "op.",
+            "sqlalchemy_module_prefix": "sa.",
+        }
+        context = MigrationContext.configure(
+            connection=connection, opts=ctx_opts
+        )
+
+        autogen_context = api.AutogenContext(context, m2)
+        uo = ops.UpgradeOps(ops=[])
+
+        yield autogen_context, uo
+
+    @testing.fixture(params=[True, False])
+    def disable_pre_cache(self, request):
+        from alembic.autogenerate.compare.util import _SQLA2InspectorConv
+
+        patcher = mock.patch.object(_SQLA2InspectorConv, "pre_cache_tables")
+
+        param = request.param
+
+        if param:
+            patcher.start()
+
+        yield param
+
+        if param:
+            patcher.stop()
+
+    class Approx:
+        def __init__(self, value):
+            self.value = value
+
+        def __eq__(self, other):
+            return abs(other - self.value) < 5
+
+        def __repr__(self):
+            return f"Approximately({self.value})"
+
+    expected = {
+        # not any savings for SQLite which does query-per-table no matter what.
+        # but we can at least see that pre-caching does all the SQL up front
+        "sqlite": {
+            "pre_cache": Approx(60),
+            "no_pre_cache": Approx(95),
+        },
+        # for PG and Oracle which use multi-queries, big savings
+        "postgresql": {"pre_cache": Approx(12), "no_pre_cache": Approx(57)},
+        "oracle": {"pre_cache": Approx(12), "no_pre_cache": Approx(32)},
+    }
+
+    def test_run_compare(
+        self, connection, models, instrument_inspector_conv, disable_pre_cache
+    ):
+        autogen_context, uo = models
+
+        autogenerate._produce_net_changes(autogen_context, uo)
+
+        sql_calls = len(instrument_inspector_conv.mock_calls)
+
+        eq_(
+            sql_calls,
+            self.expected[connection.dialect.name][
+                "pre_cache" if not disable_pre_cache else "no_pre_cache"
+            ],
+        )