From: Mike Bayer Date: Sat, 26 Dec 2020 16:46:42 +0000 (-0500) Subject: implement sessionmaker.begin(), scalar() for async session X-Git-Tag: rel_1_4_0b2~80 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=93d5904161c310ffe843ed79e7e7bef13ab11798;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git implement sessionmaker.begin(), scalar() for async session Added :meth:`_asyncio.AsyncSession.scalar` as well as support for :meth:`_orm.sessionmaker.begin` to work as an async context manager with :class:`_asyncio.AsyncSession`. Also added :meth:`_asyncio.AsyncSession.in_transaction` accessor. Fixes: #5796 Fixes: #5797 Change-Id: Id3d431421df0f8c38f356469a50a946ba9c38513 --- diff --git a/doc/build/changelog/unreleased_14/5797.rst b/doc/build/changelog/unreleased_14/5797.rst new file mode 100644 index 0000000000..be322807af --- /dev/null +++ b/doc/build/changelog/unreleased_14/5797.rst @@ -0,0 +1,8 @@ +.. change:: + :tags: usecase, orm, asyncio + :tickets: 5796, 5797 + + Added :meth:`_asyncio.AsyncSession.scalar` as well as support for + :meth:`_orm.sessionmaker.begin` to work as an async context manager with + :class:`_asyncio.AsyncSession`. Also added + :meth:`_asyncio.AsyncSession.in_transaction` accessor. \ No newline at end of file diff --git a/lib/sqlalchemy/ext/asyncio/session.py b/lib/sqlalchemy/ext/asyncio/session.py index c21b679543..fc5b9cb448 100644 --- a/lib/sqlalchemy/ext/asyncio/session.py +++ b/lib/sqlalchemy/ext/asyncio/session.py @@ -1,3 +1,4 @@ +from typing import Any from typing import Callable from typing import Mapping from typing import Optional @@ -34,6 +35,7 @@ T = TypeVar("T") "expunge_all", "get_bind", "is_modified", + "in_transaction", ], attributes=[ "dirty", @@ -144,6 +146,25 @@ class AsyncSession: **kw ) + async def scalar( + self, + statement: Executable, + params: Optional[Mapping] = None, + execution_options: Mapping = util.EMPTY_DICT, + bind_arguments: Optional[Mapping] = None, + **kw + ) -> Any: + """Execute a statement and return a scalar result.""" + + result = await self.execute( + statement, + params=params, + execution_options=execution_options, + bind_arguments=bind_arguments, + **kw + ) + return result.scalar() + async def stream( self, statement, @@ -262,6 +283,24 @@ class AsyncSession: async def __aexit__(self, type_, value, traceback): await self.close() + def _maker_context_manager(self): + # no @contextlib.asynccontextmanager until python3.7, gr + return _AsyncSessionContextManager(self) + + +class _AsyncSessionContextManager: + def __init__(self, async_session): + self.async_session = async_session + + async def __aenter__(self): + self.trans = self.async_session.begin() + await self.trans.__aenter__() + return self.async_session + + async def __aexit__(self, type_, value, traceback): + await self.trans.__aexit__(type_, value, traceback) + await self.async_session.__aexit__(type_, value, traceback) + class AsyncSessionTransaction(StartableContext): """A wrapper for the ORM :class:`_orm.SessionTransaction` object. diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 1ec63fa40a..a5f0894f6a 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -1152,6 +1152,12 @@ class Session(_SessionClassMethods): def __exit__(self, type_, value, traceback): self.close() + @util.contextmanager + def _maker_context_manager(self): + with self: + with self.begin(): + yield self + @property @util.deprecated_20( ":attr:`_orm.Session.transaction`", @@ -3969,7 +3975,6 @@ class sessionmaker(_SessionClassMethods): # events can be associated with it specifically. self.class_ = type(class_.__name__, (class_,), {}) - @util.contextmanager def begin(self): """Produce a context manager that both provides a new :class:`_orm.Session` as well as a transaction that commits. @@ -3988,9 +3993,9 @@ class sessionmaker(_SessionClassMethods): """ - with self() as session: - with session.begin(): - yield session + + session = self() + return session._maker_context_manager() def __call__(self, **local_kw): """Produce a new :class:`.Session` object using the configuration diff --git a/test/ext/asyncio/test_session_py3k.py b/test/ext/asyncio/test_session_py3k.py index 44e2955428..c0ba8c2b3c 100644 --- a/test/ext/asyncio/test_session_py3k.py +++ b/test/ext/asyncio/test_session_py3k.py @@ -7,6 +7,7 @@ from sqlalchemy import update from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.orm import selectinload +from sqlalchemy.orm import sessionmaker from sqlalchemy.testing import async_test from sqlalchemy.testing import eq_ from sqlalchemy.testing import is_ @@ -54,6 +55,15 @@ class AsyncSessionQueryTest(AsyncFixture): result = await async_session.execute(stmt) eq_(result.scalars().all(), self.static.user_address_result) + @async_test + async def test_scalar(self, async_session): + User = self.classes.User + + stmt = select(User.id).order_by(User.id).limit(1) + + result = await async_session.scalar(stmt) + eq_(result, 7) + @async_test @testing.requires.independent_cursors async def test_stream_partitions(self, async_session): @@ -83,6 +93,52 @@ class AsyncSessionQueryTest(AsyncFixture): class AsyncSessionTransactionTest(AsyncFixture): run_inserts = None + @async_test + async def test_sessionmaker_block_one(self, async_engine): + + User = self.classes.User + maker = sessionmaker(async_engine, class_=AsyncSession) + + session = maker() + + async with session.begin(): + u1 = User(name="u1") + assert session.in_transaction() + session.add(u1) + + assert not session.in_transaction() + + async with maker() as session: + result = await session.execute( + select(User).where(User.name == "u1") + ) + + u1 = result.scalar_one() + + eq_(u1.name, "u1") + + @async_test + async def test_sessionmaker_block_two(self, async_engine): + + User = self.classes.User + maker = sessionmaker(async_engine, class_=AsyncSession) + + async with maker.begin() as session: + u1 = User(name="u1") + assert session.in_transaction() + session.add(u1) + + assert not session.in_transaction() + + async with maker() as session: + result = await session.execute( + select(User).where(User.name == "u1") + ) + + u1 = result.scalar_one() + + eq_(u1.name, "u1") + @async_test async def test_trans(self, async_session, async_engine): async with async_engine.connect() as outer_conn: