]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
close aio cursors etc. that require await close
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 12 Aug 2025 19:25:15 +0000 (15:25 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 13 Aug 2025 03:02:03 +0000 (23:02 -0400)
Improved the base implementation of the asyncio cursor such that it
includes the option for the underlying driver's cursor to be actively
closed in those cases where it requires ``await`` in order to complete the
close sequence, rather than relying on garbage collection to "close" it,
when a plain :class:`.Result` is returned that does not use ``await`` for
any of its methods.  The previous approach of relying on gc was fine for
MySQL and SQLite dialects but has caused problems with the aioodbc
implementation on top of SQL Server.   The new option is enabled
for those dialects which have an "awaitable" ``cursor.close()``, which
includes the aioodbc, aiomysql, and asyncmy dialects (aiosqlite is also
modified for 2.1 only).

Fixes: #12798
Change-Id: Ib17d611201fedf9780dfe3d760760ace99a8835c

doc/build/changelog/unreleased_20/12798.rst [new file with mode: 0644]
lib/sqlalchemy/connectors/asyncio.py
lib/sqlalchemy/dialects/oracle/oracledb.py
lib/sqlalchemy/dialects/postgresql/asyncpg.py
lib/sqlalchemy/dialects/postgresql/psycopg.py
lib/sqlalchemy/ext/asyncio/result.py
test/ext/asyncio/test_engine_py3k.py
test/requirements.py

diff --git a/doc/build/changelog/unreleased_20/12798.rst b/doc/build/changelog/unreleased_20/12798.rst
new file mode 100644 (file)
index 0000000..0161026
--- /dev/null
@@ -0,0 +1,15 @@
+.. change::
+    :tags: bug, mssql
+    :tickets: 12798
+
+    Improved the base implementation of the asyncio cursor such that it
+    includes the option for the underlying driver's cursor to be actively
+    closed in those cases where it requires ``await`` in order to complete the
+    close sequence, rather than relying on garbage collection to "close" it,
+    when a plain :class:`.Result` is returned that does not use ``await`` for
+    any of its methods.  The previous approach of relying on gc was fine for
+    MySQL and SQLite dialects but has caused problems with the aioodbc
+    implementation on top of SQL Server.   The new option is enabled
+    for those dialects which have an "awaitable" ``cursor.close()``, which
+    includes the aioodbc, aiomysql, and asyncmy dialects (aiosqlite is also
+    modified for 2.1 only).
index 2037c248efc4b14ab1058f80900c9275d3636cc8..87548e510bc65228895237a731081500104d5eef 100644 (file)
@@ -23,7 +23,9 @@ from typing import Sequence
 from typing import TYPE_CHECKING
 
 from ..engine import AdaptedConnection
+from ..util import EMPTY_DICT
 from ..util.concurrency import await_
+from ..util.concurrency import in_greenlet
 
 if TYPE_CHECKING:
     from ..engine.interfaces import _DBAPICursorDescription
@@ -127,8 +129,11 @@ class AsyncAdapt_dbapi_cursor:
         "_connection",
         "_cursor",
         "_rows",
+        "_soft_closed_memoized",
     )
 
+    _awaitable_cursor_close: bool = True
+
     _cursor: AsyncIODBAPICursor
     _adapt_connection: AsyncAdapt_dbapi_connection
     _connection: AsyncIODBAPIConnection
@@ -140,7 +145,7 @@ class AsyncAdapt_dbapi_cursor:
 
         cursor = self._make_new_cursor(self._connection)
         self._cursor = self._aenter_cursor(cursor)
-
+        self._soft_closed_memoized = EMPTY_DICT
         if not self.server_side:
             self._rows = collections.deque()
 
@@ -157,6 +162,8 @@ class AsyncAdapt_dbapi_cursor:
 
     @property
     def description(self) -> Optional[_DBAPICursorDescription]:
+        if "description" in self._soft_closed_memoized:
+            return self._soft_closed_memoized["description"]  # type: ignore[no-any-return]  # noqa: E501
         return self._cursor.description
 
     @property
@@ -175,11 +182,40 @@ class AsyncAdapt_dbapi_cursor:
     def lastrowid(self) -> int:
         return self._cursor.lastrowid
 
+    async def _async_soft_close(self) -> None:
+        """close the cursor but keep the results pending, and memoize the
+        description.
+
+        .. versionadded:: 2.0.44
+
+        """
+
+        if not self._awaitable_cursor_close or self.server_side:
+            return
+
+        self._soft_closed_memoized = self._soft_closed_memoized.union(
+            {
+                "description": self._cursor.description,
+            }
+        )
+        await self._cursor.close()
+
     def close(self) -> None:
-        # note we aren't actually closing the cursor here,
-        # we are just letting GC do it.  see notes in aiomysql dialect
         self._rows.clear()
 
+        # updated as of 2.0.44
+        # try to "close" the cursor based on what we know about the driver
+        # and if we are able to.  otherwise, hope that the asyncio
+        # extension called _async_soft_close() if the cursor is going into
+        # a sync context
+        if self._cursor is None or bool(self._soft_closed_memoized):
+            return
+
+        if not self._awaitable_cursor_close:
+            self._cursor.close()  # type: ignore[unused-coroutine]
+        elif in_greenlet():
+            await_(self._cursor.close())
+
     def execute(
         self,
         operation: Any,
index d4fb99befa50570488d59d8c3905c08043718f22..a35fa9255c4ef419f9c049742fc68e387b834455 100644 (file)
@@ -716,6 +716,8 @@ class OracleDialect_oracledb(_cx_oracle.OracleDialect_cx_oracle):
 
 class AsyncAdapt_oracledb_cursor(AsyncAdapt_dbapi_cursor):
     _cursor: AsyncCursor
+    _awaitable_cursor_close: bool = False
+
     __slots__ = ()
 
     @property
@@ -729,10 +731,6 @@ class AsyncAdapt_oracledb_cursor(AsyncAdapt_dbapi_cursor):
     def var(self, *args, **kwargs):
         return self._cursor.var(*args, **kwargs)
 
-    def close(self):
-        self._rows.clear()
-        self._cursor.close()
-
     def setinputsizes(self, *args: Any, **kwargs: Any) -> Any:
         return self._cursor.setinputsizes(*args, **kwargs)
 
index 09ede9a7e74ffd5e51ce242f1492d795ff6d0d89..09ff9f48c087a55c7f60b71566acd91cc753b16c 100644 (file)
@@ -561,6 +561,7 @@ class AsyncAdapt_asyncpg_cursor(AsyncAdapt_dbapi_cursor):
     _adapt_connection: AsyncAdapt_asyncpg_connection
     _connection: _AsyncpgConnection
     _cursor: Optional[_AsyncpgCursor]
+    _awaitable_cursor_close: bool = False
 
     def __init__(self, adapt_connection: AsyncAdapt_asyncpg_connection):
         self._adapt_connection = adapt_connection
index 4df6f8a4fa26f8c9b778b832611f693c38144853..7ad63a2fd3c96936cc17f0c0455cb7ef7898dac5 100644 (file)
@@ -568,6 +568,8 @@ class PGDialect_psycopg(_PGDialect_common_psycopg):
 class AsyncAdapt_psycopg_cursor(AsyncAdapt_dbapi_cursor):
     __slots__ = ()
 
+    _awaitable_cursor_close: bool = False
+
     def close(self):
         self._rows.clear()
         # Normal cursor just call _close() in a non-sync way.
index 002bb7e03c36a3b826afb36fdbf5a64501d59047..970bb791bca93f2d5f3515fa67459d61902576c1 100644 (file)
@@ -988,4 +988,7 @@ async def _ensure_sync_result(result: _RT, calling_method: Any) -> _RT:
                 calling_method.__self__.__class__.__name__,
             )
         )
+
+    if is_cursor and cursor_result.cursor is not None:
+        await cursor_result.cursor._async_soft_close()
     return result
index 48226aa27bd577fb50d6953aa0e49004079a5402..901b47f42ffd7d2f78cf5f1bcc6923f28ec829b4 100644 (file)
@@ -896,6 +896,20 @@ class AsyncEngineTest(EngineFixture):
                     else:
                         testing.fail(method)
 
+    @async_test
+    @testing.requires.async_dialect_with_await_close
+    async def test_active_await_close(self, async_engine):
+        select_one_sql = select(1).compile(async_engine.sync_engine).string
+
+        async with async_engine.connect() as conn:
+            result = await conn.exec_driver_sql(select_one_sql)
+            eq_(result.scalar_one(), 1)
+            driver_cursor = result.context.cursor._cursor
+
+            with expect_raises(Exception):
+                # because the cursor should be closed
+                await driver_cursor.execute(select_one_sql)
+
 
 class AsyncCreatePoolTest(fixtures.TestBase):
     @config.fixture
index 34bd3386494fcc85b3bca74acd58980738e773fa..8ba81f389aba78ae83714c42f9c4518981860601 100644 (file)
@@ -1600,6 +1600,12 @@ class DefaultRequirements(SuiteRequirements):
             )
         )
 
+    @property
+    def async_dialect_with_await_close(self):
+        """dialect's cursor has a close() method called with await"""
+
+        return only_on(["+aioodbc", "+aiosqlite", "+aiomysql", "+asyncmy"])
+
     def _has_oracle_test_dblink(self, key):
         def check(config):
             assert config.db.dialect.name == "oracle"