From: Denis Laxalde Date: Tue, 1 Apr 2025 17:30:48 +0000 (-0400) Subject: Support postgresql_include in UniqueConstraint and PrimaryKeyConstraint X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=3b7725dd1243134341cf1bfb331ed4501fc882e8;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Support postgresql_include in UniqueConstraint and PrimaryKeyConstraint This is supported both for schema definition and reflection. Fixes #10665. Closes: #12485 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/12485 Pull-request-sha: 1aabea7b55ece9fc0c6e069b777d4404ac01f964 Change-Id: I81d23966f84390dd1b03f0d13284ce6d883ee24e --- diff --git a/doc/build/changelog/unreleased_20/10665.rst b/doc/build/changelog/unreleased_20/10665.rst new file mode 100644 index 0000000000..967dda14b1 --- /dev/null +++ b/doc/build/changelog/unreleased_20/10665.rst @@ -0,0 +1,11 @@ +.. change:: + :tags: usecase, postgresql + :tickets: 10665 + + Added support for ``postgresql_include`` keyword argument to + :class:`_schema.UniqueConstraint` and :class:`_schema.PrimaryKeyConstraint`. + Pull request courtesy Denis Laxalde. + + .. seealso:: + + :ref:`postgresql_constraint_options` diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index b9bb796e2a..53a477b1a6 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -978,6 +978,8 @@ PostgreSQL-Specific Index Options Several extensions to the :class:`.Index` construct are available, specific to the PostgreSQL dialect. +.. _postgresql_covering_indexes: + Covering Indexes ^^^^^^^^^^^^^^^^ @@ -990,6 +992,10 @@ would render the index as ``CREATE INDEX my_index ON table (x) INCLUDE (y)`` Note that this feature requires PostgreSQL 11 or later. +.. seealso:: + + :ref:`postgresql_constraint_options` + .. versionadded:: 1.4 .. _postgresql_partial_indexes: @@ -1258,6 +1264,42 @@ with selected constraint constructs: `_ - in the PostgreSQL documentation. +* ``INCLUDE``: This option adds one or more columns as a "payload" to the + unique index created automatically by PostgreSQL for the constraint. + For example, the following table definition:: + + Table( + "mytable", + metadata, + Column("id", Integer, nullable=False), + Column("value", Integer, nullable=False), + UniqueConstraint("id", postgresql_include=["value"]), + ) + + would produce the DDL statement + + .. sourcecode:: sql + + CREATE TABLE mytable ( + id INTEGER NOT NULL, + value INTEGER NOT NULL, + UNIQUE (id) INCLUDE (value) + ) + + Note that this feature requires PostgreSQL 11 or later. + + .. versionadded:: 2.0.41 + + .. seealso:: + + :ref:`postgresql_covering_indexes` + + .. seealso:: + + `PostgreSQL CREATE TABLE options + `_ - + in the PostgreSQL documentation. + * Column list with foreign key ``ON DELETE SET`` actions: This applies to :class:`.ForeignKey` and :class:`.ForeignKeyConstraint`, the :paramref:`.ForeignKey.ondelete` parameter will accept on the PostgreSQL backend only a string list of column @@ -2263,6 +2305,18 @@ class PGDDLCompiler(compiler.DDLCompiler): not_valid = constraint.dialect_options["postgresql"]["not_valid"] return " NOT VALID" if not_valid else "" + def _define_include(self, obj): + includeclause = obj.dialect_options["postgresql"]["include"] + if not includeclause: + return "" + inclusions = [ + obj.table.c[col] if isinstance(col, str) else col + for col in includeclause + ] + return " INCLUDE (%s)" % ", ".join( + [self.preparer.quote(c.name) for c in inclusions] + ) + def visit_check_constraint(self, constraint, **kw): if constraint._type_bound: typ = list(constraint.columns)[0].type @@ -2286,6 +2340,16 @@ class PGDDLCompiler(compiler.DDLCompiler): text += self._define_constraint_validity(constraint) return text + def visit_primary_key_constraint(self, constraint, **kw): + text = super().visit_primary_key_constraint(constraint) + text += self._define_include(constraint) + return text + + def visit_unique_constraint(self, constraint, **kw): + text = super().visit_unique_constraint(constraint) + text += self._define_include(constraint) + return text + @util.memoized_property def _fk_ondelete_pattern(self): return re.compile( @@ -2400,15 +2464,7 @@ class PGDDLCompiler(compiler.DDLCompiler): ) ) - includeclause = index.dialect_options["postgresql"]["include"] - if includeclause: - inclusions = [ - index.table.c[col] if isinstance(col, str) else col - for col in includeclause - ] - text += " INCLUDE (%s)" % ", ".join( - [preparer.quote(c.name) for c in inclusions] - ) + text += self._define_include(index) nulls_not_distinct = index.dialect_options["postgresql"][ "nulls_not_distinct" @@ -3156,9 +3212,16 @@ class PGDialect(default.DefaultDialect): "not_valid": False, }, ), + ( + schema.PrimaryKeyConstraint, + {"include": None}, + ), ( schema.UniqueConstraint, - {"nulls_not_distinct": None}, + { + "include": None, + "nulls_not_distinct": None, + }, ), ] @@ -4040,21 +4103,35 @@ class PGDialect(default.DefaultDialect): result = connection.execute(oid_q, params) return result.all() - @lru_cache() - def _constraint_query(self, is_unique): + @util.memoized_property + def _constraint_query(self): + if self.server_version_info >= (11, 0): + indnkeyatts = pg_catalog.pg_index.c.indnkeyatts + else: + indnkeyatts = sql.null().label("indnkeyatts") + + if self.server_version_info >= (15,): + indnullsnotdistinct = pg_catalog.pg_index.c.indnullsnotdistinct + else: + indnullsnotdistinct = sql.false().label("indnullsnotdistinct") + con_sq = ( select( pg_catalog.pg_constraint.c.conrelid, pg_catalog.pg_constraint.c.conname, - pg_catalog.pg_constraint.c.conindid, - sql.func.unnest(pg_catalog.pg_constraint.c.conkey).label( - "attnum" - ), + sql.func.unnest(pg_catalog.pg_index.c.indkey).label("attnum"), sql.func.generate_subscripts( - pg_catalog.pg_constraint.c.conkey, 1 + pg_catalog.pg_index.c.indkey, 1 ).label("ord"), + indnkeyatts, + indnullsnotdistinct, pg_catalog.pg_description.c.description, ) + .join( + pg_catalog.pg_index, + pg_catalog.pg_constraint.c.conindid + == pg_catalog.pg_index.c.indexrelid, + ) .outerjoin( pg_catalog.pg_description, pg_catalog.pg_description.c.objoid @@ -4063,6 +4140,9 @@ class PGDialect(default.DefaultDialect): .where( pg_catalog.pg_constraint.c.contype == bindparam("contype"), pg_catalog.pg_constraint.c.conrelid.in_(bindparam("oids")), + # NOTE: filtering also on pg_index.indrelid for oids does + # not seem to have a performance effect, but it may be an + # option if perf problems are reported ) .subquery("con") ) @@ -4071,9 +4151,10 @@ class PGDialect(default.DefaultDialect): select( con_sq.c.conrelid, con_sq.c.conname, - con_sq.c.conindid, con_sq.c.description, con_sq.c.ord, + con_sq.c.indnkeyatts, + con_sq.c.indnullsnotdistinct, pg_catalog.pg_attribute.c.attname, ) .select_from(pg_catalog.pg_attribute) @@ -4096,7 +4177,7 @@ class PGDialect(default.DefaultDialect): .subquery("attr") ) - constraint_query = ( + return ( select( attr_sq.c.conrelid, sql.func.array_agg( @@ -4108,31 +4189,15 @@ class PGDialect(default.DefaultDialect): ).label("cols"), attr_sq.c.conname, sql.func.min(attr_sq.c.description).label("description"), + sql.func.min(attr_sq.c.indnkeyatts).label("indnkeyatts"), + sql.func.bool_and(attr_sq.c.indnullsnotdistinct).label( + "indnullsnotdistinct" + ), ) .group_by(attr_sq.c.conrelid, attr_sq.c.conname) .order_by(attr_sq.c.conrelid, attr_sq.c.conname) ) - if is_unique: - if self.server_version_info >= (15,): - constraint_query = constraint_query.join( - pg_catalog.pg_index, - attr_sq.c.conindid == pg_catalog.pg_index.c.indexrelid, - ).add_columns( - sql.func.bool_and( - pg_catalog.pg_index.c.indnullsnotdistinct - ).label("indnullsnotdistinct") - ) - else: - constraint_query = constraint_query.add_columns( - sql.false().label("indnullsnotdistinct") - ) - else: - constraint_query = constraint_query.add_columns( - sql.null().label("extra") - ) - return constraint_query - def _reflect_constraint( self, connection, contype, schema, filter_names, scope, kind, **kw ): @@ -4148,26 +4213,45 @@ class PGDialect(default.DefaultDialect): batches[0:3000] = [] result = connection.execute( - self._constraint_query(is_unique), + self._constraint_query, {"oids": [r[0] for r in batch], "contype": contype}, - ) + ).mappings() result_by_oid = defaultdict(list) - for oid, cols, constraint_name, comment, extra in result: - result_by_oid[oid].append( - (cols, constraint_name, comment, extra) - ) + for row_dict in result: + result_by_oid[row_dict["conrelid"]].append(row_dict) for oid, tablename in batch: for_oid = result_by_oid.get(oid, ()) if for_oid: - for cols, constraint, comment, extra in for_oid: - if is_unique: - yield tablename, cols, constraint, comment, { - "nullsnotdistinct": extra - } + for row in for_oid: + # See note in get_multi_indexes + all_cols = row["cols"] + indnkeyatts = row["indnkeyatts"] + if ( + indnkeyatts is not None + and len(all_cols) > indnkeyatts + ): + inc_cols = all_cols[indnkeyatts:] + cst_cols = all_cols[:indnkeyatts] else: - yield tablename, cols, constraint, comment, None + inc_cols = [] + cst_cols = all_cols + + opts = {} + if self.server_version_info >= (11,): + opts["postgresql_include"] = inc_cols + if is_unique: + opts["postgresql_nulls_not_distinct"] = row[ + "indnullsnotdistinct" + ] + yield ( + tablename, + cst_cols, + row["conname"], + row["description"], + opts, + ) else: yield tablename, None, None, None, None @@ -4193,20 +4277,27 @@ class PGDialect(default.DefaultDialect): # only a single pk can be present for each table. Return an entry # even if a table has no primary key default = ReflectionDefaults.pk_constraint + + def pk_constraint(pk_name, cols, comment, opts): + info = { + "constrained_columns": cols, + "name": pk_name, + "comment": comment, + } + if opts: + info["dialect_options"] = opts + return info + return ( ( (schema, table_name), ( - { - "constrained_columns": [] if cols is None else cols, - "name": pk_name, - "comment": comment, - } + pk_constraint(pk_name, cols, comment, opts) if pk_name is not None else default() ), ) - for table_name, cols, pk_name, comment, _ in result + for table_name, cols, pk_name, comment, opts in result ) @reflection.cache @@ -4597,7 +4688,10 @@ class PGDialect(default.DefaultDialect): # "The number of key columns in the index, not counting any # included columns, which are merely stored and do not # participate in the index semantics" - if indnkeyatts and len(all_elements) > indnkeyatts: + if ( + indnkeyatts is not None + and len(all_elements) > indnkeyatts + ): # this is a "covering index" which has INCLUDE columns # as well as regular index columns inc_cols = all_elements[indnkeyatts:] @@ -4727,12 +4821,7 @@ class PGDialect(default.DefaultDialect): "comment": comment, } if options: - if options["nullsnotdistinct"]: - uc_dict["dialect_options"] = { - "postgresql_nulls_not_distinct": options[ - "nullsnotdistinct" - ] - } + uc_dict["dialect_options"] = options uniques[(schema, table_name)].append(uc_dict) return uniques.items() diff --git a/lib/sqlalchemy/engine/reflection.py b/lib/sqlalchemy/engine/reflection.py index 9b68358385..d063cd7c9f 100644 --- a/lib/sqlalchemy/engine/reflection.py +++ b/lib/sqlalchemy/engine/reflection.py @@ -1712,9 +1712,12 @@ class Inspector(inspection.Inspectable["Inspector"]): if pk in cols_by_orig_name and pk not in exclude_columns ] - # update pk constraint name and comment + # update pk constraint name, comment and dialect_kwargs table.primary_key.name = pk_cons.get("name") table.primary_key.comment = pk_cons.get("comment", None) + dialect_options = pk_cons.get("dialect_options") + if dialect_options: + table.primary_key.dialect_kwargs.update(dialect_options) # tell the PKConstraint to re-initialize # its column collection diff --git a/lib/sqlalchemy/testing/suite/test_reflection.py b/lib/sqlalchemy/testing/suite/test_reflection.py index 6be86cde10..faafe7dc57 100644 --- a/lib/sqlalchemy/testing/suite/test_reflection.py +++ b/lib/sqlalchemy/testing/suite/test_reflection.py @@ -1955,6 +1955,8 @@ class ComponentReflectionTest(ComparesTables, OneConnectionTablesTest): if dupe: names_that_duplicate_index.add(dupe) eq_(refl.pop("comment", None), None) + # ignore dialect_options + refl.pop("dialect_options", None) eq_(orig, refl) reflected_metadata = MetaData() diff --git a/test/dialect/postgresql/test_compiler.py b/test/dialect/postgresql/test_compiler.py index 370981e19d..eda9f96662 100644 --- a/test/dialect/postgresql/test_compiler.py +++ b/test/dialect/postgresql/test_compiler.py @@ -23,6 +23,7 @@ from sqlalchemy import Integer from sqlalchemy import literal from sqlalchemy import MetaData from sqlalchemy import null +from sqlalchemy import PrimaryKeyConstraint from sqlalchemy import schema from sqlalchemy import select from sqlalchemy import Sequence @@ -796,6 +797,40 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): expr = testing.resolve_lambda(expr_fn, tbl=tbl) self.assert_compile(expr, expected, dialect=dd) + @testing.combinations( + ( + lambda tbl: schema.AddConstraint( + UniqueConstraint(tbl.c.id, postgresql_include=[tbl.c.value]) + ), + "ALTER TABLE foo ADD UNIQUE (id) INCLUDE (value)", + ), + ( + lambda tbl: schema.AddConstraint( + PrimaryKeyConstraint( + tbl.c.id, postgresql_include=[tbl.c.value, "misc"] + ) + ), + "ALTER TABLE foo ADD PRIMARY KEY (id) INCLUDE (value, misc)", + ), + ( + lambda tbl: schema.CreateIndex( + Index("idx", tbl.c.id, postgresql_include=[tbl.c.value]) + ), + "CREATE INDEX idx ON foo (id) INCLUDE (value)", + ), + ) + def test_include(self, expr_fn, expected): + m = MetaData() + tbl = Table( + "foo", + m, + Column("id", Integer, nullable=False), + Column("value", Integer, nullable=False), + Column("misc", String), + ) + expr = testing.resolve_lambda(expr_fn, tbl=tbl) + self.assert_compile(expr, expected) + def test_create_index_with_labeled_ops(self): m = MetaData() tbl = Table( diff --git a/test/dialect/postgresql/test_reflection.py b/test/dialect/postgresql/test_reflection.py index 20844a0eae..ebe751b5b3 100644 --- a/test/dialect/postgresql/test_reflection.py +++ b/test/dialect/postgresql/test_reflection.py @@ -1770,6 +1770,7 @@ class ReflectionTest( "column_names": ["y"], "name": "unq1", "dialect_options": { + "postgresql_include": [], "postgresql_nulls_not_distinct": True, }, "comment": None, @@ -2602,6 +2603,51 @@ class ReflectionTest( connection.execute(sa_ddl.DropConstraintComment(cst)) all_none() + @testing.skip_if("postgresql < 11.0", "not supported") + def test_reflection_constraints_with_include(self, connection, metadata): + Table( + "foo", + metadata, + Column("id", Integer, nullable=False), + Column("value", Integer, nullable=False), + Column("foo", String), + Column("arr", ARRAY(Integer)), + Column("bar", SmallInteger), + ) + metadata.create_all(connection) + connection.exec_driver_sql( + "ALTER TABLE foo ADD UNIQUE (id) INCLUDE (value)" + ) + connection.exec_driver_sql( + "ALTER TABLE foo " + "ADD PRIMARY KEY (id) INCLUDE (arr, foo, bar, value)" + ) + + unq = inspect(connection).get_unique_constraints("foo") + expected_unq = [ + { + "column_names": ["id"], + "name": "foo_id_value_key", + "dialect_options": { + "postgresql_nulls_not_distinct": False, + "postgresql_include": ["value"], + }, + "comment": None, + } + ] + eq_(unq, expected_unq) + + pk = inspect(connection).get_pk_constraint("foo") + expected_pk = { + "comment": None, + "constrained_columns": ["id"], + "dialect_options": { + "postgresql_include": ["arr", "foo", "bar", "value"] + }, + "name": "foo_pkey", + } + eq_(pk, expected_pk) + class CustomTypeReflectionTest(fixtures.TestBase): class CustomType: