From 7576286a498ec32202dbc29b053b3a6b49dcab3d Mon Sep 17 00:00:00 2001 From: semen603089 Date: Sat, 15 Jul 2023 06:00:35 +0300 Subject: [PATCH] `aclose` implementation --- lib/sqlalchemy/ext/asyncio/engine.py | 5 +++++ lib/sqlalchemy/ext/asyncio/session.py | 5 +++++ test/ext/asyncio/test_engine_py3k.py | 13 +++++++++++++ test/ext/asyncio/test_scoping_py3k.py | 1 + test/ext/asyncio/test_session_py3k.py | 11 +++++++++++ 5 files changed, 35 insertions(+) diff --git a/lib/sqlalchemy/ext/asyncio/engine.py b/lib/sqlalchemy/ext/asyncio/engine.py index 594eb02a7d..916e13dcad 100644 --- a/lib/sqlalchemy/ext/asyncio/engine.py +++ b/lib/sqlalchemy/ext/asyncio/engine.py @@ -476,6 +476,11 @@ class AsyncConnection( """ await greenlet_spawn(self._proxied.close) + async def aclose(self) -> None: + """Call the close() method of :class:`_asyncio.AsyncConnection` + """ + await self.close() + async def exec_driver_sql( self, statement: str, diff --git a/lib/sqlalchemy/ext/asyncio/session.py b/lib/sqlalchemy/ext/asyncio/session.py index 19a441ca61..25137dce73 100644 --- a/lib/sqlalchemy/ext/asyncio/session.py +++ b/lib/sqlalchemy/ext/asyncio/session.py @@ -973,6 +973,11 @@ class AsyncSession(ReversibleProxy[Session]): """ await greenlet_spawn(self.sync_session.close) + async def aclose(self) -> None: + """Call the close() method of :class:`_asyncio.AsyncSession` + """ + await self.close() + async def invalidate(self) -> None: """Close this Session, using connection invalidation. diff --git a/test/ext/asyncio/test_engine_py3k.py b/test/ext/asyncio/test_engine_py3k.py index bbbdbf512f..7289d5494e 100644 --- a/test/ext/asyncio/test_engine_py3k.py +++ b/test/ext/asyncio/test_engine_py3k.py @@ -1,4 +1,5 @@ import asyncio +import contextlib import inspect as stdlib_inspect from unittest.mock import patch @@ -663,6 +664,18 @@ class AsyncEngineTest(EngineFixture): with expect_raises(exc.TimeoutError): await engine.connect() + @testing.requires.python310 + @async_test + async def test_engine_aclose(self, async_engine): + users = self.tables.users + async with contextlib.aclosing(async_engine.connect()) as conn: + await conn.start() + trans = conn.begin() + await trans.start() + await conn.execute(delete(users)) + await trans.commit() + assert conn.closed + @testing.requires.queue_pool @async_test async def test_pool_exhausted_no_timeout(self, async_engine): diff --git a/test/ext/asyncio/test_scoping_py3k.py b/test/ext/asyncio/test_scoping_py3k.py index caba1c6600..1b00e484d5 100644 --- a/test/ext/asyncio/test_scoping_py3k.py +++ b/test/ext/asyncio/test_scoping_py3k.py @@ -63,6 +63,7 @@ class AsyncScopedSessionTest(AsyncFixture): "get_nested_transaction", "in_transaction", "in_nested_transaction", + "aclose", } SM = async_scoped_session( diff --git a/test/ext/asyncio/test_session_py3k.py b/test/ext/asyncio/test_session_py3k.py index 228489349a..8fa174eeba 100644 --- a/test/ext/asyncio/test_session_py3k.py +++ b/test/ext/asyncio/test_session_py3k.py @@ -1,5 +1,6 @@ from __future__ import annotations +import contextlib from typing import List from typing import Optional @@ -148,6 +149,16 @@ class AsyncSessionQueryTest(AsyncFixture): result = await async_session.scalar(stmt) eq_(result, 7) + @testing.requires.python310 + @async_test + async def test_session_aclose(self, async_session): + User = self.classes.User + u = User(name="u") + async with contextlib.aclosing(async_session) as session: + session.add(u) + await session.commit() + assert async_session.sync_session.identity_map.values() == [] + @testing.combinations( ("scalars",), ("stream_scalars",), argnames="filter_" ) -- 2.47.3