]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
`aclose` implementation
authorsemen603089 <semen603089@mail.ru>
Sat, 15 Jul 2023 03:00:35 +0000 (06:00 +0300)
committersemen603089 <semen603089@mail.ru>
Sat, 15 Jul 2023 03:00:35 +0000 (06:00 +0300)
lib/sqlalchemy/ext/asyncio/engine.py
lib/sqlalchemy/ext/asyncio/session.py
test/ext/asyncio/test_engine_py3k.py
test/ext/asyncio/test_scoping_py3k.py
test/ext/asyncio/test_session_py3k.py

index 594eb02a7df29d01d0e334cae8deb85a839516d8..916e13dcad5884a05a5c6a8132e535c96ac00fd8 100644 (file)
@@ -476,6 +476,11 @@ class AsyncConnection(
         """
         await greenlet_spawn(self._proxied.close)
 
+    async def aclose(self) -> None:
+        """Call the close() method of :class:`_asyncio.AsyncConnection`
+        """
+        await self.close()
+
     async def exec_driver_sql(
         self,
         statement: str,
index 19a441ca613943ddc2c0b677e28b30e6a1a9991c..25137dce73e3f9b9bb969153574c07e9588f0b31 100644 (file)
@@ -973,6 +973,11 @@ class AsyncSession(ReversibleProxy[Session]):
         """
         await greenlet_spawn(self.sync_session.close)
 
+    async def aclose(self) -> None:
+        """Call the close() method of :class:`_asyncio.AsyncSession`
+        """
+        await self.close()
+
     async def invalidate(self) -> None:
         """Close this Session, using connection invalidation.
 
index bbbdbf512f4b528026653e9a9700c198b86e2a60..7289d5494ebcaa172254fad8cc05915798c5ab0b 100644 (file)
@@ -1,4 +1,5 @@
 import asyncio
+import contextlib
 import inspect as stdlib_inspect
 from unittest.mock import patch
 
@@ -663,6 +664,18 @@ class AsyncEngineTest(EngineFixture):
             with expect_raises(exc.TimeoutError):
                 await engine.connect()
 
+    @testing.requires.python310
+    @async_test
+    async def test_engine_aclose(self, async_engine):
+        users = self.tables.users
+        async with contextlib.aclosing(async_engine.connect()) as conn:
+            await conn.start()
+            trans = conn.begin()
+            await trans.start()
+            await conn.execute(delete(users))
+            await trans.commit()
+        assert conn.closed
+
     @testing.requires.queue_pool
     @async_test
     async def test_pool_exhausted_no_timeout(self, async_engine):
index caba1c66001fe2be4340186a69d11eda536e1385..1b00e484d5bc9e8dde55418b56fdcaf7736edb18 100644 (file)
@@ -63,6 +63,7 @@ class AsyncScopedSessionTest(AsyncFixture):
             "get_nested_transaction",
             "in_transaction",
             "in_nested_transaction",
+            "aclose",
         }
 
         SM = async_scoped_session(
index 228489349a15c339cc285ff6836d79254255c9be..8fa174eebaaea5f0e3f8ed220f8e86cb03bd1922 100644 (file)
@@ -1,5 +1,6 @@
 from __future__ import annotations
 
+import contextlib
 from typing import List
 from typing import Optional
 
@@ -148,6 +149,16 @@ class AsyncSessionQueryTest(AsyncFixture):
         result = await async_session.scalar(stmt)
         eq_(result, 7)
 
+    @testing.requires.python310
+    @async_test
+    async def test_session_aclose(self, async_session):
+        User = self.classes.User
+        u = User(name="u")
+        async with contextlib.aclosing(async_session) as session:
+            session.add(u)
+            await session.commit()
+        assert async_session.sync_session.identity_map.values() == []
+
     @testing.combinations(
         ("scalars",), ("stream_scalars",), argnames="filter_"
     )