]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
close aio cursors etc. that require await close
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 12 Aug 2025 19:25:15 +0000 (15:25 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 13 Aug 2025 03:00:58 +0000 (23:00 -0400)
Improved the base implementation of the asyncio cursor such that it
includes the option for the underlying driver's cursor to be actively
closed in those cases where it requires ``await`` in order to complete the
close sequence, rather than relying on garbage collection to "close" it,
when a plain :class:`.Result` is returned that does not use ``await`` for
any of its methods.  The previous approach of relying on gc was fine for
MySQL and SQLite dialects but has caused problems with the aioodbc
implementation on top of SQL Server.   The new option is enabled
for those dialects which have an "awaitable" ``cursor.close()``, which
includes the aioodbc, aiomysql, and asyncmy dialects (aiosqlite is also
modified for 2.1 only).

Fixes: #12798
Change-Id: Ib17d611201fedf9780dfe3d760760ace99a8835c
(cherry picked from commit 5dbb5ec0e4ce71f7b806b87808a504083a7e8ffa)

doc/build/changelog/unreleased_20/12798.rst [new file with mode: 0644]
lib/sqlalchemy/connectors/asyncio.py
lib/sqlalchemy/dialects/oracle/oracledb.py
lib/sqlalchemy/dialects/postgresql/asyncpg.py
lib/sqlalchemy/dialects/postgresql/psycopg.py
lib/sqlalchemy/dialects/sqlite/aiosqlite.py
lib/sqlalchemy/ext/asyncio/result.py
test/ext/asyncio/test_engine_py3k.py
test/requirements.py

diff --git a/doc/build/changelog/unreleased_20/12798.rst b/doc/build/changelog/unreleased_20/12798.rst
new file mode 100644 (file)
index 0000000..0161026
--- /dev/null
@@ -0,0 +1,15 @@
+.. change::
+    :tags: bug, mssql
+    :tickets: 12798
+
+    Improved the base implementation of the asyncio cursor such that it
+    includes the option for the underlying driver's cursor to be actively
+    closed in those cases where it requires ``await`` in order to complete the
+    close sequence, rather than relying on garbage collection to "close" it,
+    when a plain :class:`.Result` is returned that does not use ``await`` for
+    any of its methods.  The previous approach of relying on gc was fine for
+    MySQL and SQLite dialects but has caused problems with the aioodbc
+    implementation on top of SQL Server.   The new option is enabled
+    for those dialects which have an "awaitable" ``cursor.close()``, which
+    includes the aioodbc, aiomysql, and asyncmy dialects (aiosqlite is also
+    modified for 2.1 only).
index fda21b6d6f08e8ca15c50159e1c607ab867886f3..68819a1f3b485231d38f999344349bcdd1eee3c2 100644 (file)
@@ -22,8 +22,10 @@ from typing import Sequence
 from typing import TYPE_CHECKING
 
 from ..engine import AdaptedConnection
+from ..util import EMPTY_DICT
 from ..util.concurrency import await_fallback
 from ..util.concurrency import await_only
+from ..util.concurrency import in_greenlet
 from ..util.typing import Protocol
 
 if TYPE_CHECKING:
@@ -129,8 +131,11 @@ class AsyncAdapt_dbapi_cursor:
         "await_",
         "_cursor",
         "_rows",
+        "_soft_closed_memoized",
     )
 
+    _awaitable_cursor_close: bool = True
+
     _cursor: AsyncIODBAPICursor
     _adapt_connection: AsyncAdapt_dbapi_connection
     _connection: AsyncIODBAPIConnection
@@ -144,7 +149,7 @@ class AsyncAdapt_dbapi_cursor:
 
         cursor = self._make_new_cursor(self._connection)
         self._cursor = self._aenter_cursor(cursor)
-
+        self._soft_closed_memoized = EMPTY_DICT
         if not self.server_side:
             self._rows = collections.deque()
 
@@ -158,6 +163,8 @@ class AsyncAdapt_dbapi_cursor:
 
     @property
     def description(self) -> Optional[_DBAPICursorDescription]:
+        if "description" in self._soft_closed_memoized:
+            return self._soft_closed_memoized["description"]  # type: ignore[no-any-return]  # noqa: E501
         return self._cursor.description
 
     @property
@@ -176,11 +183,40 @@ class AsyncAdapt_dbapi_cursor:
     def lastrowid(self) -> int:
         return self._cursor.lastrowid
 
+    async def _async_soft_close(self) -> None:
+        """close the cursor but keep the results pending, and memoize the
+        description.
+
+        .. versionadded:: 2.0.44
+
+        """
+
+        if not self._awaitable_cursor_close or self.server_side:
+            return
+
+        self._soft_closed_memoized = self._soft_closed_memoized.union(
+            {
+                "description": self._cursor.description,
+            }
+        )
+        await self._cursor.close()
+
     def close(self) -> None:
-        # note we aren't actually closing the cursor here,
-        # we are just letting GC do it.  see notes in aiomysql dialect
         self._rows.clear()
 
+        # updated as of 2.0.44
+        # try to "close" the cursor based on what we know about the driver
+        # and if we are able to.  otherwise, hope that the asyncio
+        # extension called _async_soft_close() if the cursor is going into
+        # a sync context
+        if self._cursor is None or bool(self._soft_closed_memoized):
+            return
+
+        if not self._awaitable_cursor_close:
+            self._cursor.close()  # type: ignore[unused-coroutine]
+        elif in_greenlet():
+            self.await_(self._cursor.close())
+
     def execute(
         self,
         operation: Any,
index c09d2bae0df6a0590770dee3103a75370e39aec7..cce7ad7b58fb6d44c98ed676b26d5095c32fd2a0 100644 (file)
@@ -736,6 +736,8 @@ class OracleDialect_oracledb(_cx_oracle.OracleDialect_cx_oracle):
 
 class AsyncAdapt_oracledb_cursor(AsyncAdapt_dbapi_cursor):
     _cursor: AsyncCursor
+    _awaitable_cursor_close: bool = False
+
     __slots__ = ()
 
     @property
@@ -749,10 +751,6 @@ class AsyncAdapt_oracledb_cursor(AsyncAdapt_dbapi_cursor):
     def var(self, *args, **kwargs):
         return self._cursor.var(*args, **kwargs)
 
-    def close(self):
-        self._rows.clear()
-        self._cursor.close()
-
     def setinputsizes(self, *args: Any, **kwargs: Any) -> Any:
         return self._cursor.setinputsizes(*args, **kwargs)
 
index 5b3073af351559adacdad140c8fd32b4b6219ab1..adba7abb67bee4140b979df1a73a220bc10b67dd 100644 (file)
@@ -490,6 +490,7 @@ class AsyncAdapt_asyncpg_cursor:
     )
 
     server_side = False
+    _awaitable_cursor_close: bool = False
 
     def __init__(self, adapt_connection):
         self._adapt_connection = adapt_connection
@@ -501,6 +502,9 @@ class AsyncAdapt_asyncpg_cursor:
         self.rowcount = -1
         self._invalidate_schema_cache_asof = 0
 
+    async def _async_soft_close(self) -> None:
+        return
+
     def close(self):
         self._rows.clear()
 
index 0554048c2bf629d01fd3bbbe4db28df96a1369e2..200bf4a020ac4f3041737b90c577f3d622ca854f 100644 (file)
@@ -585,6 +585,9 @@ class AsyncAdapt_psycopg_cursor:
     def arraysize(self, value):
         self._cursor.arraysize = value
 
+    async def _async_soft_close(self) -> None:
+        return
+
     def close(self):
         self._rows.clear()
         # Normal cursor just call _close() in a non-sync way.
index 3f39d4dbc7db56bbd27408f527f601695773ecd9..63cf8190b7c100824c651f351509109a3eccac53 100644 (file)
@@ -141,6 +141,9 @@ class AsyncAdapt_aiosqlite_cursor:
         self.description: Optional[_DBAPICursorDescription] = None
         self._rows: Deque[Any] = deque()
 
+    async def _async_soft_close(self) -> None:
+        return
+
     def close(self) -> None:
         self._rows.clear()
 
index f1df53bc13c8dee20830283b4848f2106f3ac5a7..88495206ea20e366bac67edc5256b044ad655d38 100644 (file)
@@ -959,4 +959,7 @@ async def _ensure_sync_result(result: _RT, calling_method: Any) -> _RT:
                 calling_method.__self__.__class__.__name__,
             )
         )
+
+    if is_cursor and cursor_result.cursor is not None:
+        await cursor_result.cursor._async_soft_close()
     return result
index 05941a79a2ab383347cb8d2d5837093d0b2ee1c4..6b7a7afd028c57599a614e0d135aae45d55188d0 100644 (file)
@@ -893,6 +893,20 @@ class AsyncEngineTest(EngineFixture):
                     else:
                         testing.fail(method)
 
+    @async_test
+    @testing.requires.async_dialect_with_await_close
+    async def test_active_await_close(self, async_engine):
+        select_one_sql = select(1).compile(async_engine.sync_engine).string
+
+        async with async_engine.connect() as conn:
+            result = await conn.exec_driver_sql(select_one_sql)
+            eq_(result.scalar_one(), 1)
+            driver_cursor = result.context.cursor._cursor
+
+            with expect_raises(Exception):
+                # because the cursor should be closed
+                await driver_cursor.execute(select_one_sql)
+
 
 class AsyncCreatePoolTest(fixtures.TestBase):
     @config.fixture
index ba481fb4efaaaaed270e2863ef1446c39b2e76d9..66c0fb174f5b3bf95e7cd9db80b020dd2c960371 100644 (file)
@@ -1600,6 +1600,12 @@ class DefaultRequirements(SuiteRequirements):
             )
         )
 
+    @property
+    def async_dialect_with_await_close(self):
+        """dialect's cursor has a close() method called with await"""
+
+        return only_on(["+aioodbc", "+aiomysql", "+asyncmy"])
+
     def _has_oracle_test_dblink(self, key):
         def check(config):
             assert config.db.dialect.name == "oracle"