From: Mike Bayer Date: Wed, 24 May 2023 02:21:30 +0000 (-0400) Subject: add reflection arguments, engine/conn bind to DeferredReflection.prepare X-Git-Tag: rel_2_0_16~30^2 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=f1f6b296c559cedf021e3611c293841191f25317;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git add reflection arguments, engine/conn bind to DeferredReflection.prepare Improved :meth:`.DeferredReflection.prepare` to accept arbitrary ``**kw`` arguments that are passed to :meth:`_schema.MetaData.reflect`, allowing use cases such as reflection of views as well as dialect-specific arguments to be passed. Additionally, modernized the :paramref:`.DeferredReflection.prepare.bind` argument so that either an :class:`.Engine` or :class:`.Connection` are accepted as the "bind" argument. Fixes: #9828 Change-Id: Ie93cd1147611a92f07d05df8a39052f61ee692f2 --- diff --git a/doc/build/changelog/unreleased_20/9828.rst b/doc/build/changelog/unreleased_20/9828.rst new file mode 100644 index 0000000000..b6fa2559fc --- /dev/null +++ b/doc/build/changelog/unreleased_20/9828.rst @@ -0,0 +1,11 @@ +.. change:: + :tags: usecase, orm + :tickets: 9828 + + Improved :meth:`.DeferredReflection.prepare` to accept arbitrary ``**kw`` + arguments that are passed to :meth:`_schema.MetaData.reflect`, allowing use + cases such as reflection of views as well as dialect-specific arguments to + be passed. Additionally, modernized the + :paramref:`.DeferredReflection.prepare.bind` argument so that either an + :class:`.Engine` or :class:`.Connection` are accepted as the "bind" + argument. diff --git a/lib/sqlalchemy/ext/declarative/extensions.py b/lib/sqlalchemy/ext/declarative/extensions.py index 2cb55a5ae8..62a0e07540 100644 --- a/lib/sqlalchemy/ext/declarative/extensions.py +++ b/lib/sqlalchemy/ext/declarative/extensions.py @@ -11,9 +11,15 @@ from __future__ import annotations import collections +import contextlib +from typing import Any from typing import Callable from typing import TYPE_CHECKING +from typing import Union +from ... import exc as sa_exc +from ...engine import Connection +from ...engine import Engine from ...orm import exc as orm_exc from ...orm import relationships from ...orm.base import _mapper_or_none @@ -414,9 +420,25 @@ class DeferredReflection: """ @classmethod - def prepare(cls, engine): - """Reflect all :class:`_schema.Table` objects for all current - :class:`.DeferredReflection` subclasses""" + def prepare( + cls, bind: Union[Engine, Connection], **reflect_kw: Any + ) -> None: + r"""Reflect all :class:`_schema.Table` objects for all current + :class:`.DeferredReflection` subclasses + + :param bind: :class:`_engine.Engine` or :class:`_engine.Connection` + instance + + ..versionchanged:: 2.0.16 a :class:`_engine.Connection` is also + accepted. + + :param \**reflect_kw: additional keyword arguments passed to + :meth:`_schema.MetaData.reflect`, such as + :paramref:`_schema.MetaData.reflect.views`. + + .. versionadded:: 2.0.16 + + """ to_map = _DeferredMapperConfig.classes_for_base(cls) @@ -432,7 +454,18 @@ class DeferredReflection: ].add(thingy.local_table.name) # then reflect all those tables into their metadatas - with engine.connect() as conn: + + if isinstance(bind, Connection): + conn = bind + ctx = contextlib.nullcontext(enter_result=conn) + elif isinstance(bind, Engine): + ctx = bind.connect() + else: + raise sa_exc.ArgumentError( + f"Expected Engine or Connection, got {bind!r}" + ) + + with ctx as conn: for (metadata, schema), table_names in metadata_to_table.items(): metadata.reflect( conn, @@ -440,6 +473,7 @@ class DeferredReflection: schema=schema, extend_existing=True, autoload_replace=False, + **reflect_kw, ) metadata_to_table.clear() diff --git a/test/ext/declarative/test_reflection.py b/test/ext/declarative/test_reflection.py index 53f518a27f..1a4effa644 100644 --- a/test/ext/declarative/test_reflection.py +++ b/test/ext/declarative/test_reflection.py @@ -1,8 +1,11 @@ from __future__ import annotations +from sqlalchemy import DDL +from sqlalchemy import event from sqlalchemy import exc from sqlalchemy import ForeignKey from sqlalchemy import Integer +from sqlalchemy import select from sqlalchemy import String from sqlalchemy import testing from sqlalchemy.ext.declarative import DeferredReflection @@ -81,7 +84,7 @@ class DeferredReflectPKFKTest(DeferredReflectBase): DeferredReflection.prepare(testing.db) -class DeferredReflectionTest(DeferredReflectBase): +class DeferredReflectionTest(testing.AssertsCompiledSQL, DeferredReflectBase): @classmethod def define_tables(cls, metadata): Table( @@ -148,7 +151,8 @@ class DeferredReflectionTest(DeferredReflectBase): User, ) - def test_basic_deferred(self): + @testing.variation("bind", ["engine", "connection", "raise_"]) + def test_basic_deferred(self, bind): class User(DeferredReflection, fixtures.ComparableEntity, Base): __tablename__ = "users" addresses = relationship("Address", backref="user") @@ -156,9 +160,58 @@ class DeferredReflectionTest(DeferredReflectBase): class Address(DeferredReflection, fixtures.ComparableEntity, Base): __tablename__ = "addresses" - DeferredReflection.prepare(testing.db) + if bind.engine: + DeferredReflection.prepare(testing.db) + elif bind.connection: + with testing.db.connect() as conn: + DeferredReflection.prepare(conn) + elif bind.raise_: + with expect_raises_message( + exc.ArgumentError, "Expected Engine or Connection, got 'foo'" + ): + DeferredReflection.prepare("foo") + return + else: + bind.fail() + self._roundtrip() + @testing.requires.view_reflection + @testing.variation("include_views", [True, False]) + def test_views(self, metadata, connection, include_views): + Table( + "test_table", + metadata, + Column("id", Integer, primary_key=True), + Column("data", String(50)), + ) + query = "CREATE VIEW view_name AS SELECT id, data FROM test_table" + + event.listen(metadata, "after_create", DDL(query)) + event.listen( + metadata, "before_drop", DDL("DROP VIEW IF EXISTS view_name") + ) + metadata.create_all(connection) + + class ViewName(DeferredReflection, Base): + __tablename__ = "view_name" + + id = Column(Integer, primary_key=True) + + if include_views: + DeferredReflection.prepare(connection, views=True) + else: + with expect_raises_message( + exc.InvalidRequestError, r"Could not reflect: .*view_name" + ): + DeferredReflection.prepare(connection) + return + + self.assert_compile( + select(ViewName), + "SELECT view_name.id, view_name.data FROM view_name", + ) + def test_abstract_base(self): class DefBase(DeferredReflection, Base): __abstract__ = True