]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Support postgresql_include in UniqueConstraint and PrimaryKeyConstraint
authorDenis Laxalde <denis@laxalde.org>
Tue, 1 Apr 2025 17:30:48 +0000 (13:30 -0400)
committerFederico Caselli <cfederico87@gmail.com>
Tue, 1 Apr 2025 21:20:32 +0000 (23:20 +0200)
This is supported both for schema definition and reflection.

Fixes #10665.

Closes: #12485
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/12485
Pull-request-sha: 1aabea7b55ece9fc0c6e069b777d4404ac01f964

Change-Id: I81d23966f84390dd1b03f0d13284ce6d883ee24e

doc/build/changelog/unreleased_20/10665.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/engine/reflection.py
lib/sqlalchemy/testing/suite/test_reflection.py
test/dialect/postgresql/test_compiler.py
test/dialect/postgresql/test_reflection.py

diff --git a/doc/build/changelog/unreleased_20/10665.rst b/doc/build/changelog/unreleased_20/10665.rst
new file mode 100644 (file)
index 0000000..967dda1
--- /dev/null
@@ -0,0 +1,11 @@
+.. change::
+    :tags: usecase, postgresql
+    :tickets: 10665
+
+    Added support for ``postgresql_include`` keyword argument to
+    :class:`_schema.UniqueConstraint` and :class:`_schema.PrimaryKeyConstraint`.
+    Pull request courtesy Denis Laxalde.
+
+    .. seealso::
+
+        :ref:`postgresql_constraint_options`
index b9bb796e2ad6842d2bd6e8dd26697d5d7aa4330c..53a477b1a6895f4c7d1a64d48c7c87f53a4a96de 100644 (file)
@@ -978,6 +978,8 @@ PostgreSQL-Specific Index Options
 Several extensions to the :class:`.Index` construct are available, specific
 to the PostgreSQL dialect.
 
+.. _postgresql_covering_indexes:
+
 Covering Indexes
 ^^^^^^^^^^^^^^^^
 
@@ -990,6 +992,10 @@ would render the index as ``CREATE INDEX my_index ON table (x) INCLUDE (y)``
 
 Note that this feature requires PostgreSQL 11 or later.
 
+.. seealso::
+
+  :ref:`postgresql_constraint_options`
+
 .. versionadded:: 1.4
 
 .. _postgresql_partial_indexes:
@@ -1258,6 +1264,42 @@ with selected constraint constructs:
       <https://www.postgresql.org/docs/current/static/sql-altertable.html>`_ -
       in the PostgreSQL documentation.
 
+* ``INCLUDE``:  This option adds one or more columns as a "payload" to the
+  unique index created automatically by PostgreSQL for the constraint.
+  For example, the following table definition::
+
+      Table(
+          "mytable",
+          metadata,
+          Column("id", Integer, nullable=False),
+          Column("value", Integer, nullable=False),
+          UniqueConstraint("id", postgresql_include=["value"]),
+      )
+
+  would produce the DDL statement
+
+  .. sourcecode:: sql
+
+       CREATE TABLE mytable (
+           id INTEGER NOT NULL,
+           value INTEGER NOT NULL,
+           UNIQUE (id) INCLUDE (value)
+        )
+
+  Note that this feature requires PostgreSQL 11 or later.
+
+  .. versionadded:: 2.0.41
+
+  .. seealso::
+
+      :ref:`postgresql_covering_indexes`
+
+  .. seealso::
+
+      `PostgreSQL CREATE TABLE options
+      <https://www.postgresql.org/docs/current/static/sql-createtable.html>`_ -
+      in the PostgreSQL documentation.
+
 * Column list with foreign key ``ON DELETE SET`` actions:  This applies to
   :class:`.ForeignKey` and :class:`.ForeignKeyConstraint`, the :paramref:`.ForeignKey.ondelete`
   parameter will accept on the PostgreSQL backend only a string list of column
@@ -2263,6 +2305,18 @@ class PGDDLCompiler(compiler.DDLCompiler):
         not_valid = constraint.dialect_options["postgresql"]["not_valid"]
         return " NOT VALID" if not_valid else ""
 
+    def _define_include(self, obj):
+        includeclause = obj.dialect_options["postgresql"]["include"]
+        if not includeclause:
+            return ""
+        inclusions = [
+            obj.table.c[col] if isinstance(col, str) else col
+            for col in includeclause
+        ]
+        return " INCLUDE (%s)" % ", ".join(
+            [self.preparer.quote(c.name) for c in inclusions]
+        )
+
     def visit_check_constraint(self, constraint, **kw):
         if constraint._type_bound:
             typ = list(constraint.columns)[0].type
@@ -2286,6 +2340,16 @@ class PGDDLCompiler(compiler.DDLCompiler):
         text += self._define_constraint_validity(constraint)
         return text
 
+    def visit_primary_key_constraint(self, constraint, **kw):
+        text = super().visit_primary_key_constraint(constraint)
+        text += self._define_include(constraint)
+        return text
+
+    def visit_unique_constraint(self, constraint, **kw):
+        text = super().visit_unique_constraint(constraint)
+        text += self._define_include(constraint)
+        return text
+
     @util.memoized_property
     def _fk_ondelete_pattern(self):
         return re.compile(
@@ -2400,15 +2464,7 @@ class PGDDLCompiler(compiler.DDLCompiler):
             )
         )
 
-        includeclause = index.dialect_options["postgresql"]["include"]
-        if includeclause:
-            inclusions = [
-                index.table.c[col] if isinstance(col, str) else col
-                for col in includeclause
-            ]
-            text += " INCLUDE (%s)" % ", ".join(
-                [preparer.quote(c.name) for c in inclusions]
-            )
+        text += self._define_include(index)
 
         nulls_not_distinct = index.dialect_options["postgresql"][
             "nulls_not_distinct"
@@ -3156,9 +3212,16 @@ class PGDialect(default.DefaultDialect):
                 "not_valid": False,
             },
         ),
+        (
+            schema.PrimaryKeyConstraint,
+            {"include": None},
+        ),
         (
             schema.UniqueConstraint,
-            {"nulls_not_distinct": None},
+            {
+                "include": None,
+                "nulls_not_distinct": None,
+            },
         ),
     ]
 
@@ -4040,21 +4103,35 @@ class PGDialect(default.DefaultDialect):
         result = connection.execute(oid_q, params)
         return result.all()
 
-    @lru_cache()
-    def _constraint_query(self, is_unique):
+    @util.memoized_property
+    def _constraint_query(self):
+        if self.server_version_info >= (11, 0):
+            indnkeyatts = pg_catalog.pg_index.c.indnkeyatts
+        else:
+            indnkeyatts = sql.null().label("indnkeyatts")
+
+        if self.server_version_info >= (15,):
+            indnullsnotdistinct = pg_catalog.pg_index.c.indnullsnotdistinct
+        else:
+            indnullsnotdistinct = sql.false().label("indnullsnotdistinct")
+
         con_sq = (
             select(
                 pg_catalog.pg_constraint.c.conrelid,
                 pg_catalog.pg_constraint.c.conname,
-                pg_catalog.pg_constraint.c.conindid,
-                sql.func.unnest(pg_catalog.pg_constraint.c.conkey).label(
-                    "attnum"
-                ),
+                sql.func.unnest(pg_catalog.pg_index.c.indkey).label("attnum"),
                 sql.func.generate_subscripts(
-                    pg_catalog.pg_constraint.c.conkey, 1
+                    pg_catalog.pg_index.c.indkey, 1
                 ).label("ord"),
+                indnkeyatts,
+                indnullsnotdistinct,
                 pg_catalog.pg_description.c.description,
             )
+            .join(
+                pg_catalog.pg_index,
+                pg_catalog.pg_constraint.c.conindid
+                == pg_catalog.pg_index.c.indexrelid,
+            )
             .outerjoin(
                 pg_catalog.pg_description,
                 pg_catalog.pg_description.c.objoid
@@ -4063,6 +4140,9 @@ class PGDialect(default.DefaultDialect):
             .where(
                 pg_catalog.pg_constraint.c.contype == bindparam("contype"),
                 pg_catalog.pg_constraint.c.conrelid.in_(bindparam("oids")),
+                # NOTE: filtering also on pg_index.indrelid for oids does
+                # not seem to have a performance effect, but it may be an
+                # option if perf problems are reported
             )
             .subquery("con")
         )
@@ -4071,9 +4151,10 @@ class PGDialect(default.DefaultDialect):
             select(
                 con_sq.c.conrelid,
                 con_sq.c.conname,
-                con_sq.c.conindid,
                 con_sq.c.description,
                 con_sq.c.ord,
+                con_sq.c.indnkeyatts,
+                con_sq.c.indnullsnotdistinct,
                 pg_catalog.pg_attribute.c.attname,
             )
             .select_from(pg_catalog.pg_attribute)
@@ -4096,7 +4177,7 @@ class PGDialect(default.DefaultDialect):
             .subquery("attr")
         )
 
-        constraint_query = (
+        return (
             select(
                 attr_sq.c.conrelid,
                 sql.func.array_agg(
@@ -4108,31 +4189,15 @@ class PGDialect(default.DefaultDialect):
                 ).label("cols"),
                 attr_sq.c.conname,
                 sql.func.min(attr_sq.c.description).label("description"),
+                sql.func.min(attr_sq.c.indnkeyatts).label("indnkeyatts"),
+                sql.func.bool_and(attr_sq.c.indnullsnotdistinct).label(
+                    "indnullsnotdistinct"
+                ),
             )
             .group_by(attr_sq.c.conrelid, attr_sq.c.conname)
             .order_by(attr_sq.c.conrelid, attr_sq.c.conname)
         )
 
-        if is_unique:
-            if self.server_version_info >= (15,):
-                constraint_query = constraint_query.join(
-                    pg_catalog.pg_index,
-                    attr_sq.c.conindid == pg_catalog.pg_index.c.indexrelid,
-                ).add_columns(
-                    sql.func.bool_and(
-                        pg_catalog.pg_index.c.indnullsnotdistinct
-                    ).label("indnullsnotdistinct")
-                )
-            else:
-                constraint_query = constraint_query.add_columns(
-                    sql.false().label("indnullsnotdistinct")
-                )
-        else:
-            constraint_query = constraint_query.add_columns(
-                sql.null().label("extra")
-            )
-        return constraint_query
-
     def _reflect_constraint(
         self, connection, contype, schema, filter_names, scope, kind, **kw
     ):
@@ -4148,26 +4213,45 @@ class PGDialect(default.DefaultDialect):
             batches[0:3000] = []
 
             result = connection.execute(
-                self._constraint_query(is_unique),
+                self._constraint_query,
                 {"oids": [r[0] for r in batch], "contype": contype},
-            )
+            ).mappings()
 
             result_by_oid = defaultdict(list)
-            for oid, cols, constraint_name, comment, extra in result:
-                result_by_oid[oid].append(
-                    (cols, constraint_name, comment, extra)
-                )
+            for row_dict in result:
+                result_by_oid[row_dict["conrelid"]].append(row_dict)
 
             for oid, tablename in batch:
                 for_oid = result_by_oid.get(oid, ())
                 if for_oid:
-                    for cols, constraint, comment, extra in for_oid:
-                        if is_unique:
-                            yield tablename, cols, constraint, comment, {
-                                "nullsnotdistinct": extra
-                            }
+                    for row in for_oid:
+                        # See note in get_multi_indexes
+                        all_cols = row["cols"]
+                        indnkeyatts = row["indnkeyatts"]
+                        if (
+                            indnkeyatts is not None
+                            and len(all_cols) > indnkeyatts
+                        ):
+                            inc_cols = all_cols[indnkeyatts:]
+                            cst_cols = all_cols[:indnkeyatts]
                         else:
-                            yield tablename, cols, constraint, comment, None
+                            inc_cols = []
+                            cst_cols = all_cols
+
+                        opts = {}
+                        if self.server_version_info >= (11,):
+                            opts["postgresql_include"] = inc_cols
+                        if is_unique:
+                            opts["postgresql_nulls_not_distinct"] = row[
+                                "indnullsnotdistinct"
+                            ]
+                        yield (
+                            tablename,
+                            cst_cols,
+                            row["conname"],
+                            row["description"],
+                            opts,
+                        )
                 else:
                     yield tablename, None, None, None, None
 
@@ -4193,20 +4277,27 @@ class PGDialect(default.DefaultDialect):
         # only a single pk can be present for each table. Return an entry
         # even if a table has no primary key
         default = ReflectionDefaults.pk_constraint
+
+        def pk_constraint(pk_name, cols, comment, opts):
+            info = {
+                "constrained_columns": cols,
+                "name": pk_name,
+                "comment": comment,
+            }
+            if opts:
+                info["dialect_options"] = opts
+            return info
+
         return (
             (
                 (schema, table_name),
                 (
-                    {
-                        "constrained_columns": [] if cols is None else cols,
-                        "name": pk_name,
-                        "comment": comment,
-                    }
+                    pk_constraint(pk_name, cols, comment, opts)
                     if pk_name is not None
                     else default()
                 ),
             )
-            for table_name, cols, pk_name, comment, _ in result
+            for table_name, cols, pk_name, comment, opts in result
         )
 
     @reflection.cache
@@ -4597,7 +4688,10 @@ class PGDialect(default.DefaultDialect):
                     # "The number of key columns in the index, not counting any
                     # included columns, which are merely stored and do not
                     # participate in the index semantics"
-                    if indnkeyatts and len(all_elements) > indnkeyatts:
+                    if (
+                        indnkeyatts is not None
+                        and len(all_elements) > indnkeyatts
+                    ):
                         # this is a "covering index" which has INCLUDE columns
                         # as well as regular index columns
                         inc_cols = all_elements[indnkeyatts:]
@@ -4727,12 +4821,7 @@ class PGDialect(default.DefaultDialect):
                 "comment": comment,
             }
             if options:
-                if options["nullsnotdistinct"]:
-                    uc_dict["dialect_options"] = {
-                        "postgresql_nulls_not_distinct": options[
-                            "nullsnotdistinct"
-                        ]
-                    }
+                uc_dict["dialect_options"] = options
 
             uniques[(schema, table_name)].append(uc_dict)
         return uniques.items()
index 9b6835838577e35140604574ea178e938a1f03dc..d063cd7c9f3ae63f5cefef12987894220ee58a56 100644 (file)
@@ -1712,9 +1712,12 @@ class Inspector(inspection.Inspectable["Inspector"]):
                 if pk in cols_by_orig_name and pk not in exclude_columns
             ]
 
-            # update pk constraint name and comment
+            # update pk constraint name, comment and dialect_kwargs
             table.primary_key.name = pk_cons.get("name")
             table.primary_key.comment = pk_cons.get("comment", None)
+            dialect_options = pk_cons.get("dialect_options")
+            if dialect_options:
+                table.primary_key.dialect_kwargs.update(dialect_options)
 
             # tell the PKConstraint to re-initialize
             # its column collection
index 6be86cde106e1f597cd4abc3f301df94f8ca34d9..faafe7dc578ac54a0cdd4f2320772291410fe678 100644 (file)
@@ -1955,6 +1955,8 @@ class ComponentReflectionTest(ComparesTables, OneConnectionTablesTest):
             if dupe:
                 names_that_duplicate_index.add(dupe)
             eq_(refl.pop("comment", None), None)
+            # ignore dialect_options
+            refl.pop("dialect_options", None)
             eq_(orig, refl)
 
         reflected_metadata = MetaData()
index 370981e19db4f2f377e52120ce297e32fc91b65a..eda9f96662e1d90d8f533c565901d3eb94501766 100644 (file)
@@ -23,6 +23,7 @@ from sqlalchemy import Integer
 from sqlalchemy import literal
 from sqlalchemy import MetaData
 from sqlalchemy import null
+from sqlalchemy import PrimaryKeyConstraint
 from sqlalchemy import schema
 from sqlalchemy import select
 from sqlalchemy import Sequence
@@ -796,6 +797,40 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
         expr = testing.resolve_lambda(expr_fn, tbl=tbl)
         self.assert_compile(expr, expected, dialect=dd)
 
+    @testing.combinations(
+        (
+            lambda tbl: schema.AddConstraint(
+                UniqueConstraint(tbl.c.id, postgresql_include=[tbl.c.value])
+            ),
+            "ALTER TABLE foo ADD UNIQUE (id) INCLUDE (value)",
+        ),
+        (
+            lambda tbl: schema.AddConstraint(
+                PrimaryKeyConstraint(
+                    tbl.c.id, postgresql_include=[tbl.c.value, "misc"]
+                )
+            ),
+            "ALTER TABLE foo ADD PRIMARY KEY (id) INCLUDE (value, misc)",
+        ),
+        (
+            lambda tbl: schema.CreateIndex(
+                Index("idx", tbl.c.id, postgresql_include=[tbl.c.value])
+            ),
+            "CREATE INDEX idx ON foo (id) INCLUDE (value)",
+        ),
+    )
+    def test_include(self, expr_fn, expected):
+        m = MetaData()
+        tbl = Table(
+            "foo",
+            m,
+            Column("id", Integer, nullable=False),
+            Column("value", Integer, nullable=False),
+            Column("misc", String),
+        )
+        expr = testing.resolve_lambda(expr_fn, tbl=tbl)
+        self.assert_compile(expr, expected)
+
     def test_create_index_with_labeled_ops(self):
         m = MetaData()
         tbl = Table(
index 20844a0eaea76629921253c7dc1b8e9058bf652c..ebe751b5b348831ebba2fe992fbd6748ca8414d4 100644 (file)
@@ -1770,6 +1770,7 @@ class ReflectionTest(
                 "column_names": ["y"],
                 "name": "unq1",
                 "dialect_options": {
+                    "postgresql_include": [],
                     "postgresql_nulls_not_distinct": True,
                 },
                 "comment": None,
@@ -2602,6 +2603,51 @@ class ReflectionTest(
             connection.execute(sa_ddl.DropConstraintComment(cst))
         all_none()
 
+    @testing.skip_if("postgresql < 11.0", "not supported")
+    def test_reflection_constraints_with_include(self, connection, metadata):
+        Table(
+            "foo",
+            metadata,
+            Column("id", Integer, nullable=False),
+            Column("value", Integer, nullable=False),
+            Column("foo", String),
+            Column("arr", ARRAY(Integer)),
+            Column("bar", SmallInteger),
+        )
+        metadata.create_all(connection)
+        connection.exec_driver_sql(
+            "ALTER TABLE foo ADD UNIQUE (id) INCLUDE (value)"
+        )
+        connection.exec_driver_sql(
+            "ALTER TABLE foo "
+            "ADD PRIMARY KEY (id) INCLUDE (arr, foo, bar, value)"
+        )
+
+        unq = inspect(connection).get_unique_constraints("foo")
+        expected_unq = [
+            {
+                "column_names": ["id"],
+                "name": "foo_id_value_key",
+                "dialect_options": {
+                    "postgresql_nulls_not_distinct": False,
+                    "postgresql_include": ["value"],
+                },
+                "comment": None,
+            }
+        ]
+        eq_(unq, expected_unq)
+
+        pk = inspect(connection).get_pk_constraint("foo")
+        expected_pk = {
+            "comment": None,
+            "constrained_columns": ["id"],
+            "dialect_options": {
+                "postgresql_include": ["arr", "foo", "bar", "value"]
+            },
+            "name": "foo_pkey",
+        }
+        eq_(pk, expected_pk)
+
 
 class CustomTypeReflectionTest(fixtures.TestBase):
     class CustomType: