import re
import time
from typing import Any
+from typing import Awaitable
+from typing import Callable
from typing import NoReturn
from typing import Optional
from typing import Protocol
pass
+class _AsyncpgTransaction(Protocol):
+ async def start(self) -> None: ...
+ async def commit(self) -> None: ...
+ async def rollback(self) -> None: ...
+
+
class _AsyncpgConnection(Protocol):
async def executemany(
self, operation: Any, seq_of_parameters: Sequence[Tuple[Any, ...]]
isolation: Optional[str] = None,
readonly: bool = False,
deferrable: bool = False,
- ) -> Any: ...
+ ) -> _AsyncpgTransaction: ...
def fetchrow(self, operation: str) -> Any: ...
- async def close(self) -> None: ...
+ async def close(self, timeout: int = ...) -> None: ...
def terminate(self) -> None: ...
adapt_connection = self._adapt_connection
async with adapt_connection._execute_mutex:
- if not adapt_connection._started:
+ if adapt_connection._transaction is None:
await adapt_connection._start_transaction()
if parameters is None:
self._invalidate_schema_cache_asof
)
- if not adapt_connection._started:
+ if adapt_connection._transaction is None:
await adapt_connection._start_transaction()
try:
_ss_cursor_cls = AsyncAdapt_asyncpg_ss_cursor
_connection: _AsyncpgConnection
+ _transaction: Optional[_AsyncpgTransaction]
__slots__ = (
"isolation_level",
"readonly",
"deferrable",
"_transaction",
- "_started",
"_prepared_statement_cache",
"_prepared_statement_name_func",
"_invalidate_schema_cache_asof",
self.readonly = False
self.deferrable = False
self._transaction = None
- self._started = False
self._invalidate_schema_cache_asof = time.time()
if prepared_statement_cache_size:
def _handle_exception(self, error: Exception) -> NoReturn:
if self._connection.is_closed():
self._transaction = None
- self._started = False
if not isinstance(error, AsyncAdapt_asyncpg_dbapi.Error):
exception_mapping = self.dbapi._asyncpg_error_translate
await self._connection.fetchrow(";")
def set_isolation_level(self, level):
- if self._started:
- self.rollback()
+ self.rollback()
self.isolation_level = self._isolation_setting = level
async def _start_transaction(self):
if self.isolation_level == "autocommit":
return
+ assert self._transaction is None
try:
self._transaction = self._connection.transaction(
isolation=self.isolation_level,
await self._transaction.start()
except Exception as error:
self._handle_exception(error)
- else:
- self._started = True
- async def _rollback_and_discard(self):
+ async def _call_and_discard(self, fn: Callable[[], Awaitable[Any]]):
try:
- await self._transaction.rollback()
+ await fn()
finally:
- # if asyncpg .rollback() was actually called, then whether or
- # not it raised or succeeded, the transation is done, discard it
+ # if asyncpg fn was actually called, then whether or
+ # not it raised or succeeded, the transaction is done, discard it
self._transaction = None
- self._started = False
-
- async def _commit_and_discard(self):
- try:
- await self._transaction.commit()
- finally:
- # if asyncpg .commit() was actually called, then whether or
- # not it raised or succeeded, the transation is done, discard it
- self._transaction = None
- self._started = False
def rollback(self):
- if self._started:
- assert self._transaction is not None
+ if self._transaction is not None:
try:
- await_(self._rollback_and_discard())
- self._transaction = None
- self._started = False
+ await_(self._call_and_discard(self._transaction.rollback))
except Exception as error:
# don't dereference asyncpg transaction if we didn't
# actually try to call rollback() on it
self._handle_exception(error)
def commit(self):
- if self._started:
- assert self._transaction is not None
+ if self._transaction is not None:
try:
- await_(self._commit_and_discard())
- self._transaction = None
- self._started = False
+ await_(self._call_and_discard(self._transaction.commit))
except Exception as error:
# don't dereference asyncpg transaction if we didn't
# actually try to call commit() on it
else:
# not in a greenlet; this is the gc cleanup case
self._connection.terminate()
- self._started = False
+ self._transaction = None
@staticmethod
def _default_name_func():