]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Improve postgresql reflection
authorFederico Caselli <cfederico87@gmail.com>
Fri, 10 Oct 2025 22:31:00 +0000 (00:31 +0200)
committerMichael Bayer <mike_mp@zzzcomputing.com>
Tue, 14 Oct 2025 13:27:18 +0000 (13:27 +0000)
Some improvements to the reflection of PostgreSQL to avoid
unnecessary queries when not needed.

Fixes: #12908
Change-Id: Ic3da50ee43670f26d3159f5ec9a235a9a1963b8a

lib/sqlalchemy/dialects/postgresql/base.py
test/dialect/postgresql/test_reflection.py

index b3932ea90f490ef782c1ea2fd9b3871fd6014feb..747dea6f3b890ff911ff8f739905eab6e0c8cd44 100644 (file)
@@ -3266,6 +3266,7 @@ class PGDialect(default.DefaultDialect):
     _supports_create_index_concurrently = True
     _supports_drop_index_concurrently = True
     _supports_jsonb_subscripting = True
+    _pg_am_btree_oid = -1
 
     def __init__(
         self,
@@ -3729,76 +3730,85 @@ class PGDialect(default.DefaultDialect):
         )
         if self.server_version_info >= (10,):
             # join lateral performs worse (~2x slower) than a scalar_subquery
-            identity = (
-                select(
-                    sql.func.json_build_object(
-                        "always",
-                        pg_catalog.pg_attribute.c.attidentity == "a",
-                        "start",
-                        pg_catalog.pg_sequence.c.seqstart,
-                        "increment",
-                        pg_catalog.pg_sequence.c.seqincrement,
-                        "minvalue",
-                        pg_catalog.pg_sequence.c.seqmin,
-                        "maxvalue",
-                        pg_catalog.pg_sequence.c.seqmax,
-                        "cache",
-                        pg_catalog.pg_sequence.c.seqcache,
-                        "cycle",
-                        pg_catalog.pg_sequence.c.seqcycle,
-                        type_=sqltypes.JSON(),
-                    )
-                )
-                .select_from(pg_catalog.pg_sequence)
-                .where(
-                    # attidentity != '' is required or it will reflect also
+            # also the subquery can be run only if the column is an identity
+            identity = sql.case(
+                (  # attidentity != '' is required or it will reflect also
                     # serial columns as identity.
                     pg_catalog.pg_attribute.c.attidentity != "",
-                    pg_catalog.pg_sequence.c.seqrelid
-                    == sql.cast(
-                        sql.cast(
-                            pg_catalog.pg_get_serial_sequence(
-                                sql.cast(
+                    select(
+                        sql.func.json_build_object(
+                            "always",
+                            pg_catalog.pg_attribute.c.attidentity == "a",
+                            "start",
+                            pg_catalog.pg_sequence.c.seqstart,
+                            "increment",
+                            pg_catalog.pg_sequence.c.seqincrement,
+                            "minvalue",
+                            pg_catalog.pg_sequence.c.seqmin,
+                            "maxvalue",
+                            pg_catalog.pg_sequence.c.seqmax,
+                            "cache",
+                            pg_catalog.pg_sequence.c.seqcache,
+                            "cycle",
+                            pg_catalog.pg_sequence.c.seqcycle,
+                            type_=sqltypes.JSON(),
+                        )
+                    )
+                    .select_from(pg_catalog.pg_sequence)
+                    .where(
+                        # not needed but pg seems to like it
+                        pg_catalog.pg_attribute.c.attidentity != "",
+                        pg_catalog.pg_sequence.c.seqrelid
+                        == sql.cast(
+                            sql.cast(
+                                pg_catalog.pg_get_serial_sequence(
                                     sql.cast(
-                                        pg_catalog.pg_attribute.c.attrelid,
-                                        REGCLASS,
+                                        sql.cast(
+                                            pg_catalog.pg_attribute.c.attrelid,
+                                            REGCLASS,
+                                        ),
+                                        TEXT,
                                     ),
-                                    TEXT,
+                                    pg_catalog.pg_attribute.c.attname,
                                 ),
-                                pg_catalog.pg_attribute.c.attname,
+                                REGCLASS,
                             ),
-                            REGCLASS,
+                            OID,
                         ),
-                        OID,
-                    ),
-                )
-                .correlate(pg_catalog.pg_attribute)
-                .scalar_subquery()
-                .label("identity_options")
-            )
+                    )
+                    .correlate(pg_catalog.pg_attribute)
+                    .scalar_subquery(),
+                ),
+                else_=sql.null(),
+            ).label("identity_options")
         else:
             identity = sql.null().label("identity_options")
 
-        # join lateral performs the same as scalar_subquery here
-        default = (
-            select(
-                pg_catalog.pg_get_expr(
-                    pg_catalog.pg_attrdef.c.adbin,
-                    pg_catalog.pg_attrdef.c.adrelid,
-                )
-            )
-            .select_from(pg_catalog.pg_attrdef)
-            .where(
-                pg_catalog.pg_attrdef.c.adrelid
-                == pg_catalog.pg_attribute.c.attrelid,
-                pg_catalog.pg_attrdef.c.adnum
-                == pg_catalog.pg_attribute.c.attnum,
+        # join lateral performs the same as scalar_subquery here, also
+        # the subquery can be run only if the column has a default
+        default = sql.case(
+            (
                 pg_catalog.pg_attribute.c.atthasdef,
-            )
-            .correlate(pg_catalog.pg_attribute)
-            .scalar_subquery()
-            .label("default")
-        )
+                select(
+                    pg_catalog.pg_get_expr(
+                        pg_catalog.pg_attrdef.c.adbin,
+                        pg_catalog.pg_attrdef.c.adrelid,
+                    )
+                )
+                .select_from(pg_catalog.pg_attrdef)
+                .where(
+                    # not needed but pg seems to like it
+                    pg_catalog.pg_attribute.c.atthasdef,
+                    pg_catalog.pg_attrdef.c.adrelid
+                    == pg_catalog.pg_attribute.c.attrelid,
+                    pg_catalog.pg_attrdef.c.adnum
+                    == pg_catalog.pg_attribute.c.attnum,
+                )
+                .correlate(pg_catalog.pg_attribute)
+                .scalar_subquery(),
+            ),
+            else_=sql.null(),
+        ).label("default")
 
         # get the name of the collate when it's different from the default one
         collate = sql.case(
@@ -3882,29 +3892,8 @@ class PGDialect(default.DefaultDialect):
         query = self._columns_query(schema, has_filter_names, scope, kind)
         rows = connection.execute(query, params).mappings()
 
-        # dictionary with (name, ) if default search path or (schema, name)
-        # as keys
-        domains = {
-            ((d["schema"], d["name"]) if not d["visible"] else (d["name"],)): d
-            for d in self._load_domains(
-                connection, schema="*", info_cache=kw.get("info_cache")
-            )
-        }
-
-        # dictionary with (name, ) if default search path or (schema, name)
-        # as keys
-        enums = dict(
-            (
-                ((rec["name"],), rec)
-                if rec["visible"]
-                else ((rec["schema"], rec["name"]), rec)
-            )
-            for rec in self._load_enums(
-                connection, schema="*", info_cache=kw.get("info_cache")
-            )
-        )
-
-        columns = self._get_columns_info(rows, domains, enums, schema)
+        named_type_loader = _NamedTypeLoader(self, connection, kw)
+        columns = self._get_columns_info(rows, named_type_loader, schema)
 
         return columns.items()
 
@@ -3915,8 +3904,7 @@ class PGDialect(default.DefaultDialect):
     def _reflect_type(
         self,
         format_type: Optional[str],
-        domains: Dict[str, ReflectedDomain],
-        enums: Dict[str, ReflectedEnum],
+        named_type_loader: _NamedTypeLoader,
         type_description: str,
         collation: Optional[str],
     ) -> sqltypes.TypeEngine[Any]:
@@ -4001,23 +3989,28 @@ class PGDialect(default.DefaultDialect):
         else:
             enum_or_domain_key = tuple(util.quoted_token_parser(attype))
 
-            if enum_or_domain_key in enums:
+            if (
+                schema_type is None
+                and enum_or_domain_key in named_type_loader.enums
+            ):
                 schema_type = ENUM
-                enum = enums[enum_or_domain_key]
+                enum = named_type_loader.enums[enum_or_domain_key]
 
                 kwargs["name"] = enum["name"]
 
                 if not enum["visible"]:
                     kwargs["schema"] = enum["schema"]
                 args = tuple(enum["labels"])
-            elif enum_or_domain_key in domains:
+            elif (
+                schema_type is None
+                and enum_or_domain_key in named_type_loader.domains
+            ):
                 schema_type = DOMAIN
-                domain = domains[enum_or_domain_key]
+                domain = named_type_loader.domains[enum_or_domain_key]
 
                 data_type = self._reflect_type(
                     domain["type"],
-                    domains,
-                    enums,
+                    named_type_loader,
                     type_description="DOMAIN '%s'" % domain["name"],
                     collation=domain["collation"],
                 )
@@ -4062,7 +4055,7 @@ class PGDialect(default.DefaultDialect):
 
         return data_type
 
-    def _get_columns_info(self, rows, domains, enums, schema):
+    def _get_columns_info(self, rows, named_type_loader, schema):
         columns = defaultdict(list)
         for row_dict in rows:
             # ensure that each table has an entry, even if it has no columns
@@ -4077,8 +4070,7 @@ class PGDialect(default.DefaultDialect):
 
             coltype = self._reflect_type(
                 row_dict["format_type"],
-                domains,
-                enums,
+                named_type_loader,
                 type_description="column '%s'" % row_dict["name"],
                 collation=collation,
             )
@@ -4737,7 +4729,10 @@ class PGDialect(default.DefaultDialect):
             connection, schema, filter_names, scope, kind, **kw
         )
 
-        pg_am_dict = self._load_pg_am_dict(connection, **kw)
+        pg_am_btree_oid = self._load_pg_am_btree_oid(connection)
+        # lazy load only if needed, the assumption is that most indexes
+        # will use btree so it may not be needed at all
+        pg_am_dict = None
         pg_opclass_dict = self._load_pg_opclass_notdefault_dict(
             connection, **kw
         )
@@ -4858,9 +4853,14 @@ class PGDialect(default.DefaultDialect):
                     # reflection info.  But we don't want an Index object
                     # to have a ``postgresql_using`` in it that is just the
                     # default, so for the moment leaving this out.
-                    amname = pg_am_dict[row["relam"]]
-                    if amname != "btree":
-                        dialect_options["postgresql_using"] = amname
+                    if row["relam"] != pg_am_btree_oid:
+                        if pg_am_dict is None:
+                            pg_am_dict = self._load_pg_am_dict(
+                                connection, **kw
+                            )
+                        dialect_options["postgresql_using"] = pg_am_dict[
+                            row["relam"]
+                        ]
                     if row["filter_definition"]:
                         dialect_options["postgresql_where"] = row[
                             "filter_definition"
@@ -5278,6 +5278,14 @@ class PGDialect(default.DefaultDialect):
         rows = connection.execute(self._pg_am_query)
         return dict(rows.all())
 
+    def _load_pg_am_btree_oid(self, connection):
+        # this oid is assumed to be stable
+        if self._pg_am_btree_oid == -1:
+            self._pg_am_btree_oid = connection.scalar(
+                self._pg_am_query.where(pg_catalog.pg_am.c.amname == "btree")
+            )
+        return self._pg_am_btree_oid
+
     @util.memoized_property
     def _pg_opclass_notdefault_query(self):
         return sql.select(
@@ -5298,3 +5306,48 @@ class PGDialect(default.DefaultDialect):
             "show standard_conforming_strings"
         ).scalar()
         self._backslash_escapes = std_string == "off"
+
+
+class _NamedTypeLoader:
+    """Helper class used for deferred loading of named types (enums, domains)
+    only when needed.
+    """
+
+    def __init__(
+        self, dialect: PGDialect, connection, kw: Dict[str, Any]
+    ) -> None:
+        self.dialect = dialect
+        self.connection = connection
+        self.kw = kw
+
+    @util.memoized_property
+    def enums(self) -> Dict[Tuple[str] | Tuple[str, str], ReflectedEnum]:
+        # dictionary with (name, ) if default search path or (schema, name)
+        # as keys
+        enums = dict(
+            (
+                ((rec["name"],), rec)
+                if rec["visible"]
+                else ((rec["schema"], rec["name"]), rec)
+            )
+            for rec in self.dialect._load_enums(
+                self.connection,
+                schema="*",
+                info_cache=self.kw.get("info_cache"),
+            )
+        )
+        return enums
+
+    @util.memoized_property
+    def domains(self) -> Dict[Tuple[str] | Tuple[str, str], ReflectedDomain]:
+        # dictionary with (name, ) if default search path or (schema, name)
+        # as keys
+        domains = {
+            ((d["schema"], d["name"]) if not d["visible"] else (d["name"],)): d
+            for d in self.dialect._load_domains(
+                self.connection,
+                schema="*",
+                info_cache=self.kw.get("info_cache"),
+            )
+        }
+        return domains
index 534c31a860dd40a638ee522139f43716a2fd5187..a8b933f74d13bdda4262e9dcb196d1bb893228ec 100644 (file)
@@ -2713,6 +2713,11 @@ class ReflectionTest(
 
 
 class CustomTypeReflectionTest(fixtures.TestBase):
+    class NTL:
+        def __init__(self, enums, domains):
+            self.enums = enums
+            self.domains = domains
+
     class CustomType:
         def __init__(self, arg1=None, arg2=None, collation=None):
             self.arg1 = arg1
@@ -2749,7 +2754,7 @@ class CustomTypeReflectionTest(fixtures.TestBase):
                 "identity_options": None,
             }
             column_info = dialect._get_columns_info(
-                [row_dict], {}, {}, "public"
+                [row_dict], self.NTL({}, {}), "public"
             )
             assert ("public", "tblname") in column_info
             column_info = column_info[("public", "tblname")]
@@ -2796,7 +2801,7 @@ class CustomTypeReflectionTest(fixtures.TestBase):
                 "identity_options": None,
             }
             column_info = dialect._get_columns_info(
-                [row_dict], {}, {}, "public"
+                [row_dict], self.NTL({}, {}), "public"
             )
             assert ("public", "tblname") in column_info
             column_info = column_info[("public", "tblname")]