]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Repair PGInspector
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 31 Mar 2021 15:48:22 +0000 (11:48 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 31 Mar 2021 15:50:26 +0000 (11:50 -0400)
Fixed issue where the PostgreSQL :class:`.PGInspector`, when generated
against an :class:`_engine.Engine`, would fail for ``.get_enums()``,
``.get_view_names()``, ``.get_foreign_table_names()`` and
``.get_table_oid()`` when used against a "future" style engine and not the
connection directly.

Fixes: #6170
Change-Id: I8c3abdfb758305c2f7a96002d3644729f29c998b

doc/build/changelog/unreleased_14/6170.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/postgresql/base.py
test/dialect/postgresql/test_reflection.py

diff --git a/doc/build/changelog/unreleased_14/6170.rst b/doc/build/changelog/unreleased_14/6170.rst
new file mode 100644 (file)
index 0000000..991eb16
--- /dev/null
@@ -0,0 +1,9 @@
+.. change::
+    :tags: bug, postgresql
+    :tickets: 6170
+
+    Fixed issue where the PostgreSQL :class:`.PGInspector`, when generated
+    against an :class:`_engine.Engine`, would fail for ``.get_enums()``,
+    ``.get_view_names()``, ``.get_foreign_table_names()`` and
+    ``.get_table_oid()`` when used against a "future" style engine and not the
+    connection directly.
index ba777459073afaf0a87c2ec09bd12b88d93cd242..0854214d029ac639c375a6290007179455f275d1 100644 (file)
@@ -2902,9 +2902,10 @@ class PGInspector(reflection.Inspector):
     def get_table_oid(self, table_name, schema=None):
         """Return the OID for the given table name."""
 
-        return self.dialect.get_table_oid(
-            self.bind, table_name, schema, info_cache=self.info_cache
-        )
+        with self._operation_context() as conn:
+            return self.dialect.get_table_oid(
+                conn, table_name, schema, info_cache=self.info_cache
+            )
 
     def get_enums(self, schema=None):
         """Return a list of ENUM objects.
@@ -2925,7 +2926,8 @@ class PGInspector(reflection.Inspector):
 
         """
         schema = schema or self.default_schema_name
-        return self.dialect._load_enums(self.bind, schema)
+        with self._operation_context() as conn:
+            return self.dialect._load_enums(conn, schema)
 
     def get_foreign_table_names(self, schema=None):
         """Return a list of FOREIGN TABLE names.
@@ -2939,7 +2941,8 @@ class PGInspector(reflection.Inspector):
 
         """
         schema = schema or self.default_schema_name
-        return self.dialect._get_foreign_table_names(self.bind, schema)
+        with self._operation_context() as conn:
+            return self.dialect._get_foreign_table_names(conn, schema)
 
     def get_view_names(self, schema=None, include=("plain", "materialized")):
         """Return all view names in `schema`.
@@ -2955,9 +2958,10 @@ class PGInspector(reflection.Inspector):
 
         """
 
-        return self.dialect.get_view_names(
-            self.bind, schema, info_cache=self.info_cache, include=include
-        )
+        with self._operation_context() as conn:
+            return self.dialect.get_view_names(
+                conn, schema, info_cache=self.info_cache, include=include
+            )
 
 
 class CreateEnumType(schema._CreateDropBase):
@@ -3481,7 +3485,11 @@ class PGDialect(default.DefaultDialect):
                 "JOIN pg_namespace n ON n.oid = c.relnamespace "
                 "WHERE n.nspname = :schema AND c.relkind = 'f'"
             ).columns(relname=sqltypes.Unicode),
-            schema=schema if schema is not None else self.default_schema_name,
+            dict(
+                schema=schema
+                if schema is not None
+                else self.default_schema_name
+            ),
         )
         return [name for name, in result]
 
index a7876a766a40ec44bd3d4d8f661e693c625a8946..4b6d927b37dd2e9162c9ca652d6420f2a39dd0ff 100644 (file)
@@ -40,7 +40,33 @@ from sqlalchemy.testing.assertions import is_
 from sqlalchemy.testing.assertions import is_true
 
 
-class ForeignTableReflectionTest(fixtures.TablesTest, AssertsExecutionResults):
+class ReflectionFixtures(object):
+    @testing.fixture(
+        params=[
+            ("engine", True),
+            ("connection", True),
+            ("engine", False),
+            ("connection", False),
+        ]
+    )
+    def inspect_fixture(self, request, metadata, testing_engine):
+        engine, future = request.param
+
+        eng = testing_engine(future=future)
+
+        conn = eng.connect()
+
+        if engine == "connection":
+            yield inspect(eng), conn
+        else:
+            yield inspect(conn), conn
+
+        conn.close()
+
+
+class ForeignTableReflectionTest(
+    ReflectionFixtures, fixtures.TablesTest, AssertsExecutionResults
+):
     """Test reflection on foreign tables"""
 
     __requires__ = ("postgresql_test_dblink",)
@@ -90,8 +116,9 @@ class ForeignTableReflectionTest(fixtures.TablesTest, AssertsExecutionResults):
             "Columns of reflected foreign table didn't equal expected columns",
         )
 
-    def test_get_foreign_table_names(self, connection):
-        inspector = inspect(connection)
+    def test_get_foreign_table_names(self, inspect_fixture):
+        inspector, conn = inspect_fixture
+
         ft_names = inspector.get_foreign_table_names()
         eq_(ft_names, ["test_foreigntable"])
 
@@ -179,7 +206,7 @@ class PartitionedReflectionTest(fixtures.TablesTest, AssertsExecutionResults):
 
 
 class MaterializedViewReflectionTest(
-    fixtures.TablesTest, AssertsExecutionResults
+    ReflectionFixtures, fixtures.TablesTest, AssertsExecutionResults
 ):
     """Test reflection on materialized views"""
 
@@ -233,8 +260,8 @@ class MaterializedViewReflectionTest(
         table = Table("test_mview", metadata, autoload_with=connection)
         eq_(connection.execute(table.select()).fetchall(), [(89, "d1")])
 
-    def test_get_view_names(self, connection):
-        insp = inspect(connection)
+    def test_get_view_names(self, inspect_fixture):
+        insp, conn = inspect_fixture
         eq_(set(insp.get_view_names()), set(["test_regview", "test_mview"]))
 
     def test_get_view_names_plain(self, connection):
@@ -471,7 +498,9 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults):
             base.PGDialect.ischema_names = ischema_names
 
 
-class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
+class ReflectionTest(
+    ReflectionFixtures, AssertsCompiledSQL, fixtures.TestBase
+):
     __only_on__ = "postgresql"
     __backend__ = True
 
@@ -1302,12 +1331,17 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
             ],
         )
 
-    def test_inspect_enums(self, metadata, connection):
+    def test_inspect_enums(self, metadata, inspect_fixture):
+
+        inspector, conn = inspect_fixture
+
         enum_type = postgresql.ENUM(
             "cat", "dog", "rat", name="pet", metadata=metadata
         )
-        enum_type.create(connection)
-        inspector = inspect(connection)
+
+        with conn.begin():
+            enum_type.create(conn)
+
         eq_(
             inspector.get_enums(),
             [
@@ -1320,6 +1354,15 @@ class ReflectionTest(AssertsCompiledSQL, fixtures.TestBase):
             ],
         )
 
+    def test_get_table_oid(self, metadata, inspect_fixture):
+
+        inspector, conn = inspect_fixture
+
+        with conn.begin():
+            Table("some_table", metadata, Column("q", Integer)).create(conn)
+
+        assert inspector.get_table_oid("some_table") is not None
+
     def test_inspect_enums_case_sensitive(self, metadata, connection):
         sa.event.listen(
             metadata,