]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
add reflection arguments, engine/conn bind to DeferredReflection.prepare
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 24 May 2023 02:21:30 +0000 (22:21 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 24 May 2023 04:13:13 +0000 (00:13 -0400)
Improved :meth:`.DeferredReflection.prepare` to accept arbitrary ``**kw``
arguments that are passed to :meth:`_schema.MetaData.reflect`, allowing use
cases such as reflection of views as well as dialect-specific arguments to
be passed. Additionally, modernized the
:paramref:`.DeferredReflection.prepare.bind` argument so that either an
:class:`.Engine` or :class:`.Connection` are accepted as the "bind"
argument.

Fixes: #9828
Change-Id: Ie93cd1147611a92f07d05df8a39052f61ee692f2

doc/build/changelog/unreleased_20/9828.rst [new file with mode: 0644]
lib/sqlalchemy/ext/declarative/extensions.py
test/ext/declarative/test_reflection.py

diff --git a/doc/build/changelog/unreleased_20/9828.rst b/doc/build/changelog/unreleased_20/9828.rst
new file mode 100644 (file)
index 0000000..b6fa255
--- /dev/null
@@ -0,0 +1,11 @@
+.. change::
+    :tags: usecase, orm
+    :tickets: 9828
+
+    Improved :meth:`.DeferredReflection.prepare` to accept arbitrary ``**kw``
+    arguments that are passed to :meth:`_schema.MetaData.reflect`, allowing use
+    cases such as reflection of views as well as dialect-specific arguments to
+    be passed. Additionally, modernized the
+    :paramref:`.DeferredReflection.prepare.bind` argument so that either an
+    :class:`.Engine` or :class:`.Connection` are accepted as the "bind"
+    argument.
index 2cb55a5ae8835dd9bc4996b68f450bfe8d85a77a..62a0e0754059c91e98df063e7f59fce3190777c0 100644 (file)
 from __future__ import annotations
 
 import collections
+import contextlib
+from typing import Any
 from typing import Callable
 from typing import TYPE_CHECKING
+from typing import Union
 
+from ... import exc as sa_exc
+from ...engine import Connection
+from ...engine import Engine
 from ...orm import exc as orm_exc
 from ...orm import relationships
 from ...orm.base import _mapper_or_none
@@ -414,9 +420,25 @@ class DeferredReflection:
     """
 
     @classmethod
-    def prepare(cls, engine):
-        """Reflect all :class:`_schema.Table` objects for all current
-        :class:`.DeferredReflection` subclasses"""
+    def prepare(
+        cls, bind: Union[Engine, Connection], **reflect_kw: Any
+    ) -> None:
+        r"""Reflect all :class:`_schema.Table` objects for all current
+        :class:`.DeferredReflection` subclasses
+
+        :param bind: :class:`_engine.Engine` or :class:`_engine.Connection`
+         instance
+
+         ..versionchanged:: 2.0.16 a :class:`_engine.Connection` is also
+         accepted.
+
+        :param \**reflect_kw: additional keyword arguments passed to
+         :meth:`_schema.MetaData.reflect`, such as
+         :paramref:`_schema.MetaData.reflect.views`.
+
+         .. versionadded:: 2.0.16
+
+        """
 
         to_map = _DeferredMapperConfig.classes_for_base(cls)
 
@@ -432,7 +454,18 @@ class DeferredReflection:
                 ].add(thingy.local_table.name)
 
         # then reflect all those tables into their metadatas
-        with engine.connect() as conn:
+
+        if isinstance(bind, Connection):
+            conn = bind
+            ctx = contextlib.nullcontext(enter_result=conn)
+        elif isinstance(bind, Engine):
+            ctx = bind.connect()
+        else:
+            raise sa_exc.ArgumentError(
+                f"Expected Engine or Connection, got {bind!r}"
+            )
+
+        with ctx as conn:
             for (metadata, schema), table_names in metadata_to_table.items():
                 metadata.reflect(
                     conn,
@@ -440,6 +473,7 @@ class DeferredReflection:
                     schema=schema,
                     extend_existing=True,
                     autoload_replace=False,
+                    **reflect_kw,
                 )
 
             metadata_to_table.clear()
index 53f518a27f7a0796160b50e56863fbc6275d99d1..1a4effa6449bbed9a5d54b244d615da610a1ffe5 100644 (file)
@@ -1,8 +1,11 @@
 from __future__ import annotations
 
+from sqlalchemy import DDL
+from sqlalchemy import event
 from sqlalchemy import exc
 from sqlalchemy import ForeignKey
 from sqlalchemy import Integer
+from sqlalchemy import select
 from sqlalchemy import String
 from sqlalchemy import testing
 from sqlalchemy.ext.declarative import DeferredReflection
@@ -81,7 +84,7 @@ class DeferredReflectPKFKTest(DeferredReflectBase):
         DeferredReflection.prepare(testing.db)
 
 
-class DeferredReflectionTest(DeferredReflectBase):
+class DeferredReflectionTest(testing.AssertsCompiledSQL, DeferredReflectBase):
     @classmethod
     def define_tables(cls, metadata):
         Table(
@@ -148,7 +151,8 @@ class DeferredReflectionTest(DeferredReflectBase):
             User,
         )
 
-    def test_basic_deferred(self):
+    @testing.variation("bind", ["engine", "connection", "raise_"])
+    def test_basic_deferred(self, bind):
         class User(DeferredReflection, fixtures.ComparableEntity, Base):
             __tablename__ = "users"
             addresses = relationship("Address", backref="user")
@@ -156,9 +160,58 @@ class DeferredReflectionTest(DeferredReflectBase):
         class Address(DeferredReflection, fixtures.ComparableEntity, Base):
             __tablename__ = "addresses"
 
-        DeferredReflection.prepare(testing.db)
+        if bind.engine:
+            DeferredReflection.prepare(testing.db)
+        elif bind.connection:
+            with testing.db.connect() as conn:
+                DeferredReflection.prepare(conn)
+        elif bind.raise_:
+            with expect_raises_message(
+                exc.ArgumentError, "Expected Engine or Connection, got 'foo'"
+            ):
+                DeferredReflection.prepare("foo")
+            return
+        else:
+            bind.fail()
+
         self._roundtrip()
 
+    @testing.requires.view_reflection
+    @testing.variation("include_views", [True, False])
+    def test_views(self, metadata, connection, include_views):
+        Table(
+            "test_table",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column("data", String(50)),
+        )
+        query = "CREATE VIEW view_name AS SELECT id, data FROM test_table"
+
+        event.listen(metadata, "after_create", DDL(query))
+        event.listen(
+            metadata, "before_drop", DDL("DROP VIEW IF EXISTS view_name")
+        )
+        metadata.create_all(connection)
+
+        class ViewName(DeferredReflection, Base):
+            __tablename__ = "view_name"
+
+            id = Column(Integer, primary_key=True)
+
+        if include_views:
+            DeferredReflection.prepare(connection, views=True)
+        else:
+            with expect_raises_message(
+                exc.InvalidRequestError, r"Could not reflect: .*view_name"
+            ):
+                DeferredReflection.prepare(connection)
+            return
+
+        self.assert_compile(
+            select(ViewName),
+            "SELECT view_name.id, view_name.data FROM view_name",
+        )
+
     def test_abstract_base(self):
         class DefBase(DeferredReflection, Base):
             __abstract__ = True