]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
cleanup implementation of asyncpg connection
authorFederico Caselli <cfederico87@gmail.com>
Mon, 24 Mar 2025 21:44:05 +0000 (22:44 +0100)
committerFederico Caselli <cfederico87@gmail.com>
Sat, 26 Jul 2025 16:11:38 +0000 (16:11 +0000)
Change-Id: Ic7ecbeb6341145544b2a501d287e3e1d30fb1cbc

lib/sqlalchemy/dialects/postgresql/asyncpg.py

index a5a5dc08fa3ddb572c01052aa48ff25a3b821656..fb35595016cb5516865b5dc049f2f46f7d676936 100644 (file)
@@ -185,6 +185,8 @@ import json as _py_json
 import re
 import time
 from typing import Any
+from typing import Awaitable
+from typing import Callable
 from typing import NoReturn
 from typing import Optional
 from typing import Protocol
@@ -510,6 +512,12 @@ class PGIdentifierPreparer_asyncpg(PGIdentifierPreparer):
     pass
 
 
+class _AsyncpgTransaction(Protocol):
+    async def start(self) -> None: ...
+    async def commit(self) -> None: ...
+    async def rollback(self) -> None: ...
+
+
 class _AsyncpgConnection(Protocol):
     async def executemany(
         self, operation: Any, seq_of_parameters: Sequence[Tuple[Any, ...]]
@@ -529,11 +537,11 @@ class _AsyncpgConnection(Protocol):
         isolation: Optional[str] = None,
         readonly: bool = False,
         deferrable: bool = False,
-    ) -> Any: ...
+    ) -> _AsyncpgTransaction: ...
 
     def fetchrow(self, operation: str) -> Any: ...
 
-    async def close(self) -> None: ...
+    async def close(self, timeout: int = ...) -> None: ...
 
     def terminate(self) -> None: ...
 
@@ -571,7 +579,7 @@ class AsyncAdapt_asyncpg_cursor(AsyncAdapt_dbapi_cursor):
         adapt_connection = self._adapt_connection
 
         async with adapt_connection._execute_mutex:
-            if not adapt_connection._started:
+            if adapt_connection._transaction is None:
                 await adapt_connection._start_transaction()
 
             if parameters is None:
@@ -642,7 +650,7 @@ class AsyncAdapt_asyncpg_cursor(AsyncAdapt_dbapi_cursor):
                 self._invalidate_schema_cache_asof
             )
 
-            if not adapt_connection._started:
+            if adapt_connection._transaction is None:
                 await adapt_connection._start_transaction()
 
             try:
@@ -747,6 +755,7 @@ class AsyncAdapt_asyncpg_connection(AsyncAdapt_dbapi_connection):
     _ss_cursor_cls = AsyncAdapt_asyncpg_ss_cursor
 
     _connection: _AsyncpgConnection
+    _transaction: Optional[_AsyncpgTransaction]
 
     __slots__ = (
         "isolation_level",
@@ -754,7 +763,6 @@ class AsyncAdapt_asyncpg_connection(AsyncAdapt_dbapi_connection):
         "readonly",
         "deferrable",
         "_transaction",
-        "_started",
         "_prepared_statement_cache",
         "_prepared_statement_name_func",
         "_invalidate_schema_cache_asof",
@@ -772,7 +780,6 @@ class AsyncAdapt_asyncpg_connection(AsyncAdapt_dbapi_connection):
         self.readonly = False
         self.deferrable = False
         self._transaction = None
-        self._started = False
         self._invalidate_schema_cache_asof = time.time()
 
         if prepared_statement_cache_size:
@@ -826,7 +833,6 @@ class AsyncAdapt_asyncpg_connection(AsyncAdapt_dbapi_connection):
     def _handle_exception(self, error: Exception) -> NoReturn:
         if self._connection.is_closed():
             self._transaction = None
-            self._started = False
 
         if not isinstance(error, AsyncAdapt_asyncpg_dbapi.Error):
             exception_mapping = self.dbapi._asyncpg_error_translate
@@ -876,14 +882,14 @@ class AsyncAdapt_asyncpg_connection(AsyncAdapt_dbapi_connection):
             await self._connection.fetchrow(";")
 
     def set_isolation_level(self, level):
-        if self._started:
-            self.rollback()
+        self.rollback()
         self.isolation_level = self._isolation_setting = level
 
     async def _start_transaction(self):
         if self.isolation_level == "autocommit":
             return
 
+        assert self._transaction is None
         try:
             self._transaction = self._connection.transaction(
                 isolation=self.isolation_level,
@@ -893,46 +899,28 @@ class AsyncAdapt_asyncpg_connection(AsyncAdapt_dbapi_connection):
             await self._transaction.start()
         except Exception as error:
             self._handle_exception(error)
-        else:
-            self._started = True
 
-    async def _rollback_and_discard(self):
+    async def _call_and_discard(self, fn: Callable[[], Awaitable[Any]]):
         try:
-            await self._transaction.rollback()
+            await fn()
         finally:
-            # if asyncpg .rollback() was actually called, then whether or
-            # not it raised or succeeded, the transation is done, discard it
+            # if asyncpg fn was actually called, then whether or
+            # not it raised or succeeded, the transaction is done, discard it
             self._transaction = None
-            self._started = False
-
-    async def _commit_and_discard(self):
-        try:
-            await self._transaction.commit()
-        finally:
-            # if asyncpg .commit() was actually called, then whether or
-            # not it raised or succeeded, the transation is done, discard it
-            self._transaction = None
-            self._started = False
 
     def rollback(self):
-        if self._started:
-            assert self._transaction is not None
+        if self._transaction is not None:
             try:
-                await_(self._rollback_and_discard())
-                self._transaction = None
-                self._started = False
+                await_(self._call_and_discard(self._transaction.rollback))
             except Exception as error:
                 # don't dereference asyncpg transaction if we didn't
                 # actually try to call rollback() on it
                 self._handle_exception(error)
 
     def commit(self):
-        if self._started:
-            assert self._transaction is not None
+        if self._transaction is not None:
             try:
-                await_(self._commit_and_discard())
-                self._transaction = None
-                self._started = False
+                await_(self._call_and_discard(self._transaction.commit))
             except Exception as error:
                 # don't dereference asyncpg transaction if we didn't
                 # actually try to call commit() on it
@@ -969,7 +957,7 @@ class AsyncAdapt_asyncpg_connection(AsyncAdapt_dbapi_connection):
         else:
             # not in a greenlet; this is the gc cleanup case
             self._connection.terminate()
-        self._started = False
+        self._transaction = None
 
     @staticmethod
     def _default_name_func():