]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Reflect index's column operator class on PostgreSQL
authorDenis Laxalde <denis@laxalde.org>
Wed, 28 May 2025 19:37:36 +0000 (15:37 -0400)
committersqla-tester <sqla-tester@sqlalchemy.org>
Wed, 28 May 2025 19:37:36 +0000 (15:37 -0400)
Fill the `postgresql_ops` key of PostgreSQL's `dialect_options` returned by get_multi_indexes() with a mapping from column names to the operator class, if it's not the default for respective data type.

As we need to join on ``pg_catalog.pg_opclass``, the table definition is added to ``postgresql.pg_catalog``.

Fixes #8664.

Closes: #12504
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/12504
Pull-request-sha: 8fdf93e1b27c371f52990d5fda8b2fdf79ec23eb

Change-Id: I8789c1e9d15f8cc9a7205f492ec730570f19bbcc

doc/build/changelog/unreleased_20/8664.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/dialects/postgresql/pg_catalog.py
test/dialect/postgresql/test_reflection.py

diff --git a/doc/build/changelog/unreleased_20/8664.rst b/doc/build/changelog/unreleased_20/8664.rst
new file mode 100644 (file)
index 0000000..8a17e43
--- /dev/null
@@ -0,0 +1,12 @@
+.. change::
+    :tags: usecase, postgresql
+    :tickets: 8664
+
+    Added ``postgresql_ops`` key to the ``dialect_options`` entry in reflected
+    dictionary. This maps names of columns used in the index to respective
+    operator class, if distinct from the default one for column's data type.
+    Pull request courtesy Denis Laxalde.
+
+    .. seealso::
+
+        :ref:`postgresql_operator_classes`
index 805b8d37201eca25388aff20b8efa69765ababed..ed45360d853a257af67a23e6a09c787ad8615704 100644 (file)
@@ -4519,6 +4519,9 @@ class PGDialect(default.DefaultDialect):
                 pg_catalog.pg_index.c.indexrelid,
                 pg_catalog.pg_index.c.indrelid,
                 sql.func.unnest(pg_catalog.pg_index.c.indkey).label("attnum"),
+                sql.func.unnest(pg_catalog.pg_index.c.indclass).label(
+                    "att_opclass"
+                ),
                 sql.func.generate_subscripts(
                     pg_catalog.pg_index.c.indkey, 1
                 ).label("ord"),
@@ -4550,6 +4553,8 @@ class PGDialect(default.DefaultDialect):
                     else_=pg_catalog.pg_attribute.c.attname.cast(TEXT),
                 ).label("element"),
                 (idx_sq.c.attnum == 0).label("is_expr"),
+                pg_catalog.pg_opclass.c.opcname,
+                pg_catalog.pg_opclass.c.opcdefault,
             )
             .select_from(idx_sq)
             .outerjoin(
@@ -4560,6 +4565,10 @@ class PGDialect(default.DefaultDialect):
                     pg_catalog.pg_attribute.c.attrelid == idx_sq.c.indrelid,
                 ),
             )
+            .outerjoin(
+                pg_catalog.pg_opclass,
+                pg_catalog.pg_opclass.c.oid == idx_sq.c.att_opclass,
+            )
             .where(idx_sq.c.indrelid.in_(bindparam("oids")))
             .subquery("idx_attr")
         )
@@ -4574,6 +4583,12 @@ class PGDialect(default.DefaultDialect):
                 sql.func.array_agg(
                     aggregate_order_by(attr_sq.c.is_expr, attr_sq.c.ord)
                 ).label("elements_is_expr"),
+                sql.func.array_agg(
+                    aggregate_order_by(attr_sq.c.opcname, attr_sq.c.ord)
+                ).label("elements_opclass"),
+                sql.func.array_agg(
+                    aggregate_order_by(attr_sq.c.opcdefault, attr_sq.c.ord)
+                ).label("elements_opdefault"),
             )
             .group_by(attr_sq.c.indexrelid)
             .subquery("idx_cols")
@@ -4616,6 +4631,8 @@ class PGDialect(default.DefaultDialect):
                 nulls_not_distinct,
                 cols_sq.c.elements,
                 cols_sq.c.elements_is_expr,
+                cols_sq.c.elements_opclass,
+                cols_sq.c.elements_opdefault,
             )
             .select_from(pg_catalog.pg_index)
             .where(
@@ -4688,6 +4705,8 @@ class PGDialect(default.DefaultDialect):
 
                     all_elements = row["elements"]
                     all_elements_is_expr = row["elements_is_expr"]
+                    all_elements_opclass = row["elements_opclass"]
+                    all_elements_opdefault = row["elements_opdefault"]
                     indnkeyatts = row["indnkeyatts"]
                     # "The number of key columns in the index, not counting any
                     # included columns, which are merely stored and do not
@@ -4707,10 +4726,18 @@ class PGDialect(default.DefaultDialect):
                             not is_expr
                             for is_expr in all_elements_is_expr[indnkeyatts:]
                         )
+                        idx_elements_opclass = all_elements_opclass[
+                            :indnkeyatts
+                        ]
+                        idx_elements_opdefault = all_elements_opdefault[
+                            :indnkeyatts
+                        ]
                     else:
                         idx_elements = all_elements
                         idx_elements_is_expr = all_elements_is_expr
                         inc_cols = []
+                        idx_elements_opclass = all_elements_opclass
+                        idx_elements_opdefault = all_elements_opdefault
 
                     index = {"name": index_name, "unique": row["indisunique"]}
                     if any(idx_elements_is_expr):
@@ -4724,6 +4751,19 @@ class PGDialect(default.DefaultDialect):
                     else:
                         index["column_names"] = idx_elements
 
+                    dialect_options = {}
+
+                    if not all(idx_elements_opdefault):
+                        dialect_options["postgresql_ops"] = {
+                            name: opclass
+                            for name, opclass, is_default in zip(
+                                idx_elements,
+                                idx_elements_opclass,
+                                idx_elements_opdefault,
+                            )
+                            if not is_default
+                        }
+
                     sorting = {}
                     for col_index, col_flags in enumerate(row["indoption"]):
                         col_sorting = ()
@@ -4743,7 +4783,6 @@ class PGDialect(default.DefaultDialect):
                     if row["has_constraint"]:
                         index["duplicates_constraint"] = index_name
 
-                    dialect_options = {}
                     if row["reloptions"]:
                         dialect_options["postgresql_with"] = dict(
                             [
index 4841056cf9d1660a47c074906c6a71a7d5afecde..9625ccf3347e735c19bb072426a7bfdea4121590 100644 (file)
@@ -310,3 +310,17 @@ pg_collation = Table(
     Column("collicurules", Text, info={"server_version": (16,)}),
     Column("collversion", Text, info={"server_version": (10,)}),
 )
+
+pg_opclass = Table(
+    "pg_opclass",
+    pg_catalog_meta,
+    Column("oid", OID, info={"server_version": (9, 3)}),
+    Column("opcmethod", NAME),
+    Column("opcname", NAME),
+    Column("opsnamespace", OID),
+    Column("opsowner", OID),
+    Column("opcfamily", OID),
+    Column("opcintype", OID),
+    Column("opcdefault", Boolean),
+    Column("opckeytype", OID),
+)
index f8030691744f9a25bfcb2b620970b39e7943d428..5dd8e00070d6a2397ae6766176816ccda33323ee 100644 (file)
@@ -27,6 +27,7 @@ from sqlalchemy.dialects.postgresql import ARRAY
 from sqlalchemy.dialects.postgresql import base as postgresql
 from sqlalchemy.dialects.postgresql import DOMAIN
 from sqlalchemy.dialects.postgresql import ExcludeConstraint
+from sqlalchemy.dialects.postgresql import INET
 from sqlalchemy.dialects.postgresql import INTEGER
 from sqlalchemy.dialects.postgresql import INTERVAL
 from sqlalchemy.dialects.postgresql import pg_catalog
@@ -1724,6 +1725,54 @@ class ReflectionTest(
             "gin",
         )
 
+    def test_index_reflection_with_operator_class(self, metadata, connection):
+        """reflect indexes with operator class on columns"""
+
+        Table(
+            "t",
+            metadata,
+            Column("id", Integer, nullable=False),
+            Column("name", String),
+            Column("alias", String),
+            Column("addr1", INET),
+            Column("addr2", INET),
+        )
+        metadata.create_all(connection)
+
+        # 'name' and 'addr1' use a non-default operator, 'addr2' uses the
+        # default one, and 'alias' uses no operator.
+        connection.exec_driver_sql(
+            "CREATE INDEX ix_t ON t USING btree"
+            " (name text_pattern_ops, alias, addr1 cidr_ops, addr2 inet_ops)"
+        )
+
+        ind = inspect(connection).get_indexes("t", None)
+        expected = [
+            {
+                "unique": False,
+                "column_names": ["name", "alias", "addr1", "addr2"],
+                "name": "ix_t",
+                "dialect_options": {
+                    "postgresql_ops": {
+                        "addr1": "cidr_ops",
+                        "name": "text_pattern_ops",
+                    },
+                },
+            }
+        ]
+        if connection.dialect.server_version_info >= (11, 0):
+            expected[0]["include_columns"] = []
+            expected[0]["dialect_options"]["postgresql_include"] = []
+        eq_(ind, expected)
+
+        m = MetaData()
+        t1 = Table("t", m, autoload_with=connection)
+        r_ind = list(t1.indexes)[0]
+        eq_(
+            r_ind.dialect_options["postgresql"]["ops"],
+            {"name": "text_pattern_ops", "addr1": "cidr_ops"},
+        )
+
     @testing.skip_if("postgresql < 15.0", "nullsnotdistinct not supported")
     def test_nullsnotdistinct(self, metadata, connection):
         Table(