From d24cd5e96d7f8e47c86b5013a7f989a15e2eec89 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Thu, 26 May 2022 14:35:03 -0400 Subject: [PATCH] establish sessionmaker and async_sessionmaker as generic This is so that custom Session and AsyncSession classes can be typed for these factories. Added appropriate typevars to `__call__()`, `__enter__()` and other methods so that a custom Session or AsyncSession subclass is carried through. Fixes: #7656 Change-Id: Ia2b8c1f22b4410db26005c3285f6ba3d13d7f0e0 --- examples/asyncio/gather_orm_statements.py | 11 +-- lib/sqlalchemy/ext/asyncio/engine.py | 11 ++- lib/sqlalchemy/ext/asyncio/scoping.py | 24 +++--- lib/sqlalchemy/ext/asyncio/session.py | 33 ++++---- lib/sqlalchemy/orm/scoping.py | 34 +++++++-- lib/sqlalchemy/orm/session.py | 18 +++-- test/ext/mypy/plain_files/sessionmakers.py | 88 ++++++++++++++++++++++ 7 files changed, 168 insertions(+), 51 deletions(-) create mode 100644 test/ext/mypy/plain_files/sessionmakers.py diff --git a/examples/asyncio/gather_orm_statements.py b/examples/asyncio/gather_orm_statements.py index edcdc1fe84..a67b5e669d 100644 --- a/examples/asyncio/gather_orm_statements.py +++ b/examples/asyncio/gather_orm_statements.py @@ -22,12 +22,11 @@ import random from sqlalchemy import Column from sqlalchemy import Integer from sqlalchemy import String -from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.ext.asyncio import async_sessionmaker from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.future import select from sqlalchemy.orm import merge_frozen_result -from sqlalchemy.orm import sessionmaker Base = declarative_base() @@ -40,14 +39,14 @@ class A(Base): async def run_out_of_band( - sessionmaker, session, statement, merge_results=True + async_sessionmaker, session, statement, merge_results=True ): """run an ORM statement in a distinct session, merging the result back into the given session. """ - async with sessionmaker() as oob_session: + async with async_sessionmaker() as oob_session: # use AUTOCOMMIT for each connection to reduce transaction # overhead / contention @@ -94,9 +93,7 @@ async def async_main(): await conn.run_sync(Base.metadata.drop_all) await conn.run_sync(Base.metadata.create_all) - async_session = sessionmaker( - engine, expire_on_commit=False, class_=AsyncSession - ) + async_session = async_sessionmaker(engine, expire_on_commit=False) async with async_session() as session, session.begin(): session.add_all([A(data="a_%d" % i) for i in range(100)]) diff --git a/lib/sqlalchemy/ext/asyncio/engine.py b/lib/sqlalchemy/ext/asyncio/engine.py index 6d07d843c9..97d69fcbd2 100644 --- a/lib/sqlalchemy/ext/asyncio/engine.py +++ b/lib/sqlalchemy/ext/asyncio/engine.py @@ -1068,10 +1068,15 @@ class AsyncEngine(ProxyComparable[Engine], AsyncConnectable): @property def engine(self) -> Any: - r""".. container:: class_bases + r"""Returns this :class:`.Engine`. - Proxied for the :class:`_engine.Engine` class - on behalf of the :class:`_asyncio.AsyncEngine` class. + .. container:: class_bases + + Proxied for the :class:`_engine.Engine` class + on behalf of the :class:`_asyncio.AsyncEngine` class. + + Used for legacy schemes that accept :class:`.Connection` / + :class:`.Engine` objects within the same variable. """ # noqa: E501 diff --git a/lib/sqlalchemy/ext/asyncio/scoping.py b/lib/sqlalchemy/ext/asyncio/scoping.py index 22a060a0d4..8d31dd07d4 100644 --- a/lib/sqlalchemy/ext/asyncio/scoping.py +++ b/lib/sqlalchemy/ext/asyncio/scoping.py @@ -9,6 +9,7 @@ from __future__ import annotations from typing import Any from typing import Callable +from typing import Generic from typing import Iterable from typing import Iterator from typing import Optional @@ -20,6 +21,7 @@ from typing import TYPE_CHECKING from typing import TypeVar from typing import Union +from .session import _AS from .session import async_sessionmaker from .session import AsyncSession from ... import exc as sa_exc @@ -104,7 +106,7 @@ _T = TypeVar("_T", bound=Any) "info", ], ) -class async_scoped_session: +class async_scoped_session(Generic[_AS]): """Provides scoped management of :class:`.AsyncSession` objects. See the section :ref:`asyncio_scoped_session` for usage details. @@ -116,16 +118,16 @@ class async_scoped_session: _support_async = True - session_factory: async_sessionmaker + session_factory: async_sessionmaker[_AS] """The `session_factory` provided to `__init__` is stored in this attribute and may be accessed at a later time. This can be useful when a new non-scoped :class:`.AsyncSession` is needed.""" - registry: ScopedRegistry[AsyncSession] + registry: ScopedRegistry[_AS] def __init__( self, - session_factory: async_sessionmaker, + session_factory: async_sessionmaker[_AS], scopefunc: Callable[[], Any], ): """Construct a new :class:`_asyncio.async_scoped_session`. @@ -144,10 +146,10 @@ class async_scoped_session: self.registry = ScopedRegistry(session_factory, scopefunc) @property - def _proxied(self) -> AsyncSession: + def _proxied(self) -> _AS: return self.registry() - def __call__(self, **kw: Any) -> AsyncSession: + def __call__(self, **kw: Any) -> _AS: r"""Return the current :class:`.AsyncSession`, creating it using the :attr:`.scoped_session.session_factory` if not present. @@ -450,8 +452,8 @@ class async_scoped_session: This method may also be used to establish execution options for the database connection used by the current transaction. - .. versionadded:: 1.4.24 Added **kw arguments which are passed through - to the underlying :meth:`_orm.Session.connection` method. + .. versionadded:: 1.4.24 Added \**kw arguments which are passed + through to the underlying :meth:`_orm.Session.connection` method. .. seealso:: @@ -752,7 +754,8 @@ class async_scoped_session: from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import create_async_engine - from sqlalchemy.orm import Session, sessionmaker + from sqlalchemy.ext.asyncio import async_sessionmaker + from sqlalchemy.orm import Session # construct async engines w/ async drivers engines = { @@ -775,8 +778,7 @@ class async_scoped_session: ].sync_engine # apply to AsyncSession using sync_session_class - AsyncSessionMaker = sessionmaker( - class_=AsyncSession, + AsyncSessionMaker = async_sessionmaker( sync_session_class=RoutingSession ) diff --git a/lib/sqlalchemy/ext/asyncio/session.py b/lib/sqlalchemy/ext/asyncio/session.py index eac2e58063..be3414cef4 100644 --- a/lib/sqlalchemy/ext/asyncio/session.py +++ b/lib/sqlalchemy/ext/asyncio/session.py @@ -8,6 +8,7 @@ from __future__ import annotations from typing import Any from typing import Dict +from typing import Generic from typing import Iterable from typing import Iterator from typing import NoReturn @@ -698,7 +699,8 @@ class AsyncSession(ReversibleProxy[Session]): from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import create_async_engine - from sqlalchemy.orm import Session, sessionmaker + from sqlalchemy.ext.asyncio import async_sessionmaker + from sqlalchemy.orm import Session # construct async engines w/ async drivers engines = { @@ -721,8 +723,7 @@ class AsyncSession(ReversibleProxy[Session]): ].sync_engine # apply to AsyncSession using sync_session_class - AsyncSessionMaker = sessionmaker( - class_=AsyncSession, + AsyncSessionMaker = async_sessionmaker( sync_session_class=RoutingSession ) @@ -850,14 +851,13 @@ class AsyncSession(ReversibleProxy[Session]): """Close all :class:`_asyncio.AsyncSession` sessions.""" await greenlet_spawn(self.sync_session.close_all) - async def __aenter__(self) -> AsyncSession: + async def __aenter__(self: _AS) -> _AS: return self async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None: await self.close() - def _maker_context_manager(self) -> _AsyncSessionContextManager: - # TODO: can this use asynccontextmanager ?? + def _maker_context_manager(self: _AS) -> _AsyncSessionContextManager[_AS]: return _AsyncSessionContextManager(self) # START PROXY METHODS AsyncSession @@ -1367,7 +1367,10 @@ class AsyncSession(ReversibleProxy[Session]): # END PROXY METHODS AsyncSession -class async_sessionmaker: +_AS = TypeVar("_AS", bound="AsyncSession") + + +class async_sessionmaker(Generic[_AS]): """A configurable :class:`.AsyncSession` factory. The :class:`.async_sessionmaker` factory works in the same way as the @@ -1409,12 +1412,12 @@ class async_sessionmaker: """ # noqa E501 - class_: Type[AsyncSession] + class_: Type[_AS] def __init__( self, bind: Optional[_AsyncSessionBind] = None, - class_: Type[AsyncSession] = AsyncSession, + class_: Type[_AS] = AsyncSession, # type: ignore autoflush: bool = True, expire_on_commit: bool = True, info: Optional[_InfoType] = None, @@ -1437,7 +1440,7 @@ class async_sessionmaker: self.kw = kw self.class_ = class_ - def begin(self) -> _AsyncSessionContextManager: + def begin(self) -> _AsyncSessionContextManager[_AS]: """Produce a context manager that both provides a new :class:`_orm.AsyncSession` as well as a transaction that commits. @@ -1458,7 +1461,7 @@ class async_sessionmaker: session = self() return session._maker_context_manager() - def __call__(self, **local_kw: Any) -> AsyncSession: + def __call__(self, **local_kw: Any) -> _AS: """Produce a new :class:`.AsyncSession` object using the configuration established in this :class:`.async_sessionmaker`. @@ -1498,16 +1501,16 @@ class async_sessionmaker: ) -class _AsyncSessionContextManager: +class _AsyncSessionContextManager(Generic[_AS]): __slots__ = ("async_session", "trans") - async_session: AsyncSession + async_session: _AS trans: AsyncSessionTransaction - def __init__(self, async_session: AsyncSession): + def __init__(self, async_session: _AS): self.async_session = async_session - async def __aenter__(self) -> AsyncSession: + async def __aenter__(self) -> _AS: self.trans = self.async_session.begin() await self.trans.__aenter__() return self.async_session diff --git a/lib/sqlalchemy/orm/scoping.py b/lib/sqlalchemy/orm/scoping.py index 9220c44c7f..c00508385c 100644 --- a/lib/sqlalchemy/orm/scoping.py +++ b/lib/sqlalchemy/orm/scoping.py @@ -10,6 +10,7 @@ from __future__ import annotations from typing import Any from typing import Callable from typing import Dict +from typing import Generic from typing import Iterable from typing import Iterator from typing import Optional @@ -21,6 +22,7 @@ from typing import TYPE_CHECKING from typing import TypeVar from typing import Union +from .session import _S from .session import Session from .. import exc as sa_exc from .. import util @@ -131,7 +133,7 @@ __all__ = ["scoped_session"] "info", ], ) -class scoped_session: +class scoped_session(Generic[_S]): """Provides scoped management of :class:`.Session` objects. See :ref:`unitofwork_contextual` for a tutorial. @@ -146,16 +148,16 @@ class scoped_session: _support_async: bool = False - session_factory: sessionmaker + session_factory: sessionmaker[_S] """The `session_factory` provided to `__init__` is stored in this attribute and may be accessed at a later time. This can be useful when a new non-scoped :class:`.Session` is needed.""" - registry: ScopedRegistry[Session] + registry: ScopedRegistry[_S] def __init__( self, - session_factory: sessionmaker, + session_factory: sessionmaker[_S], scopefunc: Optional[Callable[[], Any]] = None, ): @@ -182,10 +184,10 @@ class scoped_session: self.registry = ThreadLocalRegistry(session_factory) @property - def _proxied(self) -> Session: + def _proxied(self) -> _S: return self.registry() - def __call__(self, **kw: Any) -> Session: + def __call__(self, **kw: Any) -> _S: r"""Return the current :class:`.Session`, creating it using the :attr:`.scoped_session.session_factory` if not present. @@ -479,8 +481,22 @@ class scoped_session: Proxied for the :class:`_orm.Session` class on behalf of the :class:`_orm.scoping.scoped_session` class. - If no transaction is in progress, the method will first - "autobegin" a new transaction and commit. + When the COMMIT operation is complete, all objects are fully + :term:`expired`, erasing their internal contents, which will be + automatically re-loaded when the objects are next accessed. In the + interim, these objects are in an expired state and will not function if + they are :term:`detached` from the :class:`.Session`. Additionally, + this re-load operation is not supported when using asyncio-oriented + APIs. The :paramref:`.Session.expire_on_commit` parameter may be used + to disable this behavior. + + When there is no transaction in place for the :class:`.Session`, + indicating that no operations were invoked on this :class:`.Session` + since the previous call to :meth:`.Session.commit`, the method will + begin and commit an internal-only "logical" transaction, that does not + normally affect the database unless pending flush changes were + detected, but will still invoke event handlers and object expiration + rules. The outermost database transaction is committed unconditionally, automatically releasing any SAVEPOINTs in effect. @@ -491,6 +507,8 @@ class scoped_session: :ref:`unitofwork_transaction` + :ref:`asyncio_orm_avoid_lazyloads` + """ # noqa: E501 diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index d72e78c9e6..788821b987 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -17,6 +17,7 @@ from typing import Any from typing import Callable from typing import cast from typing import Dict +from typing import Generic from typing import Iterable from typing import Iterator from typing import List @@ -1420,14 +1421,14 @@ class Session(_SessionClassMethods, EventTarget): connection_callable: Optional[_ConnectionCallableProto] = None - def __enter__(self) -> Session: + def __enter__(self: _S) -> _S: return self def __exit__(self, type_: Any, value: Any, traceback: Any) -> None: self.close() @contextlib.contextmanager - def _maker_context_manager(self) -> Iterator[Session]: + def _maker_context_manager(self: _S) -> Iterator[_S]: with self: with self.begin(): yield self @@ -4398,7 +4399,10 @@ class Session(_SessionClassMethods, EventTarget): return util.IdentitySet(list(self._new.values())) -class sessionmaker(_SessionClassMethods): +_S = TypeVar("_S", bound="Session") + + +class sessionmaker(_SessionClassMethods, Generic[_S]): """A configurable :class:`.Session` factory. The :class:`.sessionmaker` factory generates new @@ -4493,12 +4497,12 @@ class sessionmaker(_SessionClassMethods): """ - class_: Type[Session] + class_: Type[_S] def __init__( self, bind: Optional[_SessionBind] = None, - class_: Type[Session] = Session, + class_: Type[_S] = Session, # type: ignore autoflush: bool = True, expire_on_commit: bool = True, info: Optional[_InfoType] = None, @@ -4545,7 +4549,7 @@ class sessionmaker(_SessionClassMethods): # events can be associated with it specifically. self.class_ = type(class_.__name__, (class_,), {}) - def begin(self) -> contextlib.AbstractContextManager[Session]: + def begin(self) -> contextlib.AbstractContextManager[_S]: """Produce a context manager that both provides a new :class:`_orm.Session` as well as a transaction that commits. @@ -4567,7 +4571,7 @@ class sessionmaker(_SessionClassMethods): session = self() return session._maker_context_manager() - def __call__(self, **local_kw: Any) -> Session: + def __call__(self, **local_kw: Any) -> _S: """Produce a new :class:`.Session` object using the configuration established in this :class:`.sessionmaker`. diff --git a/test/ext/mypy/plain_files/sessionmakers.py b/test/ext/mypy/plain_files/sessionmakers.py new file mode 100644 index 0000000000..ce9b766385 --- /dev/null +++ b/test/ext/mypy/plain_files/sessionmakers.py @@ -0,0 +1,88 @@ +"""test #7656""" + +from sqlalchemy import create_engine +from sqlalchemy import Engine +from sqlalchemy.ext.asyncio import async_scoped_session +from sqlalchemy.ext.asyncio import async_sessionmaker +from sqlalchemy.ext.asyncio import AsyncEngine +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.ext.asyncio import create_async_engine +from sqlalchemy.orm import scoped_session +from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker + + +async_engine = create_async_engine("...") + + +class MyAsyncSession(AsyncSession): + pass + + +def async_session_factory( + engine: AsyncEngine, +) -> async_sessionmaker[MyAsyncSession]: + return async_sessionmaker(engine, class_=MyAsyncSession) + + +def async_scoped_session_factory( + engine: AsyncEngine, +) -> async_scoped_session[MyAsyncSession]: + return async_scoped_session( + async_sessionmaker(engine, class_=MyAsyncSession), + scopefunc=lambda: None, + ) + + +async def async_main() -> None: + fac = async_session_factory(async_engine) + + async with fac() as sess: + # EXPECTED_TYPE: MyAsyncSession + reveal_type(sess) + + async with fac.begin() as sess: + # EXPECTED_TYPE: MyAsyncSession + reveal_type(sess) + + scoped_fac = async_scoped_session_factory(async_engine) + + sess = scoped_fac() + + # EXPECTED_TYPE: MyAsyncSession + reveal_type(sess) + + +engine = create_engine("...") + + +class MySession(Session): + pass + + +def session_factory( + engine: Engine, +) -> sessionmaker[MySession]: + return sessionmaker(engine, class_=MySession) + + +def scoped_session_factory(engine: Engine) -> scoped_session[MySession]: + return scoped_session(sessionmaker(engine, class_=MySession)) + + +def main() -> None: + fac = session_factory(engine) + + with fac() as sess: + # EXPECTED_TYPE: MySession + reveal_type(sess) + + with fac.begin() as sess: + # EXPECTED_TYPE: MySession + reveal_type(sess) + + scoped_fac = scoped_session_factory(engine) + + sess = scoped_fac() + # EXPECTED_TYPE: MySession + reveal_type(sess) -- 2.47.2