from .result import AsyncMappingResult
from .result import AsyncResult
from .result import AsyncScalarResult
+from .scoping import async_scoped_session
from .session import async_object_session
from .session import async_session
from .session import AsyncSession
--- /dev/null
+# ext/asyncio/scoping.py
+# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+
+from .session import AsyncSession
+from ...orm.scoping import ScopedSessionMixin
+from ...util import create_proxy_methods
+from ...util import ScopedRegistry
+
+
+@create_proxy_methods(
+ AsyncSession,
+ ":class:`_asyncio.AsyncSession`",
+ ":class:`_asyncio.scoping.async_scoped_session`",
+ classmethods=["close_all", "object_session", "identity_key"],
+ methods=[
+ "__contains__",
+ "__iter__",
+ "add",
+ "add_all",
+ "begin",
+ "begin_nested",
+ "close",
+ "commit",
+ "connection",
+ "delete",
+ "execute",
+ "expire",
+ "expire_all",
+ "expunge",
+ "expunge_all",
+ "flush",
+ "get",
+ "get_bind",
+ "is_modified",
+ "merge",
+ "refresh",
+ "rollback",
+ "scalar",
+ ],
+ attributes=[
+ "bind",
+ "dirty",
+ "deleted",
+ "new",
+ "identity_map",
+ "is_active",
+ "autoflush",
+ "no_autoflush",
+ "info",
+ ],
+)
+class async_scoped_session(ScopedSessionMixin):
+ """Provides scoped management of :class:`.AsyncSession` objects."""
+
+ def __init__(self, session_factory, scopefunc):
+ """Construct a new :class:`.scoped_session`.
+
+ :param session_factory: a factory to create new :class:`.Session`
+ instances. This is usually, but not necessarily, an instance
+ of :class:`.sessionmaker`.
+ :param scopefunc: function which defines
+ the current scope. Different from scoped_session's __init__ signature,
+ which has a default 'thread-local' scope, the scopefunc must be passed
+ in and it should return a hashable token with the same requirement as
+ in scoped_session.
+
+ """
+ self.session_factory = session_factory
+ self.registry = ScopedRegistry(session_factory, scopefunc)
+
+ @property
+ def _proxied(self):
+ return self.registry()
+
+ async def remove(self):
+ """Dispose of the current :class:`.AsyncSession`, if present.
+
+ Different from scoped_session's remove method, this method would use
+ await to wait for the close method of AsyncSession.
+
+ """
+
+ if self.registry.has():
+ await self.registry().close()
+ self.registry.clear()
from ..util import ThreadLocalRegistry
from ..util import warn
-__all__ = ["scoped_session"]
+__all__ = ["scoped_session", "ScopedSessionMixin"]
+
+
+class ScopedSessionMixin(object):
+ @property
+ def _proxied(self):
+ return self.registry()
+
+ def __call__(self, **kw):
+ r"""Return the current :class:`.Session`, creating it
+ using the :attr:`.scoped_session.session_factory` if not present.
+
+ :param \**kw: Keyword arguments will be passed to the
+ :attr:`.scoped_session.session_factory` callable, if an existing
+ :class:`.Session` is not present. If the :class:`.Session` is present
+ and keyword arguments have been passed,
+ :exc:`~sqlalchemy.exc.InvalidRequestError` is raised.
+
+ """
+ if kw:
+ if self.registry.has():
+ raise sa_exc.InvalidRequestError(
+ "Scoped session is already present; "
+ "no new arguments may be specified."
+ )
+ else:
+ sess = self.session_factory(**kw)
+ self.registry.set(sess)
+ return sess
+ else:
+ return self.registry()
+
+ def configure(self, **kwargs):
+ """reconfigure the :class:`.sessionmaker` used by this
+ :class:`.scoped_session`.
+
+ See :meth:`.sessionmaker.configure`.
+
+ """
+
+ if self.registry.has():
+ warn(
+ "At least one scoped session is already present. "
+ " configure() can not affect sessions that have "
+ "already been created."
+ )
+
+ self.session_factory.configure(**kwargs)
@create_proxy_methods(
"autocommit",
],
)
-class scoped_session(object):
+class scoped_session(ScopedSessionMixin):
"""Provides scoped management of :class:`.Session` objects.
See :ref:`unitofwork_contextual` for a tutorial.
else:
self.registry = ThreadLocalRegistry(session_factory)
- @property
- def _proxied(self):
- return self.registry()
-
- def __call__(self, **kw):
- r"""Return the current :class:`.Session`, creating it
- using the :attr:`.scoped_session.session_factory` if not present.
-
- :param \**kw: Keyword arguments will be passed to the
- :attr:`.scoped_session.session_factory` callable, if an existing
- :class:`.Session` is not present. If the :class:`.Session` is present
- and keyword arguments have been passed,
- :exc:`~sqlalchemy.exc.InvalidRequestError` is raised.
-
- """
- if kw:
- if self.registry.has():
- raise sa_exc.InvalidRequestError(
- "Scoped session is already present; "
- "no new arguments may be specified."
- )
- else:
- sess = self.session_factory(**kw)
- self.registry.set(sess)
- return sess
- else:
- return self.registry()
-
def remove(self):
"""Dispose of the current :class:`.Session`, if present.
self.registry().close()
self.registry.clear()
- def configure(self, **kwargs):
- """reconfigure the :class:`.sessionmaker` used by this
- :class:`.scoped_session`.
-
- See :meth:`.sessionmaker.configure`.
-
- """
-
- if self.registry.has():
- warn(
- "At least one scoped session is already present. "
- " configure() can not affect sessions that have "
- "already been created."
- )
-
- self.session_factory.configure(**kwargs)
-
def query_property(self, query_cls=None):
"""return a class property which produces a :class:`_query.Query`
object
--- /dev/null
+from asyncio import current_task
+
+import sqlalchemy as sa
+from sqlalchemy import func
+from sqlalchemy import select
+from sqlalchemy.ext.asyncio import async_scoped_session
+from sqlalchemy.ext.asyncio import AsyncSession as _AsyncSession
+from sqlalchemy.testing import async_test
+from sqlalchemy.testing import eq_
+from sqlalchemy.testing import is_
+from .test_session_py3k import AsyncFixture
+
+
+class AsyncScopedSessionTest(AsyncFixture):
+ @async_test
+ async def test_basic(self, async_engine):
+ AsyncSession = async_scoped_session(
+ sa.orm.sessionmaker(async_engine, class_=_AsyncSession),
+ scopefunc=current_task,
+ )
+
+ some_async_session = AsyncSession()
+ some_other_async_session = AsyncSession()
+
+ is_(some_async_session, some_other_async_session)
+ is_(some_async_session.bind, async_engine)
+
+ User = self.classes.User
+
+ async with AsyncSession.begin():
+ user_name = "scoped_async_session_u1"
+ u1 = User(name=user_name)
+
+ AsyncSession.add(u1)
+
+ await AsyncSession.flush()
+
+ conn = await AsyncSession.connection()
+ stmt = select(func.count(User.id)).where(User.name == user_name)
+ eq_(await conn.scalar(stmt), 1)
+
+ await AsyncSession.delete(u1)
+ await AsyncSession.flush()
+ eq_(await conn.scalar(stmt), 0)