From dddfa96736dd905be59c8601ae3e09c8bc52299c Mon Sep 17 00:00:00 2001 From: Federico Caselli Date: Thu, 23 Jan 2025 22:42:14 +0100 Subject: [PATCH] Create terminate mixin Generalize the terminate logic employed by the asyncpg dialect to reuse it in the aiomysql and asyncmy dialect implementation. Fixes: #12273 Change-Id: Iddb658b7118de774f169e31e888a8aae1c7c6ec2 --- doc/build/changelog/unreleased_20/12273.rst | 6 +++ lib/sqlalchemy/connectors/asyncio.py | 42 ++++++++++++++++++ lib/sqlalchemy/dialects/mysql/aiomysql.py | 16 ++++--- lib/sqlalchemy/dialects/mysql/asyncmy.py | 16 ++++--- lib/sqlalchemy/dialects/postgresql/asyncpg.py | 44 +++++++------------ 5 files changed, 86 insertions(+), 38 deletions(-) create mode 100644 doc/build/changelog/unreleased_20/12273.rst diff --git a/doc/build/changelog/unreleased_20/12273.rst b/doc/build/changelog/unreleased_20/12273.rst new file mode 100644 index 0000000000..754677afaa --- /dev/null +++ b/doc/build/changelog/unreleased_20/12273.rst @@ -0,0 +1,6 @@ +.. change:: + :tags: usecase, asyncio + :tickets: 12273 + + Generalize the terminate logic employed by the asyncpg dialect to reuse + it in the aiomysql and asyncmy dialect implementation. diff --git a/lib/sqlalchemy/connectors/asyncio.py b/lib/sqlalchemy/connectors/asyncio.py index 87548e510b..c29aa3f69d 100644 --- a/lib/sqlalchemy/connectors/asyncio.py +++ b/lib/sqlalchemy/connectors/asyncio.py @@ -20,6 +20,8 @@ from typing import NoReturn from typing import Optional from typing import Protocol from typing import Sequence +from typing import Tuple +from typing import Type from typing import TYPE_CHECKING from ..engine import AdaptedConnection @@ -374,3 +376,43 @@ class AsyncAdapt_dbapi_connection(AdaptedConnection): def close(self) -> None: await_(self._connection.close()) + + +class AsyncAdapt_terminate: + """Mixin for a AsyncAdapt_dbapi_connection to add terminate support.""" + + __slots__ = () + + def terminate(self) -> None: + if in_greenlet(): + # in a greenlet; this is the connection was invalidated case. + try: + # try to gracefully close; see #10717 + await_(asyncio.shield(self._terminate_graceful_close())) + except self._terminate_handled_exceptions() as e: + # in the case where we are recycling an old connection + # that may have already been disconnected, close() will + # fail. In this case, terminate + # the connection without any further waiting. + # see issue #8419 + self._terminate_force_close() + if isinstance(e, asyncio.CancelledError): + # re-raise CancelledError if we were cancelled + raise + else: + # not in a greenlet; this is the gc cleanup case + self._terminate_force_close() + + def _terminate_handled_exceptions(self) -> Tuple[Type[BaseException], ...]: + """Returns the exceptions that should be handled when + calling _graceful_close. + """ + return (asyncio.TimeoutError, asyncio.CancelledError, OSError) + + async def _terminate_graceful_close(self) -> None: + """Try to close connection gracefully""" + raise NotImplementedError + + def _terminate_force_close(self) -> None: + """Terminate the connection""" + raise NotImplementedError diff --git a/lib/sqlalchemy/dialects/mysql/aiomysql.py b/lib/sqlalchemy/dialects/mysql/aiomysql.py index b23dbcf8f6..9c043850e4 100644 --- a/lib/sqlalchemy/dialects/mysql/aiomysql.py +++ b/lib/sqlalchemy/dialects/mysql/aiomysql.py @@ -41,6 +41,7 @@ from ...connectors.asyncio import AsyncAdapt_dbapi_connection from ...connectors.asyncio import AsyncAdapt_dbapi_cursor from ...connectors.asyncio import AsyncAdapt_dbapi_module from ...connectors.asyncio import AsyncAdapt_dbapi_ss_cursor +from ...connectors.asyncio import AsyncAdapt_terminate from ...util.concurrency import await_ if TYPE_CHECKING: @@ -77,7 +78,9 @@ class AsyncAdapt_aiomysql_ss_cursor( ) -class AsyncAdapt_aiomysql_connection(AsyncAdapt_dbapi_connection): +class AsyncAdapt_aiomysql_connection( + AsyncAdapt_terminate, AsyncAdapt_dbapi_connection +): __slots__ = () _cursor_cls = AsyncAdapt_aiomysql_cursor @@ -96,13 +99,16 @@ class AsyncAdapt_aiomysql_connection(AsyncAdapt_dbapi_connection): def get_autocommit(self) -> bool: return self._connection.get_autocommit() # type: ignore - def terminate(self) -> None: - # it's not awaitable. - self._connection.close() - def close(self) -> None: await_(self._connection.ensure_closed()) + async def _terminate_graceful_close(self) -> None: + await self._connection.ensure_closed() + + def _terminate_force_close(self) -> None: + # it's not awaitable. + self._connection.close() + class AsyncAdapt_aiomysql_dbapi(AsyncAdapt_dbapi_module): def __init__(self, aiomysql: ModuleType, pymysql: ModuleType): diff --git a/lib/sqlalchemy/dialects/mysql/asyncmy.py b/lib/sqlalchemy/dialects/mysql/asyncmy.py index b183f11c55..22a60a099a 100644 --- a/lib/sqlalchemy/dialects/mysql/asyncmy.py +++ b/lib/sqlalchemy/dialects/mysql/asyncmy.py @@ -41,6 +41,7 @@ from ...connectors.asyncio import AsyncAdapt_dbapi_connection from ...connectors.asyncio import AsyncAdapt_dbapi_cursor from ...connectors.asyncio import AsyncAdapt_dbapi_module from ...connectors.asyncio import AsyncAdapt_dbapi_ss_cursor +from ...connectors.asyncio import AsyncAdapt_terminate from ...util.concurrency import await_ if TYPE_CHECKING: @@ -72,7 +73,9 @@ class AsyncAdapt_asyncmy_ss_cursor( ) -class AsyncAdapt_asyncmy_connection(AsyncAdapt_dbapi_connection): +class AsyncAdapt_asyncmy_connection( + AsyncAdapt_terminate, AsyncAdapt_dbapi_connection +): __slots__ = () _cursor_cls = AsyncAdapt_asyncmy_cursor @@ -106,13 +109,16 @@ class AsyncAdapt_asyncmy_connection(AsyncAdapt_dbapi_connection): def get_autocommit(self) -> bool: return self._connection.get_autocommit() # type: ignore - def terminate(self) -> None: - # it's not awaitable. - self._connection.close() - def close(self) -> None: await_(self._connection.ensure_closed()) + async def _terminate_graceful_close(self) -> None: + await self._connection.ensure_closed() + + def _terminate_force_close(self) -> None: + # it's not awaitable. + self._connection.close() + class AsyncAdapt_asyncmy_dbapi(AsyncAdapt_dbapi_module): def __init__(self, asyncmy: ModuleType): diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py index 09ff9f48c0..7409a2023e 100644 --- a/lib/sqlalchemy/dialects/postgresql/asyncpg.py +++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py @@ -178,7 +178,6 @@ client using this setting passed to :func:`_asyncio.create_async_engine`:: from __future__ import annotations -import asyncio from collections import deque import decimal import json as _py_json @@ -218,6 +217,7 @@ from ... import util from ...connectors.asyncio import AsyncAdapt_dbapi_connection from ...connectors.asyncio import AsyncAdapt_dbapi_cursor from ...connectors.asyncio import AsyncAdapt_dbapi_ss_cursor +from ...connectors.asyncio import AsyncAdapt_terminate from ...engine import processors from ...sql import sqltypes from ...util.concurrency import await_ @@ -751,7 +751,9 @@ class AsyncAdapt_asyncpg_ss_cursor( ) -class AsyncAdapt_asyncpg_connection(AsyncAdapt_dbapi_connection): +class AsyncAdapt_asyncpg_connection( + AsyncAdapt_terminate, AsyncAdapt_dbapi_connection +): _cursor_cls = AsyncAdapt_asyncpg_cursor _ss_cursor_cls = AsyncAdapt_asyncpg_ss_cursor @@ -932,32 +934,18 @@ class AsyncAdapt_asyncpg_connection(AsyncAdapt_dbapi_connection): await_(self._connection.close()) - def terminate(self): - if util.concurrency.in_greenlet(): - # in a greenlet; this is the connection was invalidated - # case. - try: - # try to gracefully close; see #10717 - # timeout added in asyncpg 0.14.0 December 2017 - await_(asyncio.shield(self._connection.close(timeout=2))) - except ( - asyncio.TimeoutError, - asyncio.CancelledError, - OSError, - self.dbapi.asyncpg.PostgresError, - ) as e: - # in the case where we are recycling an old connection - # that may have already been disconnected, close() will - # fail with the above timeout. in this case, terminate - # the connection without any further waiting. - # see issue #8419 - self._connection.terminate() - if isinstance(e, asyncio.CancelledError): - # re-raise CancelledError if we were cancelled - raise - else: - # not in a greenlet; this is the gc cleanup case - self._connection.terminate() + def _terminate_handled_exceptions(self): + return super()._terminate_handled_exceptions() + ( + self.dbapi.asyncpg.PostgresError, + ) + + async def _terminate_graceful_close(self) -> None: + # timeout added in asyncpg 0.14.0 December 2017 + await self._connection.close(timeout=2) + self._transaction = None + + def _terminate_force_close(self) -> None: + self._connection.terminate() self._transaction = None @staticmethod -- 2.47.3