]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Create terminate mixin
authorFederico Caselli <cfederico87@gmail.com>
Thu, 23 Jan 2025 21:42:14 +0000 (22:42 +0100)
committerFederico Caselli <cfederico87@gmail.com>
Mon, 18 Aug 2025 22:09:29 +0000 (00:09 +0200)
Generalize the terminate logic employed by the asyncpg dialect to reuse
it in the aiomysql and asyncmy dialect implementation.

Fixes: #12273
Change-Id: Iddb658b7118de774f169e31e888a8aae1c7c6ec2

doc/build/changelog/unreleased_20/12273.rst [new file with mode: 0644]
lib/sqlalchemy/connectors/asyncio.py
lib/sqlalchemy/dialects/mysql/aiomysql.py
lib/sqlalchemy/dialects/mysql/asyncmy.py
lib/sqlalchemy/dialects/postgresql/asyncpg.py

diff --git a/doc/build/changelog/unreleased_20/12273.rst b/doc/build/changelog/unreleased_20/12273.rst
new file mode 100644 (file)
index 0000000..754677a
--- /dev/null
@@ -0,0 +1,6 @@
+.. change::
+    :tags: usecase, asyncio
+    :tickets: 12273
+
+    Generalize the terminate logic employed by the asyncpg dialect to reuse
+    it in the aiomysql and asyncmy dialect implementation.
index 87548e510bc65228895237a731081500104d5eef..c29aa3f69dd2bc511f63f59698f4867d4acc8147 100644 (file)
@@ -20,6 +20,8 @@ from typing import NoReturn
 from typing import Optional
 from typing import Protocol
 from typing import Sequence
+from typing import Tuple
+from typing import Type
 from typing import TYPE_CHECKING
 
 from ..engine import AdaptedConnection
@@ -374,3 +376,43 @@ class AsyncAdapt_dbapi_connection(AdaptedConnection):
 
     def close(self) -> None:
         await_(self._connection.close())
+
+
+class AsyncAdapt_terminate:
+    """Mixin for a AsyncAdapt_dbapi_connection to add terminate support."""
+
+    __slots__ = ()
+
+    def terminate(self) -> None:
+        if in_greenlet():
+            # in a greenlet; this is the connection was invalidated case.
+            try:
+                # try to gracefully close; see #10717
+                await_(asyncio.shield(self._terminate_graceful_close()))
+            except self._terminate_handled_exceptions() as e:
+                # in the case where we are recycling an old connection
+                # that may have already been disconnected, close() will
+                # fail.  In this case, terminate
+                # the connection without any further waiting.
+                # see issue #8419
+                self._terminate_force_close()
+                if isinstance(e, asyncio.CancelledError):
+                    # re-raise CancelledError if we were cancelled
+                    raise
+        else:
+            # not in a greenlet; this is the gc cleanup case
+            self._terminate_force_close()
+
+    def _terminate_handled_exceptions(self) -> Tuple[Type[BaseException], ...]:
+        """Returns the exceptions that should be handled when
+        calling _graceful_close.
+        """
+        return (asyncio.TimeoutError, asyncio.CancelledError, OSError)
+
+    async def _terminate_graceful_close(self) -> None:
+        """Try to close connection gracefully"""
+        raise NotImplementedError
+
+    def _terminate_force_close(self) -> None:
+        """Terminate the connection"""
+        raise NotImplementedError
index b23dbcf8f6235034e23333c371b0ec4c2201ea38..9c043850e452d65a7194f714521224aa4f2157e5 100644 (file)
@@ -41,6 +41,7 @@ from ...connectors.asyncio import AsyncAdapt_dbapi_connection
 from ...connectors.asyncio import AsyncAdapt_dbapi_cursor
 from ...connectors.asyncio import AsyncAdapt_dbapi_module
 from ...connectors.asyncio import AsyncAdapt_dbapi_ss_cursor
+from ...connectors.asyncio import AsyncAdapt_terminate
 from ...util.concurrency import await_
 
 if TYPE_CHECKING:
@@ -77,7 +78,9 @@ class AsyncAdapt_aiomysql_ss_cursor(
         )
 
 
-class AsyncAdapt_aiomysql_connection(AsyncAdapt_dbapi_connection):
+class AsyncAdapt_aiomysql_connection(
+    AsyncAdapt_terminate, AsyncAdapt_dbapi_connection
+):
     __slots__ = ()
 
     _cursor_cls = AsyncAdapt_aiomysql_cursor
@@ -96,13 +99,16 @@ class AsyncAdapt_aiomysql_connection(AsyncAdapt_dbapi_connection):
     def get_autocommit(self) -> bool:
         return self._connection.get_autocommit()  # type: ignore
 
-    def terminate(self) -> None:
-        # it's not awaitable.
-        self._connection.close()
-
     def close(self) -> None:
         await_(self._connection.ensure_closed())
 
+    async def _terminate_graceful_close(self) -> None:
+        await self._connection.ensure_closed()
+
+    def _terminate_force_close(self) -> None:
+        # it's not awaitable.
+        self._connection.close()
+
 
 class AsyncAdapt_aiomysql_dbapi(AsyncAdapt_dbapi_module):
     def __init__(self, aiomysql: ModuleType, pymysql: ModuleType):
index b183f11c553cd5e3ec7019d84ccbebc1c259ba4e..22a60a099ab9e6097428b09a3983cd09d4f738c8 100644 (file)
@@ -41,6 +41,7 @@ from ...connectors.asyncio import AsyncAdapt_dbapi_connection
 from ...connectors.asyncio import AsyncAdapt_dbapi_cursor
 from ...connectors.asyncio import AsyncAdapt_dbapi_module
 from ...connectors.asyncio import AsyncAdapt_dbapi_ss_cursor
+from ...connectors.asyncio import AsyncAdapt_terminate
 from ...util.concurrency import await_
 
 if TYPE_CHECKING:
@@ -72,7 +73,9 @@ class AsyncAdapt_asyncmy_ss_cursor(
         )
 
 
-class AsyncAdapt_asyncmy_connection(AsyncAdapt_dbapi_connection):
+class AsyncAdapt_asyncmy_connection(
+    AsyncAdapt_terminate, AsyncAdapt_dbapi_connection
+):
     __slots__ = ()
 
     _cursor_cls = AsyncAdapt_asyncmy_cursor
@@ -106,13 +109,16 @@ class AsyncAdapt_asyncmy_connection(AsyncAdapt_dbapi_connection):
     def get_autocommit(self) -> bool:
         return self._connection.get_autocommit()  # type: ignore
 
-    def terminate(self) -> None:
-        # it's not awaitable.
-        self._connection.close()
-
     def close(self) -> None:
         await_(self._connection.ensure_closed())
 
+    async def _terminate_graceful_close(self) -> None:
+        await self._connection.ensure_closed()
+
+    def _terminate_force_close(self) -> None:
+        # it's not awaitable.
+        self._connection.close()
+
 
 class AsyncAdapt_asyncmy_dbapi(AsyncAdapt_dbapi_module):
     def __init__(self, asyncmy: ModuleType):
index 09ff9f48c087a55c7f60b71566acd91cc753b16c..7409a2023e638f7295a4cbc6b4cc50b457b1d5b7 100644 (file)
@@ -178,7 +178,6 @@ client using this setting passed to :func:`_asyncio.create_async_engine`::
 
 from __future__ import annotations
 
-import asyncio
 from collections import deque
 import decimal
 import json as _py_json
@@ -218,6 +217,7 @@ from ... import util
 from ...connectors.asyncio import AsyncAdapt_dbapi_connection
 from ...connectors.asyncio import AsyncAdapt_dbapi_cursor
 from ...connectors.asyncio import AsyncAdapt_dbapi_ss_cursor
+from ...connectors.asyncio import AsyncAdapt_terminate
 from ...engine import processors
 from ...sql import sqltypes
 from ...util.concurrency import await_
@@ -751,7 +751,9 @@ class AsyncAdapt_asyncpg_ss_cursor(
         )
 
 
-class AsyncAdapt_asyncpg_connection(AsyncAdapt_dbapi_connection):
+class AsyncAdapt_asyncpg_connection(
+    AsyncAdapt_terminate, AsyncAdapt_dbapi_connection
+):
     _cursor_cls = AsyncAdapt_asyncpg_cursor
     _ss_cursor_cls = AsyncAdapt_asyncpg_ss_cursor
 
@@ -932,32 +934,18 @@ class AsyncAdapt_asyncpg_connection(AsyncAdapt_dbapi_connection):
 
         await_(self._connection.close())
 
-    def terminate(self):
-        if util.concurrency.in_greenlet():
-            # in a greenlet; this is the connection was invalidated
-            # case.
-            try:
-                # try to gracefully close; see #10717
-                # timeout added in asyncpg 0.14.0 December 2017
-                await_(asyncio.shield(self._connection.close(timeout=2)))
-            except (
-                asyncio.TimeoutError,
-                asyncio.CancelledError,
-                OSError,
-                self.dbapi.asyncpg.PostgresError,
-            ) as e:
-                # in the case where we are recycling an old connection
-                # that may have already been disconnected, close() will
-                # fail with the above timeout.  in this case, terminate
-                # the connection without any further waiting.
-                # see issue #8419
-                self._connection.terminate()
-                if isinstance(e, asyncio.CancelledError):
-                    # re-raise CancelledError if we were cancelled
-                    raise
-        else:
-            # not in a greenlet; this is the gc cleanup case
-            self._connection.terminate()
+    def _terminate_handled_exceptions(self):
+        return super()._terminate_handled_exceptions() + (
+            self.dbapi.asyncpg.PostgresError,
+        )
+
+    async def _terminate_graceful_close(self) -> None:
+        # timeout added in asyncpg 0.14.0 December 2017
+        await self._connection.close(timeout=2)
+        self._transaction = None
+
+    def _terminate_force_close(self) -> None:
+        self._connection.terminate()
         self._transaction = None
 
     @staticmethod