]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
Use column sort in index compare on postgresql
authorCaselIT <cfederico87@gmail.com>
Thu, 6 Apr 2023 20:16:41 +0000 (22:16 +0200)
committerFederico Caselli <cfederico87@gmail.com>
Mon, 10 Apr 2023 19:58:17 +0000 (21:58 +0200)
Added support for autogenerate comparison of indexes on PostgreSQL which
include SQL sort option, such as ``ASC`` or ``NULLS FIRST``.

Fixes: #1213
Change-Id: I3ddcb647928d948e41462b1c889b1cbb515ace4f

alembic/autogenerate/compare.py
alembic/ddl/postgresql.py
docs/build/unreleased/1213.rst [new file with mode: 0644]
tests/requirements.py
tests/test_autogen_indexes.py

index 4f5126f53d635db3e21e8d9e4b77446bfb6b2e95..85cb426ed1cf742909ba5039e0f84ce6c8d9203f 100644 (file)
@@ -8,6 +8,7 @@ from typing import cast
 from typing import Dict
 from typing import Iterator
 from typing import List
+from typing import Mapping
 from typing import Optional
 from typing import Set
 from typing import Tuple
@@ -19,6 +20,7 @@ from sqlalchemy import inspect
 from sqlalchemy import schema as sa_schema
 from sqlalchemy import text
 from sqlalchemy import types as sqltypes
+from sqlalchemy.sql import expression
 from sqlalchemy.util import OrderedSet
 
 from alembic.ddl.base import _fk_spec
@@ -278,15 +280,35 @@ def _compare_tables(
                 upgrade_ops.ops.append(modify_table_ops)
 
 
+_IndexColumnSortingOps: Mapping[str, Any] = util.immutabledict(
+    {
+        "asc": expression.asc,
+        "desc": expression.desc,
+        "nulls_first": expression.nullsfirst,
+        "nulls_last": expression.nullslast,
+        "nullsfirst": expression.nullsfirst,  # 1_3 name
+        "nullslast": expression.nullslast,  # 1_3 name
+    }
+)
+
+
 def _make_index(params: Dict[str, Any], conn_table: Table) -> Optional[Index]:
     exprs: list[Union[Column[Any], TextClause]] = []
+    sorting = params.get("column_sorting")
+
     for num, col_name in enumerate(params["column_names"]):
         item: Union[Column[Any], TextClause]
         if col_name is None:
             assert "expressions" in params
-            item = text(params["expressions"][num])
+            name = params["expressions"][num]
+            item = text(name)
         else:
+            name = col_name
             item = conn_table.c[col_name]
+        if sorting and name in sorting:
+            for operator in sorting[name]:
+                if operator in _IndexColumnSortingOps:
+                    item = _IndexColumnSortingOps[operator](item)
         exprs.append(item)
     ix = sa_schema.Index(
         params["name"], *exprs, unique=params["unique"], _table=conn_table
index 4ffc2eb99bd1e437104a60c398757cf423aab205..247838bff28152e43d7ff15da856d4846ae10404 100644 (file)
@@ -21,8 +21,10 @@ from sqlalchemy.dialects.postgresql import BIGINT
 from sqlalchemy.dialects.postgresql import ExcludeConstraint
 from sqlalchemy.dialects.postgresql import INTEGER
 from sqlalchemy.schema import CreateIndex
+from sqlalchemy.sql import operators
 from sqlalchemy.sql.elements import ColumnClause
 from sqlalchemy.sql.elements import TextClause
+from sqlalchemy.sql.elements import UnaryExpression
 from sqlalchemy.types import NULLTYPE
 
 from .base import alter_column
@@ -53,6 +55,7 @@ if TYPE_CHECKING:
     from sqlalchemy.dialects.postgresql.json import JSON
     from sqlalchemy.dialects.postgresql.json import JSONB
     from sqlalchemy.sql.elements import BinaryExpression
+    from sqlalchemy.sql.elements import ClauseElement
     from sqlalchemy.sql.elements import quoted_name
     from sqlalchemy.sql.schema import MetaData
     from sqlalchemy.sql.schema import Table
@@ -248,11 +251,14 @@ class PostgresqlImpl(DefaultImpl):
         if not sqla_compat.sqla_2:
             self._skip_functional_indexes(metadata_indexes, conn_indexes)
 
-    def _cleanup_index_expr(self, index: Index, expr: str) -> str:
+    def _cleanup_index_expr(
+        self, index: Index, expr: str, remove_suffix: str
+    ) -> str:
         # start = expr
         expr = expr.lower()
         expr = expr.replace('"', "")
         if index.table is not None:
+            # should not be needed, since include_table=False is in compile
             expr = expr.replace(f"{index.table.name.lower()}.", "")
 
         while expr and expr[0] == "(" and expr[-1] == ")":
@@ -261,25 +267,64 @@ class PostgresqlImpl(DefaultImpl):
             # strip :: cast. types can have spaces in them
             expr = re.sub(r"(::[\w ]+\w)", "", expr)
 
+        if remove_suffix and expr.endswith(remove_suffix):
+            expr = expr[: -len(remove_suffix)]
+
         # print(f"START: {start} END: {expr}")
         return expr
 
+    def _default_modifiers(self, exp: ClauseElement) -> str:
+        to_remove = ""
+        while isinstance(exp, UnaryExpression):
+            if exp.modifier is None:
+                exp = exp.element
+            else:
+                op = exp.modifier
+                if isinstance(exp.element, UnaryExpression):
+                    inner_op = exp.element.modifier
+                else:
+                    inner_op = None
+                if inner_op is None:
+                    if op == operators.asc_op:
+                        # default is asc
+                        to_remove = " asc"
+                    elif op == operators.nullslast_op:
+                        # default is nulls last
+                        to_remove = " nulls last"
+                else:
+                    if (
+                        inner_op == operators.asc_op
+                        and op == operators.nullslast_op
+                    ):
+                        # default is asc nulls last
+                        to_remove = " asc nulls last"
+                    elif (
+                        inner_op == operators.desc_op
+                        and op == operators.nullsfirst_op
+                    ):
+                        # default for desc is nulls first
+                        to_remove = " nulls first"
+                break
+        return to_remove
+
     def create_index_sig(self, index: Index) -> Tuple[Any, ...]:
-        if sqla_compat.is_expression_index(index):
-            return tuple(
-                self._cleanup_index_expr(
-                    index,
-                    e
+        return tuple(
+            self._cleanup_index_expr(
+                index,
+                *(
+                    (e, "")
                     if isinstance(e, str)
-                    else e.compile(
-                        dialect=self.dialect,
-                        compile_kwargs={"literal_binds": True},
-                    ).string,
-                )
-                for e in index.expressions
+                    else (self._compile_element(e), self._default_modifiers(e))
+                ),
             )
-        else:
-            return super().create_index_sig(index)
+            for e in index.expressions
+        )
+
+    def _compile_element(self, element: ClauseElement) -> str:
+        return element.compile(
+            dialect=self.dialect,
+            compile_kwargs={"literal_binds": True, "include_table": False},
+        ).string
 
     def render_type(
         self, type_: TypeEngine, autogen_context: AutogenContext
diff --git a/docs/build/unreleased/1213.rst b/docs/build/unreleased/1213.rst
new file mode 100644 (file)
index 0000000..29b88a4
--- /dev/null
@@ -0,0 +1,6 @@
+.. change::
+    :tags: postgresql, autogenerate
+    :tickets: 1213
+
+    Added support for autogenerate comparison of indexes on PostgreSQL which
+    include SQL sort option, such as ``ASC`` or ``NULLS FIRST``.
index 1a100ddb6159afc82166ab182a7b9df733061423..dbbb88a5676037ad4ca03efda985b8ded267ec7a 100644 (file)
@@ -138,8 +138,14 @@ class DefaultRequirements(SuiteRequirements):
     def reflects_indexes_w_sorting(self):
         # TODO: figure out what's happening on the SQLAlchemy side
         # when we reflect an index that has asc() / desc() on the column
+        # Tracked by https://github.com/sqlalchemy/sqlalchemy/issues/9597
         return exclusions.fails_on(["oracle"])
 
+    @property
+    def reflects_indexes_column_sorting(self):
+        "Actually reflect column_sorting on the indexes"
+        return exclusions.only_on(["postgresql"])
+
     @property
     def long_names(self):
         if sqla_compat.sqla_14:
index 30b7d9029ecc2ff3612ea2e1d5dedee804b2395e..f697e5a624edf4a25a590b424ee1f65439ae4ffc 100644 (file)
@@ -15,8 +15,11 @@ from sqlalchemy import String
 from sqlalchemy import Table
 from sqlalchemy import UniqueConstraint
 from sqlalchemy.dialects.postgresql import DOUBLE_PRECISION
+from sqlalchemy.sql.expression import asc
 from sqlalchemy.sql.expression import column
 from sqlalchemy.sql.expression import desc
+from sqlalchemy.sql.expression import nullsfirst
+from sqlalchemy.sql.expression import nullslast
 
 from alembic import testing
 from alembic.testing import combinations
@@ -1130,11 +1133,6 @@ class AutogenerateIndexTest(AutogenFixtureTest, TestBase):
 
         eq_(diffs, [])
 
-    # fails in the 0.8 series where we have truncation rules,
-    # but no control over quoting. passes in 0.7.9 where we don't have
-    # truncation rules either.    dropping these ancient versions
-    # is long overdue.
-
     def test_unchanged_case_sensitive_implicit_idx(self):
         m1 = MetaData()
         m2 = MetaData()
@@ -1216,6 +1214,206 @@ class AutogenerateIndexTest(AutogenFixtureTest, TestBase):
             ],
         )
 
+    @config.requirements.reflects_indexes_column_sorting
+    @testing.combinations(
+        (desc, asc),
+        (asc, desc),
+        (desc, lambda x: nullslast(desc(x))),
+        (nullslast, nullsfirst),
+        (nullsfirst, nullslast),
+        (lambda x: nullslast(desc(x)), lambda x: nullsfirst(asc(x))),
+    )
+    def test_column_sort_changed(self, old_fn, new_fn):
+        m1 = MetaData()
+        m2 = MetaData()
+
+        old = Index("SomeIndex", old_fn("y"))
+        Table("order_change", m1, Column("y", Integer), old)
+
+        new = Index("SomeIndex", new_fn("y"))
+        Table("order_change", m2, Column("y", Integer), new)
+        diffs = self._fixture(m1, m2)
+        eq_(
+            diffs,
+            [
+                (
+                    "remove_index",
+                    schemacompare.CompareIndex(old, name_only=True),
+                ),
+                ("add_index", schemacompare.CompareIndex(new, name_only=True)),
+            ],
+        )
+
+    @config.requirements.reflects_indexes_column_sorting
+    @testing.combinations(
+        (asc, asc),
+        (desc, desc),
+        (nullslast, nullslast),
+        (nullsfirst, nullsfirst),
+        (lambda x: x, asc),
+        (lambda x: x, nullslast),
+        (desc, lambda x: nullsfirst(desc(x))),
+        (lambda x: nullslast(asc(x)), lambda x: x),
+    )
+    def test_column_sort_not_changed(self, old_fn, new_fn):
+        m1 = MetaData()
+        m2 = MetaData()
+
+        old = Index("SomeIndex", old_fn("y"))
+        Table("order_change", m1, Column("y", Integer), old)
+
+        new = Index("SomeIndex", new_fn("y"))
+        Table("order_change", m2, Column("y", Integer), new)
+        diffs = self._fixture(m1, m2)
+        eq_(diffs, [])
+
+
+def _lots_of_indexes(flatten: bool = False):
+    diff_pairs = [
+        (
+            lambda t: Index("SomeIndex", "y", func.lower(t.c.x)),
+            lambda t: Index("SomeIndex", func.lower(t.c.x)),
+        ),
+        (
+            lambda CapT: Index("SomeIndex", "y", func.lower(CapT.c.XCol)),
+            lambda CapT: Index("SomeIndex", func.lower(CapT.c.XCol)),
+        ),
+        (
+            lambda t: Index(
+                "SomeIndex", "y", func.lower(column("x")), _table=t
+            ),
+            lambda t: Index("SomeIndex", func.lower(column("x")), _table=t),
+        ),
+        (
+            lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.x)),
+            lambda t: Index("SomeIndex", func.lower(t.c.x), t.c.y),
+        ),
+        (
+            lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.x)),
+            lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.q)),
+        ),
+        (
+            lambda t: Index("SomeIndex", t.c.z, func.lower(t.c.x)),
+            lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.x)),
+        ),
+        (
+            lambda t: Index("SomeIndex", func.lower(t.c.x)),
+            lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.x)),
+        ),
+        (
+            lambda t: Index("SomeIndex", t.c.y, func.upper(t.c.x)),
+            lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.x)),
+        ),
+        (
+            lambda t: Index("SomeIndex", t.c.y, t.c.ff + 1),
+            lambda t: Index("SomeIndex", t.c.y, t.c.ff + 3),
+        ),
+        (
+            lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.x)),
+            lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.x + t.c.q)),
+        ),
+        (
+            lambda t: Index("SomeIndex", t.c.y, t.c.z + 3),
+            lambda t: Index("SomeIndex", t.c.y, t.c.z * 3),
+        ),
+        (
+            lambda t: Index("SomeIndex", func.lower(t.c.x), t.c.q + "42"),
+            lambda t: Index("SomeIndex", func.lower(t.c.q), t.c.x + "42"),
+        ),
+        (
+            lambda t: Index("SomeIndex", func.lower(t.c.x), t.c.z + 42),
+            lambda t: Index("SomeIndex", t.c.z + 42, func.lower(t.c.q)),
+        ),
+        (
+            lambda t: Index("SomeIndex", t.c.ff + 42),
+            lambda t: Index("SomeIndex", 42 + t.c.ff),
+        ),
+    ]
+
+    with_sort = [
+        (
+            lambda t: Index("SomeIndex", "y", func.lower(t.c.x)),
+            lambda t: Index("SomeIndex", "y", desc(func.lower(t.c.x))),
+        ),
+        (
+            lambda t: Index("SomeIndex", "y", func.lower(t.c.x)),
+            lambda t: Index("SomeIndex", desc("y"), func.lower(t.c.x)),
+        ),
+        (
+            lambda t: Index("SomeIndex", "y", func.lower(t.c.x)),
+            lambda t: Index("SomeIndex", "y", nullsfirst(func.lower(t.c.x))),
+        ),
+        (
+            lambda t: Index("SomeIndex", "y", func.lower(t.c.x)),
+            lambda t: Index("SomeIndex", nullsfirst("y"), func.lower(t.c.x)),
+        ),
+        (
+            lambda t: Index("SomeIndex", asc(func.lower(t.c.x))),
+            lambda t: Index("SomeIndex", desc(func.lower(t.c.x))),
+        ),
+        (
+            lambda t: Index("SomeIndex", desc(func.lower(t.c.x))),
+            lambda t: Index("SomeIndex", asc(func.lower(t.c.x))),
+        ),
+        (
+            lambda t: Index("SomeIndex", nullslast(asc(func.lower(t.c.x)))),
+            lambda t: Index("SomeIndex", nullslast(desc(func.lower(t.c.x)))),
+        ),
+        (
+            lambda t: Index("SomeIndex", nullslast(desc(func.lower(t.c.x)))),
+            lambda t: Index("SomeIndex", nullsfirst(desc(func.lower(t.c.x)))),
+        ),
+        (
+            lambda t: Index("SomeIndex", nullsfirst(func.lower(t.c.x))),
+            lambda t: Index("SomeIndex", desc(func.lower(t.c.x))),
+        ),
+    ]
+
+    req = config.requirements.reflects_indexes_column_sorting
+
+    if flatten:
+
+        flat = list(itertools.chain.from_iterable(diff_pairs))
+        for f1, f2 in with_sort:
+            flat.extend([(f1, req), (f2, req)])
+        return flat
+    else:
+        return diff_pairs + [(f1, f2, req) for f1, f2 in with_sort]
+
+
+def _lost_of_equal_indexes(_lots_of_indexes):
+    equal_pairs = [
+        (fn, fn) if not isinstance(fn, tuple) else (fn[0], fn[0], fn[1])
+        for fn in _lots_of_indexes(flatten=True)
+    ]
+    equal_pairs += [
+        (
+            lambda t: Index("SomeIndex", "y", func.lower(t.c.x)),
+            lambda t: Index("SomeIndex", "y", asc(func.lower(t.c.x))),
+            config.requirements.reflects_indexes_column_sorting,
+        ),
+        (
+            lambda t: Index("SomeIndex", "y", func.lower(t.c.x)),
+            lambda t: Index("SomeIndex", "y", nullslast(func.lower(t.c.x))),
+            config.requirements.reflects_indexes_column_sorting,
+        ),
+        (
+            lambda t: Index("SomeIndex", "y", func.lower(t.c.x)),
+            lambda t: Index(
+                "SomeIndex", "y", nullslast(asc(func.lower(t.c.x)))
+            ),
+            config.requirements.reflects_indexes_column_sorting,
+        ),
+        (
+            lambda t: Index("SomeIndex", "y", desc(func.lower(t.c.x))),
+            lambda t: Index(
+                "SomeIndex", "y", nullsfirst(desc(func.lower(t.c.x)))
+            ),
+            config.requirements.reflects_indexes_column_sorting,
+        ),
+    ]
+    return equal_pairs
+
 
 class AutogenerateExpressionIndexTest(AutogenFixtureTest, TestBase):
     """tests involving indexes with expression"""
@@ -1263,74 +1461,6 @@ class AutogenerateExpressionIndexTest(AutogenFixtureTest, TestBase):
                 diffs = self._fixture(m1, m2)
             eq_(diffs, [])
 
-    def _lots_of_indexes(flatten: bool = False):
-        diff_pairs = [
-            (
-                lambda t: Index("SomeIndex", "y", func.lower(t.c.x)),
-                lambda t: Index("SomeIndex", func.lower(t.c.x)),
-            ),
-            (
-                lambda CapT: Index("SomeIndex", "y", func.lower(CapT.c.XCol)),
-                lambda CapT: Index("SomeIndex", func.lower(CapT.c.XCol)),
-            ),
-            (
-                lambda t: Index(
-                    "SomeIndex", "y", func.lower(column("x")), _table=t
-                ),
-                lambda t: Index(
-                    "SomeIndex", func.lower(column("x")), _table=t
-                ),
-            ),
-            (
-                lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.x)),
-                lambda t: Index("SomeIndex", func.lower(t.c.x), t.c.y),
-            ),
-            (
-                lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.x)),
-                lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.q)),
-            ),
-            (
-                lambda t: Index("SomeIndex", t.c.z, func.lower(t.c.x)),
-                lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.x)),
-            ),
-            (
-                lambda t: Index("SomeIndex", func.lower(t.c.x)),
-                lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.x)),
-            ),
-            (
-                lambda t: Index("SomeIndex", t.c.y, func.upper(t.c.x)),
-                lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.x)),
-            ),
-            (
-                lambda t: Index("SomeIndex", t.c.y, t.c.ff + 1),
-                lambda t: Index("SomeIndex", t.c.y, t.c.ff + 3),
-            ),
-            (
-                lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.x)),
-                lambda t: Index("SomeIndex", t.c.y, func.lower(t.c.x + t.c.q)),
-            ),
-            (
-                lambda t: Index("SomeIndex", t.c.y, t.c.z + 3),
-                lambda t: Index("SomeIndex", t.c.y, t.c.z * 3),
-            ),
-            (
-                lambda t: Index("SomeIndex", func.lower(t.c.x), t.c.q + "42"),
-                lambda t: Index("SomeIndex", func.lower(t.c.q), t.c.x + "42"),
-            ),
-            (
-                lambda t: Index("SomeIndex", func.lower(t.c.x), t.c.z + 42),
-                lambda t: Index("SomeIndex", t.c.z + 42, func.lower(t.c.q)),
-            ),
-            (
-                lambda t: Index("SomeIndex", t.c.ff + 42),
-                lambda t: Index("SomeIndex", 42 + t.c.ff),
-            ),
-        ]
-        if flatten:
-            return list(itertools.chain.from_iterable(diff_pairs))
-        else:
-            return diff_pairs
-
     @testing.fixture
     def index_changed_tables(self):
         m1 = MetaData()
@@ -1412,12 +1542,16 @@ class AutogenerateExpressionIndexTest(AutogenFixtureTest, TestBase):
                 diffs = self._fixture(m1, m2)
             eq_(diffs, [])
 
-    @combinations(*_lots_of_indexes(flatten=True), argnames="fn")
-    def test_expression_indexes_no_change(self, index_changed_tables, fn):
+    @combinations(
+        *_lost_of_equal_indexes(_lots_of_indexes), argnames="fn1, fn2"
+    )
+    def test_expression_indexes_no_change(
+        self, index_changed_tables, fn1, fn2
+    ):
         m1, m2, old_fixture_tables, new_fixture_tables = index_changed_tables
 
-        resolve_lambda(fn, **old_fixture_tables)
-        resolve_lambda(fn, **new_fixture_tables)
+        resolve_lambda(fn1, **old_fixture_tables)
+        resolve_lambda(fn2, **new_fixture_tables)
 
         if self.has_reflection:
             ctx = nullcontext()