]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
Nulls not distinct support in postgresql
authorFederico Caselli <cfederico87@gmail.com>
Thu, 8 Jun 2023 19:32:11 +0000 (15:32 -0400)
committerFederico Caselli <cfederico87@gmail.com>
Tue, 20 Jun 2023 18:02:39 +0000 (20:02 +0200)
Added support in autogenerate for NULLS NOT DISTINCT in
the PostgreSQL dialect.

Closes: #1249
Pull-request: https://github.com/sqlalchemy/alembic/pull/1249
Pull-request-sha: e4a7ffed54677d5aba9ab0251026a8a2a0e71278

Change-Id: I299a24fa7af4ae9387d6b48ce49fb516dfb84518

alembic/autogenerate/compare.py
alembic/autogenerate/render.py
alembic/ddl/impl.py
alembic/ddl/postgresql.py
alembic/operations/base.py
docs/build/unreleased/1249.rst [new file with mode: 0644]
tests/requirements.py
tests/test_postgresql.py

index db32a6a4f368d9275b3394a5982ebd543aa19052..c441a200243c0b855b0a46887debb9620abd4d0c 100644 (file)
@@ -444,11 +444,11 @@ class _uq_constraint_sig(_constraint_sig):
     is_index = False
     is_unique = True
 
-    def __init__(self, const: UniqueConstraint) -> None:
+    def __init__(self, const: UniqueConstraint, impl: DefaultImpl) -> None:
         self.const = const
         self.name = const.name
-        self.sig = ("UNIQUE_CONSTRAINT",) + tuple(
-            sorted([col.name for col in const.columns])
+        self.sig = ("UNIQUE_CONSTRAINT",) + impl.create_unique_constraint_sig(
+            const
         )
 
     @property
@@ -616,6 +616,7 @@ def _compare_indexes_and_uniques(
     # 2a. if the dialect dupes unique indexes as unique constraints
     # (mysql and oracle), correct for that
 
+    impl = autogen_context.migration_context.impl
     if unique_constraints_duplicate_unique_indexes:
         _correct_for_uq_duplicates_uix(
             conn_uniques,
@@ -623,6 +624,7 @@ def _compare_indexes_and_uniques(
             metadata_unique_constraints,
             metadata_indexes,
             autogen_context.dialect,
+            impl,
         )
 
     # 3. give the dialect a chance to omit indexes and constraints that
@@ -640,15 +642,16 @@ def _compare_indexes_and_uniques(
     # Index and UniqueConstraint so we can easily work with them
     # interchangeably
     metadata_unique_constraints_sig = {
-        _uq_constraint_sig(uq) for uq in metadata_unique_constraints
+        _uq_constraint_sig(uq, impl) for uq in metadata_unique_constraints
     }
 
-    impl = autogen_context.migration_context.impl
     metadata_indexes_sig = {
         _ix_constraint_sig(ix, impl) for ix in metadata_indexes
     }
 
-    conn_unique_constraints = {_uq_constraint_sig(uq) for uq in conn_uniques}
+    conn_unique_constraints = {
+        _uq_constraint_sig(uq, impl) for uq in conn_uniques
+    }
 
     conn_indexes_sig = {_ix_constraint_sig(ix, impl) for ix in conn_indexes}
 
@@ -858,6 +861,7 @@ def _correct_for_uq_duplicates_uix(
     metadata_unique_constraints,
     metadata_indexes,
     dialect,
+    impl,
 ):
     # dedupe unique indexes vs. constraints, since MySQL / Oracle
     # doesn't really have unique constraints as a separate construct.
@@ -880,7 +884,7 @@ def _correct_for_uq_duplicates_uix(
     }
 
     unnamed_metadata_uqs = {
-        _uq_constraint_sig(cons).sig
+        _uq_constraint_sig(cons, impl).sig
         for name, cons in metadata_cons_names
         if name is None
     }
@@ -904,7 +908,7 @@ def _correct_for_uq_duplicates_uix(
     for overlap in uqs_dupe_indexes:
         if overlap not in metadata_uq_names:
             if (
-                _uq_constraint_sig(uqs_dupe_indexes[overlap]).sig
+                _uq_constraint_sig(uqs_dupe_indexes[overlap], impl).sig
                 not in unnamed_metadata_uqs
             ):
                 conn_unique_constraints.discard(uqs_dupe_indexes[overlap])
index 215af8ce5e53504c1d8f1d5cda0cdc5cae546bd6..3dfb5e9e3b3e4b68a35c74ed74d86dd8b200160d 100644 (file)
@@ -26,6 +26,7 @@ from ..util import sqla_compat
 if TYPE_CHECKING:
     from typing import Literal
 
+    from sqlalchemy.sql.base import DialectKWArgs
     from sqlalchemy.sql.elements import ColumnElement
     from sqlalchemy.sql.elements import TextClause
     from sqlalchemy.sql.schema import CheckConstraint
@@ -268,6 +269,15 @@ def _drop_table(autogen_context: AutogenContext, op: ops.DropTableOp) -> str:
     return text
 
 
+def _render_dialect_kwargs_items(
+    autogen_context: AutogenContext, item: DialectKWArgs
+) -> list[str]:
+    return [
+        f"{key}={_render_potential_expr(val, autogen_context)}"
+        for key, val in item.dialect_kwargs.items()
+    ]
+
+
 @renderers.dispatch_for(ops.CreateIndexOp)
 def _add_index(autogen_context: AutogenContext, op: ops.CreateIndexOp) -> str:
     index = op.to_index()
@@ -286,6 +296,8 @@ def _add_index(autogen_context: AutogenContext, op: ops.CreateIndexOp) -> str:
         )
 
     assert index.table is not None
+
+    opts = _render_dialect_kwargs_items(autogen_context, index)
     text = tmpl % {
         "prefix": _alembic_autogenerate_prefix(autogen_context),
         "name": _render_gen_name(autogen_context, index.name),
@@ -297,18 +309,7 @@ def _add_index(autogen_context: AutogenContext, op: ops.CreateIndexOp) -> str:
         "schema": (", schema=%r" % _ident(index.table.schema))
         if index.table.schema
         else "",
-        "kwargs": (
-            ", "
-            + ", ".join(
-                [
-                    "%s=%s"
-                    % (key, _render_potential_expr(val, autogen_context))
-                    for key, val in index.kwargs.items()
-                ]
-            )
-        )
-        if len(index.kwargs)
-        else "",
+        "kwargs": ", " + ", ".join(opts) if opts else "",
     }
     return text
 
@@ -326,24 +327,13 @@ def _drop_index(autogen_context: AutogenContext, op: ops.DropIndexOp) -> str:
             "%(prefix)sdrop_index(%(name)r, "
             "table_name=%(table_name)r%(schema)s%(kwargs)s)"
         )
-
+    opts = _render_dialect_kwargs_items(autogen_context, index)
     text = tmpl % {
         "prefix": _alembic_autogenerate_prefix(autogen_context),
         "name": _render_gen_name(autogen_context, op.index_name),
         "table_name": _ident(op.table_name),
         "schema": ((", schema=%r" % _ident(op.schema)) if op.schema else ""),
-        "kwargs": (
-            ", "
-            + ", ".join(
-                [
-                    "%s=%s"
-                    % (key, _render_potential_expr(val, autogen_context))
-                    for key, val in index.kwargs.items()
-                ]
-            )
-        )
-        if len(index.kwargs)
-        else "",
+        "kwargs": ", " + ", ".join(opts) if opts else "",
     }
     return text
 
@@ -604,6 +594,7 @@ def _uq_constraint(
         opts.append(
             ("name", _render_gen_name(autogen_context, constraint.name))
         )
+    dialect_options = _render_dialect_kwargs_items(autogen_context, constraint)
 
     if alter:
         args = [repr(_render_gen_name(autogen_context, constraint.name))]
@@ -611,6 +602,7 @@ def _uq_constraint(
             args += [repr(_ident(constraint.table.name))]
         args.append(repr([_ident(col.name) for col in constraint.columns]))
         args.extend(["%s=%r" % (k, v) for k, v in opts])
+        args.extend(dialect_options)
         return "%(prefix)screate_unique_constraint(%(args)s)" % {
             "prefix": _alembic_autogenerate_prefix(autogen_context),
             "args": ", ".join(args),
@@ -618,6 +610,7 @@ def _uq_constraint(
     else:
         args = [repr(_ident(col.name)) for col in constraint.columns]
         args.extend(["%s=%r" % (k, v) for k, v in opts])
+        args.extend(dialect_options)
         return "%(prefix)sUniqueConstraint(%(args)s)" % {
             "prefix": _sqlalchemy_autogenerate_prefix(autogen_context),
             "args": ", ".join(args),
index 726f16867b8e14f2a33963524657e7f31937c043..31667ef8c694b68a91ddfbc7a782bad110f83142 100644 (file)
@@ -668,6 +668,12 @@ class DefaultImpl(metaclass=ImplMeta):
         # order of col matters in an index
         return tuple(col.name for col in index.columns)
 
+    def create_unique_constraint_sig(
+        self, const: UniqueConstraint
+    ) -> Tuple[Any, ...]:
+        # order of col does not matters in an unique constraint
+        return tuple(sorted([col.name for col in const.columns]))
+
     def _skip_functional_indexes(self, metadata_indexes, conn_indexes):
         conn_indexes_by_name = {c.name: c for c in conn_indexes}
 
index c2d31062156423b1a71839b13003b1774f5fa331..afabd6c010bd11edca40d36b6165e0d6bbad3548 100644 (file)
@@ -12,7 +12,6 @@ from typing import TYPE_CHECKING
 from typing import Union
 
 from sqlalchemy import Column
-from sqlalchemy import Index
 from sqlalchemy import literal_column
 from sqlalchemy import Numeric
 from sqlalchemy import text
@@ -50,6 +49,8 @@ from ..util import sqla_compat
 if TYPE_CHECKING:
     from typing import Literal
 
+    from sqlalchemy import Index
+    from sqlalchemy import UniqueConstraint
     from sqlalchemy.dialects.postgresql.array import ARRAY
     from sqlalchemy.dialects.postgresql.base import PGDDLCompiler
     from sqlalchemy.dialects.postgresql.hstore import HSTORE
@@ -305,6 +306,21 @@ class PostgresqlImpl(DefaultImpl):
                 break
         return to_remove
 
+    def _dialect_sig(
+        self, item: Union[Index, UniqueConstraint]
+    ) -> Tuple[Any, ...]:
+        if (
+            item.dialect_kwargs.get("postgresql_nulls_not_distinct")
+            is not None
+        ):
+            return (
+                (
+                    "nulls_not_distinct",
+                    item.dialect_kwargs["postgresql_nulls_not_distinct"],
+                ),
+            )
+        return ()
+
     def create_index_sig(self, index: Index) -> Tuple[Any, ...]:
         return tuple(
             self._cleanup_index_expr(
@@ -316,7 +332,14 @@ class PostgresqlImpl(DefaultImpl):
                 ),
             )
             for e in index.expressions
-        )
+        ) + self._dialect_sig(index)
+
+    def create_unique_constraint_sig(
+        self, const: UniqueConstraint
+    ) -> Tuple[Any, ...]:
+        return tuple(
+            sorted([col.name for col in const.columns])
+        ) + self._dialect_sig(const)
 
     def _compile_element(self, element: ClauseElement) -> str:
         return element.compile(
index a2acafef4e06707c07d6b107805da699a56c0d4f..e2c1fd23061556fb71e51a2c52a0a505b58103e1 100644 (file)
@@ -86,7 +86,7 @@ class AbstractOperations(util.ModuleClsProxy):
     @classmethod
     def register_operation(
         cls, name: str, sourcename: Optional[str] = None
-    ) -> Callable[..., Any]:
+    ) -> Callable[[_T], _T]:
         """Register a new operation for this class.
 
         This method is normally used to add new operations
diff --git a/docs/build/unreleased/1249.rst b/docs/build/unreleased/1249.rst
new file mode 100644 (file)
index 0000000..b3740cb
--- /dev/null
@@ -0,0 +1,6 @@
+.. change::
+    :tags: usecase, autogenerate
+    :tickets: 1248
+
+    Added support in autogenerate for NULLS NOT DISTINCT in
+    the PostgreSQL dialect.
index dbbb88a5676037ad4ca03efda985b8ded267ec7a..d67a847987df45a0a55ab81ff674eb1fd760cbfe 100644 (file)
@@ -1,4 +1,5 @@
 from sqlalchemy import exc as sqla_exc
+from sqlalchemy import Index
 from sqlalchemy import text
 
 from alembic.testing import exclusions
@@ -430,3 +431,23 @@ class DefaultRequirements(SuiteRequirements):
     @property
     def indexes_with_expressions(self):
         return exclusions.only_on(["postgresql", "sqlite>=3.9.0"])
+
+    @property
+    def nulls_not_distinct_sa(self):
+        def _has_nulls_not_distinct():
+            try:
+                Index("foo", "bar", postgresql_nulls_not_distinct=True)
+                return True
+            except sqla_exc.ArgumentError:
+                return False
+
+        return exclusions.only_if(
+            _has_nulls_not_distinct,
+            "sqlalchemy with nulls not distinct support needed",
+        )
+
+    @property
+    def nulls_not_distinct_db(self):
+        return self.nulls_not_distinct_sa + exclusions.only_on(
+            ["postgresql>=15"]
+        )
index 7b7afdc06e28aaebc028a5713036ddb786172841..8984437b713613ce10ee37f5a6c6b5226f27adee 100644 (file)
@@ -1258,6 +1258,45 @@ class PostgresqlAutogenRenderTest(TestBase):
             "postgresql.JSONB(astext_type=sa.Text())",
         )
 
+    @config.requirements.nulls_not_distinct_sa
+    def test_render_unique_nulls_not_distinct_constraint(self):
+        m = MetaData()
+        t = Table("tbl", m, Column("c", Integer))
+        uc = UniqueConstraint(
+            t.c.c,
+            name="uq_1",
+            deferrable="XYZ",
+            postgresql_nulls_not_distinct=True,
+        )
+        eq_ignore_whitespace(
+            autogenerate.render.render_op_text(
+                self.autogen_context,
+                ops.AddConstraintOp.from_constraint(uc),
+            ),
+            "op.create_unique_constraint('uq_1', 'tbl', ['c'], "
+            "deferrable='XYZ', postgresql_nulls_not_distinct=True)",
+        )
+        eq_ignore_whitespace(
+            autogenerate.render._render_unique_constraint(
+                uc, self.autogen_context, None
+            ),
+            "sa.UniqueConstraint('c', deferrable='XYZ', name='uq_1', "
+            "postgresql_nulls_not_distinct=True)",
+        )
+
+    @config.requirements.nulls_not_distinct_sa
+    def test_render_index_nulls_not_distinct_constraint(self):
+        m = MetaData()
+        t = Table("tbl", m, Column("c", Integer))
+        idx = Index("ix_42", t.c.c, postgresql_nulls_not_distinct=False)
+        eq_ignore_whitespace(
+            autogenerate.render.render_op_text(
+                self.autogen_context, ops.CreateIndexOp.from_index(idx)
+            ),
+            "op.create_index('ix_42', 'tbl', ['c'], unique=False, "
+            "postgresql_nulls_not_distinct=False)",
+        )
+
 
 class PGUniqueIndexAutogenerateTest(AutogenFixtureTest, TestBase):
     __only_on__ = "postgresql"
@@ -1394,3 +1433,103 @@ class PGUniqueIndexAutogenerateTest(AutogenFixtureTest, TestBase):
         eq_(diffs[0][0], "remove_constraint")
         eq_(diffs[0][1].name, "uq_name")
         eq_(len(diffs), 1)
+
+
+case = combinations(False, True, None, argnames="case", id_="s")
+name_type = combinations(
+    (
+        "index",
+        lambda value: Index(
+            "nnd_obj", "name", unique=True, postgresql_nulls_not_distinct=value
+        ),
+    ),
+    (
+        "constraint",
+        lambda value: UniqueConstraint(
+            "id", "name", name="nnd_obj", postgresql_nulls_not_distinct=value
+        ),
+    ),
+    argnames="name,type_",
+    id_="sa",
+)
+
+
+class PGNullsNotDistinctAutogenerateTest(AutogenFixtureTest, TestBase):
+    __requires__ = ("nulls_not_distinct_db",)
+    __only_on__ = "postgresql"
+    __backend__ = True
+
+    @case
+    @name_type
+    def test_add(self, case, name, type_):
+        m1 = MetaData()
+        m2 = MetaData()
+        Table(
+            "tbl",
+            m1,
+            Column("id", Integer, primary_key=True),
+            Column("name", String),
+        )
+        Table(
+            "tbl",
+            m2,
+            Column("id", Integer, primary_key=True),
+            Column("name", String),
+            type_(case),
+        )
+        diffs = self._fixture(m1, m2)
+        eq_(len(diffs), 1)
+        eq_(diffs[0][0], f"add_{name}")
+        added = diffs[0][1]
+        eq_(added.name, "nnd_obj")
+        eq_(added.dialect_kwargs["postgresql_nulls_not_distinct"], case)
+
+    @case
+    @name_type
+    def test_remove(self, case, name, type_):
+        m1 = MetaData()
+        m2 = MetaData()
+        Table(
+            "tbl",
+            m1,
+            Column("id", Integer, primary_key=True),
+            Column("name", String),
+            type_(case),
+        )
+        Table(
+            "tbl",
+            m2,
+            Column("id", Integer, primary_key=True),
+            Column("name", String),
+        )
+        diffs = self._fixture(m1, m2)
+        eq_(len(diffs), 1)
+        eq_(diffs[0][0], f"remove_{name}")
+        eq_(diffs[0][1].name, "nnd_obj")
+
+    @case
+    @name_type
+    def test_toggle_not_distinct(self, case, name, type_):
+        m1 = MetaData()
+        m2 = MetaData()
+        to = not case
+        Table(
+            "tbl",
+            m1,
+            Column("id", Integer, primary_key=True),
+            Column("name", String),
+            type_(case),
+        )
+        Table(
+            "tbl",
+            m2,
+            Column("id", Integer, primary_key=True),
+            Column("name", String),
+            type_(to),
+        )
+        diffs = self._fixture(m1, m2)
+        eq_(len(diffs), 2)
+        eq_(diffs[0][0], f"remove_{name}")
+        eq_(diffs[1][0], f"add_{name}")
+        eq_(diffs[1][1].name, "nnd_obj")
+        eq_(diffs[1][1].dialect_kwargs["postgresql_nulls_not_distinct"], to)