]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Simplify postgresql index reflection query
authorFederico Caselli <cfederico87@gmail.com>
Wed, 28 May 2025 20:03:51 +0000 (22:03 +0200)
committerFederico Caselli <cfederico87@gmail.com>
Tue, 3 Jun 2025 20:15:31 +0000 (20:15 +0000)
Match on python side the values of `pg_am` and `pg_opclass`
to avoid joining them in the main query.
Since both queries have a limited size and are generally
stable their value can be cached using the inspector
cache.

Change-Id: I7074e88dc9ffb8f9c53c3cc12f1a7b72eec7fe8c

lib/sqlalchemy/dialects/postgresql/base.py

index ed45360d853a257af67a23e6a09c787ad8615704..aa45d898916deba5e4b8564cc764b3c436da5dba 100644 (file)
@@ -4553,8 +4553,10 @@ class PGDialect(default.DefaultDialect):
                     else_=pg_catalog.pg_attribute.c.attname.cast(TEXT),
                 ).label("element"),
                 (idx_sq.c.attnum == 0).label("is_expr"),
-                pg_catalog.pg_opclass.c.opcname,
-                pg_catalog.pg_opclass.c.opcdefault,
+                # since it's converted to array cast it to bigint (oid are
+                # "unsigned four-byte integer") to make it earier for
+                # dialects to iterpret
+                idx_sq.c.att_opclass.cast(BIGINT),
             )
             .select_from(idx_sq)
             .outerjoin(
@@ -4565,10 +4567,6 @@ class PGDialect(default.DefaultDialect):
                     pg_catalog.pg_attribute.c.attrelid == idx_sq.c.indrelid,
                 ),
             )
-            .outerjoin(
-                pg_catalog.pg_opclass,
-                pg_catalog.pg_opclass.c.oid == idx_sq.c.att_opclass,
-            )
             .where(idx_sq.c.indrelid.in_(bindparam("oids")))
             .subquery("idx_attr")
         )
@@ -4584,11 +4582,8 @@ class PGDialect(default.DefaultDialect):
                     aggregate_order_by(attr_sq.c.is_expr, attr_sq.c.ord)
                 ).label("elements_is_expr"),
                 sql.func.array_agg(
-                    aggregate_order_by(attr_sq.c.opcname, attr_sq.c.ord)
+                    aggregate_order_by(attr_sq.c.att_opclass, attr_sq.c.ord)
                 ).label("elements_opclass"),
-                sql.func.array_agg(
-                    aggregate_order_by(attr_sq.c.opcdefault, attr_sq.c.ord)
-                ).label("elements_opdefault"),
             )
             .group_by(attr_sq.c.indexrelid)
             .subquery("idx_cols")
@@ -4614,7 +4609,8 @@ class PGDialect(default.DefaultDialect):
                 ),
                 pg_catalog.pg_index.c.indoption,
                 pg_catalog.pg_class.c.reloptions,
-                pg_catalog.pg_am.c.amname,
+                # will get the value using the pg_am cached dict
+                pg_catalog.pg_class.c.relam,
                 # NOTE: pg_get_expr is very fast so this case has almost no
                 # performance impact
                 sql.case(
@@ -4631,8 +4627,8 @@ class PGDialect(default.DefaultDialect):
                 nulls_not_distinct,
                 cols_sq.c.elements,
                 cols_sq.c.elements_is_expr,
+                # will get the value using the pg_opclass cached dict
                 cols_sq.c.elements_opclass,
-                cols_sq.c.elements_opdefault,
             )
             .select_from(pg_catalog.pg_index)
             .where(
@@ -4643,10 +4639,6 @@ class PGDialect(default.DefaultDialect):
                 pg_catalog.pg_class,
                 pg_catalog.pg_index.c.indexrelid == pg_catalog.pg_class.c.oid,
             )
-            .join(
-                pg_catalog.pg_am,
-                pg_catalog.pg_class.c.relam == pg_catalog.pg_am.c.oid,
-            )
             .outerjoin(
                 cols_sq,
                 pg_catalog.pg_index.c.indexrelid == cols_sq.c.indexrelid,
@@ -4674,6 +4666,11 @@ class PGDialect(default.DefaultDialect):
             connection, schema, filter_names, scope, kind, **kw
         )
 
+        pg_am_dict = self._load_pg_am_dict(connection, **kw)
+        pg_opclass_dict = self._load_pg_opclass_notdefault_dict(
+            connection, **kw
+        )
+
         indexes = defaultdict(list)
         default = ReflectionDefaults.indexes
 
@@ -4706,7 +4703,6 @@ class PGDialect(default.DefaultDialect):
                     all_elements = row["elements"]
                     all_elements_is_expr = row["elements_is_expr"]
                     all_elements_opclass = row["elements_opclass"]
-                    all_elements_opdefault = row["elements_opdefault"]
                     indnkeyatts = row["indnkeyatts"]
                     # "The number of key columns in the index, not counting any
                     # included columns, which are merely stored and do not
@@ -4729,15 +4725,11 @@ class PGDialect(default.DefaultDialect):
                         idx_elements_opclass = all_elements_opclass[
                             :indnkeyatts
                         ]
-                        idx_elements_opdefault = all_elements_opdefault[
-                            :indnkeyatts
-                        ]
                     else:
                         idx_elements = all_elements
                         idx_elements_is_expr = all_elements_is_expr
                         inc_cols = []
                         idx_elements_opclass = all_elements_opclass
-                        idx_elements_opdefault = all_elements_opdefault
 
                     index = {"name": index_name, "unique": row["indisunique"]}
                     if any(idx_elements_is_expr):
@@ -4753,16 +4745,17 @@ class PGDialect(default.DefaultDialect):
 
                     dialect_options = {}
 
-                    if not all(idx_elements_opdefault):
-                        dialect_options["postgresql_ops"] = {
-                            name: opclass
-                            for name, opclass, is_default in zip(
-                                idx_elements,
-                                idx_elements_opclass,
-                                idx_elements_opdefault,
-                            )
-                            if not is_default
-                        }
+                    postgresql_ops = {}
+                    for name, opclass in zip(
+                        idx_elements, idx_elements_opclass
+                    ):
+                        # is not in the dict if the opclass is the default one
+                        opclass_name = pg_opclass_dict.get(opclass)
+                        if opclass_name is not None:
+                            postgresql_ops[name] = opclass_name
+
+                    if postgresql_ops:
+                        dialect_options["postgresql_ops"] = postgresql_ops
 
                     sorting = {}
                     for col_index, col_flags in enumerate(row["indoption"]):
@@ -4794,9 +4787,9 @@ class PGDialect(default.DefaultDialect):
                     # reflection info.  But we don't want an Index object
                     # to have a ``postgresql_using`` in it that is just the
                     # default, so for the moment leaving this out.
-                    amname = row["amname"]
+                    amname = pg_am_dict[row["relam"]]
                     if amname != "btree":
-                        dialect_options["postgresql_using"] = row["amname"]
+                        dialect_options["postgresql_using"] = amname
                     if row["filter_definition"]:
                         dialect_options["postgresql_where"] = row[
                             "filter_definition"
@@ -5205,6 +5198,28 @@ class PGDialect(default.DefaultDialect):
 
         return domains
 
+    @util.memoized_property
+    def _pg_am_query(self):
+        return sql.select(pg_catalog.pg_am.c.oid, pg_catalog.pg_am.c.amname)
+
+    @reflection.cache
+    def _load_pg_am_dict(self, connection, **kw) -> dict[int, str]:
+        rows = connection.execute(self._pg_am_query)
+        return dict(rows.all())
+
+    @util.memoized_property
+    def _pg_opclass_notdefault_query(self):
+        return sql.select(
+            pg_catalog.pg_opclass.c.oid, pg_catalog.pg_opclass.c.opcname
+        ).where(~pg_catalog.pg_opclass.c.opcdefault)
+
+    @reflection.cache
+    def _load_pg_opclass_notdefault_dict(
+        self, connection, **kw
+    ) -> dict[int, str]:
+        rows = connection.execute(self._pg_opclass_notdefault_query)
+        return dict(rows.all())
+
     def _set_backslash_escapes(self, connection):
         # this method is provided as an override hook for descendant
         # dialects (e.g. Redshift), so removing it may break them