From a92942a5313c323afc027f69ed3a92cfe818cf76 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Wed, 31 Mar 2021 11:48:22 -0400 Subject: [PATCH] Repair PGInspector Fixed issue where the PostgreSQL :class:`.PGInspector`, when generated against an :class:`_engine.Engine`, would fail for ``.get_enums()``, ``.get_view_names()``, ``.get_foreign_table_names()`` and ``.get_table_oid()`` when used against a "future" style engine and not the connection directly. Fixes: #6170 Change-Id: I8c3abdfb758305c2f7a96002d3644729f29c998b --- doc/build/changelog/unreleased_14/6170.rst | 9 ++++ lib/sqlalchemy/dialects/postgresql/base.py | 26 +++++---- test/dialect/postgresql/test_reflection.py | 63 ++++++++++++++++++---- 3 files changed, 79 insertions(+), 19 deletions(-) create mode 100644 doc/build/changelog/unreleased_14/6170.rst diff --git a/doc/build/changelog/unreleased_14/6170.rst b/doc/build/changelog/unreleased_14/6170.rst new file mode 100644 index 0000000000..991eb16350 --- /dev/null +++ b/doc/build/changelog/unreleased_14/6170.rst @@ -0,0 +1,9 @@ +.. change:: + :tags: bug, postgresql + :tickets: 6170 + + Fixed issue where the PostgreSQL :class:`.PGInspector`, when generated + against an :class:`_engine.Engine`, would fail for ``.get_enums()``, + ``.get_view_names()``, ``.get_foreign_table_names()`` and + ``.get_table_oid()`` when used against a "future" style engine and not the + connection directly. diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index ba77745907..0854214d02 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -2902,9 +2902,10 @@ class PGInspector(reflection.Inspector): def get_table_oid(self, table_name, schema=None): """Return the OID for the given table name.""" - return self.dialect.get_table_oid( - self.bind, table_name, schema, info_cache=self.info_cache - ) + with self._operation_context() as conn: + return self.dialect.get_table_oid( + conn, table_name, schema, info_cache=self.info_cache + ) def get_enums(self, schema=None): """Return a list of ENUM objects. @@ -2925,7 +2926,8 @@ class PGInspector(reflection.Inspector): """ schema = schema or self.default_schema_name - return self.dialect._load_enums(self.bind, schema) + with self._operation_context() as conn: + return self.dialect._load_enums(conn, schema) def get_foreign_table_names(self, schema=None): """Return a list of FOREIGN TABLE names. @@ -2939,7 +2941,8 @@ class PGInspector(reflection.Inspector): """ schema = schema or self.default_schema_name - return self.dialect._get_foreign_table_names(self.bind, schema) + with self._operation_context() as conn: + return self.dialect._get_foreign_table_names(conn, schema) def get_view_names(self, schema=None, include=("plain", "materialized")): """Return all view names in `schema`. @@ -2955,9 +2958,10 @@ class PGInspector(reflection.Inspector): """ - return self.dialect.get_view_names( - self.bind, schema, info_cache=self.info_cache, include=include - ) + with self._operation_context() as conn: + return self.dialect.get_view_names( + conn, schema, info_cache=self.info_cache, include=include + ) class CreateEnumType(schema._CreateDropBase): @@ -3481,7 +3485,11 @@ class PGDialect(default.DefaultDialect): "JOIN pg_namespace n ON n.oid = c.relnamespace " "WHERE n.nspname = :schema AND c.relkind = 'f'" ).columns(relname=sqltypes.Unicode), - schema=schema if schema is not None else self.default_schema_name, + dict( + schema=schema + if schema is not None + else self.default_schema_name + ), ) return [name for name, in result] diff --git a/test/dialect/postgresql/test_reflection.py b/test/dialect/postgresql/test_reflection.py index a7876a766a..4b6d927b37 100644 --- a/test/dialect/postgresql/test_reflection.py +++ b/test/dialect/postgresql/test_reflection.py @@ -40,7 +40,33 @@ from sqlalchemy.testing.assertions import is_ from sqlalchemy.testing.assertions import is_true -class ForeignTableReflectionTest(fixtures.TablesTest, AssertsExecutionResults): +class ReflectionFixtures(object): + @testing.fixture( + params=[ + ("engine", True), + ("connection", True), + ("engine", False), + ("connection", False), + ] + ) + def inspect_fixture(self, request, metadata, testing_engine): + engine, future = request.param + + eng = testing_engine(future=future) + + conn = eng.connect() + + if engine == "connection": + yield inspect(eng), conn + else: + yield inspect(conn), conn + + conn.close() + + +class ForeignTableReflectionTest( + ReflectionFixtures, fixtures.TablesTest, AssertsExecutionResults +): """Test reflection on foreign tables""" __requires__ = ("postgresql_test_dblink",) @@ -90,8 +116,9 @@ class ForeignTableReflectionTest(fixtures.TablesTest, AssertsExecutionResults): "Columns of reflected foreign table didn't equal expected columns", ) - def test_get_foreign_table_names(self, connection): - inspector = inspect(connection) + def test_get_foreign_table_names(self, inspect_fixture): + inspector, conn = inspect_fixture + ft_names = inspector.get_foreign_table_names() eq_(ft_names, ["test_foreigntable"]) @@ -179,7 +206,7 @@ class PartitionedReflectionTest(fixtures.TablesTest, AssertsExecutionResults): class MaterializedViewReflectionTest( - fixtures.TablesTest, AssertsExecutionResults + ReflectionFixtures, fixtures.TablesTest, AssertsExecutionResults ): """Test reflection on materialized views""" @@ -233,8 +260,8 @@ class MaterializedViewReflectionTest( table = Table("test_mview", metadata, autoload_with=connection) eq_(connection.execute(table.select()).fetchall(), [(89, "d1")]) - def test_get_view_names(self, connection): - insp = inspect(connection) + def test_get_view_names(self, inspect_fixture): + insp, conn = inspect_fixture eq_(set(insp.get_view_names()), set(["test_regview", "test_mview"])) def test_get_view_names_plain(self, connection): @@ -471,7 +498,9 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults): base.PGDialect.ischema_names = ischema_names -class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): +class ReflectionTest( + ReflectionFixtures, AssertsCompiledSQL, fixtures.TestBase +): __only_on__ = "postgresql" __backend__ = True @@ -1302,12 +1331,17 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): ], ) - def test_inspect_enums(self, metadata, connection): + def test_inspect_enums(self, metadata, inspect_fixture): + + inspector, conn = inspect_fixture + enum_type = postgresql.ENUM( "cat", "dog", "rat", name="pet", metadata=metadata ) - enum_type.create(connection) - inspector = inspect(connection) + + with conn.begin(): + enum_type.create(conn) + eq_( inspector.get_enums(), [ @@ -1320,6 +1354,15 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase): ], ) + def test_get_table_oid(self, metadata, inspect_fixture): + + inspector, conn = inspect_fixture + + with conn.begin(): + Table("some_table", metadata, Column("q", Integer)).create(conn) + + assert inspector.get_table_oid("some_table") is not None + def test_inspect_enums_case_sensitive(self, metadata, connection): sa.event.listen( metadata, -- 2.47.2