from __future__ import annotations
import collections
+import contextlib
+from typing import Any
from typing import Callable
from typing import TYPE_CHECKING
+from typing import Union
+from ... import exc as sa_exc
+from ...engine import Connection
+from ...engine import Engine
from ...orm import exc as orm_exc
from ...orm import relationships
from ...orm.base import _mapper_or_none
"""
@classmethod
- def prepare(cls, engine):
- """Reflect all :class:`_schema.Table` objects for all current
- :class:`.DeferredReflection` subclasses"""
+ def prepare(
+ cls, bind: Union[Engine, Connection], **reflect_kw: Any
+ ) -> None:
+ r"""Reflect all :class:`_schema.Table` objects for all current
+ :class:`.DeferredReflection` subclasses
+
+ :param bind: :class:`_engine.Engine` or :class:`_engine.Connection`
+ instance
+
+ ..versionchanged:: 2.0.16 a :class:`_engine.Connection` is also
+ accepted.
+
+ :param \**reflect_kw: additional keyword arguments passed to
+ :meth:`_schema.MetaData.reflect`, such as
+ :paramref:`_schema.MetaData.reflect.views`.
+
+ .. versionadded:: 2.0.16
+
+ """
to_map = _DeferredMapperConfig.classes_for_base(cls)
].add(thingy.local_table.name)
# then reflect all those tables into their metadatas
- with engine.connect() as conn:
+
+ if isinstance(bind, Connection):
+ conn = bind
+ ctx = contextlib.nullcontext(enter_result=conn)
+ elif isinstance(bind, Engine):
+ ctx = bind.connect()
+ else:
+ raise sa_exc.ArgumentError(
+ f"Expected Engine or Connection, got {bind!r}"
+ )
+
+ with ctx as conn:
for (metadata, schema), table_names in metadata_to_table.items():
metadata.reflect(
conn,
schema=schema,
extend_existing=True,
autoload_replace=False,
+ **reflect_kw,
)
metadata_to_table.clear()
from __future__ import annotations
+from sqlalchemy import DDL
+from sqlalchemy import event
from sqlalchemy import exc
from sqlalchemy import ForeignKey
from sqlalchemy import Integer
+from sqlalchemy import select
from sqlalchemy import String
from sqlalchemy import testing
from sqlalchemy.ext.declarative import DeferredReflection
DeferredReflection.prepare(testing.db)
-class DeferredReflectionTest(DeferredReflectBase):
+class DeferredReflectionTest(testing.AssertsCompiledSQL, DeferredReflectBase):
@classmethod
def define_tables(cls, metadata):
Table(
User,
)
- def test_basic_deferred(self):
+ @testing.variation("bind", ["engine", "connection", "raise_"])
+ def test_basic_deferred(self, bind):
class User(DeferredReflection, fixtures.ComparableEntity, Base):
__tablename__ = "users"
addresses = relationship("Address", backref="user")
class Address(DeferredReflection, fixtures.ComparableEntity, Base):
__tablename__ = "addresses"
- DeferredReflection.prepare(testing.db)
+ if bind.engine:
+ DeferredReflection.prepare(testing.db)
+ elif bind.connection:
+ with testing.db.connect() as conn:
+ DeferredReflection.prepare(conn)
+ elif bind.raise_:
+ with expect_raises_message(
+ exc.ArgumentError, "Expected Engine or Connection, got 'foo'"
+ ):
+ DeferredReflection.prepare("foo")
+ return
+ else:
+ bind.fail()
+
self._roundtrip()
+ @testing.requires.view_reflection
+ @testing.variation("include_views", [True, False])
+ def test_views(self, metadata, connection, include_views):
+ Table(
+ "test_table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("data", String(50)),
+ )
+ query = "CREATE VIEW view_name AS SELECT id, data FROM test_table"
+
+ event.listen(metadata, "after_create", DDL(query))
+ event.listen(
+ metadata, "before_drop", DDL("DROP VIEW IF EXISTS view_name")
+ )
+ metadata.create_all(connection)
+
+ class ViewName(DeferredReflection, Base):
+ __tablename__ = "view_name"
+
+ id = Column(Integer, primary_key=True)
+
+ if include_views:
+ DeferredReflection.prepare(connection, views=True)
+ else:
+ with expect_raises_message(
+ exc.InvalidRequestError, r"Could not reflect: .*view_name"
+ ):
+ DeferredReflection.prepare(connection)
+ return
+
+ self.assert_compile(
+ select(ViewName),
+ "SELECT view_name.id, view_name.data FROM view_name",
+ )
+
def test_abstract_base(self):
class DefBase(DeferredReflection, Base):
__abstract__ = True