]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Cache pg_collation query to avoid joins
authorDenis Laxalde <denis@laxalde.org>
Fri, 27 Jun 2025 08:22:54 +0000 (10:22 +0200)
committerDenis Laxalde <denis@laxalde.org>
Tue, 15 Jul 2025 13:15:08 +0000 (15:15 +0200)
Following a similar approach to 703a323329b420fefec2b8a0a5f5f87ea3dc49d0.

The CAST(..., BIGINT) for array_agg()'s elements is needed for pg8000.

lib/sqlalchemy/dialects/postgresql/base.py
test/dialect/postgresql/test_reflection.py

index 24311842ee0973c110b20e1e2e4359b0587d6b39..a024d7b46881b2ffefd1bc9eb49a3465382e5da7 100644 (file)
@@ -3791,9 +3791,10 @@ class PGDialect(default.DefaultDialect):
                 ).label("format_type"),
                 default,
                 pg_catalog.pg_attribute.c.attnotnull.label("not_null"),
+                pg_catalog.pg_attribute.c.atttypid.label("type"),
+                pg_catalog.pg_attribute.c.attcollation.label("collation"),
                 pg_catalog.pg_class.c.relname.label("table_name"),
                 pg_catalog.pg_description.c.description.label("comment"),
-                pg_catalog.pg_collation.c.collname.label("collation"),
                 generated,
                 identity,
             )
@@ -3819,19 +3820,6 @@ class PGDialect(default.DefaultDialect):
                     == pg_catalog.pg_attribute.c.attnum,
                 ),
             )
-            .outerjoin(
-                pg_catalog.pg_type,
-                pg_catalog.pg_type.c.oid == pg_catalog.pg_attribute.c.atttypid,
-            )
-            .outerjoin(
-                pg_catalog.pg_collation,
-                sql.and_(
-                    pg_catalog.pg_attribute.c.attcollation
-                    != pg_catalog.pg_type.c.typcollation,
-                    pg_catalog.pg_collation.c.oid
-                    == pg_catalog.pg_attribute.c.attcollation,
-                ),
-            )
             .where(self._pg_class_relkind_condition(relkinds))
             .order_by(
                 pg_catalog.pg_class.c.relname, pg_catalog.pg_attribute.c.attnum
@@ -3873,7 +3861,11 @@ class PGDialect(default.DefaultDialect):
             )
         )
 
-        columns = self._get_columns_info(rows, domains, enums, schema)
+        collations = self._load_collation_dict(connection)
+
+        columns = self._get_columns_info(
+            rows, domains, enums, collations, schema
+        )
 
         return columns.items()
 
@@ -4030,7 +4022,7 @@ class PGDialect(default.DefaultDialect):
 
         return data_type
 
-    def _get_columns_info(self, rows, domains, enums, schema):
+    def _get_columns_info(self, rows, domains, enums, collations, schema):
         columns = defaultdict(list)
         for row_dict in rows:
             # ensure that each table has an entry, even if it has no columns
@@ -4041,12 +4033,27 @@ class PGDialect(default.DefaultDialect):
                 continue
             table_cols = columns[(schema, row_dict["table_name"])]
 
+            try:
+                collation_name, default_collation_for_types = collations[
+                    row_dict["collation"]
+                ]
+            except KeyError:
+                collation = None
+            else:
+                # Only export the collation if distinct from type's default.
+                collation = (
+                    collation_name
+                    if default_collation_for_types is not None
+                    and row_dict["type"] not in default_collation_for_types
+                    else None
+                )
+
             coltype = self._reflect_type(
                 row_dict["format_type"],
                 domains,
                 enums,
                 type_description="column '%s'" % row_dict["name"],
-                collation=row_dict["collation"],
+                collation=collation,
             )
 
             default = row_dict["default"]
@@ -5257,6 +5264,36 @@ class PGDialect(default.DefaultDialect):
         rows = connection.execute(self._pg_opclass_notdefault_query)
         return dict(rows.all())
 
+    @util.memoized_property
+    def _pg_collation_query(self):
+        """Query collations and types using them as default."""
+        return (
+            sql.select(
+                pg_catalog.pg_collation.c.oid,
+                pg_catalog.pg_collation.c.collname,
+                # cast to bigint (oid are "unsigned four-byte integer") to make
+                # it easier for dialects to interpret
+                sql.func.array_agg(
+                    pg_catalog.pg_type.c.oid.cast(BIGINT)
+                ).filter(pg_catalog.pg_type.c.oid.is_not(None)),
+            )
+            .select_from(pg_catalog.pg_collation)
+            .outerjoin(
+                pg_catalog.pg_type,
+                pg_catalog.pg_type.c.typcollation
+                == pg_catalog.pg_collation.c.oid,
+            )
+        ).group_by(
+            pg_catalog.pg_collation.c.oid, pg_catalog.pg_collation.c.collname
+        )
+
+    @reflection.cache
+    def _load_collation_dict(
+        self, connection
+    ) -> dict[int, Tuple[str, Optional[list[int]]]]:
+        rows = connection.execute(self._pg_collation_query)
+        return {oid: (name, types) for oid, name, types in rows}
+
     def _set_backslash_escapes(self, connection):
         # this method is provided as an override hook for descendant
         # dialects (e.g. Redshift), so removing it may break them
index 74c925ef901ab049714bceb330f09993e9319e80..653bca515fda1ec9ffc624d2cb1da7b6337304c8 100644 (file)
@@ -2731,13 +2731,14 @@ class CustomTypeReflectionTest(fixtures.TestBase):
                 "format_type": sch,
                 "default": None,
                 "not_null": False,
+                "type": 123,
+                "collation": 456,
                 "comment": None,
-                "collation": None,
                 "generated": "",
                 "identity_options": None,
             }
             column_info = dialect._get_columns_info(
-                [row_dict], {}, {}, "public"
+                [row_dict], {}, {}, {456: ("mycoll", [123])}, "public"
             )
             assert ("public", "tblname") in column_info
             column_info = column_info[("public", "tblname")]
@@ -2774,13 +2775,14 @@ class CustomTypeReflectionTest(fixtures.TestBase):
                 "format_type": None,
                 "default": None,
                 "not_null": False,
+                "type": 987,
+                "collation": 0,
                 "comment": None,
-                "collation": None,
                 "generated": "",
                 "identity_options": None,
             }
             column_info = dialect._get_columns_info(
-                [row_dict], {}, {}, "public"
+                [row_dict], {}, {}, {654: ("somecollation", [])}, "public"
             )
             assert ("public", "tblname") in column_info
             column_info = column_info[("public", "tblname")]