]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
More PostgreSQL expression index compare fixes
authorFederico Caselli <cfederico87@gmail.com>
Sat, 21 Oct 2023 19:06:57 +0000 (21:06 +0200)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 21 Nov 2023 20:17:35 +0000 (15:17 -0500)
Additional fixes to PostgreSQL expression index compare feature.
The compare now correctly accommodates casts and differences in
spacing.
Added detection logic for operation clauses inside the expression,
skipping the compare of these expressions.
To accommodate these changes the logic for the comparison of the
indexes and unique constraints was moved to the dialect
implementation, allowing greater flexibility.

Co-authored-by: Mike Bayer <mike_mp@zzzcomputing.com>
Fixes: #1321
Fixes: #1327
Fixes: #1356
Change-Id: Icad15bc556a63bfa55b84779e7691c745d943c63

14 files changed:
.pre-commit-config.yaml
alembic/autogenerate/api.py
alembic/autogenerate/compare.py
alembic/ddl/_autogen.py [new file with mode: 0644]
alembic/ddl/impl.py
alembic/ddl/mysql.py
alembic/ddl/postgresql.py
alembic/operations/ops.py
alembic/testing/schemacompare.py
alembic/util/sqla_compat.py
docs/build/unreleased/more_index_fixes.rst [new file with mode: 0644]
setup.cfg
tests/test_autogen_indexes.py
tests/test_postgresql.py

index 8d68141e0379763fba3bf8bb41e8d70aa2d4d2fb..f1a8b41838d716ef9fc21dbeb0b95f5c3a7675e5 100644 (file)
@@ -14,7 +14,7 @@ repos:
             - --keep-unused-type-checking
 
 -   repo: https://github.com/pycqa/flake8
-    rev: 6.0.0
+    rev: 6.1.0
     hooks:
     -   id: flake8
         additional_dependencies:
index 7282487be240d54c1e40ff04ede5e8b4e70ae91b..b7f43b1936886edc09d9a88f55b3de86cb6ffdf7 100644 (file)
@@ -17,6 +17,7 @@ from . import compare
 from . import render
 from .. import util
 from ..operations import ops
+from ..util import sqla_compat
 
 """Provide the 'autogenerate' feature which can produce migration operations
 automatically."""
@@ -440,7 +441,7 @@ class AutogenContext:
     def run_object_filters(
         self,
         object_: SchemaItem,
-        name: Optional[str],
+        name: sqla_compat._ConstraintName,
         type_: NameFilterType,
         reflected: bool,
         compare_to: Optional[SchemaItem],
index a24a75d1c92c14ce801717e2c687cb5b32a16239..a50d8b8186b7c0c96cd1b81596ae944f7ddde45b 100644 (file)
@@ -7,12 +7,12 @@ from typing import Any
 from typing import cast
 from typing import Dict
 from typing import Iterator
-from typing import List
 from typing import Mapping
 from typing import Optional
 from typing import Set
 from typing import Tuple
 from typing import TYPE_CHECKING
+from typing import TypeVar
 from typing import Union
 
 from sqlalchemy import event
@@ -21,10 +21,14 @@ from sqlalchemy import schema as sa_schema
 from sqlalchemy import text
 from sqlalchemy import types as sqltypes
 from sqlalchemy.sql import expression
+from sqlalchemy.sql.schema import ForeignKeyConstraint
+from sqlalchemy.sql.schema import Index
+from sqlalchemy.sql.schema import UniqueConstraint
 from sqlalchemy.util import OrderedSet
 
-from alembic.ddl.base import _fk_spec
 from .. import util
+from ..ddl._autogen import is_index_sig
+from ..ddl._autogen import is_uq_sig
 from ..operations import ops
 from ..util import sqla_compat
 
@@ -35,10 +39,7 @@ if TYPE_CHECKING:
     from sqlalchemy.sql.elements import quoted_name
     from sqlalchemy.sql.elements import TextClause
     from sqlalchemy.sql.schema import Column
-    from sqlalchemy.sql.schema import ForeignKeyConstraint
-    from sqlalchemy.sql.schema import Index
     from sqlalchemy.sql.schema import Table
-    from sqlalchemy.sql.schema import UniqueConstraint
 
     from alembic.autogenerate.api import AutogenContext
     from alembic.ddl.impl import DefaultImpl
@@ -46,6 +47,8 @@ if TYPE_CHECKING:
     from alembic.operations.ops import MigrationScript
     from alembic.operations.ops import ModifyTableOps
     from alembic.operations.ops import UpgradeOps
+    from ..ddl._autogen import _constraint_sig
+
 
 log = logging.getLogger(__name__)
 
@@ -429,102 +432,7 @@ def _compare_columns(
             log.info("Detected removed column '%s.%s'", name, cname)
 
 
-class _constraint_sig:
-    const: Union[UniqueConstraint, ForeignKeyConstraint, Index]
-
-    def md_name_to_sql_name(self, context: AutogenContext) -> Optional[str]:
-        return sqla_compat._get_constraint_final_name(
-            self.const, context.dialect
-        )
-
-    def __eq__(self, other):
-        return self.const == other.const
-
-    def __ne__(self, other):
-        return self.const != other.const
-
-    def __hash__(self) -> int:
-        return hash(self.const)
-
-
-class _uq_constraint_sig(_constraint_sig):
-    is_index = False
-    is_unique = True
-
-    def __init__(self, const: UniqueConstraint, impl: DefaultImpl) -> None:
-        self.const = const
-        self.name = const.name
-        self.sig = ("UNIQUE_CONSTRAINT",) + impl.create_unique_constraint_sig(
-            const
-        )
-
-    @property
-    def column_names(self) -> List[str]:
-        return [col.name for col in self.const.columns]
-
-
-class _ix_constraint_sig(_constraint_sig):
-    is_index = True
-
-    def __init__(self, const: Index, impl: DefaultImpl) -> None:
-        self.const = const
-        self.name = const.name
-        self.sig = ("INDEX",) + impl.create_index_sig(const)
-        self.is_unique = bool(const.unique)
-
-    def md_name_to_sql_name(self, context: AutogenContext) -> Optional[str]:
-        return sqla_compat._get_constraint_final_name(
-            self.const, context.dialect
-        )
-
-    @property
-    def column_names(self) -> Union[List[quoted_name], List[None]]:
-        return sqla_compat._get_index_column_names(self.const)
-
-
-class _fk_constraint_sig(_constraint_sig):
-    def __init__(
-        self, const: ForeignKeyConstraint, include_options: bool = False
-    ) -> None:
-        self.const = const
-        self.name = const.name
-
-        (
-            self.source_schema,
-            self.source_table,
-            self.source_columns,
-            self.target_schema,
-            self.target_table,
-            self.target_columns,
-            onupdate,
-            ondelete,
-            deferrable,
-            initially,
-        ) = _fk_spec(const)
-
-        self.sig: Tuple[Any, ...] = (
-            self.source_schema,
-            self.source_table,
-            tuple(self.source_columns),
-            self.target_schema,
-            self.target_table,
-            tuple(self.target_columns),
-        )
-        if include_options:
-            self.sig += (
-                (None if onupdate.lower() == "no action" else onupdate.lower())
-                if onupdate
-                else None,
-                (None if ondelete.lower() == "no action" else ondelete.lower())
-                if ondelete
-                else None,
-                # convert initially + deferrable into one three-state value
-                "initially_deferrable"
-                if initially and initially.lower() == "deferred"
-                else "deferrable"
-                if deferrable
-                else "not deferrable",
-            )
+_C = TypeVar("_C", bound=Union[UniqueConstraint, ForeignKeyConstraint, Index])
 
 
 @comparators.dispatch_for("table")
@@ -561,32 +469,31 @@ def _compare_indexes_and_uniques(
 
     if conn_table is not None:
         # 1b. ... and from connection, if the table exists
-        if hasattr(inspector, "get_unique_constraints"):
-            try:
-                conn_uniques = inspector.get_unique_constraints(  # type:ignore[assignment] # noqa
-                    tname, schema=schema
+        try:
+            conn_uniques = inspector.get_unique_constraints(  # type:ignore[assignment] # noqa
+                tname, schema=schema
+            )
+            supports_unique_constraints = True
+        except NotImplementedError:
+            pass
+        except TypeError:
+            # number of arguments is off for the base
+            # method in SQLAlchemy due to the cache decorator
+            # not being present
+            pass
+        else:
+            conn_uniques = [  # type:ignore[assignment]
+                uq
+                for uq in conn_uniques
+                if autogen_context.run_name_filters(
+                    uq["name"],
+                    "unique_constraint",
+                    {"table_name": tname, "schema_name": schema},
                 )
-                supports_unique_constraints = True
-            except NotImplementedError:
-                pass
-            except TypeError:
-                # number of arguments is off for the base
-                # method in SQLAlchemy due to the cache decorator
-                # not being present
-                pass
-            else:
-                conn_uniques = [  # type:ignore[assignment]
-                    uq
-                    for uq in conn_uniques
-                    if autogen_context.run_name_filters(
-                        uq["name"],
-                        "unique_constraint",
-                        {"table_name": tname, "schema_name": schema},
-                    )
-                ]
-                for uq in conn_uniques:
-                    if uq.get("duplicates_index"):
-                        unique_constraints_duplicate_unique_indexes = True
+            ]
+            for uq in conn_uniques:
+                if uq.get("duplicates_index"):
+                    unique_constraints_duplicate_unique_indexes = True
         try:
             conn_indexes = inspector.get_indexes(  # type:ignore[assignment]
                 tname, schema=schema
@@ -639,7 +546,7 @@ def _compare_indexes_and_uniques(
     # 3. give the dialect a chance to omit indexes and constraints that
     # we know are either added implicitly by the DB or that the DB
     # can't accurately report on
-    autogen_context.migration_context.impl.correct_for_autogen_constraints(
+    impl.correct_for_autogen_constraints(
         conn_uniques,  # type: ignore[arg-type]
         conn_indexes,  # type: ignore[arg-type]
         metadata_unique_constraints,
@@ -651,18 +558,21 @@ def _compare_indexes_and_uniques(
     # Index and UniqueConstraint so we can easily work with them
     # interchangeably
     metadata_unique_constraints_sig = {
-        _uq_constraint_sig(uq, impl) for uq in metadata_unique_constraints
+        impl._create_metadata_constraint_sig(uq)
+        for uq in metadata_unique_constraints
     }
 
     metadata_indexes_sig = {
-        _ix_constraint_sig(ix, impl) for ix in metadata_indexes
+        impl._create_metadata_constraint_sig(ix) for ix in metadata_indexes
     }
 
     conn_unique_constraints = {
-        _uq_constraint_sig(uq, impl) for uq in conn_uniques
+        impl._create_reflected_constraint_sig(uq) for uq in conn_uniques
     }
 
-    conn_indexes_sig = {_ix_constraint_sig(ix, impl) for ix in conn_indexes}
+    conn_indexes_sig = {
+        impl._create_reflected_constraint_sig(ix) for ix in conn_indexes
+    }
 
     # 5. index things by name, for those objects that have names
     metadata_names = {
@@ -670,12 +580,11 @@ def _compare_indexes_and_uniques(
         for c in metadata_unique_constraints_sig.union(
             metadata_indexes_sig  # type:ignore[arg-type]
         )
-        if isinstance(c, _ix_constraint_sig)
-        or sqla_compat._constraint_is_named(c.const, autogen_context.dialect)
+        if c.is_named
     }
 
-    conn_uniques_by_name: Dict[sqla_compat._ConstraintName, _uq_constraint_sig]
-    conn_indexes_by_name: Dict[sqla_compat._ConstraintName, _ix_constraint_sig]
+    conn_uniques_by_name: Dict[sqla_compat._ConstraintName, _constraint_sig]
+    conn_indexes_by_name: Dict[sqla_compat._ConstraintName, _constraint_sig]
 
     conn_uniques_by_name = {c.name: c for c in conn_unique_constraints}
     conn_indexes_by_name = {c.name: c for c in conn_indexes_sig}
@@ -694,13 +603,12 @@ def _compare_indexes_and_uniques(
 
     # 6. index things by "column signature", to help with unnamed unique
     # constraints.
-    conn_uniques_by_sig = {uq.sig: uq for uq in conn_unique_constraints}
+    conn_uniques_by_sig = {uq.unnamed: uq for uq in conn_unique_constraints}
     metadata_uniques_by_sig = {
-        uq.sig: uq for uq in metadata_unique_constraints_sig
+        uq.unnamed: uq for uq in metadata_unique_constraints_sig
     }
-    metadata_indexes_by_sig = {ix.sig: ix for ix in metadata_indexes_sig}
     unnamed_metadata_uniques = {
-        uq.sig: uq
+        uq.unnamed: uq
         for uq in metadata_unique_constraints_sig
         if not sqla_compat._constraint_is_named(
             uq.const, autogen_context.dialect
@@ -715,18 +623,18 @@ def _compare_indexes_and_uniques(
     # 4. The backend may double up indexes as unique constraints and
     #    vice versa (e.g. MySQL, Postgresql)
 
-    def obj_added(obj):
-        if obj.is_index:
+    def obj_added(obj: _constraint_sig):
+        if is_index_sig(obj):
             if autogen_context.run_object_filters(
                 obj.const, obj.name, "index", False, None
             ):
                 modify_ops.ops.append(ops.CreateIndexOp.from_index(obj.const))
                 log.info(
-                    "Detected added index '%s' on %s",
+                    "Detected added index '%r' on '%s'",
                     obj.name,
-                    ", ".join(["'%s'" % obj.column_names]),
+                    obj.column_names,
                 )
-        else:
+        elif is_uq_sig(obj):
             if not supports_unique_constraints:
                 # can't report unique indexes as added if we don't
                 # detect them
@@ -741,13 +649,15 @@ def _compare_indexes_and_uniques(
                     ops.AddConstraintOp.from_constraint(obj.const)
                 )
                 log.info(
-                    "Detected added unique constraint '%s' on %s",
+                    "Detected added unique constraint %r on '%s'",
                     obj.name,
-                    ", ".join(["'%s'" % obj.column_names]),
+                    obj.column_names,
                 )
+        else:
+            assert False
 
-    def obj_removed(obj):
-        if obj.is_index:
+    def obj_removed(obj: _constraint_sig):
+        if is_index_sig(obj):
             if obj.is_unique and not supports_unique_constraints:
                 # many databases double up unique constraints
                 # as unique indexes.  without that list we can't
@@ -758,10 +668,8 @@ def _compare_indexes_and_uniques(
                 obj.const, obj.name, "index", True, None
             ):
                 modify_ops.ops.append(ops.DropIndexOp.from_index(obj.const))
-                log.info(
-                    "Detected removed index '%s' on '%s'", obj.name, tname
-                )
-        else:
+                log.info("Detected removed index %r on %r", obj.name, tname)
+        elif is_uq_sig(obj):
             if is_create_table or is_drop_table:
                 # if the whole table is being dropped, we don't need to
                 # consider unique constraint separately
@@ -773,33 +681,40 @@ def _compare_indexes_and_uniques(
                     ops.DropConstraintOp.from_constraint(obj.const)
                 )
                 log.info(
-                    "Detected removed unique constraint '%s' on '%s'",
+                    "Detected removed unique constraint %r on %r",
                     obj.name,
                     tname,
                 )
+        else:
+            assert False
+
+    def obj_changed(
+        old: _constraint_sig,
+        new: _constraint_sig,
+        msg: str,
+    ):
+        if is_index_sig(old):
+            assert is_index_sig(new)
 
-    def obj_changed(old, new, msg):
-        if old.is_index:
             if autogen_context.run_object_filters(
                 new.const, new.name, "index", False, old.const
             ):
                 log.info(
-                    "Detected changed index '%s' on '%s':%s",
-                    old.name,
-                    tname,
-                    ", ".join(msg),
+                    "Detected changed index %r on %r: %s", old.name, tname, msg
                 )
                 modify_ops.ops.append(ops.DropIndexOp.from_index(old.const))
                 modify_ops.ops.append(ops.CreateIndexOp.from_index(new.const))
-        else:
+        elif is_uq_sig(old):
+            assert is_uq_sig(new)
+
             if autogen_context.run_object_filters(
                 new.const, new.name, "unique_constraint", False, old.const
             ):
                 log.info(
-                    "Detected changed unique constraint '%s' on '%s':%s",
+                    "Detected changed unique constraint %r on %r: %s",
                     old.name,
                     tname,
-                    ", ".join(msg),
+                    msg,
                 )
                 modify_ops.ops.append(
                     ops.DropConstraintOp.from_constraint(old.const)
@@ -807,18 +722,24 @@ def _compare_indexes_and_uniques(
                 modify_ops.ops.append(
                     ops.AddConstraintOp.from_constraint(new.const)
                 )
+        else:
+            assert False
 
     for removed_name in sorted(set(conn_names).difference(metadata_names)):
-        conn_obj: Union[_ix_constraint_sig, _uq_constraint_sig] = conn_names[
-            removed_name
-        ]
-        if not conn_obj.is_index and conn_obj.sig in unnamed_metadata_uniques:
+        conn_obj = conn_names[removed_name]
+        if (
+            is_uq_sig(conn_obj)
+            and conn_obj.unnamed in unnamed_metadata_uniques
+        ):
             continue
         elif removed_name in doubled_constraints:
             conn_uq, conn_idx = doubled_constraints[removed_name]
             if (
-                conn_idx.sig not in metadata_indexes_by_sig
-                and conn_uq.sig not in metadata_uniques_by_sig
+                all(
+                    conn_idx.unnamed != meta_idx.unnamed
+                    for meta_idx in metadata_indexes_sig
+                )
+                and conn_uq.unnamed not in metadata_uniques_by_sig
             ):
                 obj_removed(conn_uq)
                 obj_removed(conn_idx)
@@ -830,30 +751,36 @@ def _compare_indexes_and_uniques(
 
         if existing_name in doubled_constraints:
             conn_uq, conn_idx = doubled_constraints[existing_name]
-            if metadata_obj.is_index:
+            if is_index_sig(metadata_obj):
                 conn_obj = conn_idx
             else:
                 conn_obj = conn_uq
         else:
             conn_obj = conn_names[existing_name]
 
-        if conn_obj.is_index != metadata_obj.is_index:
+        if type(conn_obj) != type(metadata_obj):
             obj_removed(conn_obj)
             obj_added(metadata_obj)
         else:
-            msg = []
-            if conn_obj.is_unique != metadata_obj.is_unique:
-                msg.append(
-                    " unique=%r to unique=%r"
-                    % (conn_obj.is_unique, metadata_obj.is_unique)
+            comparison = metadata_obj.compare_to_reflected(conn_obj)
+
+            if comparison.is_different:
+                # constraint are different
+                obj_changed(conn_obj, metadata_obj, comparison.message)
+            elif comparison.is_skip:
+                # constraint cannot be compared, skip them
+                thing = (
+                    "index" if is_index_sig(conn_obj) else "unique constraint"
                 )
-            if conn_obj.sig != metadata_obj.sig:
-                msg.append(
-                    " expression %r to %r" % (conn_obj.sig, metadata_obj.sig)
+                log.info(
+                    "Cannot compare %s %r, assuming equal and skipping. %s",
+                    thing,
+                    conn_obj.name,
+                    comparison.message,
                 )
-
-            if msg:
-                obj_changed(conn_obj, metadata_obj, msg)
+            else:
+                # constraint are equal
+                assert comparison.is_equal
 
     for added_name in sorted(set(metadata_names).difference(conn_names)):
         obj = metadata_names[added_name]
@@ -893,7 +820,7 @@ def _correct_for_uq_duplicates_uix(
     }
 
     unnamed_metadata_uqs = {
-        _uq_constraint_sig(cons, impl).sig
+        impl._create_metadata_constraint_sig(cons).unnamed
         for name, cons in metadata_cons_names
         if name is None
     }
@@ -917,7 +844,9 @@ def _correct_for_uq_duplicates_uix(
     for overlap in uqs_dupe_indexes:
         if overlap not in metadata_uq_names:
             if (
-                _uq_constraint_sig(uqs_dupe_indexes[overlap], impl).sig
+                impl._create_reflected_constraint_sig(
+                    uqs_dupe_indexes[overlap]
+                ).unnamed
                 not in unnamed_metadata_uqs
             ):
                 conn_unique_constraints.discard(uqs_dupe_indexes[overlap])
@@ -1243,8 +1172,8 @@ def _compare_foreign_keys(
     modify_table_ops: ModifyTableOps,
     schema: Optional[str],
     tname: Union[quoted_name, str],
-    conn_table: Optional[Table],
-    metadata_table: Optional[Table],
+    conn_table: Table,
+    metadata_table: Table,
 ) -> None:
     # if we're doing CREATE TABLE, all FKs are created
     # inline within the table def
@@ -1268,15 +1197,13 @@ def _compare_foreign_keys(
         )
     ]
 
-    backend_reflects_fk_options = bool(
-        conn_fks_list and "options" in conn_fks_list[0]
-    )
-
     conn_fks = {
         _make_foreign_key(const, conn_table)  # type: ignore[arg-type]
         for const in conn_fks_list
     }
 
+    impl = autogen_context.migration_context.impl
+
     # give the dialect a chance to correct the FKs to match more
     # closely
     autogen_context.migration_context.impl.correct_for_autogen_foreignkeys(
@@ -1284,17 +1211,24 @@ def _compare_foreign_keys(
     )
 
     metadata_fks_sig = {
-        _fk_constraint_sig(fk, include_options=backend_reflects_fk_options)
-        for fk in metadata_fks
+        impl._create_metadata_constraint_sig(fk) for fk in metadata_fks
     }
 
     conn_fks_sig = {
-        _fk_constraint_sig(fk, include_options=backend_reflects_fk_options)
-        for fk in conn_fks
+        impl._create_reflected_constraint_sig(fk) for fk in conn_fks
     }
 
-    conn_fks_by_sig = {c.sig: c for c in conn_fks_sig}
-    metadata_fks_by_sig = {c.sig: c for c in metadata_fks_sig}
+    # check if reflected FKs include options, indicating the backend
+    # can reflect FK options
+    if conn_fks_list and "options" in conn_fks_list[0]:
+        conn_fks_by_sig = {c.unnamed: c for c in conn_fks_sig}
+        metadata_fks_by_sig = {c.unnamed: c for c in metadata_fks_sig}
+    else:
+        # otherwise compare by sig without options added
+        conn_fks_by_sig = {c.unnamed_no_options: c for c in conn_fks_sig}
+        metadata_fks_by_sig = {
+            c.unnamed_no_options: c for c in metadata_fks_sig
+        }
 
     metadata_fks_by_name = {
         c.name: c for c in metadata_fks_sig if c.name is not None
diff --git a/alembic/ddl/_autogen.py b/alembic/ddl/_autogen.py
new file mode 100644 (file)
index 0000000..cc1a1fc
--- /dev/null
@@ -0,0 +1,323 @@
+from __future__ import annotations
+
+from typing import Any
+from typing import ClassVar
+from typing import Dict
+from typing import Generic
+from typing import NamedTuple
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Type
+from typing import TYPE_CHECKING
+from typing import TypeVar
+from typing import Union
+
+from sqlalchemy.sql.schema import Constraint
+from sqlalchemy.sql.schema import ForeignKeyConstraint
+from sqlalchemy.sql.schema import Index
+from sqlalchemy.sql.schema import UniqueConstraint
+from typing_extensions import TypeGuard
+
+from alembic.ddl.base import _fk_spec
+from .. import util
+from ..util import sqla_compat
+
+if TYPE_CHECKING:
+    from typing import Literal
+
+    from alembic.autogenerate.api import AutogenContext
+    from alembic.ddl.impl import DefaultImpl
+
+CompareConstraintType = Union[Constraint, Index]
+
+_C = TypeVar("_C", bound=CompareConstraintType)
+
+_clsreg: Dict[str, Type[_constraint_sig]] = {}
+
+
+class ComparisonResult(NamedTuple):
+    status: Literal["equal", "different", "skip"]
+    message: str
+
+    @property
+    def is_equal(self) -> bool:
+        return self.status == "equal"
+
+    @property
+    def is_different(self) -> bool:
+        return self.status == "different"
+
+    @property
+    def is_skip(self) -> bool:
+        return self.status == "skip"
+
+    @classmethod
+    def Equal(cls) -> ComparisonResult:
+        """the constraints are equal."""
+        return cls("equal", "The two constraints are equal")
+
+    @classmethod
+    def Different(cls, reason: Union[str, Sequence[str]]) -> ComparisonResult:
+        """the constraints are different for the provided reason(s)."""
+        return cls("different", ", ".join(util.to_list(reason)))
+
+    @classmethod
+    def Skip(cls, reason: Union[str, Sequence[str]]) -> ComparisonResult:
+        """the constraint cannot be compared for the provided reason(s).
+
+        The message is logged, but the constraints will be otherwise
+        considered equal, meaning that no migration command will be
+        generated.
+        """
+        return cls("skip", ", ".join(util.to_list(reason)))
+
+
+class _constraint_sig(Generic[_C]):
+    const: _C
+
+    _sig: Tuple[Any, ...]
+    name: Optional[sqla_compat._ConstraintNameDefined]
+
+    impl: DefaultImpl
+
+    _is_index: ClassVar[bool] = False
+    _is_fk: ClassVar[bool] = False
+    _is_uq: ClassVar[bool] = False
+
+    _is_metadata: bool
+
+    def __init_subclass__(cls) -> None:
+        cls._register()
+
+    @classmethod
+    def _register(cls):
+        raise NotImplementedError()
+
+    def __init__(
+        self, is_metadata: bool, impl: DefaultImpl, const: _C
+    ) -> None:
+        raise NotImplementedError()
+
+    def compare_to_reflected(
+        self, other: _constraint_sig[Any]
+    ) -> ComparisonResult:
+        assert self.impl is other.impl
+        assert self._is_metadata
+        assert not other._is_metadata
+
+        return self._compare_to_reflected(other)
+
+    def _compare_to_reflected(
+        self, other: _constraint_sig[_C]
+    ) -> ComparisonResult:
+        raise NotImplementedError()
+
+    @classmethod
+    def from_constraint(
+        cls, is_metadata: bool, impl: DefaultImpl, constraint: _C
+    ) -> _constraint_sig[_C]:
+        # these could be cached by constraint/impl, however, if the
+        # constraint is modified in place, then the sig is wrong.  the mysql
+        # impl currently does this, and if we fixed that we can't be sure
+        # someone else might do it too, so play it safe.
+        sig = _clsreg[constraint.__visit_name__](is_metadata, impl, constraint)
+        return sig
+
+    def md_name_to_sql_name(self, context: AutogenContext) -> Optional[str]:
+        return sqla_compat._get_constraint_final_name(
+            self.const, context.dialect
+        )
+
+    @util.memoized_property
+    def is_named(self):
+        return sqla_compat._constraint_is_named(self.const, self.impl.dialect)
+
+    @util.memoized_property
+    def unnamed(self) -> Tuple[Any, ...]:
+        return self._sig
+
+    @util.memoized_property
+    def unnamed_no_options(self) -> Tuple[Any, ...]:
+        raise NotImplementedError()
+
+    @util.memoized_property
+    def _full_sig(self) -> Tuple[Any, ...]:
+        return (self.name,) + self.unnamed
+
+    def __eq__(self, other) -> bool:
+        return self._full_sig == other._full_sig
+
+    def __ne__(self, other) -> bool:
+        return self._full_sig != other._full_sig
+
+    def __hash__(self) -> int:
+        return hash(self._full_sig)
+
+
+class _uq_constraint_sig(_constraint_sig[UniqueConstraint]):
+    _is_uq = True
+
+    @classmethod
+    def _register(cls) -> None:
+        _clsreg["unique_constraint"] = cls
+
+    is_unique = True
+
+    def __init__(
+        self,
+        is_metadata: bool,
+        impl: DefaultImpl,
+        const: UniqueConstraint,
+    ) -> None:
+        self.impl = impl
+        self.const = const
+        self.name = sqla_compat.constraint_name_or_none(const.name)
+        self._sig = tuple(sorted([col.name for col in const.columns]))
+        self._is_metadata = is_metadata
+
+    @property
+    def column_names(self) -> Tuple[str, ...]:
+        return tuple([col.name for col in self.const.columns])
+
+    def _compare_to_reflected(
+        self, other: _constraint_sig[_C]
+    ) -> ComparisonResult:
+        assert self._is_metadata
+        metadata_obj = self
+        conn_obj = other
+
+        assert is_uq_sig(conn_obj)
+        return self.impl.compare_unique_constraint(
+            metadata_obj.const, conn_obj.const
+        )
+
+
+class _ix_constraint_sig(_constraint_sig[Index]):
+    _is_index = True
+
+    name: sqla_compat._ConstraintName
+
+    @classmethod
+    def _register(cls) -> None:
+        _clsreg["index"] = cls
+
+    def __init__(
+        self, is_metadata: bool, impl: DefaultImpl, const: Index
+    ) -> None:
+        self.impl = impl
+        self.const = const
+        self.name = const.name
+        self.is_unique = bool(const.unique)
+        self._is_metadata = is_metadata
+
+    def _compare_to_reflected(
+        self, other: _constraint_sig[_C]
+    ) -> ComparisonResult:
+        assert self._is_metadata
+        metadata_obj = self
+        conn_obj = other
+
+        assert is_index_sig(conn_obj)
+        return self.impl.compare_indexes(metadata_obj.const, conn_obj.const)
+
+    @util.memoized_property
+    def has_expressions(self):
+        return sqla_compat.is_expression_index(self.const)
+
+    @util.memoized_property
+    def column_names(self) -> Tuple[str, ...]:
+        return tuple([col.name for col in self.const.columns])
+
+    @util.memoized_property
+    def column_names_optional(self) -> Tuple[Optional[str], ...]:
+        return tuple(
+            [getattr(col, "name", None) for col in self.const.expressions]
+        )
+
+    @util.memoized_property
+    def is_named(self):
+        return True
+
+    @util.memoized_property
+    def unnamed(self):
+        return (self.is_unique,) + self.column_names_optional
+
+
+class _fk_constraint_sig(_constraint_sig[ForeignKeyConstraint]):
+    _is_fk = True
+
+    @classmethod
+    def _register(cls) -> None:
+        _clsreg["foreign_key_constraint"] = cls
+
+    def __init__(
+        self,
+        is_metadata: bool,
+        impl: DefaultImpl,
+        const: ForeignKeyConstraint,
+    ) -> None:
+        self._is_metadata = is_metadata
+
+        self.impl = impl
+        self.const = const
+
+        self.name = sqla_compat.constraint_name_or_none(const.name)
+
+        (
+            self.source_schema,
+            self.source_table,
+            self.source_columns,
+            self.target_schema,
+            self.target_table,
+            self.target_columns,
+            onupdate,
+            ondelete,
+            deferrable,
+            initially,
+        ) = _fk_spec(const)
+
+        self._sig: Tuple[Any, ...] = (
+            self.source_schema,
+            self.source_table,
+            tuple(self.source_columns),
+            self.target_schema,
+            self.target_table,
+            tuple(self.target_columns),
+        ) + (
+            (None if onupdate.lower() == "no action" else onupdate.lower())
+            if onupdate
+            else None,
+            (None if ondelete.lower() == "no action" else ondelete.lower())
+            if ondelete
+            else None,
+            # convert initially + deferrable into one three-state value
+            "initially_deferrable"
+            if initially and initially.lower() == "deferred"
+            else "deferrable"
+            if deferrable
+            else "not deferrable",
+        )
+
+    @util.memoized_property
+    def unnamed_no_options(self):
+        return (
+            self.source_schema,
+            self.source_table,
+            tuple(self.source_columns),
+            self.target_schema,
+            self.target_table,
+            tuple(self.target_columns),
+        )
+
+
+def is_index_sig(sig: _constraint_sig) -> TypeGuard[_ix_constraint_sig]:
+    return sig._is_index
+
+
+def is_uq_sig(sig: _constraint_sig) -> TypeGuard[_uq_constraint_sig]:
+    return sig._is_uq
+
+
+def is_fk_sig(sig: _constraint_sig) -> TypeGuard[_fk_constraint_sig]:
+    return sig._is_fk
index 8a7c75d46170ae2738a931b387f5080b404f010c..571a3041cc66ba52f4fe20d87402fb2789547a37 100644 (file)
@@ -1,6 +1,6 @@
 from __future__ import annotations
 
-from collections import namedtuple
+import logging
 import re
 from typing import Any
 from typing import Callable
@@ -8,6 +8,7 @@ from typing import Dict
 from typing import Iterable
 from typing import List
 from typing import Mapping
+from typing import NamedTuple
 from typing import Optional
 from typing import Sequence
 from typing import Set
@@ -20,7 +21,10 @@ from sqlalchemy import cast
 from sqlalchemy import schema
 from sqlalchemy import text
 
+from . import _autogen
 from . import base
+from ._autogen import _constraint_sig
+from ._autogen import ComparisonResult
 from .. import util
 from ..util import sqla_compat
 
@@ -50,6 +54,8 @@ if TYPE_CHECKING:
     from ..operations.batch import ApplyBatchImpl
     from ..operations.batch import BatchOperationsImpl
 
+log = logging.getLogger(__name__)
+
 
 class ImplMeta(type):
     def __init__(
@@ -66,8 +72,6 @@ class ImplMeta(type):
 
 _impls: Dict[str, Type[DefaultImpl]] = {}
 
-Params = namedtuple("Params", ["token0", "tokens", "args", "kwargs"])
-
 
 class DefaultImpl(metaclass=ImplMeta):
 
@@ -438,6 +442,7 @@ class DefaultImpl(metaclass=ImplMeta):
                         )
 
     def _tokenize_column_type(self, column: Column) -> Params:
+        definition: str
         definition = self.dialect.type_compiler.process(column.type).lower()
 
         # tokenize the SQLAlchemy-generated version of a type, so that
@@ -452,9 +457,9 @@ class DefaultImpl(metaclass=ImplMeta):
         # varchar character set utf8
         #
 
-        tokens = re.findall(r"[\w\-_]+|\(.+?\)", definition)
+        tokens: List[str] = re.findall(r"[\w\-_]+|\(.+?\)", definition)
 
-        term_tokens = []
+        term_tokens: List[str] = []
         paren_term = None
 
         for token in tokens:
@@ -466,6 +471,7 @@ class DefaultImpl(metaclass=ImplMeta):
         params = Params(term_tokens[0], term_tokens[1:], [], {})
 
         if paren_term:
+            term: str
             for term in re.findall("[^(),]+", paren_term):
                 if "=" in term:
                     key, val = term.split("=")
@@ -664,15 +670,96 @@ class DefaultImpl(metaclass=ImplMeta):
             bool(diff) or bool(metadata_identity) != bool(inspector_identity),
         )
 
-    def create_index_sig(self, index: Index) -> Tuple[Any, ...]:
-        # order of col matters in an index
-        return tuple(col.name for col in index.columns)
+    def _compare_index_unique(
+        self, metadata_index: Index, reflected_index: Index
+    ) -> Optional[str]:
+        conn_unique = bool(reflected_index.unique)
+        meta_unique = bool(metadata_index.unique)
+        if conn_unique != meta_unique:
+            return f"unique={conn_unique} to unique={meta_unique}"
+        else:
+            return None
+
+    def _create_metadata_constraint_sig(
+        self, constraint: _autogen._C, **opts: Any
+    ) -> _constraint_sig[_autogen._C]:
+        return _constraint_sig.from_constraint(True, self, constraint, **opts)
+
+    def _create_reflected_constraint_sig(
+        self, constraint: _autogen._C, **opts: Any
+    ) -> _constraint_sig[_autogen._C]:
+        return _constraint_sig.from_constraint(False, self, constraint, **opts)
+
+    def compare_indexes(
+        self,
+        metadata_index: Index,
+        reflected_index: Index,
+    ) -> ComparisonResult:
+        """Compare two indexes by comparing the signature generated by
+        ``create_index_sig``.
 
-    def create_unique_constraint_sig(
-        self, const: UniqueConstraint
-    ) -> Tuple[Any, ...]:
-        # order of col does not matters in an unique constraint
-        return tuple(sorted([col.name for col in const.columns]))
+        This method returns a ``ComparisonResult``.
+        """
+        msg: List[str] = []
+        unique_msg = self._compare_index_unique(
+            metadata_index, reflected_index
+        )
+        if unique_msg:
+            msg.append(unique_msg)
+        m_sig = self._create_metadata_constraint_sig(metadata_index)
+        r_sig = self._create_reflected_constraint_sig(reflected_index)
+
+        assert _autogen.is_index_sig(m_sig)
+        assert _autogen.is_index_sig(r_sig)
+
+        # The assumption is that the index have no expression
+        for sig in m_sig, r_sig:
+            if sig.has_expressions:
+                log.warning(
+                    "Generating approximate signature for index %s. "
+                    "The dialect "
+                    "implementation should either skip expression indexes "
+                    "or provide a custom implementation.",
+                    sig.const,
+                )
+
+        if m_sig.column_names != r_sig.column_names:
+            msg.append(
+                f"expression {r_sig.column_names} to {m_sig.column_names}"
+            )
+
+        if msg:
+            return ComparisonResult.Different(msg)
+        else:
+            return ComparisonResult.Equal()
+
+    def compare_unique_constraint(
+        self,
+        metadata_constraint: UniqueConstraint,
+        reflected_constraint: UniqueConstraint,
+    ) -> ComparisonResult:
+        """Compare two unique constraints by comparing the two signatures.
+
+        The arguments are two tuples that contain the unique constraint and
+        the signatures generated by ``create_unique_constraint_sig``.
+
+        This method returns a ``ComparisonResult``.
+        """
+        metadata_tup = self._create_metadata_constraint_sig(
+            metadata_constraint
+        )
+        reflected_tup = self._create_reflected_constraint_sig(
+            reflected_constraint
+        )
+
+        meta_sig = metadata_tup.unnamed
+        conn_sig = reflected_tup.unnamed
+        if conn_sig != meta_sig:
+            return ComparisonResult.Different(
+                f"expression {conn_sig} to {meta_sig}"
+            )
+        else:
+            return ComparisonResult.Equal()
 
     def _skip_functional_indexes(self, metadata_indexes, conn_indexes):
         conn_indexes_by_name = {c.name: c for c in conn_indexes}
@@ -697,6 +784,13 @@ class DefaultImpl(metaclass=ImplMeta):
         return reflected_object.get("dialect_options", {})
 
 
+class Params(NamedTuple):
+    token0: str
+    tokens: List[str]
+    args: List[str]
+    kwargs: Dict[str, str]
+
+
 def _compare_identity_options(
     metadata_io: Union[schema.Identity, schema.Sequence, None],
     inspector_io: Union[schema.Identity, schema.Sequence, None],
index 32ced498b13017d823e5d3ab25d7faca659d5c23..5a2af5ce7b773b12295a02afad6d8614137d3a77 100644 (file)
@@ -20,7 +20,6 @@ from .base import format_column_name
 from .base import format_server_default
 from .impl import DefaultImpl
 from .. import util
-from ..autogenerate import compare
 from ..util import sqla_compat
 from ..util.sqla_compat import _is_mariadb
 from ..util.sqla_compat import _is_type_bound
@@ -272,10 +271,12 @@ class MySQLImpl(DefaultImpl):
 
     def correct_for_autogen_foreignkeys(self, conn_fks, metadata_fks):
         conn_fk_by_sig = {
-            compare._fk_constraint_sig(fk).sig: fk for fk in conn_fks
+            self._create_reflected_constraint_sig(fk).unnamed_no_options: fk
+            for fk in conn_fks
         }
         metadata_fk_by_sig = {
-            compare._fk_constraint_sig(fk).sig: fk for fk in metadata_fks
+            self._create_metadata_constraint_sig(fk).unnamed_no_options: fk
+            for fk in metadata_fks
         }
 
         for sig in set(conn_fk_by_sig).intersection(metadata_fk_by_sig):
index 949e256260b44b8306535f77762662d67974864f..68628c8ecfda7f41f967d83deba054e60fbe1cef 100644 (file)
@@ -21,10 +21,8 @@ from sqlalchemy.dialects.postgresql import BIGINT
 from sqlalchemy.dialects.postgresql import ExcludeConstraint
 from sqlalchemy.dialects.postgresql import INTEGER
 from sqlalchemy.schema import CreateIndex
-from sqlalchemy.sql import operators
 from sqlalchemy.sql.elements import ColumnClause
 from sqlalchemy.sql.elements import TextClause
-from sqlalchemy.sql.elements import UnaryExpression
 from sqlalchemy.sql.functions import FunctionElement
 from sqlalchemy.types import NULLTYPE
 
@@ -38,6 +36,7 @@ from .base import format_table_name
 from .base import format_type
 from .base import IdentityColumnDefault
 from .base import RenameTable
+from .impl import ComparisonResult
 from .impl import DefaultImpl
 from .. import util
 from ..autogenerate import render
@@ -252,62 +251,60 @@ class PostgresqlImpl(DefaultImpl):
         if not sqla_compat.sqla_2:
             self._skip_functional_indexes(metadata_indexes, conn_indexes)
 
-    def _cleanup_index_expr(
-        self, index: Index, expr: str, remove_suffix: str
-    ) -> str:
-        # start = expr
+    # pg behavior regarding modifiers
+    # | # | compiled sql     | returned sql     | regexp. group is removed |
+    # | - | ---------------- | -----------------| ------------------------ |
+    # | 1 | nulls first      | nulls first      | -                        |
+    # | 2 | nulls last       |                  | (?<! desc)( nulls last)$ |
+    # | 3 | asc              |                  | ( asc)$                  |
+    # | 4 | asc nulls first  | nulls first      | ( asc) nulls first$      |
+    # | 5 | asc nulls last   |                  | ( asc nulls last)$       |
+    # | 6 | desc             | desc             | -                        |
+    # | 7 | desc nulls first | desc             | desc( nulls first)$      |
+    # | 8 | desc nulls last  | desc nulls last  | -                        |
+    _default_modifiers_re = (  # order of case 2 and 5 matters
+        re.compile("( asc nulls last)$"),  # case 5
+        re.compile("(?<! desc)( nulls last)$"),  # case 2
+        re.compile("( asc)$"),  # case 3
+        re.compile("( asc) nulls first$"),  # case 4
+        re.compile(" desc( nulls first)$"),  # case 7
+    )
+
+    def _cleanup_index_expr(self, index: Index, expr: str) -> str:
         expr = expr.lower().replace('"', "").replace("'", "")
         if index.table is not None:
             # should not be needed, since include_table=False is in compile
             expr = expr.replace(f"{index.table.name.lower()}.", "")
 
-        while expr and expr[0] == "(" and expr[-1] == ")":
-            expr = expr[1:-1]
         if "::" in expr:
             # strip :: cast. types can have spaces in them
             expr = re.sub(r"(::[\w ]+\w)", "", expr)
 
-        if remove_suffix and expr.endswith(remove_suffix):
-            expr = expr[: -len(remove_suffix)]
-
-        # print(f"START: {start} END: {expr}")
-        return expr
+        while expr and expr[0] == "(" and expr[-1] == ")":
+            expr = expr[1:-1]
 
-    def _default_modifiers(self, exp: ClauseElement) -> str:
-        to_remove = ""
-        while isinstance(exp, UnaryExpression):
-            if exp.modifier is None:
-                exp = exp.element
-            else:
-                op = exp.modifier
-                if isinstance(exp.element, UnaryExpression):
-                    inner_op = exp.element.modifier
-                else:
-                    inner_op = None
-                if inner_op is None:
-                    if op == operators.asc_op:
-                        # default is asc
-                        to_remove = " asc"
-                    elif op == operators.nullslast_op:
-                        # default is nulls last
-                        to_remove = " nulls last"
-                else:
-                    if (
-                        inner_op == operators.asc_op
-                        and op == operators.nullslast_op
-                    ):
-                        # default is asc nulls last
-                        to_remove = " asc nulls last"
-                    elif (
-                        inner_op == operators.desc_op
-                        and op == operators.nullsfirst_op
-                    ):
-                        # default for desc is nulls first
-                        to_remove = " nulls first"
+        # NOTE: when parsing the connection expression this cleanup could
+        # be skipped
+        for rs in self._default_modifiers_re:
+            if match := rs.search(expr):
+                start, end = match.span(1)
+                expr = expr[:start] + expr[end:]
                 break
-        return to_remove
 
-    def _dialect_sig(
+        while expr and expr[0] == "(" and expr[-1] == ")":
+            expr = expr[1:-1]
+
+        # strip casts
+        cast_re = re.compile(r"cast\s*\(")
+        if cast_re.match(expr):
+            expr = cast_re.sub("", expr)
+            # remove the as type
+            expr = re.sub(r"as\s+[^)]+\)", "", expr)
+        # remove spaces
+        expr = expr.replace(" ", "")
+        return expr
+
+    def _dialect_options(
         self, item: Union[Index, UniqueConstraint]
     ) -> Tuple[Any, ...]:
         # only the positive case is returned by sqlalchemy reflection so
@@ -316,25 +313,93 @@ class PostgresqlImpl(DefaultImpl):
             return ("nulls_not_distinct",)
         return ()
 
-    def create_index_sig(self, index: Index) -> Tuple[Any, ...]:
-        return tuple(
-            self._cleanup_index_expr(
-                index,
-                *(
-                    (e, "")
-                    if isinstance(e, str)
-                    else (self._compile_element(e), self._default_modifiers(e))
-                ),
+    def compare_indexes(
+        self,
+        metadata_index: Index,
+        reflected_index: Index,
+    ) -> ComparisonResult:
+        msg = []
+        unique_msg = self._compare_index_unique(
+            metadata_index, reflected_index
+        )
+        if unique_msg:
+            msg.append(unique_msg)
+        m_exprs = metadata_index.expressions
+        r_exprs = reflected_index.expressions
+        if len(m_exprs) != len(r_exprs):
+            msg.append(f"expression number {len(r_exprs)} to {len(m_exprs)}")
+        if msg:
+            # no point going further, return early
+            return ComparisonResult.Different(msg)
+        skip = []
+        for pos, (m_e, r_e) in enumerate(zip(m_exprs, r_exprs), 1):
+            m_compile = self._compile_element(m_e)
+            m_text = self._cleanup_index_expr(metadata_index, m_compile)
+            # print(f"META ORIG: {m_compile!r} CLEANUP: {m_text!r}")
+            r_compile = self._compile_element(r_e)
+            r_text = self._cleanup_index_expr(metadata_index, r_compile)
+            # print(f"CONN ORIG: {r_compile!r} CLEANUP: {r_text!r}")
+            if m_text == r_text:
+                continue  # expressions these are equal
+            elif m_compile.strip().endswith("_ops") and (
+                " " in m_compile or ")" in m_compile  # is an expression
+            ):
+                skip.append(
+                    f"expression #{pos} {m_compile!r} detected "
+                    "as including operator clause."
+                )
+                util.warn(
+                    f"Expression #{pos} {m_compile!r} in index "
+                    f"{reflected_index.name!r} detected to include "
+                    "an operator clause. Expression compare cannot proceed. "
+                    "Please move the operator clause to the "
+                    "``postgresql_ops`` dict to enable proper compare "
+                    "of the index expressions: "
+                    "https://docs.sqlalchemy.org/en/latest/dialects/postgresql.html#operator-classes",  # noqa: E501
+                )
+            else:
+                msg.append(f"expression #{pos} {r_compile!r} to {m_compile!r}")
+
+        m_options = self._dialect_options(metadata_index)
+        r_options = self._dialect_options(reflected_index)
+        if m_options != r_options:
+            msg.extend(f"options {r_options} to {m_options}")
+
+        if msg:
+            return ComparisonResult.Different(msg)
+        elif skip:
+            # if there are other changes detected don't skip the index
+            return ComparisonResult.Skip(skip)
+        else:
+            return ComparisonResult.Equal()
+
+    def compare_unique_constraint(
+        self,
+        metadata_constraint: UniqueConstraint,
+        reflected_constraint: UniqueConstraint,
+    ) -> ComparisonResult:
+        metadata_tup = self._create_metadata_constraint_sig(
+            metadata_constraint
+        )
+        reflected_tup = self._create_reflected_constraint_sig(
+            reflected_constraint
+        )
+
+        meta_sig = metadata_tup.unnamed
+        conn_sig = reflected_tup.unnamed
+        if conn_sig != meta_sig:
+            return ComparisonResult.Different(
+                f"expression {conn_sig} to {meta_sig}"
             )
-            for e in index.expressions
-        ) + self._dialect_sig(index)
 
-    def create_unique_constraint_sig(
-        self, const: UniqueConstraint
-    ) -> Tuple[Any, ...]:
-        return tuple(
-            sorted([col.name for col in const.columns])
-        ) + self._dialect_sig(const)
+        metadata_do = self._dialect_options(metadata_tup.const)
+        conn_do = self._dialect_options(reflected_tup.const)
+        if metadata_do != conn_do:
+            return ComparisonResult.Different(
+                f"expression {conn_do} to {metadata_do}"
+            )
+
+        return ComparisonResult.Equal()
 
     def adjust_reflected_dialect_options(
         self, reflected_options: Dict[str, Any], kind: str
@@ -345,7 +410,9 @@ class PostgresqlImpl(DefaultImpl):
             options.pop("postgresql_include", None)
         return options
 
-    def _compile_element(self, element: ClauseElement) -> str:
+    def _compile_element(self, element: Union[ClauseElement, str]) -> str:
+        if isinstance(element, str):
+            return element
         return element.compile(
             dialect=self.dialect,
             compile_kwargs={"literal_binds": True, "include_table": False},
index 711d7aba33f0136e7752b6311d366fa3883c88c7..8bedcd87825ce4089fb76866d9b84fa7090babc1 100644 (file)
@@ -899,7 +899,7 @@ class CreateIndexOp(MigrateOperation):
         return cls(
             index.name,  # type: ignore[arg-type]
             index.table.name,
-            sqla_compat._get_index_expressions(index),
+            index.expressions,
             schema=index.table.schema,
             unique=index.unique,
             **index.kwargs,
index c06349957649613bf5a57c0514033d7d4809fe69..204cc4ddc15b1457cdbacb2c238a625e19c49100 100644 (file)
@@ -1,6 +1,7 @@
 from itertools import zip_longest
 
 from sqlalchemy import schema
+from sqlalchemy.sql.elements import ClauseList
 
 
 class CompareTable:
@@ -60,6 +61,14 @@ class CompareIndex:
     def __ne__(self, other):
         return not self.__eq__(other)
 
+    def __repr__(self):
+        expr = ClauseList(*self.index.expressions)
+        try:
+            expr_str = expr.compile().string
+        except Exception:
+            expr_str = str(expr)
+        return f"<CompareIndex {self.index.name}({expr_str})>"
+
 
 class CompareCheckConstraint:
     def __init__(self, constraint):
index 3f175cf5747f90ca5d09f12396c99c0bce186300..9332a062563c89a634257966f3427d372f5f22e4 100644 (file)
@@ -524,14 +524,6 @@ def _render_literal_bindparam(
     return compiler.render_literal_bindparam(element, **kw)
 
 
-def _get_index_expressions(idx):
-    return list(idx.expressions)
-
-
-def _get_index_column_names(idx):
-    return [getattr(exp, "name", None) for exp in _get_index_expressions(idx)]
-
-
 def _column_kwargs(col: Column) -> Mapping:
     if sqla_13:
         return col.kwargs
@@ -630,10 +622,15 @@ else:
 
 
 def is_expression_index(index: Index) -> bool:
-    expr: Any
     for expr in index.expressions:
-        while isinstance(expr, UnaryExpression):
-            expr = expr.element
-        if not isinstance(expr, ColumnClause) or expr.is_literal:
+        if is_expression(expr):
             return True
     return False
+
+
+def is_expression(expr: Any) -> bool:
+    while isinstance(expr, UnaryExpression):
+        expr = expr.element
+    if not isinstance(expr, ColumnClause) or expr.is_literal:
+        return True
+    return False
diff --git a/docs/build/unreleased/more_index_fixes.rst b/docs/build/unreleased/more_index_fixes.rst
new file mode 100644 (file)
index 0000000..4645c9c
--- /dev/null
@@ -0,0 +1,12 @@
+.. change::
+    :tags: bug, postgresql
+    :tickets: 1321, 1327, 1356
+
+    Additional fixes to PostgreSQL expression index compare feature.
+    The compare now correctly accommodates casts and differences in
+    spacing.
+    Added detection logic for operation clauses inside the expression,
+    skipping the compare of these expressions.
+    To accommodate these changes the logic for the comparison of the
+    indexes and unique constraints was moved to the dialect
+    implementation, allowing greater flexibility.
index e3e67ea971c014f476be3f78f72eae86e82f5198..b90b6ade05220cafd4e7191ed02934b858f93775 100644 (file)
--- a/setup.cfg
+++ b/setup.cfg
@@ -110,9 +110,10 @@ default=sqlite:///:memory:
 sqlite=sqlite:///:memory:
 sqlite_file=sqlite:///querytest.db
 postgresql=postgresql://scott:tiger@127.0.0.1:5432/test
+psycopg=postgresql+psycopg://scott:tiger@127.0.0.1:5432/test
 mysql=mysql://scott:tiger@127.0.0.1:3306/test?charset=utf8mb4
-mariadb = mariadb://scott:tiger@127.0.0.1:3306/test?charset=utf8mb4
-mssql = mssql+pyodbc://scott:tiger^5HHH@mssql2017:1433/test?driver=ODBC+Driver+13+for+SQL+Server
+mariadb=mariadb://scott:tiger@127.0.0.1:3306/test?charset=utf8mb4
+mssql=mssql+pyodbc://scott:tiger^5HHH@mssql2017:1433/test?driver=ODBC+Driver+18+for+SQL+Server&TrustServerCertificate=yes
 oracle=oracle://scott:tiger@127.0.0.1:1521
 oracle8=oracle://scott:tiger@127.0.0.1:1521/?use_ansi=0
 
index 01a19d59c6731263f55a740f3ab1987ef812d9a9..b06e7c90c2eb470ca1debe3416b878647c53f46b 100644 (file)
@@ -1346,105 +1346,139 @@ class AutogenerateIndexTest(AutogenFixtureTest, TestBase):
 def _lots_of_indexes(flatten: bool = False):
     diff_pairs = [
         (
-            lambda t: Index("SomeIndex", "y", func.lower(t.c.x)),
-            lambda t: Index("SomeIndex", func.lower(t.c.x)),
+            lambda t: Index("idx", "y", func.lower(t.c.x)),
+            lambda t: Index("idx", func.lower(t.c.x)),
         ),
         (
-            lambda CapT: Index("SomeIndex", "y", func.lower(CapT.c.XCol)),
-            lambda CapT: Index("SomeIndex", func.lower(CapT.c.XCol)),
+            lambda CapT: Index("idx", "y", func.lower(CapT.c.XCol)),
+            lambda CapT: Index("idx", func.lower(CapT.c.XCol)),
         ),
         (
-            lambda t: Index(
-                "SomeIndex", "y", func.lower(column("x")), _table=t
-            ),
-            lambda t: Index("SomeIndex", func.lower(column("x")), _table=t),
+            lambda t: Index("idx", "y", func.lower(column("x")), _table=t),
+            lambda t: Index("idx", func.lower(column("x")), _table=t),
+        ),
+        (
+            lambda t: Index("idx", t.c.y, func.lower(t.c.x)),
+            lambda t: Index("idx", func.lower(t.c.x), t.c.y),
         ),
         (
-            lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.x)),
-            lambda t: Index("SomeIndex", func.lower(t.c.x), t.c.y),
+            lambda t: Index("idx", t.c.y, func.lower(t.c.x)),
+            lambda t: Index("idx", t.c.y, func.lower(t.c.q)),
         ),
         (
-            lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.x)),
-            lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.q)),
+            lambda t: Index("idx", t.c.z, func.lower(t.c.x)),
+            lambda t: Index("idx", t.c.y, func.lower(t.c.x)),
         ),
         (
-            lambda t: Index("SomeIndex", t.c.z, func.lower(t.c.x)),
-            lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.x)),
+            lambda t: Index("idx", func.lower(t.c.x)),
+            lambda t: Index("idx", t.c.y, func.lower(t.c.x)),
         ),
         (
-            lambda t: Index("SomeIndex", func.lower(t.c.x)),
-            lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.x)),
+            lambda t: Index("idx", t.c.y, func.upper(t.c.x)),
+            lambda t: Index("idx", t.c.y, func.lower(t.c.x)),
         ),
         (
-            lambda t: Index("SomeIndex", t.c.y, func.upper(t.c.x)),
-            lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.x)),
+            lambda t: Index("idx", t.c.y, t.c.ff + 1),
+            lambda t: Index("idx", t.c.y, t.c.ff + 3),
         ),
         (
-            lambda t: Index("SomeIndex", t.c.y, t.c.ff + 1),
-            lambda t: Index("SomeIndex", t.c.y, t.c.ff + 3),
+            lambda t: Index("idx", t.c.y, func.lower(t.c.x)),
+            lambda t: Index("idx", t.c.y, func.lower(t.c.x + t.c.q)),
         ),
         (
-            lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.x)),
-            lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.x + t.c.q)),
+            lambda t: Index("idx", t.c.y, t.c.z + 3),
+            lambda t: Index("idx", t.c.y, t.c.z * 3),
         ),
         (
-            lambda t: Index("SomeIndex", t.c.y, t.c.z + 3),
-            lambda t: Index("SomeIndex", t.c.y, t.c.z * 3),
+            lambda t: Index("idx", func.lower(t.c.x), t.c.q + "42"),
+            lambda t: Index("idx", func.lower(t.c.q), t.c.x + "42"),
         ),
         (
-            lambda t: Index("SomeIndex", func.lower(t.c.x), t.c.q + "42"),
-            lambda t: Index("SomeIndex", func.lower(t.c.q), t.c.x + "42"),
+            lambda t: Index("idx", func.lower(t.c.x), t.c.z + 42),
+            lambda t: Index("idx", t.c.z + 42, func.lower(t.c.q)),
         ),
         (
-            lambda t: Index("SomeIndex", func.lower(t.c.x), t.c.z + 42),
-            lambda t: Index("SomeIndex", t.c.z + 42, func.lower(t.c.q)),
+            lambda t: Index("idx", t.c.ff + 42),
+            lambda t: Index("idx", 42 + t.c.ff),
         ),
         (
-            lambda t: Index("SomeIndex", t.c.ff + 42),
-            lambda t: Index("SomeIndex", 42 + t.c.ff),
+            lambda t: Index("idx", text("coalesce(z, -1)"), _table=t),
+            lambda t: Index("idx", text("coalesce(q, '-1')"), _table=t),
         ),
         (
-            lambda t: Index("SomeIndex", text("coalesce(z, -1)"), _table=t),
-            lambda t: Index("SomeIndex", text("coalesce(q, '-1')"), _table=t),
+            lambda t: Index("idx", t.c.y.cast(Integer)),
+            lambda t: Index("idx", t.c.x.cast(Integer)),
         ),
     ]
 
     with_sort = [
         (
-            lambda t: Index("SomeIndex", "y", func.lower(t.c.x)),
-            lambda t: Index("SomeIndex", "y", desc(func.lower(t.c.x))),
+            lambda t: Index("idx", "y", func.lower(t.c.x)),
+            lambda t: Index("idx", "y", desc(func.lower(t.c.x))),
+        ),
+        (
+            lambda t: Index("idx", "y", func.lower(t.c.x)),
+            lambda t: Index("idx", desc("y"), func.lower(t.c.x)),
         ),
         (
-            lambda t: Index("SomeIndex", "y", func.lower(t.c.x)),
-            lambda t: Index("SomeIndex", desc("y"), func.lower(t.c.x)),
+            lambda t: Index("idx", "y", func.lower(t.c.x)),
+            lambda t: Index("idx", "y", nullsfirst(func.lower(t.c.x))),
         ),
         (
-            lambda t: Index("SomeIndex", "y", func.lower(t.c.x)),
-            lambda t: Index("SomeIndex", "y", nullsfirst(func.lower(t.c.x))),
+            lambda t: Index("idx", "y", func.lower(t.c.x)),
+            lambda t: Index("idx", nullsfirst("y"), func.lower(t.c.x)),
         ),
         (
-            lambda t: Index("SomeIndex", "y", func.lower(t.c.x)),
-            lambda t: Index("SomeIndex", nullsfirst("y"), func.lower(t.c.x)),
+            lambda t: Index("idx", asc(func.lower(t.c.x))),
+            lambda t: Index("idx", desc(func.lower(t.c.x))),
         ),
         (
-            lambda t: Index("SomeIndex", asc(func.lower(t.c.x))),
-            lambda t: Index("SomeIndex", desc(func.lower(t.c.x))),
+            lambda t: Index("idx", desc(func.lower(t.c.x))),
+            lambda t: Index("idx", asc(func.lower(t.c.x))),
         ),
         (
-            lambda t: Index("SomeIndex", desc(func.lower(t.c.x))),
-            lambda t: Index("SomeIndex", asc(func.lower(t.c.x))),
+            lambda t: Index("idx", nullslast(asc(func.lower(t.c.x)))),
+            lambda t: Index("idx", nullslast(desc(func.lower(t.c.x)))),
         ),
         (
-            lambda t: Index("SomeIndex", nullslast(asc(func.lower(t.c.x)))),
-            lambda t: Index("SomeIndex", nullslast(desc(func.lower(t.c.x)))),
+            lambda t: Index("idx", nullslast(desc(func.lower(t.c.x)))),
+            lambda t: Index("idx", nullsfirst(desc(func.lower(t.c.x)))),
         ),
         (
-            lambda t: Index("SomeIndex", nullslast(desc(func.lower(t.c.x)))),
-            lambda t: Index("SomeIndex", nullsfirst(desc(func.lower(t.c.x)))),
+            lambda t: Index("idx", nullsfirst(func.lower(t.c.x))),
+            lambda t: Index("idx", desc(func.lower(t.c.x))),
+        ),
+        (
+            lambda t: Index(
+                "idx", text("x nulls first"), text("lower(y)"), _table=t
+            ),
+            lambda t: Index(
+                "idx", text("x nulls last"), text("lower(y)"), _table=t
+            ),
+        ),
+        (
+            lambda t: Index(
+                "idx", text("x nulls last"), text("lower(y)"), _table=t
+            ),
+            lambda t: Index(
+                "idx", text("x nulls first"), text("lower(y)"), _table=t
+            ),
+        ),
+        (
+            lambda t: Index(
+                "idx", text("x nulls first"), text("lower(y)"), _table=t
+            ),
+            lambda t: Index(
+                "idx", text("y nulls first"), text("lower(x)"), _table=t
+            ),
         ),
         (
-            lambda t: Index("SomeIndex", nullsfirst(func.lower(t.c.x))),
-            lambda t: Index("SomeIndex", desc(func.lower(t.c.x))),
+            lambda t: Index(
+                "idx", text("x nulls last"), text("lower(y)"), _table=t
+            ),
+            lambda t: Index(
+                "idx", text("y nulls last"), text("lower(x)"), _table=t
+            ),
         ),
     ]
 
@@ -1464,32 +1498,50 @@ def _lost_of_equal_indexes(_lots_of_indexes):
         (fn, fn) if not isinstance(fn, tuple) else (fn[0], fn[0], fn[1])
         for fn in _lots_of_indexes(flatten=True)
     ]
+
     equal_pairs += [
         (
-            lambda t: Index("SomeIndex", "y", func.lower(t.c.x)),
-            lambda t: Index("SomeIndex", "y", asc(func.lower(t.c.x))),
+            lambda t: Index("idx", "y", func.lower(t.c.x)),
+            lambda t: Index("idx", "y", asc(func.lower(t.c.x))),
             config.requirements.reflects_indexes_column_sorting,
         ),
         (
-            lambda t: Index("SomeIndex", "y", func.lower(t.c.x)),
-            lambda t: Index("SomeIndex", "y", nullslast(func.lower(t.c.x))),
+            lambda t: Index("idx", "y", func.lower(t.c.x)),
+            lambda t: Index("idx", "y", nullslast(func.lower(t.c.x))),
             config.requirements.reflects_indexes_column_sorting,
         ),
         (
-            lambda t: Index("SomeIndex", "y", func.lower(t.c.x)),
-            lambda t: Index(
-                "SomeIndex", "y", nullslast(asc(func.lower(t.c.x)))
-            ),
+            lambda t: Index("idx", "y", func.lower(t.c.x)),
+            lambda t: Index("idx", "y", nullslast(asc(func.lower(t.c.x)))),
             config.requirements.reflects_indexes_column_sorting,
         ),
         (
-            lambda t: Index("SomeIndex", "y", desc(func.lower(t.c.x))),
-            lambda t: Index(
-                "SomeIndex", "y", nullsfirst(desc(func.lower(t.c.x)))
-            ),
+            lambda t: Index("idx", "y", desc(func.lower(t.c.x))),
+            lambda t: Index("idx", "y", nullsfirst(desc(func.lower(t.c.x)))),
             config.requirements.reflects_indexes_column_sorting,
         ),
     ]
+
+    # textual_sorting
+    equal_pairs += [
+        (
+            # use eval to avoid resolve lambda complaining about the closure
+            eval(f'lambda t: Index("idx", text({conn!r}), _table=t)'),
+            eval(f'lambda t: Index("idx", text({meta!r}), _table=t)'),
+            config.requirements.reflects_indexes_column_sorting,
+        )
+        for meta, conn in (
+            ("z nulls first", "z nulls first"),
+            ("z nulls last", "z"),
+            ("z asc", "z"),
+            ("z asc nulls first", "z nulls first"),
+            ("z asc nulls last", "z"),
+            ("z desc", "z desc"),
+            ("z desc nulls first", "z desc"),
+            ("z desc nulls last", "z desc nulls last"),
+        )
+    ]
+
     return equal_pairs
 
 
@@ -1595,9 +1647,8 @@ class AutogenerateExpressionIndexTest(AutogenFixtureTest, TestBase):
     ):
         m1, m2, old_fixture_tables, new_fixture_tables = index_changed_tables
 
-        old, new = resolve_lambda(
-            old_fn, **old_fixture_tables
-        ), resolve_lambda(new_fn, **new_fixture_tables)
+        old = resolve_lambda(old_fn, **old_fixture_tables)
+        new = resolve_lambda(new_fn, **new_fixture_tables)
 
         if self.has_reflection:
             diffs = self._fixture(m1, m2)
@@ -1611,9 +1662,9 @@ class AutogenerateExpressionIndexTest(AutogenFixtureTest, TestBase):
         else:
             with expect_warnings(
                 r"Skipped unsupported reflection of expression-based index "
-                r"SomeIndex",
+                r"idx",
                 r"autogenerate skipping metadata-specified expression-based "
-                r"index 'SomeIndex'; dialect '.*' under SQLAlchemy .* "
+                r"index 'idx'; dialect '.*' under SQLAlchemy .* "
                 r"can't reflect these "
                 r"indexes so they can't be compared",
             ):
index 5c5e4f3a4f7a9d269df6bf5c71f5c941f94c7f4f..07db25187cd281b32358612843cd964fcfa9d9d0 100644 (file)
@@ -1,3 +1,5 @@
+import itertools
+
 from sqlalchemy import BigInteger
 from sqlalchemy import Boolean
 from sqlalchemy import Column
@@ -33,6 +35,7 @@ from sqlalchemy.sql.expression import literal_column
 from alembic import autogenerate
 from alembic import command
 from alembic import op
+from alembic import testing
 from alembic import util
 from alembic.autogenerate import api
 from alembic.autogenerate.compare import _compare_server_default
@@ -47,6 +50,9 @@ from alembic.testing import config
 from alembic.testing import eq_
 from alembic.testing import eq_ignore_whitespace
 from alembic.testing import provide_metadata
+from alembic.testing import resolve_lambda
+from alembic.testing import schemacompare
+from alembic.testing.assertions import expect_warnings
 from alembic.testing.env import _no_sql_testing_config
 from alembic.testing.env import clear_staging_env
 from alembic.testing.env import staging_env
@@ -1492,6 +1498,217 @@ class PGUniqueIndexAutogenerateTest(AutogenFixtureTest, TestBase):
         eq_(len(diffs), 1)
 
 
+def _lots_of_indexes(flatten: bool = False):
+    diff_pairs = [
+        (
+            lambda t: Index("idx", t.c.jb["foo"]),
+            lambda t: Index("idx", t.c.jb["bar"]),
+        ),
+        (
+            lambda t: Index("idx", t.c.jb["foo"]),
+            lambda t: Index("idx", t.c.jb["not_jsonb_path_ops"]),
+        ),
+        (
+            lambda t: Index("idx", t.c.jb["not_jsonb_path_ops"]),
+            lambda t: Index("idx", t.c.jb["bar"]),
+        ),
+        (
+            lambda t: Index("idx", t.c.aa),
+            lambda t: Index("idx", t.c.not_jsonb_path_ops),
+        ),
+        (
+            lambda t: Index("idx", t.c.not_jsonb_path_ops),
+            lambda t: Index("idx", t.c.aa),
+        ),
+        (
+            lambda t: Index(
+                "idx",
+                t.c.jb["foo"].label("x"),
+                postgresql_using="gin",
+                postgresql_ops={"x": "jsonb_path_ops"},
+            ),
+            lambda t: Index(
+                "idx",
+                t.c.jb["bar"].label("x"),
+                postgresql_using="gin",
+                postgresql_ops={"x": "jsonb_path_ops"},
+            ),
+        ),
+        (
+            lambda t: Index("idx", t.c.jb["foo"].astext),
+            lambda t: Index("idx", t.c.jb["bar"].astext),
+        ),
+        (
+            lambda t: Index("idx", t.c.jb["foo"].as_integer()),
+            lambda t: Index("idx", t.c.jb["bar"].as_integer()),
+        ),
+        (
+            lambda t: Index("idx", text("(jb->'x')"), _table=t),
+            lambda t: Index("idx", text("(jb->'y')"), _table=t),
+        ),
+    ]
+    if flatten:
+        return list(itertools.chain.from_iterable(diff_pairs))
+    else:
+        return diff_pairs
+
+
+def _equal_indexes():
+    the_indexes = [(fn, fn) for fn in _lots_of_indexes(True)]
+    the_indexes += [
+        (
+            lambda t: Index("idx", text("(jb->'x')"), _table=t),
+            lambda t: Index("idx", text("(jb -> 'x')"), _table=t),
+        ),
+        (
+            lambda t: Index("idx", text("cast(jb->'x' as integer)"), _table=t),
+            lambda t: Index("idx", text("(jb -> 'x')::integer"), _table=t),
+        ),
+    ]
+    return the_indexes
+
+
+def _index_op_clause():
+    def make_idx(t, *expr):
+        return Index(
+            "idx",
+            *(text(e) if isinstance(e, str) else e for e in expr),
+            postgresql_using="gin",
+            _table=t,
+        )
+
+    return [
+        (
+            False,
+            lambda t: make_idx(t, "(jb->'x')jsonb_path_ops"),
+            lambda t: make_idx(t, "(jb->'x')jsonb_path_ops"),
+        ),
+        (
+            False,
+            lambda t: make_idx(t, "aa array_ops"),
+            lambda t: make_idx(t, "aa array_ops"),
+        ),
+        (
+            False,
+            lambda t: make_idx(t, "(jb->'x')jsonb_path_ops"),
+            lambda t: make_idx(t, "(jb->'y')jsonb_path_ops"),
+        ),
+        (
+            False,
+            lambda t: make_idx(t, "aa array_ops"),
+            lambda t: make_idx(t, "jb array_ops"),
+        ),
+        (
+            False,
+            lambda t: make_idx(t, "aa array_ops", "(jb->'y')jsonb_path_ops"),
+            lambda t: make_idx(t, "(jb->'y')jsonb_path_ops", "aa array_ops"),
+        ),
+        (
+            True,
+            lambda t: make_idx(t, "aa array_ops", text("(jb->'x')")),
+            lambda t: make_idx(t, "aa array_ops", text("(jb->'y')")),
+        ),
+        (
+            True,
+            lambda t: make_idx(t, text("(jb->'x')"), "aa array_ops"),
+            lambda t: make_idx(t, text("(jb->'y')"), "aa array_ops"),
+        ),
+        (
+            True,
+            lambda t: make_idx(t, "aa array_ops", text("(jb->'x')")),
+            lambda t: make_idx(t, "jb array_ops", text("(jb->'y')")),
+        ),
+        (
+            True,
+            lambda t: make_idx(t, text("(jb->'x')"), "aa array_ops"),
+            lambda t: make_idx(t, text("(jb->'y')"), "jb array_ops"),
+        ),
+    ]
+
+
+class PGIndexAutogenerateTest(AutogenFixtureTest, TestBase):
+    __backend__ = True
+    __only_on__ = "postgresql"
+    __requires__ = ("reflect_indexes_with_expressions",)
+
+    @testing.fixture
+    def index_tables(self):
+        m1 = MetaData()
+        m2 = MetaData()
+
+        t_old = Table(
+            "exp_index",
+            m1,
+            Column("id", Integer, primary_key=True),
+            Column("aa", ARRAY(Integer)),
+            Column("jb", JSONB),
+            Column("not_jsonb_path_ops", Integer),
+        )
+
+        t_new = Table(
+            "exp_index",
+            m2,
+            Column("id", Integer, primary_key=True),
+            Column("aa", ARRAY(Integer)),
+            Column("jb", JSONB),
+            Column("not_jsonb_path_ops", Integer),
+        )
+
+        return m1, m2, t_old, t_new
+
+    @combinations(*_lots_of_indexes(), argnames="old_fn, new_fn")
+    def test_expression_indexes_changed(self, index_tables, old_fn, new_fn):
+        m1, m2, old_table, new_table = index_tables
+
+        old = resolve_lambda(old_fn, t=old_table)
+        new = resolve_lambda(new_fn, t=new_table)
+
+        diffs = self._fixture(m1, m2)
+        eq_(
+            diffs,
+            [
+                ("remove_index", schemacompare.CompareIndex(old, True)),
+                ("add_index", schemacompare.CompareIndex(new)),
+            ],
+        )
+
+    @combinations(*_equal_indexes(), argnames="fn1, fn2")
+    def test_expression_indexes_no_change(self, index_tables, fn1, fn2):
+        m1, m2, old_table, new_table = index_tables
+
+        resolve_lambda(fn1, t=old_table)
+        resolve_lambda(fn2, t=new_table)
+
+        diffs = self._fixture(m1, m2)
+        eq_(diffs, [])
+
+    @combinations(*_index_op_clause(), argnames="changed, old_fn, new_fn")
+    def test_expression_indexes_warn_operator(
+        self, index_tables, changed, old_fn, new_fn
+    ):
+        m1, m2, old_table, new_table = index_tables
+
+        old = old_fn(t=old_table)
+        new = new_fn(t=new_table)
+
+        with expect_warnings(
+            r"Expression #\d .+ in index 'idx' detected to include "
+            "an operator clause. Expression compare cannot proceed. "
+            "Please move the operator clause to the "
+        ):
+            diffs = self._fixture(m1, m2)
+        if changed:
+            eq_(
+                diffs,
+                [
+                    ("remove_index", schemacompare.CompareIndex(old, True)),
+                    ("add_index", schemacompare.CompareIndex(new)),
+                ],
+            )
+        else:
+            eq_(diffs, [])
+
+
 case = combinations(
     ("nulls_not_distinct=False", False),
     ("nulls_not_distinct=True", True),