]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Implement get_table_options, get_multi_table_options in postgres dialect
authorThomas Stephenson <ovangle@gmail.com>
Mon, 12 May 2025 02:12:21 +0000 (12:12 +1000)
committerThomas Stephenson <ovangle@gmail.com>
Mon, 12 May 2025 03:16:57 +0000 (13:16 +1000)
Reflect dialect specific options passed to `Table` constructor

Table options reflected:
- postgresql_inherits
- postgresql_with
- postgresql_with_oids (in supported versions)
- postgresql_using
- postgresql_tablespace

Fixes #10909

lib/sqlalchemy/dialects/postgresql/base.py
test/dialect/postgresql/test_reflection.py

index b98c0d69630ea45ee8c38b7bfc5b1f4e9c83c2ad..1e800215e9e8dbb44b1cc192907bf09c2fd932e0 100644 (file)
@@ -3687,6 +3687,144 @@ class PGDialect(default.DefaultDialect):
             relkinds += pg_catalog.RELKINDS_MAT_VIEW
         return relkinds
 
+    @reflection.cache
+    def get_table_options(self, connection, table_name, schema=None, **kw):
+        data = self.get_multi_table_options(
+            connection,
+            schema=schema,
+            filter_names=[table_name],
+            scope=ObjectScope.ANY,
+            kind=ObjectKind.ANY,
+            **kw,
+        )
+        return self._value_or_raise(data, table_name, schema)
+
+    @util.memoized_property
+    def _table_options_query(self):
+        inherits_sq = (
+            select(
+                pg_catalog.pg_inherits.c.inhrelid,
+                pg_catalog.pg_inherits.c.inhseqno,
+                pg_catalog.pg_class.c.relname.label("parent_table"),
+            )
+            .select_from(pg_catalog.pg_inherits)
+            .join(
+                pg_catalog.pg_class,
+                pg_catalog.pg_inherits.c.inhparent
+                == pg_catalog.pg_class.c.oid,
+            )
+            .where(pg_catalog.pg_inherits.c.inhrelid.in_(bindparam("oids")))
+            .subquery("inherits")
+        )
+
+        return (
+            select(
+                pg_catalog.pg_class.c.oid,
+                pg_catalog.pg_class.c.relname,
+                pg_catalog.pg_class.c.reloptions,
+                sql.func.min(pg_catalog.pg_am.c.amname).label(
+                    "access_method_name"
+                ),
+                sql.func.min(pg_catalog.pg_tablespace.c.spcname).label(
+                    "tablespace_name"
+                ),
+                sql.func.array_remove(
+                    sql.func.array_agg(
+                        aggregate_order_by(
+                            inherits_sq.c.parent_table,
+                            inherits_sq.c.inhseqno.asc(),
+                        )
+                    ),
+                    None,
+                ).label("parent_table_names"),
+            )
+            .select_from(pg_catalog.pg_class)
+            .join(
+                pg_catalog.pg_am,
+                sql.and_(
+                    pg_catalog.pg_class.c.relam == pg_catalog.pg_am.c.oid,
+                    pg_catalog.pg_am.c.amtype == "t",
+                ),
+            )
+            .outerjoin(
+                inherits_sq,
+                pg_catalog.pg_class.c.oid == inherits_sq.c.inhrelid,
+            )
+            .outerjoin(
+                pg_catalog.pg_tablespace,
+                pg_catalog.pg_tablespace.c.oid
+                == pg_catalog.pg_class.c.reltablespace,
+            )
+            .group_by(pg_catalog.pg_class.c.oid)
+            .where(pg_catalog.pg_class.c.oid.in_(bindparam("oids")))
+        )
+
+    def get_multi_table_options(
+        self, connection, schema, filter_names, scope, kind, **kw
+    ):
+        table_oids = self._get_table_oids(
+            connection, schema, filter_names, scope, kind, **kw
+        )
+
+        table_options = {}
+        default = ReflectionDefaults.table_options
+
+        batches = list(table_oids)
+
+        while batches:
+            batch = batches[0:3000]
+            batches[0:3000] = []
+
+            print("fetching oids", batch)
+
+            result = connection.execute(
+                self._table_options_query, {"oids": [r[0] for r in batch]}
+            ).mappings()
+
+            result_by_oid = {}
+
+            for row_dict in result:
+                result_by_oid[row_dict["oid"]] = row_dict
+
+            print("result_by_oid", result_by_oid)
+
+            for oid, tablename in batch:
+                if oid not in result_by_oid:
+                    table_options[(schema, tablename)] = default()
+                    continue
+
+                result = result_by_oid[oid]
+                this_table_options: dict[str, Any] = {
+                    "postgresql_inherits": tuple(result["parent_table_names"]),
+                    "postgresql_using": result["access_method_name"],
+                }
+
+                if result["reloptions"]:
+                    # TODO: Compiler should (properly) support with.
+                    this_table_options["postgresql_with"] = dict(
+                        option.split("=", 1) for option in result["reloptions"]
+                    )
+
+                if self.server_version_info and self.server_version_info < (
+                    12,
+                ):
+                    if isinstance(this_table_options["postgresql_with"], dict):
+                        table_storage_params = this_table_options[
+                            "postgresql_with"
+                        ]
+                        result["postgresql_with_oids"] = (
+                            table_storage_params["OIDS"].lower() == "true"
+                        )
+
+                if result["tablespace_name"] is not None:
+                    table_options["postgresql_tablespace"] = result[
+                        "tablespace_name"
+                    ]
+
+                table_options[(schema, tablename)] = this_table_options
+
+        return table_options.items()
+
     @reflection.cache
     def get_columns(self, connection, table_name, schema=None, **kw):
         data = self.get_multi_columns(
index ebe751b5b348831ebba2fe992fbd6748ca8414d4..3c3794bbb2e584c10caaa4b12c6cd254de312fad 100644 (file)
@@ -2924,3 +2924,116 @@ class TestReflectDifficultColTypes(fixtures.TablesTest):
         is_true(len(rows) > 0)
         for row in rows:
             self.check_int_list(row, "conkey")
+
+
+class TestTableOptionsReflection(fixtures.TestBase):
+    __only_on__ = "postgresql"
+    __backend__ = True
+
+    def test_table_inherits(self, metadata, connection):
+        def assert_inherits_from(table_name, expect_base_tables):
+            table_options = inspect(connection).get_table_options(table_name)
+            eq_(table_options["postgresql_inherits"], expect_base_tables)
+
+        def assert_column_names(table_name, expect_columns):
+            columns = inspect(connection).get_columns(table_name)
+            print("columns", columns)
+            eq_([c["name"] for c in columns], expect_columns)
+
+        Table("base", metadata, Column("id", INTEGER, primary_key=True))
+        Table("name_mixin", metadata, Column("name", String(16)))
+        Table("single_inherits", metadata, postgresql_inherits="base")
+        Table(
+            "single_inherits_tuple_arg",
+            metadata,
+            postgresql_inherits=("base",),
+        )
+        Table(
+            "inherits_mixin",
+            metadata,
+            postgresql_inherits=("base", "name_mixin"),
+        )
+
+        metadata.create_all(connection)
+
+        assert_inherits_from("base", ())
+        assert_inherits_from("name_mixin", ())
+
+        assert_inherits_from("single_inherits", ("base",))
+        assert_column_names("single_inherits", ["id"])
+
+        assert_inherits_from("single_inherits_tuple_arg", ("base",))
+
+        assert_inherits_from("inherits_mixin", ("base", "name_mixin"))
+        assert_column_names("inherits_mixin", ["id", "name"])
+
+    def test_table_storage_params(self, metadata, connection):
+        def assert_has_storage_param(table_name, option_key, option_value):
+            table_options = inspect(connection).get_table_options(table_name)
+            storage_params = table_options["postgresql_with"]
+            assert isinstance(storage_params, dict)
+            eq_(storage_params[option_key], option_value)
+
+        Table("table_no_storage_params", metadata)
+        Table(
+            "table_with_fillfactor",
+            metadata,
+            postgresql_with={"fillfactor": 10},
+        )
+        Table(
+            "table_with_parallel_workers",
+            metadata,
+            postgresql_with={"parallel_workers": 15},
+        )
+
+        metadata.create_all(connection)
+
+        no_params_options = inspect(connection).get_table_options(
+            "table_no_storage_params"
+        )
+        assert "postgresql_with" in no_params_options
+
+        assert_has_storage_param("table_with_fillfactor", "fillfactor", "10")
+        assert_has_storage_param(
+            "table_with_parallel_workers", "parallel_workers", "15"
+        )
+
+    # @testing.skip_if("postgresql >= 12.0", "with_oids not supported")
+    # def test_table_with_oids(self, metadata, connection):
+    #     Table("table_with_oids", metadata, postgresql_with_oids=True)
+    #     Table("table_without_oids", metadata, postgresql_with_oids=False)
+    #     metadata.create_all(connection)
+
+    #     table_options = inspect(connection).get_table_options("table_with_oids")
+    #     eq_(table_options["postgresql_with_oids"], True)
+
+    #     table_options = inspect(connection).get_table_options("table_without_oids")
+    #     eq_(table_options["postgresql_with_oids"], False)
+
+    def test_table_using(self, metadata, connection):
+        Table("table_using_heap", metadata, postgresql_using="heap")
+        Table("heap_is_default", metadata)
+        metadata.create_all(connection)
+
+        table_options = inspect(connection).get_table_options(
+            "table_using_heap"
+        )
+        print("table_options", table_options)
+        eq_(table_options["postgresql_using"], "heap")
+
+        table_options = inspect(connection).get_table_options(
+            "heap_is_default"
+        )
+        eq_(table_options["postgresql_using"], "heap")
+
+        # TODO: Test custom access method.
+
+        # self.define_table(metadata, "table_using_btree", postgresql_using="btree")
+        # table_options = inspect(connection).get_table_options("table_using_btree")
+        # eq_(table_options["using"], "btree")
+
+    # def test_table_option_tablespace(self, metadata, connection):
+    #     self.define_simple_table(metadata, "table_sample_tablespace", postgresql_tablespace="sample_tablespace")
+
+    #     table_options = inspect(connection).get_table_options("table_sample_tablespace")
+    #     eq_(table_options["tablespace"], "sample_tablespace")