]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Support postgresql_include in UniqueConstraint and PrimaryKeyConstraint
authorDenis Laxalde <denis@laxalde.org>
Fri, 28 Mar 2025 13:54:20 +0000 (14:54 +0100)
committerDenis Laxalde <denis@laxalde.org>
Tue, 1 Apr 2025 09:08:35 +0000 (11:08 +0200)
This is supported both for schema definition and reflection.

Fix #10665.

doc/build/changelog/unreleased_20/10665.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/engine/reflection.py
test/dialect/postgresql/test_compiler.py
test/dialect/postgresql/test_reflection.py

diff --git a/doc/build/changelog/unreleased_20/10665.rst b/doc/build/changelog/unreleased_20/10665.rst
new file mode 100644 (file)
index 0000000..967dda1
--- /dev/null
@@ -0,0 +1,11 @@
+.. change::
+    :tags: usecase, postgresql
+    :tickets: 10665
+
+    Added support for ``postgresql_include`` keyword argument to
+    :class:`_schema.UniqueConstraint` and :class:`_schema.PrimaryKeyConstraint`.
+    Pull request courtesy Denis Laxalde.
+
+    .. seealso::
+
+        :ref:`postgresql_constraint_options`
index b9bb796e2ad6842d2bd6e8dd26697d5d7aa4330c..9a85402652796917561423e12a7b8b1531eb5c81 100644 (file)
@@ -978,6 +978,8 @@ PostgreSQL-Specific Index Options
 Several extensions to the :class:`.Index` construct are available, specific
 to the PostgreSQL dialect.
 
+.. _postgresql_covering_indexes:
+
 Covering Indexes
 ^^^^^^^^^^^^^^^^
 
@@ -990,6 +992,10 @@ would render the index as ``CREATE INDEX my_index ON table (x) INCLUDE (y)``
 
 Note that this feature requires PostgreSQL 11 or later.
 
+.. seealso::
+
+  :ref:`postgresql_constraint_options`
+
 .. versionadded:: 1.4
 
 .. _postgresql_partial_indexes:
@@ -1258,6 +1264,42 @@ with selected constraint constructs:
       <https://www.postgresql.org/docs/current/static/sql-altertable.html>`_ -
       in the PostgreSQL documentation.
 
+* ``INCLUDE``:  This option adds one or more columns as a "payload" to the
+  unique index created automatically by PostgreSQL for the constraint.
+  For example, the following table definition::
+
+      Table(
+          "mytable",
+          metadata,
+          Column("id", Integer, nullable=False),
+          Column("value", Integer, nullable=False),
+          UniqueConstraint("id", postgresql_include=["value"]),
+      )
+
+  would produce the DDL statement
+
+  .. sourcecode:: sql
+
+       CREATE TABLE mytable (
+           id INTEGER NOT NULL,
+           value INTEGER NOT NULL,
+           UNIQUE (id) INCLUDE (value)
+        )
+
+  Note that this feature requires PostgreSQL 11 or later.
+
+  .. versionadded:: 2.0.41
+
+  .. seealso::
+
+      :ref:`postgresql_covering_indexes`
+
+  .. seealso::
+
+      `PostgreSQL CREATE TABLE options
+      <https://www.postgresql.org/docs/current/static/sql-createtable.html>`_ -
+      in the PostgreSQL documentation.
+
 * Column list with foreign key ``ON DELETE SET`` actions:  This applies to
   :class:`.ForeignKey` and :class:`.ForeignKeyConstraint`, the :paramref:`.ForeignKey.ondelete`
   parameter will accept on the PostgreSQL backend only a string list of column
@@ -2263,6 +2305,18 @@ class PGDDLCompiler(compiler.DDLCompiler):
         not_valid = constraint.dialect_options["postgresql"]["not_valid"]
         return " NOT VALID" if not_valid else ""
 
+    def _define_include(self, obj):
+        includeclause = obj.dialect_options["postgresql"]["include"]
+        if not includeclause:
+            return ""
+        inclusions = [
+            obj.table.c[col] if isinstance(col, str) else col
+            for col in includeclause
+        ]
+        return " INCLUDE (%s)" % ", ".join(
+            [self.preparer.quote(c.name) for c in inclusions]
+        )
+
     def visit_check_constraint(self, constraint, **kw):
         if constraint._type_bound:
             typ = list(constraint.columns)[0].type
@@ -2286,6 +2340,16 @@ class PGDDLCompiler(compiler.DDLCompiler):
         text += self._define_constraint_validity(constraint)
         return text
 
+    def visit_primary_key_constraint(self, constraint, **kw):
+        text = super().visit_primary_key_constraint(constraint)
+        text += self._define_include(constraint)
+        return text
+
+    def visit_unique_constraint(self, constraint, **kw):
+        text = super().visit_unique_constraint(constraint)
+        text += self._define_include(constraint)
+        return text
+
     @util.memoized_property
     def _fk_ondelete_pattern(self):
         return re.compile(
@@ -2400,15 +2464,7 @@ class PGDDLCompiler(compiler.DDLCompiler):
             )
         )
 
-        includeclause = index.dialect_options["postgresql"]["include"]
-        if includeclause:
-            inclusions = [
-                index.table.c[col] if isinstance(col, str) else col
-                for col in includeclause
-            ]
-            text += " INCLUDE (%s)" % ", ".join(
-                [preparer.quote(c.name) for c in inclusions]
-            )
+        text += self._define_include(index)
 
         nulls_not_distinct = index.dialect_options["postgresql"][
             "nulls_not_distinct"
@@ -3156,9 +3212,16 @@ class PGDialect(default.DefaultDialect):
                 "not_valid": False,
             },
         ),
+        (
+            schema.PrimaryKeyConstraint,
+            {"include": None},
+        ),
         (
             schema.UniqueConstraint,
-            {"nulls_not_distinct": None},
+            {
+                "include": None,
+                "nulls_not_distinct": None,
+            },
         ),
     ]
 
@@ -4054,6 +4117,9 @@ class PGDialect(default.DefaultDialect):
                     pg_catalog.pg_constraint.c.conkey, 1
                 ).label("ord"),
                 pg_catalog.pg_description.c.description,
+                pg_catalog.pg_get_constraintdef(
+                    pg_catalog.pg_constraint.c.oid, True
+                ).label("condef"),
             )
             .outerjoin(
                 pg_catalog.pg_description,
@@ -4074,6 +4140,7 @@ class PGDialect(default.DefaultDialect):
                 con_sq.c.conindid,
                 con_sq.c.description,
                 con_sq.c.ord,
+                con_sq.c.condef,
                 pg_catalog.pg_attribute.c.attname,
             )
             .select_from(pg_catalog.pg_attribute)
@@ -4108,8 +4175,9 @@ class PGDialect(default.DefaultDialect):
                 ).label("cols"),
                 attr_sq.c.conname,
                 sql.func.min(attr_sq.c.description).label("description"),
+                attr_sq.c.condef,
             )
-            .group_by(attr_sq.c.conrelid, attr_sq.c.conname)
+            .group_by(attr_sq.c.conrelid, attr_sq.c.conname, attr_sq.c.condef)
             .order_by(attr_sq.c.conrelid, attr_sq.c.conname)
         )
 
@@ -4133,6 +4201,10 @@ class PGDialect(default.DefaultDialect):
             )
         return constraint_query
 
+    @util.memoized_property
+    def _include_regex_pattern(self):
+        return re.compile(r"INCLUDE \((.+)\)")
+
     def _reflect_constraint(
         self, connection, contype, schema, filter_names, scope, kind, **kw
     ):
@@ -4143,6 +4215,8 @@ class PGDialect(default.DefaultDialect):
         batches = list(table_oids)
         is_unique = contype == "u"
 
+        INCLUDE_REGEX = self._include_regex_pattern
+
         while batches:
             batch = batches[0:3000]
             batches[0:3000] = []
@@ -4153,21 +4227,26 @@ class PGDialect(default.DefaultDialect):
             )
 
             result_by_oid = defaultdict(list)
-            for oid, cols, constraint_name, comment, extra in result:
+            for oid, cols, constraint_name, comment, condef, extra in result:
                 result_by_oid[oid].append(
-                    (cols, constraint_name, comment, extra)
+                    (cols, constraint_name, comment, condef, extra)
                 )
 
             for oid, tablename in batch:
                 for_oid = result_by_oid.get(oid, ())
                 if for_oid:
-                    for cols, constraint, comment, extra in for_oid:
+                    for cols, constraint, comment, condef, extra in for_oid:
+                        opts = {}
                         if is_unique:
-                            yield tablename, cols, constraint, comment, {
-                                "nullsnotdistinct": extra
-                            }
-                        else:
-                            yield tablename, cols, constraint, comment, None
+                            opts["nullsnotdistinct"] = extra
+                        m = INCLUDE_REGEX.search(condef)
+                        if m:
+                            opts["include"] = [
+                                v.strip() for v in m.group(1).split(", ")
+                            ]
+                        if not opts:
+                            opts = None
+                        yield tablename, cols, constraint, comment, opts
                 else:
                     yield tablename, None, None, None, None
 
@@ -4193,20 +4272,29 @@ class PGDialect(default.DefaultDialect):
         # only a single pk can be present for each table. Return an entry
         # even if a table has no primary key
         default = ReflectionDefaults.pk_constraint
+
+        def pk_constraint(pk_name, cols, comment, opts):
+            info = {
+                "constrained_columns": ([] if cols is None else cols),
+                "name": pk_name,
+                "comment": comment,
+            }
+            if opts and "include" in opts:
+                info["dialect_options"] = {
+                    "postgresql_include": opts["include"]
+                }
+            return info
+
         return (
             (
                 (schema, table_name),
                 (
-                    {
-                        "constrained_columns": [] if cols is None else cols,
-                        "name": pk_name,
-                        "comment": comment,
-                    }
+                    pk_constraint(pk_name, cols, comment, opts)
                     if pk_name is not None
                     else default()
                 ),
             )
-            for table_name, cols, pk_name, comment, _ in result
+            for table_name, cols, pk_name, comment, opts in result
         )
 
     @reflection.cache
@@ -4728,11 +4816,13 @@ class PGDialect(default.DefaultDialect):
             }
             if options:
                 if options["nullsnotdistinct"]:
-                    uc_dict["dialect_options"] = {
-                        "postgresql_nulls_not_distinct": options[
-                            "nullsnotdistinct"
-                        ]
-                    }
+                    uc_dict.setdefault("dialect_options", {})[
+                        "postgresql_nulls_not_distinct"
+                    ] = options["nullsnotdistinct"]
+                if "include" in options:
+                    uc_dict.setdefault("dialect_options", {})[
+                        "postgresql_include"
+                    ] = options["include"]
 
             uniques[(schema, table_name)].append(uc_dict)
         return uniques.items()
index 9b6835838577e35140604574ea178e938a1f03dc..658acdca690183faaf8fe46b065dc9f805228faf 100644 (file)
@@ -1712,9 +1712,12 @@ class Inspector(inspection.Inspectable["Inspector"]):
                 if pk in cols_by_orig_name and pk not in exclude_columns
             ]
 
-            # update pk constraint name and comment
+            # update pk constraint name, comment and dialect_kwargs
             table.primary_key.name = pk_cons.get("name")
             table.primary_key.comment = pk_cons.get("comment", None)
+            table.primary_key.dialect_kwargs.update(
+                pk_cons.get("dialect_options", {})
+            )
 
             # tell the PKConstraint to re-initialize
             # its column collection
index 370981e19db4f2f377e52120ce297e32fc91b65a..cc963528a95eb2901ae4bf01167fa47f74a5e082 100644 (file)
@@ -23,6 +23,7 @@ from sqlalchemy import Integer
 from sqlalchemy import literal
 from sqlalchemy import MetaData
 from sqlalchemy import null
+from sqlalchemy import PrimaryKeyConstraint
 from sqlalchemy import schema
 from sqlalchemy import select
 from sqlalchemy import Sequence
@@ -796,6 +797,41 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
         expr = testing.resolve_lambda(expr_fn, tbl=tbl)
         self.assert_compile(expr, expected, dialect=dd)
 
+    @testing.combinations(
+        (
+            lambda tbl: schema.AddConstraint(
+                UniqueConstraint(tbl.c.id, postgresql_include=[tbl.c.value])
+            ),
+            "ALTER TABLE foo ADD UNIQUE (id) INCLUDE (value)",
+        ),
+        (
+            lambda tbl: schema.AddConstraint(
+                PrimaryKeyConstraint(
+                    tbl.c.id, postgresql_include=[tbl.c.value, "misc"]
+                )
+            ),
+            "ALTER TABLE foo ADD PRIMARY KEY (id) INCLUDE (value, misc)",
+        ),
+        (
+            lambda tbl: schema.CreateIndex(
+                Index("idx", tbl.c.id, postgresql_include=[tbl.c.value])
+            ),
+            "CREATE INDEX idx ON foo (id) INCLUDE (value)",
+        ),
+    )
+    def test_include(self, expr_fn, expected):
+        dd = PGDialect()
+        m = MetaData()
+        tbl = Table(
+            "foo",
+            m,
+            Column("id", Integer, nullable=False),
+            Column("value", Integer, nullable=False),
+            Column("misc", String),
+        )
+        expr = testing.resolve_lambda(expr_fn, tbl=tbl)
+        self.assert_compile(expr, expected, dialect=dd)
+
     def test_create_index_with_labeled_ops(self):
         m = MetaData()
         tbl = Table(
index 20844a0eaea76629921253c7dc1b8e9058bf652c..4f609e160167bdd4713714114c33ad7db05fcc21 100644 (file)
@@ -2602,6 +2602,45 @@ class ReflectionTest(
             connection.execute(sa_ddl.DropConstraintComment(cst))
         all_none()
 
+    @testing.skip_if("postgresql < 11.0", "not supported")
+    def test_reflection_constraints_with_include(self, connection, metadata):
+        Table(
+            "foo",
+            metadata,
+            Column("id", Integer, nullable=False),
+            Column("value", Integer, nullable=False),
+            Column("misc", String),
+        )
+        metadata.create_all(connection)
+        connection.exec_driver_sql(
+            "ALTER TABLE foo ADD UNIQUE (id) INCLUDE (value)"
+        )
+        connection.exec_driver_sql(
+            "ALTER TABLE foo ADD PRIMARY KEY (id) INCLUDE (value, misc)"
+        )
+
+        unq = inspect(connection).get_unique_constraints("foo")
+        expected_unq = [
+            {
+                "column_names": ["id"],
+                "name": "foo_id_value_key",
+                "dialect_options": {
+                    "postgresql_include": ["value"],
+                },
+                "comment": None,
+            }
+        ]
+        eq_(unq, expected_unq)
+
+        pk = inspect(connection).get_pk_constraint("foo")
+        expected_pk = {
+            "comment": None,
+            "constrained_columns": ["id"],
+            "dialect_options": {"postgresql_include": ["value", "misc"]},
+            "name": "foo_pkey",
+        }
+        eq_(pk, expected_pk)
+
 
 class CustomTypeReflectionTest(fixtures.TestBase):
     class CustomType: