]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Create terminate mixin
authorFederico Caselli <cfederico87@gmail.com>
Thu, 23 Jan 2025 21:42:14 +0000 (22:42 +0100)
committerFederico Caselli <cfederico87@gmail.com>
Tue, 19 Aug 2025 17:18:26 +0000 (19:18 +0200)
Generalize the terminate logic employed by the asyncpg dialect to reuse
it in the aiomysql and asyncmy dialect implementation.

Fixes: #12273
Change-Id: Iddb658b7118de774f169e31e888a8aae1c7c6ec2
(cherry picked from commit dddfa96736dd905be59c8601ae3e09c8bc52299c)

doc/build/changelog/unreleased_20/12273.rst [new file with mode: 0644]
lib/sqlalchemy/connectors/asyncio.py
lib/sqlalchemy/dialects/mysql/aiomysql.py
lib/sqlalchemy/dialects/mysql/asyncmy.py
lib/sqlalchemy/dialects/postgresql/asyncpg.py

diff --git a/doc/build/changelog/unreleased_20/12273.rst b/doc/build/changelog/unreleased_20/12273.rst
new file mode 100644 (file)
index 0000000..754677a
--- /dev/null
@@ -0,0 +1,6 @@
+.. change::
+    :tags: usecase, asyncio
+    :tickets: 12273
+
+    Generalize the terminate logic employed by the asyncpg dialect to reuse
+    it in the aiomysql and asyncmy dialect implementation.
index 68819a1f3b485231d38f999344349bcdd1eee3c2..335ddd221ea6909e140eb0d362d2f2a16ad29d86 100644 (file)
@@ -19,6 +19,8 @@ from typing import Iterator
 from typing import NoReturn
 from typing import Optional
 from typing import Sequence
+from typing import Tuple
+from typing import Type
 from typing import TYPE_CHECKING
 
 from ..engine import AdaptedConnection
@@ -385,3 +387,43 @@ class AsyncAdaptFallback_dbapi_connection(AsyncAdapt_dbapi_connection):
     __slots__ = ()
 
     await_ = staticmethod(await_fallback)
+
+
+class AsyncAdapt_terminate:
+    """Mixin for a AsyncAdapt_dbapi_connection to add terminate support."""
+
+    __slots__ = ()
+
+    def terminate(self) -> None:
+        if in_greenlet():
+            # in a greenlet; this is the connection was invalidated case.
+            try:
+                # try to gracefully close; see #10717
+                self.await_(asyncio.shield(self._terminate_graceful_close()))  # type: ignore[attr-defined] # noqa: E501
+            except self._terminate_handled_exceptions() as e:
+                # in the case where we are recycling an old connection
+                # that may have already been disconnected, close() will
+                # fail.  In this case, terminate
+                # the connection without any further waiting.
+                # see issue #8419
+                self._terminate_force_close()
+                if isinstance(e, asyncio.CancelledError):
+                    # re-raise CancelledError if we were cancelled
+                    raise
+        else:
+            # not in a greenlet; this is the gc cleanup case
+            self._terminate_force_close()
+
+    def _terminate_handled_exceptions(self) -> Tuple[Type[BaseException], ...]:
+        """Returns the exceptions that should be handled when
+        calling _graceful_close.
+        """
+        return (asyncio.TimeoutError, asyncio.CancelledError, OSError)
+
+    async def _terminate_graceful_close(self) -> None:
+        """Try to close connection gracefully"""
+        raise NotImplementedError
+
+    def _terminate_force_close(self) -> None:
+        """Terminate the connection"""
+        raise NotImplementedError
index af1ac2f3346446ba4de33fe72233df0e1c7c8609..77b2960aabf858fe8170801e2037cc42a700106c 100644 (file)
@@ -45,6 +45,7 @@ from ...connectors.asyncio import AsyncAdapt_dbapi_connection
 from ...connectors.asyncio import AsyncAdapt_dbapi_cursor
 from ...connectors.asyncio import AsyncAdapt_dbapi_module
 from ...connectors.asyncio import AsyncAdapt_dbapi_ss_cursor
+from ...connectors.asyncio import AsyncAdapt_terminate
 from ...util.concurrency import await_fallback
 from ...util.concurrency import await_only
 
@@ -82,7 +83,9 @@ class AsyncAdapt_aiomysql_ss_cursor(
         )
 
 
-class AsyncAdapt_aiomysql_connection(AsyncAdapt_dbapi_connection):
+class AsyncAdapt_aiomysql_connection(
+    AsyncAdapt_terminate, AsyncAdapt_dbapi_connection
+):
     __slots__ = ()
 
     _cursor_cls = AsyncAdapt_aiomysql_cursor
@@ -101,13 +104,16 @@ class AsyncAdapt_aiomysql_connection(AsyncAdapt_dbapi_connection):
     def get_autocommit(self) -> bool:
         return self._connection.get_autocommit()  # type: ignore
 
-    def terminate(self) -> None:
-        # it's not awaitable.
-        self._connection.close()
-
     def close(self) -> None:
         self.await_(self._connection.ensure_closed())
 
+    async def _terminate_graceful_close(self) -> None:
+        await self._connection.ensure_closed()
+
+    def _terminate_force_close(self) -> None:
+        # it's not awaitable.
+        self._connection.close()
+
 
 class AsyncAdaptFallback_aiomysql_connection(AsyncAdapt_aiomysql_connection):
     __slots__ = ()
index 61157facd340041d6df442053eceb829ff7b1500..d36a7eaeed4e9444a25bf31fc92a5e04880f6070 100644 (file)
@@ -42,6 +42,7 @@ from ...connectors.asyncio import AsyncAdapt_dbapi_connection
 from ...connectors.asyncio import AsyncAdapt_dbapi_cursor
 from ...connectors.asyncio import AsyncAdapt_dbapi_module
 from ...connectors.asyncio import AsyncAdapt_dbapi_ss_cursor
+from ...connectors.asyncio import AsyncAdapt_terminate
 from ...util.concurrency import await_fallback
 from ...util.concurrency import await_only
 
@@ -73,7 +74,9 @@ class AsyncAdapt_asyncmy_ss_cursor(
         )
 
 
-class AsyncAdapt_asyncmy_connection(AsyncAdapt_dbapi_connection):
+class AsyncAdapt_asyncmy_connection(
+    AsyncAdapt_terminate, AsyncAdapt_dbapi_connection
+):
     __slots__ = ()
 
     _cursor_cls = AsyncAdapt_asyncmy_cursor
@@ -107,13 +110,16 @@ class AsyncAdapt_asyncmy_connection(AsyncAdapt_dbapi_connection):
     def get_autocommit(self) -> bool:
         return self._connection.get_autocommit()  # type: ignore
 
-    def terminate(self) -> None:
-        # it's not awaitable.
-        self._connection.close()
-
     def close(self) -> None:
         self.await_(self._connection.ensure_closed())
 
+    async def _terminate_graceful_close(self) -> None:
+        await self._connection.ensure_closed()
+
+    def _terminate_force_close(self) -> None:
+        # it's not awaitable.
+        self._connection.close()
+
 
 class AsyncAdaptFallback_asyncmy_connection(AsyncAdapt_asyncmy_connection):
     __slots__ = ()
index adba7abb67bee4140b979df1a73a220bc10b67dd..5702f2bc1c8b4b42d7c29bc5d0a5551c1632659f 100644 (file)
@@ -205,6 +205,7 @@ from .types import CITEXT
 from ... import exc
 from ... import pool
 from ... import util
+from ...connectors.asyncio import AsyncAdapt_terminate
 from ...engine import AdaptedConnection
 from ...engine import processors
 from ...sql import sqltypes
@@ -695,7 +696,7 @@ class AsyncAdapt_asyncpg_ss_cursor(AsyncAdapt_asyncpg_cursor):
         )
 
 
-class AsyncAdapt_asyncpg_connection(AdaptedConnection):
+class AsyncAdapt_asyncpg_connection(AsyncAdapt_terminate, AdaptedConnection):
     __slots__ = (
         "dbapi",
         "isolation_level",
@@ -901,32 +902,18 @@ class AsyncAdapt_asyncpg_connection(AdaptedConnection):
 
         self.await_(self._connection.close())
 
-    def terminate(self):
-        if util.concurrency.in_greenlet():
-            # in a greenlet; this is the connection was invalidated
-            # case.
-            try:
-                # try to gracefully close; see #10717
-                # timeout added in asyncpg 0.14.0 December 2017
-                self.await_(asyncio.shield(self._connection.close(timeout=2)))
-            except (
-                asyncio.TimeoutError,
-                asyncio.CancelledError,
-                OSError,
-                self.dbapi.asyncpg.PostgresError,
-            ) as e:
-                # in the case where we are recycling an old connection
-                # that may have already been disconnected, close() will
-                # fail with the above timeout.  in this case, terminate
-                # the connection without any further waiting.
-                # see issue #8419
-                self._connection.terminate()
-                if isinstance(e, asyncio.CancelledError):
-                    # re-raise CancelledError if we were cancelled
-                    raise
-        else:
-            # not in a greenlet; this is the gc cleanup case
-            self._connection.terminate()
+    def _terminate_handled_exceptions(self):
+        return super()._terminate_handled_exceptions() + (
+            self.dbapi.asyncpg.PostgresError,
+        )
+
+    async def _terminate_graceful_close(self) -> None:
+        # timeout added in asyncpg 0.14.0 December 2017
+        await self._connection.close(timeout=2)
+        self._started = False
+
+    def _terminate_force_close(self) -> None:
+        self._connection.terminate()
         self._started = False
 
     @staticmethod