]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
implement sessionmaker.begin(), scalar() for async session
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 26 Dec 2020 16:46:42 +0000 (11:46 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 26 Dec 2020 20:56:12 +0000 (15:56 -0500)
Added :meth:`_asyncio.AsyncSession.scalar` as well as support for
:meth:`_orm.sessionmaker.begin` to work as an async context manager with
:class:`_asyncio.AsyncSession`.  Also added
:meth:`_asyncio.AsyncSession.in_transaction` accessor.

Fixes: #5796
Fixes: #5797
Change-Id: Id3d431421df0f8c38f356469a50a946ba9c38513

doc/build/changelog/unreleased_14/5797.rst [new file with mode: 0644]
lib/sqlalchemy/ext/asyncio/session.py
lib/sqlalchemy/orm/session.py
test/ext/asyncio/test_session_py3k.py

diff --git a/doc/build/changelog/unreleased_14/5797.rst b/doc/build/changelog/unreleased_14/5797.rst
new file mode 100644 (file)
index 0000000..be32280
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: usecase, orm, asyncio
+    :tickets: 5796, 5797
+
+    Added :meth:`_asyncio.AsyncSession.scalar` as well as support for
+    :meth:`_orm.sessionmaker.begin` to work as an async context manager with
+    :class:`_asyncio.AsyncSession`.  Also added
+    :meth:`_asyncio.AsyncSession.in_transaction` accessor.
\ No newline at end of file
index c21b67954327b368b818813c7c89fecb64598a06..fc5b9cb4485239e438b4f9bf435596abbf771211 100644 (file)
@@ -1,3 +1,4 @@
+from typing import Any
 from typing import Callable
 from typing import Mapping
 from typing import Optional
@@ -34,6 +35,7 @@ T = TypeVar("T")
         "expunge_all",
         "get_bind",
         "is_modified",
+        "in_transaction",
     ],
     attributes=[
         "dirty",
@@ -144,6 +146,25 @@ class AsyncSession:
             **kw
         )
 
+    async def scalar(
+        self,
+        statement: Executable,
+        params: Optional[Mapping] = None,
+        execution_options: Mapping = util.EMPTY_DICT,
+        bind_arguments: Optional[Mapping] = None,
+        **kw
+    ) -> Any:
+        """Execute a statement and return a scalar result."""
+
+        result = await self.execute(
+            statement,
+            params=params,
+            execution_options=execution_options,
+            bind_arguments=bind_arguments,
+            **kw
+        )
+        return result.scalar()
+
     async def stream(
         self,
         statement,
@@ -262,6 +283,24 @@ class AsyncSession:
     async def __aexit__(self, type_, value, traceback):
         await self.close()
 
+    def _maker_context_manager(self):
+        # no @contextlib.asynccontextmanager until python3.7, gr
+        return _AsyncSessionContextManager(self)
+
+
+class _AsyncSessionContextManager:
+    def __init__(self, async_session):
+        self.async_session = async_session
+
+    async def __aenter__(self):
+        self.trans = self.async_session.begin()
+        await self.trans.__aenter__()
+        return self.async_session
+
+    async def __aexit__(self, type_, value, traceback):
+        await self.trans.__aexit__(type_, value, traceback)
+        await self.async_session.__aexit__(type_, value, traceback)
+
 
 class AsyncSessionTransaction(StartableContext):
     """A wrapper for the ORM :class:`_orm.SessionTransaction` object.
index 1ec63fa40a0678c69a8ec8ce743b858397f49dad..a5f0894f6ac2e4e68e818645bbb86270f0a01da5 100644 (file)
@@ -1152,6 +1152,12 @@ class Session(_SessionClassMethods):
     def __exit__(self, type_, value, traceback):
         self.close()
 
+    @util.contextmanager
+    def _maker_context_manager(self):
+        with self:
+            with self.begin():
+                yield self
+
     @property
     @util.deprecated_20(
         ":attr:`_orm.Session.transaction`",
@@ -3969,7 +3975,6 @@ class sessionmaker(_SessionClassMethods):
         # events can be associated with it specifically.
         self.class_ = type(class_.__name__, (class_,), {})
 
-    @util.contextmanager
     def begin(self):
         """Produce a context manager that both provides a new
         :class:`_orm.Session` as well as a transaction that commits.
@@ -3988,9 +3993,9 @@ class sessionmaker(_SessionClassMethods):
 
 
         """
-        with self() as session:
-            with session.begin():
-                yield session
+
+        session = self()
+        return session._maker_context_manager()
 
     def __call__(self, **local_kw):
         """Produce a new :class:`.Session` object using the configuration
index 44e2955428b2898fbb0655c1d080e40e6a5a458a..c0ba8c2b3c04ca2921800b653e305ac5edf740c9 100644 (file)
@@ -7,6 +7,7 @@ from sqlalchemy import update
 from sqlalchemy.ext.asyncio import AsyncSession
 from sqlalchemy.ext.asyncio import create_async_engine
 from sqlalchemy.orm import selectinload
+from sqlalchemy.orm import sessionmaker
 from sqlalchemy.testing import async_test
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import is_
@@ -54,6 +55,15 @@ class AsyncSessionQueryTest(AsyncFixture):
         result = await async_session.execute(stmt)
         eq_(result.scalars().all(), self.static.user_address_result)
 
+    @async_test
+    async def test_scalar(self, async_session):
+        User = self.classes.User
+
+        stmt = select(User.id).order_by(User.id).limit(1)
+
+        result = await async_session.scalar(stmt)
+        eq_(result, 7)
+
     @async_test
     @testing.requires.independent_cursors
     async def test_stream_partitions(self, async_session):
@@ -83,6 +93,52 @@ class AsyncSessionQueryTest(AsyncFixture):
 class AsyncSessionTransactionTest(AsyncFixture):
     run_inserts = None
 
+    @async_test
+    async def test_sessionmaker_block_one(self, async_engine):
+
+        User = self.classes.User
+        maker = sessionmaker(async_engine, class_=AsyncSession)
+
+        session = maker()
+
+        async with session.begin():
+            u1 = User(name="u1")
+            assert session.in_transaction()
+            session.add(u1)
+
+        assert not session.in_transaction()
+
+        async with maker() as session:
+            result = await session.execute(
+                select(User).where(User.name == "u1")
+            )
+
+            u1 = result.scalar_one()
+
+            eq_(u1.name, "u1")
+
+    @async_test
+    async def test_sessionmaker_block_two(self, async_engine):
+
+        User = self.classes.User
+        maker = sessionmaker(async_engine, class_=AsyncSession)
+
+        async with maker.begin() as session:
+            u1 = User(name="u1")
+            assert session.in_transaction()
+            session.add(u1)
+
+        assert not session.in_transaction()
+
+        async with maker() as session:
+            result = await session.execute(
+                select(User).where(User.name == "u1")
+            )
+
+            u1 = result.scalar_one()
+
+            eq_(u1.name, "u1")
+
     @async_test
     async def test_trans(self, async_session, async_engine):
         async with async_engine.connect() as outer_conn: