]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
WIP: Retrieve "included" columns in unique/primary key constraint from the index
authorDenis Laxalde <denis@laxalde.org>
Tue, 1 Apr 2025 10:33:43 +0000 (12:33 +0200)
committerDenis Laxalde <denis@laxalde.org>
Tue, 1 Apr 2025 16:55:01 +0000 (18:55 +0200)
Replace the previous approach relying on a regex on the result of
pg_get_constraintdef() with a subquery on pg_attribute joined with
pg_index using the 'indkey' column (and excluding 'indnkeyatts').

In tests, in order to make sure the order is okay, we add more columns
to the reflected table and create the constraint with an order different
from table attributes numbering.

TODO:
- test/dialect/postgresql/test_reflection.py::ReflectionTest::test_nullsnotdistinct
  fails, getting an unexpected 'postgresql_include' in 'dialect_options'
- add condition on server_version for the 'include' subquery

lib/sqlalchemy/dialects/postgresql/base.py
test/dialect/postgresql/test_reflection.py

index 9a85402652796917561423e12a7b8b1531eb5c81..11da8ea7921422a263bf43f00b25d65c4d1239c0 100644 (file)
@@ -4117,9 +4117,6 @@ class PGDialect(default.DefaultDialect):
                     pg_catalog.pg_constraint.c.conkey, 1
                 ).label("ord"),
                 pg_catalog.pg_description.c.description,
-                pg_catalog.pg_get_constraintdef(
-                    pg_catalog.pg_constraint.c.oid, True
-                ).label("condef"),
             )
             .outerjoin(
                 pg_catalog.pg_description,
@@ -4133,6 +4130,38 @@ class PGDialect(default.DefaultDialect):
             .subquery("con")
         )
 
+        include_sq = (
+            select(
+                pg_catalog.pg_attribute.c.attrelid,
+                pg_catalog.pg_index.c.indexrelid,
+                sql.func.array_agg(
+                    aggregate_order_by(
+                        pg_catalog.pg_attribute.c.attname,
+                        sql.func.array_position(
+                            pg_catalog.pg_index.c.indkey,
+                            pg_catalog.pg_attribute.c.attnum,
+                        ),
+                    )
+                ).label("include"),
+            )
+            .select_from(pg_catalog.pg_attribute)
+            .join(
+                pg_catalog.pg_index,
+                pg_catalog.pg_index.c.indrelid
+                == pg_catalog.pg_attribute.c.attrelid,
+            )
+            .where(
+                pg_catalog.pg_attribute.c.attnum
+                == sql.any_(pg_catalog.pg_index.c.indkey),
+                pg_catalog.pg_attribute.c.attnum
+                != pg_catalog.pg_index.c.indnkeyatts,
+            )
+            .group_by(
+                pg_catalog.pg_attribute.c.attrelid,
+                pg_catalog.pg_index.c.indexrelid,
+            )
+        ).subquery("include")
+
         attr_sq = (
             select(
                 con_sq.c.conrelid,
@@ -4140,8 +4169,8 @@ class PGDialect(default.DefaultDialect):
                 con_sq.c.conindid,
                 con_sq.c.description,
                 con_sq.c.ord,
-                con_sq.c.condef,
                 pg_catalog.pg_attribute.c.attname,
+                include_sq.c.include,
             )
             .select_from(pg_catalog.pg_attribute)
             .join(
@@ -4151,6 +4180,14 @@ class PGDialect(default.DefaultDialect):
                     pg_catalog.pg_attribute.c.attrelid == con_sq.c.conrelid,
                 ),
             )
+            .outerjoin(
+                include_sq,
+                sql.and_(
+                    include_sq.c.attrelid
+                    == pg_catalog.pg_attribute.c.attrelid,
+                    include_sq.c.indexrelid == con_sq.c.conindid,
+                ),
+            )
             .where(
                 # NOTE: restate the condition here, since pg15 otherwise
                 # seems to get confused on pscopg2 sometimes, doing
@@ -4175,9 +4212,9 @@ class PGDialect(default.DefaultDialect):
                 ).label("cols"),
                 attr_sq.c.conname,
                 sql.func.min(attr_sq.c.description).label("description"),
-                attr_sq.c.condef,
+                attr_sq.c.include,
             )
-            .group_by(attr_sq.c.conrelid, attr_sq.c.conname, attr_sq.c.condef)
+            .group_by(attr_sq.c.conrelid, attr_sq.c.conname, attr_sq.c.include)
             .order_by(attr_sq.c.conrelid, attr_sq.c.conname)
         )
 
@@ -4201,10 +4238,6 @@ class PGDialect(default.DefaultDialect):
             )
         return constraint_query
 
-    @util.memoized_property
-    def _include_regex_pattern(self):
-        return re.compile(r"INCLUDE \((.+)\)")
-
     def _reflect_constraint(
         self, connection, contype, schema, filter_names, scope, kind, **kw
     ):
@@ -4215,8 +4248,6 @@ class PGDialect(default.DefaultDialect):
         batches = list(table_oids)
         is_unique = contype == "u"
 
-        INCLUDE_REGEX = self._include_regex_pattern
-
         while batches:
             batch = batches[0:3000]
             batches[0:3000] = []
@@ -4227,23 +4258,20 @@ class PGDialect(default.DefaultDialect):
             )
 
             result_by_oid = defaultdict(list)
-            for oid, cols, constraint_name, comment, condef, extra in result:
+            for oid, cols, constraint_name, comment, include, extra in result:
                 result_by_oid[oid].append(
-                    (cols, constraint_name, comment, condef, extra)
+                    (cols, constraint_name, comment, include, extra)
                 )
 
             for oid, tablename in batch:
                 for_oid = result_by_oid.get(oid, ())
                 if for_oid:
-                    for cols, constraint, comment, condef, extra in for_oid:
+                    for cols, constraint, comment, include, extra in for_oid:
                         opts = {}
                         if is_unique:
                             opts["nullsnotdistinct"] = extra
-                        m = INCLUDE_REGEX.search(condef)
-                        if m:
-                            opts["include"] = [
-                                v.strip() for v in m.group(1).split(", ")
-                            ]
+                        if include:
+                            opts["include"] = include
                         if not opts:
                             opts = None
                         yield tablename, cols, constraint, comment, opts
index 4f609e160167bdd4713714114c33ad7db05fcc21..4fea8e8c183d190a771830f6eb84cb071e843f08 100644 (file)
@@ -2609,14 +2609,17 @@ class ReflectionTest(
             metadata,
             Column("id", Integer, nullable=False),
             Column("value", Integer, nullable=False),
-            Column("misc", String),
+            Column("foo", String),
+            Column("arr", ARRAY(Integer)),
+            Column("bar", SmallInteger),
         )
         metadata.create_all(connection)
         connection.exec_driver_sql(
             "ALTER TABLE foo ADD UNIQUE (id) INCLUDE (value)"
         )
         connection.exec_driver_sql(
-            "ALTER TABLE foo ADD PRIMARY KEY (id) INCLUDE (value, misc)"
+            "ALTER TABLE foo "
+            "ADD PRIMARY KEY (id) INCLUDE (arr, foo, bar, value)"
         )
 
         unq = inspect(connection).get_unique_constraints("foo")
@@ -2636,7 +2639,14 @@ class ReflectionTest(
         expected_pk = {
             "comment": None,
             "constrained_columns": ["id"],
-            "dialect_options": {"postgresql_include": ["value", "misc"]},
+            "dialect_options": {
+                "postgresql_include": [
+                    "arr",
+                    "foo",
+                    "bar",
+                    "value",
+                ]
+            },
             "name": "foo_pkey",
         }
         eq_(pk, expected_pk)