From 5b49b59ab7b131ee3a0804222910190e4014263c Mon Sep 17 00:00:00 2001 From: Federico Caselli Date: Sat, 11 Oct 2025 00:31:00 +0200 Subject: [PATCH] Improve postgresql reflection Some improvements to the reflection of PostgreSQL to avoid unnecessary queries when not needed. Fixes: #12908 Change-Id: Ic3da50ee43670f26d3159f5ec9a235a9a1963b8a --- lib/sqlalchemy/dialects/postgresql/base.py | 247 +++++++++++++-------- test/dialect/postgresql/test_reflection.py | 9 +- 2 files changed, 157 insertions(+), 99 deletions(-) diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index b3932ea90f..747dea6f3b 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -3266,6 +3266,7 @@ class PGDialect(default.DefaultDialect): _supports_create_index_concurrently = True _supports_drop_index_concurrently = True _supports_jsonb_subscripting = True + _pg_am_btree_oid = -1 def __init__( self, @@ -3729,76 +3730,85 @@ class PGDialect(default.DefaultDialect): ) if self.server_version_info >= (10,): # join lateral performs worse (~2x slower) than a scalar_subquery - identity = ( - select( - sql.func.json_build_object( - "always", - pg_catalog.pg_attribute.c.attidentity == "a", - "start", - pg_catalog.pg_sequence.c.seqstart, - "increment", - pg_catalog.pg_sequence.c.seqincrement, - "minvalue", - pg_catalog.pg_sequence.c.seqmin, - "maxvalue", - pg_catalog.pg_sequence.c.seqmax, - "cache", - pg_catalog.pg_sequence.c.seqcache, - "cycle", - pg_catalog.pg_sequence.c.seqcycle, - type_=sqltypes.JSON(), - ) - ) - .select_from(pg_catalog.pg_sequence) - .where( - # attidentity != '' is required or it will reflect also + # also the subquery can be run only if the column is an identity + identity = sql.case( + ( # attidentity != '' is required or it will reflect also # serial columns as identity. pg_catalog.pg_attribute.c.attidentity != "", - pg_catalog.pg_sequence.c.seqrelid - == sql.cast( - sql.cast( - pg_catalog.pg_get_serial_sequence( - sql.cast( + select( + sql.func.json_build_object( + "always", + pg_catalog.pg_attribute.c.attidentity == "a", + "start", + pg_catalog.pg_sequence.c.seqstart, + "increment", + pg_catalog.pg_sequence.c.seqincrement, + "minvalue", + pg_catalog.pg_sequence.c.seqmin, + "maxvalue", + pg_catalog.pg_sequence.c.seqmax, + "cache", + pg_catalog.pg_sequence.c.seqcache, + "cycle", + pg_catalog.pg_sequence.c.seqcycle, + type_=sqltypes.JSON(), + ) + ) + .select_from(pg_catalog.pg_sequence) + .where( + # not needed but pg seems to like it + pg_catalog.pg_attribute.c.attidentity != "", + pg_catalog.pg_sequence.c.seqrelid + == sql.cast( + sql.cast( + pg_catalog.pg_get_serial_sequence( sql.cast( - pg_catalog.pg_attribute.c.attrelid, - REGCLASS, + sql.cast( + pg_catalog.pg_attribute.c.attrelid, + REGCLASS, + ), + TEXT, ), - TEXT, + pg_catalog.pg_attribute.c.attname, ), - pg_catalog.pg_attribute.c.attname, + REGCLASS, ), - REGCLASS, + OID, ), - OID, - ), - ) - .correlate(pg_catalog.pg_attribute) - .scalar_subquery() - .label("identity_options") - ) + ) + .correlate(pg_catalog.pg_attribute) + .scalar_subquery(), + ), + else_=sql.null(), + ).label("identity_options") else: identity = sql.null().label("identity_options") - # join lateral performs the same as scalar_subquery here - default = ( - select( - pg_catalog.pg_get_expr( - pg_catalog.pg_attrdef.c.adbin, - pg_catalog.pg_attrdef.c.adrelid, - ) - ) - .select_from(pg_catalog.pg_attrdef) - .where( - pg_catalog.pg_attrdef.c.adrelid - == pg_catalog.pg_attribute.c.attrelid, - pg_catalog.pg_attrdef.c.adnum - == pg_catalog.pg_attribute.c.attnum, + # join lateral performs the same as scalar_subquery here, also + # the subquery can be run only if the column has a default + default = sql.case( + ( pg_catalog.pg_attribute.c.atthasdef, - ) - .correlate(pg_catalog.pg_attribute) - .scalar_subquery() - .label("default") - ) + select( + pg_catalog.pg_get_expr( + pg_catalog.pg_attrdef.c.adbin, + pg_catalog.pg_attrdef.c.adrelid, + ) + ) + .select_from(pg_catalog.pg_attrdef) + .where( + # not needed but pg seems to like it + pg_catalog.pg_attribute.c.atthasdef, + pg_catalog.pg_attrdef.c.adrelid + == pg_catalog.pg_attribute.c.attrelid, + pg_catalog.pg_attrdef.c.adnum + == pg_catalog.pg_attribute.c.attnum, + ) + .correlate(pg_catalog.pg_attribute) + .scalar_subquery(), + ), + else_=sql.null(), + ).label("default") # get the name of the collate when it's different from the default one collate = sql.case( @@ -3882,29 +3892,8 @@ class PGDialect(default.DefaultDialect): query = self._columns_query(schema, has_filter_names, scope, kind) rows = connection.execute(query, params).mappings() - # dictionary with (name, ) if default search path or (schema, name) - # as keys - domains = { - ((d["schema"], d["name"]) if not d["visible"] else (d["name"],)): d - for d in self._load_domains( - connection, schema="*", info_cache=kw.get("info_cache") - ) - } - - # dictionary with (name, ) if default search path or (schema, name) - # as keys - enums = dict( - ( - ((rec["name"],), rec) - if rec["visible"] - else ((rec["schema"], rec["name"]), rec) - ) - for rec in self._load_enums( - connection, schema="*", info_cache=kw.get("info_cache") - ) - ) - - columns = self._get_columns_info(rows, domains, enums, schema) + named_type_loader = _NamedTypeLoader(self, connection, kw) + columns = self._get_columns_info(rows, named_type_loader, schema) return columns.items() @@ -3915,8 +3904,7 @@ class PGDialect(default.DefaultDialect): def _reflect_type( self, format_type: Optional[str], - domains: Dict[str, ReflectedDomain], - enums: Dict[str, ReflectedEnum], + named_type_loader: _NamedTypeLoader, type_description: str, collation: Optional[str], ) -> sqltypes.TypeEngine[Any]: @@ -4001,23 +3989,28 @@ class PGDialect(default.DefaultDialect): else: enum_or_domain_key = tuple(util.quoted_token_parser(attype)) - if enum_or_domain_key in enums: + if ( + schema_type is None + and enum_or_domain_key in named_type_loader.enums + ): schema_type = ENUM - enum = enums[enum_or_domain_key] + enum = named_type_loader.enums[enum_or_domain_key] kwargs["name"] = enum["name"] if not enum["visible"]: kwargs["schema"] = enum["schema"] args = tuple(enum["labels"]) - elif enum_or_domain_key in domains: + elif ( + schema_type is None + and enum_or_domain_key in named_type_loader.domains + ): schema_type = DOMAIN - domain = domains[enum_or_domain_key] + domain = named_type_loader.domains[enum_or_domain_key] data_type = self._reflect_type( domain["type"], - domains, - enums, + named_type_loader, type_description="DOMAIN '%s'" % domain["name"], collation=domain["collation"], ) @@ -4062,7 +4055,7 @@ class PGDialect(default.DefaultDialect): return data_type - def _get_columns_info(self, rows, domains, enums, schema): + def _get_columns_info(self, rows, named_type_loader, schema): columns = defaultdict(list) for row_dict in rows: # ensure that each table has an entry, even if it has no columns @@ -4077,8 +4070,7 @@ class PGDialect(default.DefaultDialect): coltype = self._reflect_type( row_dict["format_type"], - domains, - enums, + named_type_loader, type_description="column '%s'" % row_dict["name"], collation=collation, ) @@ -4737,7 +4729,10 @@ class PGDialect(default.DefaultDialect): connection, schema, filter_names, scope, kind, **kw ) - pg_am_dict = self._load_pg_am_dict(connection, **kw) + pg_am_btree_oid = self._load_pg_am_btree_oid(connection) + # lazy load only if needed, the assumption is that most indexes + # will use btree so it may not be needed at all + pg_am_dict = None pg_opclass_dict = self._load_pg_opclass_notdefault_dict( connection, **kw ) @@ -4858,9 +4853,14 @@ class PGDialect(default.DefaultDialect): # reflection info. But we don't want an Index object # to have a ``postgresql_using`` in it that is just the # default, so for the moment leaving this out. - amname = pg_am_dict[row["relam"]] - if amname != "btree": - dialect_options["postgresql_using"] = amname + if row["relam"] != pg_am_btree_oid: + if pg_am_dict is None: + pg_am_dict = self._load_pg_am_dict( + connection, **kw + ) + dialect_options["postgresql_using"] = pg_am_dict[ + row["relam"] + ] if row["filter_definition"]: dialect_options["postgresql_where"] = row[ "filter_definition" @@ -5278,6 +5278,14 @@ class PGDialect(default.DefaultDialect): rows = connection.execute(self._pg_am_query) return dict(rows.all()) + def _load_pg_am_btree_oid(self, connection): + # this oid is assumed to be stable + if self._pg_am_btree_oid == -1: + self._pg_am_btree_oid = connection.scalar( + self._pg_am_query.where(pg_catalog.pg_am.c.amname == "btree") + ) + return self._pg_am_btree_oid + @util.memoized_property def _pg_opclass_notdefault_query(self): return sql.select( @@ -5298,3 +5306,48 @@ class PGDialect(default.DefaultDialect): "show standard_conforming_strings" ).scalar() self._backslash_escapes = std_string == "off" + + +class _NamedTypeLoader: + """Helper class used for deferred loading of named types (enums, domains) + only when needed. + """ + + def __init__( + self, dialect: PGDialect, connection, kw: Dict[str, Any] + ) -> None: + self.dialect = dialect + self.connection = connection + self.kw = kw + + @util.memoized_property + def enums(self) -> Dict[Tuple[str] | Tuple[str, str], ReflectedEnum]: + # dictionary with (name, ) if default search path or (schema, name) + # as keys + enums = dict( + ( + ((rec["name"],), rec) + if rec["visible"] + else ((rec["schema"], rec["name"]), rec) + ) + for rec in self.dialect._load_enums( + self.connection, + schema="*", + info_cache=self.kw.get("info_cache"), + ) + ) + return enums + + @util.memoized_property + def domains(self) -> Dict[Tuple[str] | Tuple[str, str], ReflectedDomain]: + # dictionary with (name, ) if default search path or (schema, name) + # as keys + domains = { + ((d["schema"], d["name"]) if not d["visible"] else (d["name"],)): d + for d in self.dialect._load_domains( + self.connection, + schema="*", + info_cache=self.kw.get("info_cache"), + ) + } + return domains diff --git a/test/dialect/postgresql/test_reflection.py b/test/dialect/postgresql/test_reflection.py index 534c31a860..a8b933f74d 100644 --- a/test/dialect/postgresql/test_reflection.py +++ b/test/dialect/postgresql/test_reflection.py @@ -2713,6 +2713,11 @@ class ReflectionTest( class CustomTypeReflectionTest(fixtures.TestBase): + class NTL: + def __init__(self, enums, domains): + self.enums = enums + self.domains = domains + class CustomType: def __init__(self, arg1=None, arg2=None, collation=None): self.arg1 = arg1 @@ -2749,7 +2754,7 @@ class CustomTypeReflectionTest(fixtures.TestBase): "identity_options": None, } column_info = dialect._get_columns_info( - [row_dict], {}, {}, "public" + [row_dict], self.NTL({}, {}), "public" ) assert ("public", "tblname") in column_info column_info = column_info[("public", "tblname")] @@ -2796,7 +2801,7 @@ class CustomTypeReflectionTest(fixtures.TestBase): "identity_options": None, } column_info = dialect._get_columns_info( - [row_dict], {}, {}, "public" + [row_dict], self.NTL({}, {}), "public" ) assert ("public", "tblname") in column_info column_info = column_info[("public", "tblname")] -- 2.47.3