]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
establish sessionmaker and async_sessionmaker as generic
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 26 May 2022 18:35:03 +0000 (14:35 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 31 May 2022 19:17:48 +0000 (15:17 -0400)
This is so that custom Session and AsyncSession classes
can be typed for these factories.  Added appropriate
typevars to  `__call__()`, `__enter__()` and other methods
so that a custom Session or AsyncSession subclass is carried
through.

Fixes: #7656
Change-Id: Ia2b8c1f22b4410db26005c3285f6ba3d13d7f0e0

examples/asyncio/gather_orm_statements.py
lib/sqlalchemy/ext/asyncio/engine.py
lib/sqlalchemy/ext/asyncio/scoping.py
lib/sqlalchemy/ext/asyncio/session.py
lib/sqlalchemy/orm/scoping.py
lib/sqlalchemy/orm/session.py
test/ext/mypy/plain_files/sessionmakers.py [new file with mode: 0644]

index edcdc1fe843240ccd461f6f6de0f35bc22c5347a..a67b5e669d8352b8c735715cee808d066f165f0c 100644 (file)
@@ -22,12 +22,11 @@ import random
 from sqlalchemy import Column
 from sqlalchemy import Integer
 from sqlalchemy import String
-from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy.ext.asyncio import async_sessionmaker
 from sqlalchemy.ext.asyncio import create_async_engine
 from sqlalchemy.ext.declarative import declarative_base
 from sqlalchemy.future import select
 from sqlalchemy.orm import merge_frozen_result
-from sqlalchemy.orm import sessionmaker
 
 Base = declarative_base()
 
@@ -40,14 +39,14 @@ class A(Base):
 
 
 async def run_out_of_band(
-    sessionmaker, session, statement, merge_results=True
+    async_sessionmaker, session, statement, merge_results=True
 ):
     """run an ORM statement in a distinct session, merging the result
     back into the given session.
 
     """
 
-    async with sessionmaker() as oob_session:
+    async with async_sessionmaker() as oob_session:
 
         # use AUTOCOMMIT for each connection to reduce transaction
         # overhead / contention
@@ -94,9 +93,7 @@ async def async_main():
         await conn.run_sync(Base.metadata.drop_all)
         await conn.run_sync(Base.metadata.create_all)
 
-    async_session = sessionmaker(
-        engine, expire_on_commit=False, class_=AsyncSession
-    )
+    async_session = async_sessionmaker(engine, expire_on_commit=False)
 
     async with async_session() as session, session.begin():
         session.add_all([A(data="a_%d" % i) for i in range(100)])
index 6d07d843c9043a853ad43a14b9126e7211938b09..97d69fcbd29a1eaed4d734052c88af58ec4e609d 100644 (file)
@@ -1068,10 +1068,15 @@ class AsyncEngine(ProxyComparable[Engine], AsyncConnectable):
 
     @property
     def engine(self) -> Any:
-        r""".. container:: class_bases
+        r"""Returns this :class:`.Engine`.
 
-        Proxied for the :class:`_engine.Engine` class
-        on behalf of the :class:`_asyncio.AsyncEngine` class.
+        .. container:: class_bases
+
+            Proxied for the :class:`_engine.Engine` class
+            on behalf of the :class:`_asyncio.AsyncEngine` class.
+
+        Used for legacy schemes that accept :class:`.Connection` /
+        :class:`.Engine` objects within the same variable.
 
 
         """  # noqa: E501
index 22a060a0d42f2ef9e201e166609339b87ac5d5f7..8d31dd07d46d8ae989a80ca55daba197f79ea766 100644 (file)
@@ -9,6 +9,7 @@ from __future__ import annotations
 
 from typing import Any
 from typing import Callable
+from typing import Generic
 from typing import Iterable
 from typing import Iterator
 from typing import Optional
@@ -20,6 +21,7 @@ from typing import TYPE_CHECKING
 from typing import TypeVar
 from typing import Union
 
+from .session import _AS
 from .session import async_sessionmaker
 from .session import AsyncSession
 from ... import exc as sa_exc
@@ -104,7 +106,7 @@ _T = TypeVar("_T", bound=Any)
         "info",
     ],
 )
-class async_scoped_session:
+class async_scoped_session(Generic[_AS]):
     """Provides scoped management of :class:`.AsyncSession` objects.
 
     See the section :ref:`asyncio_scoped_session` for usage details.
@@ -116,16 +118,16 @@ class async_scoped_session:
 
     _support_async = True
 
-    session_factory: async_sessionmaker
+    session_factory: async_sessionmaker[_AS]
     """The `session_factory` provided to `__init__` is stored in this
     attribute and may be accessed at a later time.  This can be useful when
     a new non-scoped :class:`.AsyncSession` is needed."""
 
-    registry: ScopedRegistry[AsyncSession]
+    registry: ScopedRegistry[_AS]
 
     def __init__(
         self,
-        session_factory: async_sessionmaker,
+        session_factory: async_sessionmaker[_AS],
         scopefunc: Callable[[], Any],
     ):
         """Construct a new :class:`_asyncio.async_scoped_session`.
@@ -144,10 +146,10 @@ class async_scoped_session:
         self.registry = ScopedRegistry(session_factory, scopefunc)
 
     @property
-    def _proxied(self) -> AsyncSession:
+    def _proxied(self) -> _AS:
         return self.registry()
 
-    def __call__(self, **kw: Any) -> AsyncSession:
+    def __call__(self, **kw: Any) -> _AS:
         r"""Return the current :class:`.AsyncSession`, creating it
         using the :attr:`.scoped_session.session_factory` if not present.
 
@@ -450,8 +452,8 @@ class async_scoped_session:
         This method may also be used to establish execution options for the
         database connection used by the current transaction.
 
-        .. versionadded:: 1.4.24  Added **kw arguments which are passed through
-           to the underlying :meth:`_orm.Session.connection` method.
+        .. versionadded:: 1.4.24  Added \**kw arguments which are passed
+           through to the underlying :meth:`_orm.Session.connection` method.
 
         .. seealso::
 
@@ -752,7 +754,8 @@ class async_scoped_session:
 
             from sqlalchemy.ext.asyncio import AsyncSession
             from sqlalchemy.ext.asyncio import create_async_engine
-            from sqlalchemy.orm import Session, sessionmaker
+            from sqlalchemy.ext.asyncio import async_sessionmaker
+            from sqlalchemy.orm import Session
 
             # construct async engines w/ async drivers
             engines = {
@@ -775,8 +778,7 @@ class async_scoped_session:
                         ].sync_engine
 
             # apply to AsyncSession using sync_session_class
-            AsyncSessionMaker = sessionmaker(
-                class_=AsyncSession,
+            AsyncSessionMaker = async_sessionmaker(
                 sync_session_class=RoutingSession
             )
 
index eac2e58063b4c204b3a51e8706696310e0939eca..be3414cef4aa39f13194a3bdf0cc8bd3f6170ee7 100644 (file)
@@ -8,6 +8,7 @@ from __future__ import annotations
 
 from typing import Any
 from typing import Dict
+from typing import Generic
 from typing import Iterable
 from typing import Iterator
 from typing import NoReturn
@@ -698,7 +699,8 @@ class AsyncSession(ReversibleProxy[Session]):
 
             from sqlalchemy.ext.asyncio import AsyncSession
             from sqlalchemy.ext.asyncio import create_async_engine
-            from sqlalchemy.orm import Session, sessionmaker
+            from sqlalchemy.ext.asyncio import async_sessionmaker
+            from sqlalchemy.orm import Session
 
             # construct async engines w/ async drivers
             engines = {
@@ -721,8 +723,7 @@ class AsyncSession(ReversibleProxy[Session]):
                         ].sync_engine
 
             # apply to AsyncSession using sync_session_class
-            AsyncSessionMaker = sessionmaker(
-                class_=AsyncSession,
+            AsyncSessionMaker = async_sessionmaker(
                 sync_session_class=RoutingSession
             )
 
@@ -850,14 +851,13 @@ class AsyncSession(ReversibleProxy[Session]):
         """Close all :class:`_asyncio.AsyncSession` sessions."""
         await greenlet_spawn(self.sync_session.close_all)
 
-    async def __aenter__(self) -> AsyncSession:
+    async def __aenter__(self: _AS) -> _AS:
         return self
 
     async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None:
         await self.close()
 
-    def _maker_context_manager(self) -> _AsyncSessionContextManager:
-        # TODO: can this use asynccontextmanager ??
+    def _maker_context_manager(self: _AS) -> _AsyncSessionContextManager[_AS]:
         return _AsyncSessionContextManager(self)
 
     # START PROXY METHODS AsyncSession
@@ -1367,7 +1367,10 @@ class AsyncSession(ReversibleProxy[Session]):
     # END PROXY METHODS AsyncSession
 
 
-class async_sessionmaker:
+_AS = TypeVar("_AS", bound="AsyncSession")
+
+
+class async_sessionmaker(Generic[_AS]):
     """A configurable :class:`.AsyncSession` factory.
 
     The :class:`.async_sessionmaker` factory works in the same way as the
@@ -1409,12 +1412,12 @@ class async_sessionmaker:
 
     """  # noqa E501
 
-    class_: Type[AsyncSession]
+    class_: Type[_AS]
 
     def __init__(
         self,
         bind: Optional[_AsyncSessionBind] = None,
-        class_: Type[AsyncSession] = AsyncSession,
+        class_: Type[_AS] = AsyncSession,  # type: ignore
         autoflush: bool = True,
         expire_on_commit: bool = True,
         info: Optional[_InfoType] = None,
@@ -1437,7 +1440,7 @@ class async_sessionmaker:
         self.kw = kw
         self.class_ = class_
 
-    def begin(self) -> _AsyncSessionContextManager:
+    def begin(self) -> _AsyncSessionContextManager[_AS]:
         """Produce a context manager that both provides a new
         :class:`_orm.AsyncSession` as well as a transaction that commits.
 
@@ -1458,7 +1461,7 @@ class async_sessionmaker:
         session = self()
         return session._maker_context_manager()
 
-    def __call__(self, **local_kw: Any) -> AsyncSession:
+    def __call__(self, **local_kw: Any) -> _AS:
         """Produce a new :class:`.AsyncSession` object using the configuration
         established in this :class:`.async_sessionmaker`.
 
@@ -1498,16 +1501,16 @@ class async_sessionmaker:
         )
 
 
-class _AsyncSessionContextManager:
+class _AsyncSessionContextManager(Generic[_AS]):
     __slots__ = ("async_session", "trans")
 
-    async_session: AsyncSession
+    async_session: _AS
     trans: AsyncSessionTransaction
 
-    def __init__(self, async_session: AsyncSession):
+    def __init__(self, async_session: _AS):
         self.async_session = async_session
 
-    async def __aenter__(self) -> AsyncSession:
+    async def __aenter__(self) -> _AS:
         self.trans = self.async_session.begin()
         await self.trans.__aenter__()
         return self.async_session
index 9220c44c7fb8790efe282e36d3a3cea75978fb19..c00508385cf8ed5d1b2ff46270bb177fd5a7ab6f 100644 (file)
@@ -10,6 +10,7 @@ from __future__ import annotations
 from typing import Any
 from typing import Callable
 from typing import Dict
+from typing import Generic
 from typing import Iterable
 from typing import Iterator
 from typing import Optional
@@ -21,6 +22,7 @@ from typing import TYPE_CHECKING
 from typing import TypeVar
 from typing import Union
 
+from .session import _S
 from .session import Session
 from .. import exc as sa_exc
 from .. import util
@@ -131,7 +133,7 @@ __all__ = ["scoped_session"]
         "info",
     ],
 )
-class scoped_session:
+class scoped_session(Generic[_S]):
     """Provides scoped management of :class:`.Session` objects.
 
     See :ref:`unitofwork_contextual` for a tutorial.
@@ -146,16 +148,16 @@ class scoped_session:
 
     _support_async: bool = False
 
-    session_factory: sessionmaker
+    session_factory: sessionmaker[_S]
     """The `session_factory` provided to `__init__` is stored in this
     attribute and may be accessed at a later time.  This can be useful when
     a new non-scoped :class:`.Session` is needed."""
 
-    registry: ScopedRegistry[Session]
+    registry: ScopedRegistry[_S]
 
     def __init__(
         self,
-        session_factory: sessionmaker,
+        session_factory: sessionmaker[_S],
         scopefunc: Optional[Callable[[], Any]] = None,
     ):
 
@@ -182,10 +184,10 @@ class scoped_session:
             self.registry = ThreadLocalRegistry(session_factory)
 
     @property
-    def _proxied(self) -> Session:
+    def _proxied(self) -> _S:
         return self.registry()
 
-    def __call__(self, **kw: Any) -> Session:
+    def __call__(self, **kw: Any) -> _S:
         r"""Return the current :class:`.Session`, creating it
         using the :attr:`.scoped_session.session_factory` if not present.
 
@@ -479,8 +481,22 @@ class scoped_session:
             Proxied for the :class:`_orm.Session` class on
             behalf of the :class:`_orm.scoping.scoped_session` class.
 
-        If no transaction is in progress, the method will first
-        "autobegin" a new transaction and commit.
+        When the COMMIT operation is complete, all objects are fully
+        :term:`expired`, erasing their internal contents, which will be
+        automatically re-loaded when the objects are next accessed. In the
+        interim, these objects are in an expired state and will not function if
+        they are :term:`detached` from the :class:`.Session`. Additionally,
+        this re-load operation is not supported when using asyncio-oriented
+        APIs. The :paramref:`.Session.expire_on_commit` parameter may be used
+        to disable this behavior.
+
+        When there is no transaction in place for the :class:`.Session`,
+        indicating that no operations were invoked on this :class:`.Session`
+        since the previous call to :meth:`.Session.commit`, the method will
+        begin and commit an internal-only "logical" transaction, that does not
+        normally affect the database unless pending flush changes were
+        detected, but will still invoke event handlers and object expiration
+        rules.
 
         The outermost database transaction is committed unconditionally,
         automatically releasing any SAVEPOINTs in effect.
@@ -491,6 +507,8 @@ class scoped_session:
 
             :ref:`unitofwork_transaction`
 
+            :ref:`asyncio_orm_avoid_lazyloads`
+
 
         """  # noqa: E501
 
index d72e78c9e69be26a0aabde5d7a23766b59d36402..788821b98737932a366098e8c8b0194cd0eafae3 100644 (file)
@@ -17,6 +17,7 @@ from typing import Any
 from typing import Callable
 from typing import cast
 from typing import Dict
+from typing import Generic
 from typing import Iterable
 from typing import Iterator
 from typing import List
@@ -1420,14 +1421,14 @@ class Session(_SessionClassMethods, EventTarget):
 
     connection_callable: Optional[_ConnectionCallableProto] = None
 
-    def __enter__(self) -> Session:
+    def __enter__(self: _S) -> _S:
         return self
 
     def __exit__(self, type_: Any, value: Any, traceback: Any) -> None:
         self.close()
 
     @contextlib.contextmanager
-    def _maker_context_manager(self) -> Iterator[Session]:
+    def _maker_context_manager(self: _S) -> Iterator[_S]:
         with self:
             with self.begin():
                 yield self
@@ -4398,7 +4399,10 @@ class Session(_SessionClassMethods, EventTarget):
         return util.IdentitySet(list(self._new.values()))
 
 
-class sessionmaker(_SessionClassMethods):
+_S = TypeVar("_S", bound="Session")
+
+
+class sessionmaker(_SessionClassMethods, Generic[_S]):
     """A configurable :class:`.Session` factory.
 
     The :class:`.sessionmaker` factory generates new
@@ -4493,12 +4497,12 @@ class sessionmaker(_SessionClassMethods):
 
     """
 
-    class_: Type[Session]
+    class_: Type[_S]
 
     def __init__(
         self,
         bind: Optional[_SessionBind] = None,
-        class_: Type[Session] = Session,
+        class_: Type[_S] = Session,  # type: ignore
         autoflush: bool = True,
         expire_on_commit: bool = True,
         info: Optional[_InfoType] = None,
@@ -4545,7 +4549,7 @@ class sessionmaker(_SessionClassMethods):
         # events can be associated with it specifically.
         self.class_ = type(class_.__name__, (class_,), {})
 
-    def begin(self) -> contextlib.AbstractContextManager[Session]:
+    def begin(self) -> contextlib.AbstractContextManager[_S]:
         """Produce a context manager that both provides a new
         :class:`_orm.Session` as well as a transaction that commits.
 
@@ -4567,7 +4571,7 @@ class sessionmaker(_SessionClassMethods):
         session = self()
         return session._maker_context_manager()
 
-    def __call__(self, **local_kw: Any) -> Session:
+    def __call__(self, **local_kw: Any) -> _S:
         """Produce a new :class:`.Session` object using the configuration
         established in this :class:`.sessionmaker`.
 
diff --git a/test/ext/mypy/plain_files/sessionmakers.py b/test/ext/mypy/plain_files/sessionmakers.py
new file mode 100644 (file)
index 0000000..ce9b766
--- /dev/null
@@ -0,0 +1,88 @@
+"""test #7656"""
+
+from sqlalchemy import create_engine
+from sqlalchemy import Engine
+from sqlalchemy.ext.asyncio import async_scoped_session
+from sqlalchemy.ext.asyncio import async_sessionmaker
+from sqlalchemy.ext.asyncio import AsyncEngine
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy.ext.asyncio import create_async_engine
+from sqlalchemy.orm import scoped_session
+from sqlalchemy.orm import Session
+from sqlalchemy.orm import sessionmaker
+
+
+async_engine = create_async_engine("...")
+
+
+class MyAsyncSession(AsyncSession):
+    pass
+
+
+def async_session_factory(
+    engine: AsyncEngine,
+) -> async_sessionmaker[MyAsyncSession]:
+    return async_sessionmaker(engine, class_=MyAsyncSession)
+
+
+def async_scoped_session_factory(
+    engine: AsyncEngine,
+) -> async_scoped_session[MyAsyncSession]:
+    return async_scoped_session(
+        async_sessionmaker(engine, class_=MyAsyncSession),
+        scopefunc=lambda: None,
+    )
+
+
+async def async_main() -> None:
+    fac = async_session_factory(async_engine)
+
+    async with fac() as sess:
+        # EXPECTED_TYPE: MyAsyncSession
+        reveal_type(sess)
+
+    async with fac.begin() as sess:
+        # EXPECTED_TYPE: MyAsyncSession
+        reveal_type(sess)
+
+    scoped_fac = async_scoped_session_factory(async_engine)
+
+    sess = scoped_fac()
+
+    # EXPECTED_TYPE: MyAsyncSession
+    reveal_type(sess)
+
+
+engine = create_engine("...")
+
+
+class MySession(Session):
+    pass
+
+
+def session_factory(
+    engine: Engine,
+) -> sessionmaker[MySession]:
+    return sessionmaker(engine, class_=MySession)
+
+
+def scoped_session_factory(engine: Engine) -> scoped_session[MySession]:
+    return scoped_session(sessionmaker(engine, class_=MySession))
+
+
+def main() -> None:
+    fac = session_factory(engine)
+
+    with fac() as sess:
+        # EXPECTED_TYPE: MySession
+        reveal_type(sess)
+
+    with fac.begin() as sess:
+        # EXPECTED_TYPE: MySession
+        reveal_type(sess)
+
+    scoped_fac = scoped_session_factory(engine)
+
+    sess = scoped_fac()
+    # EXPECTED_TYPE: MySession
+    reveal_type(sess)