]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
PG nulls not distinct support
authorpavelserchenia <pavel.serchenia@splitmetrics.com>
Fri, 26 May 2023 11:16:54 +0000 (07:16 -0400)
committermike bayer <mike_mp@zzzcomputing.com>
Tue, 6 Jun 2023 17:58:32 +0000 (17:58 +0000)
Added support for PostgreSQL 10 ``NULLS NOT DISTINCT`` feature of
unique indexes and unique constraint using the dialect option
postgresql_nulls_not_distinct``.
Updated the reflection logic to also correctly take this option
into account.

Fixes: #8240
Closes: #9834
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/9834
Pull-request-sha: 825a4ff13a1f428470e184944a167c9d4c57e604

Change-Id: I6585fbb05ad32a131cf9ba25a59f7b229bce5b52

doc/build/changelog/unreleased_20/8240.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/mssql/base.py
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/dialects/postgresql/pg_catalog.py
lib/sqlalchemy/engine/reflection.py
lib/sqlalchemy/sql/compiler.py
test/dialect/postgresql/test_compiler.py
test/dialect/postgresql/test_reflection.py

diff --git a/doc/build/changelog/unreleased_20/8240.rst b/doc/build/changelog/unreleased_20/8240.rst
new file mode 100644 (file)
index 0000000..15e1191
--- /dev/null
@@ -0,0 +1,10 @@
+.. change::
+    :tags: usecase, postgresql
+    :tickets: 8240
+
+    Added support for PostgreSQL 10 ``NULLS NOT DISTINCT`` feature of
+    unique indexes and unique constraint using the dialect option
+    ``postgresql_nulls_not_distinct``.
+    Updated the reflection logic to also correctly take this option
+    into account.
+    Pull request courtesy of Pavel Siarchenia.
index ebf2ce2d3e565da0fa89b38d38bd8dd58d0a08e3..99c3abe7a4158bf77858ba33ff798118fca640a7 100644 (file)
@@ -2688,8 +2688,9 @@ class MSDDLCompiler(compiler.DDLCompiler):
             formatted_name = self.preparer.format_constraint(constraint)
             if formatted_name is not None:
                 text += "CONSTRAINT %s " % formatted_name
-        text += "UNIQUE "
-
+        text += "UNIQUE %s" % self.define_unique_constraint_distinct(
+            constraint, **kw
+        )
         clustered = constraint.dialect_options["mssql"]["clustered"]
         if clustered is not None:
             if clustered:
index 8ac04d6548f2439a750833b93e5af20e519dd0d0..00443d79bae8eab96e6f7864c368ad0e59745e96 100644 (file)
@@ -2226,6 +2226,7 @@ class PGDDLCompiler(compiler.DDLCompiler):
         text = "CREATE "
         if index.unique:
             text += "UNIQUE "
+
         text += "INDEX "
 
         if self.dialect._supports_create_index_concurrently:
@@ -2279,6 +2280,14 @@ class PGDDLCompiler(compiler.DDLCompiler):
                 [preparer.quote(c.name) for c in inclusions]
             )
 
+        nulls_not_distinct = index.dialect_options["postgresql"][
+            "nulls_not_distinct"
+        ]
+        if nulls_not_distinct is True:
+            text += " NULLS NOT DISTINCT"
+        elif nulls_not_distinct is False:
+            text += " NULLS DISTINCT"
+
         withclause = index.dialect_options["postgresql"]["with"]
         if withclause:
             text += " WITH (%s)" % (
@@ -2307,6 +2316,18 @@ class PGDDLCompiler(compiler.DDLCompiler):
 
         return text
 
+    def define_unique_constraint_distinct(self, constraint, **kw):
+        nulls_not_distinct = constraint.dialect_options["postgresql"][
+            "nulls_not_distinct"
+        ]
+        if nulls_not_distinct is True:
+            nulls_not_distinct_param = "NULLS NOT DISTINCT "
+        elif nulls_not_distinct is False:
+            nulls_not_distinct_param = "NULLS DISTINCT "
+        else:
+            nulls_not_distinct_param = ""
+        return nulls_not_distinct_param
+
     def visit_drop_index(self, drop, **kw):
         index = drop.element
 
@@ -2969,6 +2990,7 @@ class PGDialect(default.DefaultDialect):
                 "concurrently": False,
                 "with": {},
                 "tablespace": None,
+                "nulls_not_distinct": None,
             },
         ),
         (
@@ -2994,6 +3016,10 @@ class PGDialect(default.DefaultDialect):
                 "not_valid": False,
             },
         ),
+        (
+            schema.UniqueConstraint,
+            {"nulls_not_distinct": None},
+        ),
     ]
 
     reflection_options = ("postgresql_ignore_search_path",)
@@ -3747,12 +3773,13 @@ class PGDialect(default.DefaultDialect):
         result = connection.execute(oid_q, params)
         return result.all()
 
-    @util.memoized_property
-    def _constraint_query(self):
+    @lru_cache()
+    def _constraint_query(self, is_unique):
         con_sq = (
             select(
                 pg_catalog.pg_constraint.c.conrelid,
                 pg_catalog.pg_constraint.c.conname,
+                pg_catalog.pg_constraint.c.conindid,
                 sql.func.unnest(pg_catalog.pg_constraint.c.conkey).label(
                     "attnum"
                 ),
@@ -3777,6 +3804,7 @@ class PGDialect(default.DefaultDialect):
             select(
                 con_sq.c.conrelid,
                 con_sq.c.conname,
+                con_sq.c.conindid,
                 con_sq.c.description,
                 con_sq.c.ord,
                 pg_catalog.pg_attribute.c.attname,
@@ -3789,10 +3817,19 @@ class PGDialect(default.DefaultDialect):
                     pg_catalog.pg_attribute.c.attrelid == con_sq.c.conrelid,
                 ),
             )
+            .where(
+                # NOTE: restate the condition here, since pg15 otherwise
+                # seems to get confused on pscopg2 sometimes, doing
+                # a sequential scan of pg_attribute.
+                # The condition in the con_sq subquery is not actually needed
+                # in pg15, but it may be needed in older versions. Keeping it
+                # does not seems to have any inpact in any case.
+                con_sq.c.conrelid.in_(bindparam("oids"))
+            )
             .subquery("attr")
         )
 
-        return (
+        constraint_query = (
             select(
                 attr_sq.c.conrelid,
                 sql.func.array_agg(
@@ -3809,34 +3846,63 @@ class PGDialect(default.DefaultDialect):
             .order_by(attr_sq.c.conrelid, attr_sq.c.conname)
         )
 
+        if is_unique:
+            if self.server_version_info >= (15,):
+                constraint_query = constraint_query.join(
+                    pg_catalog.pg_index,
+                    attr_sq.c.conindid == pg_catalog.pg_index.c.indexrelid,
+                ).add_columns(
+                    sql.func.bool_and(
+                        pg_catalog.pg_index.c.indnullsnotdistinct
+                    ).label("indnullsnotdistinct")
+                )
+            else:
+                constraint_query = constraint_query.add_columns(
+                    sql.false().label("indnullsnotdistinct")
+                )
+        else:
+            constraint_query = constraint_query.add_columns(
+                sql.null().label("extra")
+            )
+        return constraint_query
+
     def _reflect_constraint(
         self, connection, contype, schema, filter_names, scope, kind, **kw
     ):
+        # used to reflect primary and unique constraint
         table_oids = self._get_table_oids(
             connection, schema, filter_names, scope, kind, **kw
         )
         batches = list(table_oids)
+        is_unique = contype == "u"
 
         while batches:
             batch = batches[0:3000]
             batches[0:3000] = []
 
             result = connection.execute(
-                self._constraint_query,
+                self._constraint_query(is_unique),
                 {"oids": [r[0] for r in batch], "contype": contype},
             )
 
             result_by_oid = defaultdict(list)
-            for oid, cols, constraint_name, comment in result:
-                result_by_oid[oid].append((cols, constraint_name, comment))
+            for oid, cols, constraint_name, comment, extra in result:
+                result_by_oid[oid].append(
+                    (cols, constraint_name, comment, extra)
+                )
 
             for oid, tablename in batch:
                 for_oid = result_by_oid.get(oid, ())
                 if for_oid:
-                    for cols, constraint, comment in for_oid:
-                        yield tablename, cols, constraint, comment
+                    for cols, constraint, comment, extra in for_oid:
+                        if is_unique:
+                            yield tablename, cols, constraint, comment, {
+                                "nullsnotdistinct": extra
+                            }
+                        else:
+                            yield tablename, cols, constraint, comment, None
                 else:
-                    yield tablename, None, None, None
+                    yield tablename, None, None, None, None
 
     @reflection.cache
     def get_pk_constraint(self, connection, table_name, schema=None, **kw):
@@ -3871,7 +3937,7 @@ class PGDialect(default.DefaultDialect):
                 if pk_name is not None
                 else default(),
             )
-            for table_name, cols, pk_name, comment in result
+            for table_name, cols, pk_name, comment, _ in result
         )
 
     @reflection.cache
@@ -4151,6 +4217,11 @@ class PGDialect(default.DefaultDialect):
         else:
             indnkeyatts = sql.null().label("indnkeyatts")
 
+        if self.server_version_info >= (15,):
+            nulls_not_distinct = pg_catalog.pg_index.c.indnullsnotdistinct
+        else:
+            nulls_not_distinct = sql.false().label("indnullsnotdistinct")
+
         return (
             select(
                 pg_catalog.pg_index.c.indrelid,
@@ -4175,6 +4246,7 @@ class PGDialect(default.DefaultDialect):
                     else_=None,
                 ).label("filter_definition"),
                 indnkeyatts,
+                nulls_not_distinct,
                 cols_sq.c.elements,
                 cols_sq.c.elements_is_expr,
             )
@@ -4318,11 +4390,17 @@ class PGDialect(default.DefaultDialect):
                         dialect_options["postgresql_where"] = row[
                             "filter_definition"
                         ]
-                    if self.server_version_info >= (11, 0):
+                    if self.server_version_info >= (11,):
                         # NOTE: this is legacy, this is part of
                         # dialect_options now as of #7382
                         index["include_columns"] = inc_cols
                         dialect_options["postgresql_include"] = inc_cols
+                    if row["indnullsnotdistinct"]:
+                        # the default is False, so ignore it.
+                        dialect_options["postgresql_nulls_not_distinct"] = row[
+                            "indnullsnotdistinct"
+                        ]
+
                     if dialect_options:
                         index["dialect_options"] = dialect_options
 
@@ -4359,20 +4437,27 @@ class PGDialect(default.DefaultDialect):
         # each table can have multiple unique constraints
         uniques = defaultdict(list)
         default = ReflectionDefaults.unique_constraints
-        for table_name, cols, con_name, comment in result:
+        for table_name, cols, con_name, comment, options in result:
             # ensure a list is created for each table. leave it empty if
             # the table has no unique cosntraint
             if con_name is None:
                 uniques[(schema, table_name)] = default()
                 continue
 
-            uniques[(schema, table_name)].append(
-                {
-                    "column_names": cols,
-                    "name": con_name,
-                    "comment": comment,
-                }
-            )
+            uc_dict = {
+                "column_names": cols,
+                "name": con_name,
+                "comment": comment,
+            }
+            if options:
+                if options["nullsnotdistinct"]:
+                    uc_dict["dialect_options"] = {
+                        "postgresql_nulls_not_distinct": options[
+                            "nullsnotdistinct"
+                        ]
+                    }
+
+            uniques[(schema, table_name)].append(uc_dict)
         return uniques.items()
 
     @reflection.cache
index ed8926a26e87321122864474706758b189e0681b..fa4b30f03f41d90adf0a7780ada45ae7abf6be4f 100644 (file)
@@ -166,6 +166,7 @@ pg_index = Table(
     Column("indnatts", SmallInteger),
     Column("indnkeyatts", SmallInteger, info={"server_version": (11,)}),
     Column("indisunique", Boolean),
+    Column("indnullsnotdistinct", Boolean, info={"server_version": (15,)}),
     Column("indisprimary", Boolean),
     Column("indisexclusion", Boolean, info={"server_version": (9, 1)}),
     Column("indimmediate", Boolean),
index 775abc4332044b0dd01e94ce9ae28e7f3ee922c0..4035901aed68bde0051371f095168379a26d3376 100644 (file)
@@ -1893,6 +1893,7 @@ class Inspector(inspection.Inspectable["Inspector"]):
             columns = const_d["column_names"]
             comment = const_d.get("comment")
             duplicates = const_d.get("duplicates_index")
+            dialect_options = const_d.get("dialect_options", {})
             if include_columns and not set(columns).issubset(include_columns):
                 continue
             if duplicates:
@@ -1916,7 +1917,10 @@ class Inspector(inspection.Inspectable["Inspector"]):
                     constrained_cols.append(constrained_col)
             table.append_constraint(
                 sa_schema.UniqueConstraint(
-                    *constrained_cols, name=conname, comment=comment
+                    *constrained_cols,
+                    name=conname,
+                    comment=comment,
+                    **dialect_options,
                 )
             )
 
index f12de9763222a6f84e8fdd3a4843d122bf98f4ec..198bba3584fe2567b0ac45ed3aa9345eebe410a2 100644 (file)
@@ -6858,12 +6858,16 @@ class DDLCompiler(Compiled):
             formatted_name = self.preparer.format_constraint(constraint)
             if formatted_name is not None:
                 text += "CONSTRAINT %s " % formatted_name
-        text += "UNIQUE (%s)" % (
-            ", ".join(self.preparer.quote(c.name) for c in constraint)
+        text += "UNIQUE %s(%s)" % (
+            self.define_unique_constraint_distinct(constraint, **kw),
+            ", ".join(self.preparer.quote(c.name) for c in constraint),
         )
         text += self.define_constraint_deferrability(constraint)
         return text
 
+    def define_unique_constraint_distinct(self, constraint, **kw):
+        return ""
+
     def define_constraint_cascades(self, constraint):
         text = ""
         if constraint.ondelete is not None:
index 2f57c69315875d4af3b215ef4234ac637af90534..bd944849de07cd1424933d0fac81706178406c64 100644 (file)
@@ -676,6 +676,102 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             dialect=postgresql.dialect(),
         )
 
+    @testing.combinations(
+        (
+            lambda tbl: schema.CreateIndex(
+                Index(
+                    "test_idx1",
+                    tbl.c.data,
+                    unique=True,
+                    postgresql_nulls_not_distinct=True,
+                )
+            ),
+            "CREATE UNIQUE INDEX test_idx1 ON test_tbl "
+            "(data) NULLS NOT DISTINCT",
+        ),
+        (
+            lambda tbl: schema.CreateIndex(
+                Index(
+                    "test_idx2",
+                    tbl.c.data2,
+                    unique=True,
+                    postgresql_nulls_not_distinct=False,
+                )
+            ),
+            "CREATE UNIQUE INDEX test_idx2 ON test_tbl "
+            "(data2) NULLS DISTINCT",
+        ),
+        (
+            lambda tbl: schema.CreateIndex(
+                Index(
+                    "test_idx3",
+                    tbl.c.data3,
+                    unique=True,
+                )
+            ),
+            "CREATE UNIQUE INDEX test_idx3 ON test_tbl " "(data3)",
+        ),
+        (
+            lambda tbl: schema.CreateIndex(
+                Index(
+                    "test_idx3_complex",
+                    tbl.c.data3,
+                    postgresql_nulls_not_distinct=True,
+                    postgresql_include=["data2"],
+                    postgresql_where=and_(tbl.c.data3 > 5),
+                    postgresql_with={"fillfactor": 50},
+                )
+            ),
+            "CREATE INDEX test_idx3_complex ON test_tbl "
+            "(data3) INCLUDE (data2) NULLS NOT DISTINCT WITH "
+            "(fillfactor = 50) WHERE data3 > 5",
+        ),
+        (
+            lambda tbl: schema.AddConstraint(
+                schema.UniqueConstraint(
+                    tbl.c.data,
+                    name="uq_data1",
+                    postgresql_nulls_not_distinct=True,
+                )
+            ),
+            "ALTER TABLE test_tbl ADD CONSTRAINT uq_data1 UNIQUE "
+            "NULLS NOT DISTINCT (data)",
+        ),
+        (
+            lambda tbl: schema.AddConstraint(
+                schema.UniqueConstraint(
+                    tbl.c.data2,
+                    name="uq_data2",
+                    postgresql_nulls_not_distinct=False,
+                )
+            ),
+            "ALTER TABLE test_tbl ADD CONSTRAINT uq_data2 UNIQUE "
+            "NULLS DISTINCT (data2)",
+        ),
+        (
+            lambda tbl: schema.AddConstraint(
+                schema.UniqueConstraint(
+                    tbl.c.data3,
+                    name="uq_data3",
+                )
+            ),
+            "ALTER TABLE test_tbl ADD CONSTRAINT uq_data3 UNIQUE (data3)",
+        ),
+    )
+    def test_nulls_not_distinct(self, expr_fn, expected):
+        dd = PGDialect()
+        m = MetaData()
+        tbl = Table(
+            "test_tbl",
+            m,
+            Column("data", String),
+            Column("data2", Integer),
+            Column("data3", Integer),
+        )
+
+        expr = testing.resolve_lambda(expr_fn, tbl=tbl)
+        self.assert_compile(expr, expected, dialect=dd)
+
     def test_create_index_with_labeled_ops(self):
         m = MetaData()
         tbl = Table(
index 49838ec6abed77e532463508f5af0ba494e4fbbc..f7f86a79c33dd12765ebb6a5abe24945fa2c1c06 100644 (file)
@@ -1179,6 +1179,15 @@ class ReflectionTest(
                 where name != 'foo'
             """
         )
+        version = connection.dialect.server_version_info
+        if version >= (15,):
+            connection.exec_driver_sql(
+                """
+                create unique index zz_idx5 on party
+                    (name desc, upper(other))
+                    nulls not distinct
+                """
+            )
 
         expected = [
             {
@@ -1238,7 +1247,23 @@ class ReflectionTest(
                 "dialect_options": {"postgresql_include": []},
             },
         ]
-        if connection.dialect.server_version_info < (11,):
+        if version > (15,):
+            expected.append(
+                {
+                    "name": "zz_idx5",
+                    "column_names": ["name", None],
+                    "expressions": ["name", "upper(other::text)"],
+                    "unique": True,
+                    "include_columns": [],
+                    "dialect_options": {
+                        "postgresql_include": [],
+                        "postgresql_nulls_not_distinct": True,
+                    },
+                    "column_sorting": {"name": ("desc",)},
+                },
+            )
+
+        if version < (11,):
             for index in expected:
                 index.pop("include_columns")
                 index["dialect_options"].pop("postgresql_include")
@@ -1462,6 +1487,72 @@ class ReflectionTest(
             "gin",
         )
 
+    @testing.skip_if("postgresql < 15.0", "nullsnotdistinct not supported")
+    def test_nullsnotdistinct(self, metadata, connection):
+        Table(
+            "t",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column("x", ARRAY(Integer)),
+            Column("y", ARRAY(Integer)),
+            Index(
+                "idx1", "x", unique=True, postgresql_nulls_not_distinct=True
+            ),
+            UniqueConstraint(
+                "y", name="unq1", postgresql_nulls_not_distinct=True
+            ),
+        )
+        metadata.create_all(connection)
+
+        ind = inspect(connection).get_indexes("t", None)
+        expected_ind = [
+            {
+                "unique": True,
+                "column_names": ["x"],
+                "name": "idx1",
+                "dialect_options": {
+                    "postgresql_nulls_not_distinct": True,
+                    "postgresql_include": [],
+                },
+                "include_columns": [],
+            },
+            {
+                "unique": True,
+                "column_names": ["y"],
+                "name": "unq1",
+                "dialect_options": {
+                    "postgresql_nulls_not_distinct": True,
+                    "postgresql_include": [],
+                },
+                "include_columns": [],
+                "duplicates_constraint": "unq1",
+            },
+        ]
+        eq_(ind, expected_ind)
+
+        unq = inspect(connection).get_unique_constraints("t", None)
+        expected_unq = [
+            {
+                "column_names": ["y"],
+                "name": "unq1",
+                "dialect_options": {
+                    "postgresql_nulls_not_distinct": True,
+                },
+                "comment": None,
+            }
+        ]
+        eq_(unq, expected_unq)
+
+        m = MetaData()
+        t1 = Table("t", m, autoload_with=connection)
+        eq_(len(t1.indexes), 1)
+        idx_options = list(t1.indexes)[0].dialect_options["postgresql"]
+        eq_(idx_options["nulls_not_distinct"], True)
+
+        cst = {c.name: c for c in t1.constraints}
+        cst_options = cst["unq1"].dialect_options["postgresql"]
+        eq_(cst_options["nulls_not_distinct"], True)
+
     @testing.skip_if("postgresql < 11.0", "indnkeyatts not supported")
     def test_index_reflection_with_include(self, metadata, connection):
         """reflect indexes with include set"""