From: Denis Laxalde Date: Fri, 27 Jun 2025 08:22:54 +0000 (+0200) Subject: Cache pg_collation query to avoid joins X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=4c2ba7349824fdacc26c35d36f8b95cf3a35c99a;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Cache pg_collation query to avoid joins Following a similar approach to 703a323329b420fefec2b8a0a5f5f87ea3dc49d0. The CAST(..., BIGINT) for array_agg()'s elements is needed for pg8000. --- diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 24311842ee..a024d7b468 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -3791,9 +3791,10 @@ class PGDialect(default.DefaultDialect): ).label("format_type"), default, pg_catalog.pg_attribute.c.attnotnull.label("not_null"), + pg_catalog.pg_attribute.c.atttypid.label("type"), + pg_catalog.pg_attribute.c.attcollation.label("collation"), pg_catalog.pg_class.c.relname.label("table_name"), pg_catalog.pg_description.c.description.label("comment"), - pg_catalog.pg_collation.c.collname.label("collation"), generated, identity, ) @@ -3819,19 +3820,6 @@ class PGDialect(default.DefaultDialect): == pg_catalog.pg_attribute.c.attnum, ), ) - .outerjoin( - pg_catalog.pg_type, - pg_catalog.pg_type.c.oid == pg_catalog.pg_attribute.c.atttypid, - ) - .outerjoin( - pg_catalog.pg_collation, - sql.and_( - pg_catalog.pg_attribute.c.attcollation - != pg_catalog.pg_type.c.typcollation, - pg_catalog.pg_collation.c.oid - == pg_catalog.pg_attribute.c.attcollation, - ), - ) .where(self._pg_class_relkind_condition(relkinds)) .order_by( pg_catalog.pg_class.c.relname, pg_catalog.pg_attribute.c.attnum @@ -3873,7 +3861,11 @@ class PGDialect(default.DefaultDialect): ) ) - columns = self._get_columns_info(rows, domains, enums, schema) + collations = self._load_collation_dict(connection) + + columns = self._get_columns_info( + rows, domains, enums, collations, schema + ) return columns.items() @@ -4030,7 +4022,7 @@ class PGDialect(default.DefaultDialect): return data_type - def _get_columns_info(self, rows, domains, enums, schema): + def _get_columns_info(self, rows, domains, enums, collations, schema): columns = defaultdict(list) for row_dict in rows: # ensure that each table has an entry, even if it has no columns @@ -4041,12 +4033,27 @@ class PGDialect(default.DefaultDialect): continue table_cols = columns[(schema, row_dict["table_name"])] + try: + collation_name, default_collation_for_types = collations[ + row_dict["collation"] + ] + except KeyError: + collation = None + else: + # Only export the collation if distinct from type's default. + collation = ( + collation_name + if default_collation_for_types is not None + and row_dict["type"] not in default_collation_for_types + else None + ) + coltype = self._reflect_type( row_dict["format_type"], domains, enums, type_description="column '%s'" % row_dict["name"], - collation=row_dict["collation"], + collation=collation, ) default = row_dict["default"] @@ -5257,6 +5264,36 @@ class PGDialect(default.DefaultDialect): rows = connection.execute(self._pg_opclass_notdefault_query) return dict(rows.all()) + @util.memoized_property + def _pg_collation_query(self): + """Query collations and types using them as default.""" + return ( + sql.select( + pg_catalog.pg_collation.c.oid, + pg_catalog.pg_collation.c.collname, + # cast to bigint (oid are "unsigned four-byte integer") to make + # it easier for dialects to interpret + sql.func.array_agg( + pg_catalog.pg_type.c.oid.cast(BIGINT) + ).filter(pg_catalog.pg_type.c.oid.is_not(None)), + ) + .select_from(pg_catalog.pg_collation) + .outerjoin( + pg_catalog.pg_type, + pg_catalog.pg_type.c.typcollation + == pg_catalog.pg_collation.c.oid, + ) + ).group_by( + pg_catalog.pg_collation.c.oid, pg_catalog.pg_collation.c.collname + ) + + @reflection.cache + def _load_collation_dict( + self, connection + ) -> dict[int, Tuple[str, Optional[list[int]]]]: + rows = connection.execute(self._pg_collation_query) + return {oid: (name, types) for oid, name, types in rows} + def _set_backslash_escapes(self, connection): # this method is provided as an override hook for descendant # dialects (e.g. Redshift), so removing it may break them diff --git a/test/dialect/postgresql/test_reflection.py b/test/dialect/postgresql/test_reflection.py index 74c925ef90..653bca515f 100644 --- a/test/dialect/postgresql/test_reflection.py +++ b/test/dialect/postgresql/test_reflection.py @@ -2731,13 +2731,14 @@ class CustomTypeReflectionTest(fixtures.TestBase): "format_type": sch, "default": None, "not_null": False, + "type": 123, + "collation": 456, "comment": None, - "collation": None, "generated": "", "identity_options": None, } column_info = dialect._get_columns_info( - [row_dict], {}, {}, "public" + [row_dict], {}, {}, {456: ("mycoll", [123])}, "public" ) assert ("public", "tblname") in column_info column_info = column_info[("public", "tblname")] @@ -2774,13 +2775,14 @@ class CustomTypeReflectionTest(fixtures.TestBase): "format_type": None, "default": None, "not_null": False, + "type": 987, + "collation": 0, "comment": None, - "collation": None, "generated": "", "identity_options": None, } column_info = dialect._get_columns_info( - [row_dict], {}, {}, "public" + [row_dict], {}, {}, {654: ("somecollation", [])}, "public" ) assert ("public", "tblname") in column_info column_info = column_info[("public", "tblname")]