From: Federico Caselli Date: Mon, 24 Mar 2025 21:44:05 +0000 (+0100) Subject: cleanup implementation of asyncpg connection X-Git-Url: http://git.ipfire.org/gitweb/gitweb.cgi?a=commitdiff_plain;h=a5f4ea7db112c244c6280475f9037abdfbdb93cf;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git cleanup implementation of asyncpg connection Change-Id: Ic7ecbeb6341145544b2a501d287e3e1d30fb1cbc --- diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py index a5a5dc08fa..fb35595016 100644 --- a/lib/sqlalchemy/dialects/postgresql/asyncpg.py +++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py @@ -185,6 +185,8 @@ import json as _py_json import re import time from typing import Any +from typing import Awaitable +from typing import Callable from typing import NoReturn from typing import Optional from typing import Protocol @@ -510,6 +512,12 @@ class PGIdentifierPreparer_asyncpg(PGIdentifierPreparer): pass +class _AsyncpgTransaction(Protocol): + async def start(self) -> None: ... + async def commit(self) -> None: ... + async def rollback(self) -> None: ... + + class _AsyncpgConnection(Protocol): async def executemany( self, operation: Any, seq_of_parameters: Sequence[Tuple[Any, ...]] @@ -529,11 +537,11 @@ class _AsyncpgConnection(Protocol): isolation: Optional[str] = None, readonly: bool = False, deferrable: bool = False, - ) -> Any: ... + ) -> _AsyncpgTransaction: ... def fetchrow(self, operation: str) -> Any: ... - async def close(self) -> None: ... + async def close(self, timeout: int = ...) -> None: ... def terminate(self) -> None: ... @@ -571,7 +579,7 @@ class AsyncAdapt_asyncpg_cursor(AsyncAdapt_dbapi_cursor): adapt_connection = self._adapt_connection async with adapt_connection._execute_mutex: - if not adapt_connection._started: + if adapt_connection._transaction is None: await adapt_connection._start_transaction() if parameters is None: @@ -642,7 +650,7 @@ class AsyncAdapt_asyncpg_cursor(AsyncAdapt_dbapi_cursor): self._invalidate_schema_cache_asof ) - if not adapt_connection._started: + if adapt_connection._transaction is None: await adapt_connection._start_transaction() try: @@ -747,6 +755,7 @@ class AsyncAdapt_asyncpg_connection(AsyncAdapt_dbapi_connection): _ss_cursor_cls = AsyncAdapt_asyncpg_ss_cursor _connection: _AsyncpgConnection + _transaction: Optional[_AsyncpgTransaction] __slots__ = ( "isolation_level", @@ -754,7 +763,6 @@ class AsyncAdapt_asyncpg_connection(AsyncAdapt_dbapi_connection): "readonly", "deferrable", "_transaction", - "_started", "_prepared_statement_cache", "_prepared_statement_name_func", "_invalidate_schema_cache_asof", @@ -772,7 +780,6 @@ class AsyncAdapt_asyncpg_connection(AsyncAdapt_dbapi_connection): self.readonly = False self.deferrable = False self._transaction = None - self._started = False self._invalidate_schema_cache_asof = time.time() if prepared_statement_cache_size: @@ -826,7 +833,6 @@ class AsyncAdapt_asyncpg_connection(AsyncAdapt_dbapi_connection): def _handle_exception(self, error: Exception) -> NoReturn: if self._connection.is_closed(): self._transaction = None - self._started = False if not isinstance(error, AsyncAdapt_asyncpg_dbapi.Error): exception_mapping = self.dbapi._asyncpg_error_translate @@ -876,14 +882,14 @@ class AsyncAdapt_asyncpg_connection(AsyncAdapt_dbapi_connection): await self._connection.fetchrow(";") def set_isolation_level(self, level): - if self._started: - self.rollback() + self.rollback() self.isolation_level = self._isolation_setting = level async def _start_transaction(self): if self.isolation_level == "autocommit": return + assert self._transaction is None try: self._transaction = self._connection.transaction( isolation=self.isolation_level, @@ -893,46 +899,28 @@ class AsyncAdapt_asyncpg_connection(AsyncAdapt_dbapi_connection): await self._transaction.start() except Exception as error: self._handle_exception(error) - else: - self._started = True - async def _rollback_and_discard(self): + async def _call_and_discard(self, fn: Callable[[], Awaitable[Any]]): try: - await self._transaction.rollback() + await fn() finally: - # if asyncpg .rollback() was actually called, then whether or - # not it raised or succeeded, the transation is done, discard it + # if asyncpg fn was actually called, then whether or + # not it raised or succeeded, the transaction is done, discard it self._transaction = None - self._started = False - - async def _commit_and_discard(self): - try: - await self._transaction.commit() - finally: - # if asyncpg .commit() was actually called, then whether or - # not it raised or succeeded, the transation is done, discard it - self._transaction = None - self._started = False def rollback(self): - if self._started: - assert self._transaction is not None + if self._transaction is not None: try: - await_(self._rollback_and_discard()) - self._transaction = None - self._started = False + await_(self._call_and_discard(self._transaction.rollback)) except Exception as error: # don't dereference asyncpg transaction if we didn't # actually try to call rollback() on it self._handle_exception(error) def commit(self): - if self._started: - assert self._transaction is not None + if self._transaction is not None: try: - await_(self._commit_and_discard()) - self._transaction = None - self._started = False + await_(self._call_and_discard(self._transaction.commit)) except Exception as error: # don't dereference asyncpg transaction if we didn't # actually try to call commit() on it @@ -969,7 +957,7 @@ class AsyncAdapt_asyncpg_connection(AsyncAdapt_dbapi_connection): else: # not in a greenlet; this is the gc cleanup case self._connection.terminate() - self._started = False + self._transaction = None @staticmethod def _default_name_func():