From 70fcb4faa02af840e068110eb791f52444652514 Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Sun, 15 Nov 2020 21:23:43 +0000 Subject: [PATCH] Added AsyncConnection.transaction() --- psycopg3/psycopg3/connection.py | 60 ++- psycopg3/psycopg3/transaction.py | 230 +++++++---- tests/test_transaction.py | 288 +++++++------- tests/test_transaction_async.py | 628 +++++++++++++++++++++++++++++++ 4 files changed, 993 insertions(+), 213 deletions(-) create mode 100644 tests/test_transaction_async.py diff --git a/psycopg3/psycopg3/connection.py b/psycopg3/psycopg3/connection.py index 45d786d63..016af7f01 100644 --- a/psycopg3/psycopg3/connection.py +++ b/psycopg3/psycopg3/connection.py @@ -4,8 +4,9 @@ psycopg3 connection objects # Copyright (C) 2020 The Psycopg Team -import logging +import sys import asyncio +import logging import threading from types import TracebackType from typing import Any, AsyncIterator, Callable, Iterator, List, NamedTuple @@ -14,6 +15,11 @@ from weakref import ref, ReferenceType from functools import partial from contextlib import contextmanager +if sys.version_info >= (3, 7): + from contextlib import asynccontextmanager +else: + from .utils.context import asynccontextmanager + from . import pq from . import cursor from . import errors as e @@ -24,7 +30,7 @@ from .proto import DumpersMap, LoadersMap, PQGen, RV, Query from .waiting import wait, wait_async from .conninfo import make_conninfo from .generators import notifies -from .transaction import Transaction +from .transaction import Transaction, AsyncTransaction logger = logging.getLogger(__name__) package_logger = logging.getLogger("psycopg3") @@ -335,7 +341,7 @@ class Connection(BaseConnection): savepoint_name: Optional[str] = None, force_rollback: bool = False, ) -> Iterator[Transaction]: - with Transaction(self, savepoint_name or "", force_rollback) as tx: + with Transaction(self, savepoint_name, force_rollback) as tx: yield tx @classmethod @@ -435,36 +441,58 @@ class AsyncConnection(BaseConnection): if self.pgconn.transaction_status != TransactionStatus.IDLE: return - self.pgconn.send_query(b"begin") - (pgres,) = await self.wait(execute(self.pgconn)) - if pgres.status != ExecStatus.COMMAND_OK: - raise e.OperationalError( - "error on begin:" - f" {pq.error_message(pgres, encoding=self.client_encoding)}" - ) + await self._exec_command(b"begin") async def commit(self) -> None: async with self.lock: if self.pgconn.transaction_status == TransactionStatus.IDLE: return - await self._exec(b"commit") + if self._savepoints is not None: + raise e.ProgrammingError( + "Explicit commit() forbidden within a Transaction " + "context. (Transaction will be automatically committed " + "on successful exit from context.)" + ) + await self._exec_command(b"commit") async def rollback(self) -> None: async with self.lock: if self.pgconn.transaction_status == TransactionStatus.IDLE: return - await self._exec(b"rollback") + if self._savepoints is not None: + raise e.ProgrammingError( + "Explicit rollback() forbidden within a Transaction " + "context. (Either raise Transaction.Rollback() or allow " + "an exception to propagate out of the context.)" + ) + await self._exec_command(b"rollback") - async def _exec(self, command: bytes) -> None: + async def _exec_command(self, command: Query) -> None: # Caller must hold self.lock + + if isinstance(command, str): + command = command.encode(self.client_encoding) + elif isinstance(command, Composable): + command = command.as_string(self).encode(self.client_encoding) + self.pgconn.send_query(command) - (pgres,) = await self.wait(execute(self.pgconn)) - if pgres.status != ExecStatus.COMMAND_OK: + results = await self.wait(execute(self.pgconn)) + if results[-1].status != ExecStatus.COMMAND_OK: raise e.OperationalError( f"error on {command.decode('utf8')}:" - f" {pq.error_message(pgres, encoding=self.client_encoding)}" + f" {pq.error_message(results[-1], encoding=self.client_encoding)}" ) + @asynccontextmanager + async def transaction( + self, + savepoint_name: Optional[str] = None, + force_rollback: bool = False, + ) -> AsyncIterator[AsyncTransaction]: + tx = AsyncTransaction(self, savepoint_name, force_rollback) + async with tx: + yield tx + @classmethod async def wait(cls, gen: PQGen[RV]) -> RV: return await wait_async(gen) diff --git a/psycopg3/psycopg3/transaction.py b/psycopg3/psycopg3/transaction.py index 0f0bbfaa7..22364d7a0 100644 --- a/psycopg3/psycopg3/transaction.py +++ b/psycopg3/psycopg3/transaction.py @@ -7,14 +7,15 @@ Transaction context managers returned by Connection.transaction() import logging from types import TracebackType -from typing import Optional, Type, TYPE_CHECKING +from typing import Generic, Optional, Type, Union, TYPE_CHECKING from . import sql from .pq import TransactionStatus -from psycopg3.errors import ProgrammingError +from .proto import ConnectionType +from .errors import ProgrammingError if TYPE_CHECKING: - from .connection import Connection + from .connection import Connection, AsyncConnection # noqa: F401 _log = logging.getLogger(__name__) @@ -28,34 +29,59 @@ class Rollback(Exception): enclosing transactions contexts up to and including the one specified. """ - def __init__(self, transaction: Optional["Transaction"] = None): + def __init__( + self, + transaction: Union["Transaction", "AsyncTransaction", None] = None, + ): self.transaction = transaction def __repr__(self) -> str: return f"{self.__class__.__qualname__}({self.transaction!r})" -class Transaction: +class BaseTransaction(Generic[ConnectionType]): def __init__( self, - connection: "Connection", - savepoint_name: str = "", + connection: ConnectionType, + savepoint_name: Optional[str] = None, force_rollback: bool = False, ): self._conn = connection - self._savepoint_name = savepoint_name + self._savepoint_name = savepoint_name or "" self.force_rollback = force_rollback self._outer_transaction: Optional[bool] = None @property - def connection(self) -> "Connection": + def connection(self) -> ConnectionType: return self._conn @property def savepoint_name(self) -> str: return self._savepoint_name + def __repr__(self) -> str: + args = [f"connection={self.connection}"] + if not self.savepoint_name: + args.append(f"savepoint_name={self.savepoint_name!r}") + if self.force_rollback: + args.append("force_rollback=True") + return f"{self.__class__.__qualname__}({', '.join(args)})" + + _out_of_order_err = ProgrammingError( + "Out-of-order Transaction context exits. Are you " + "calling __exit__() manually and getting it wrong?" + ) + + def _pop_savepoint(self) -> None: + if self._conn._savepoints is None: + raise self._out_of_order_err + actual = self._conn._savepoints.pop() + if actual != self._savepoint_name: + raise self._out_of_order_err + + +class Transaction(BaseTransaction["Connection"]): def __enter__(self) -> "Transaction": with self._conn.lock: if self._conn.pgconn.transaction_status == TransactionStatus.IDLE: @@ -87,69 +113,137 @@ class Transaction: exc_val: Optional[BaseException], exc_tb: Optional[TracebackType], ) -> bool: - out_of_order_err = ProgrammingError( - "Out-of-order Transaction context exits. Are you " - "calling __exit__() manually and getting it wrong?" - ) if self._outer_transaction is None: - raise out_of_order_err + raise self._out_of_order_err with self._conn.lock: - if exc_type is None and not self.force_rollback: - # Commit changes made in the transaction context - if self._savepoint_name: - if self._conn._savepoints is None: - raise out_of_order_err - actual = self._conn._savepoints.pop() - if actual != self._savepoint_name: - raise out_of_order_err - self._conn._exec_command( - sql.SQL("release savepoint {}").format( - sql.Identifier(self._savepoint_name) - ) - ) - if self._outer_transaction: - if self._conn._savepoints is None: - raise out_of_order_err - if len(self._conn._savepoints) != 0: - raise out_of_order_err - self._conn._exec_command(b"commit") - self._conn._savepoints = None + if not exc_val and not self.force_rollback: + return self._commit() + else: + return self._rollback(exc_val) + + def _commit(self) -> bool: + """Commit changes made in the transaction context.""" + if self._savepoint_name: + self._pop_savepoint() + self._conn._exec_command( + sql.SQL("release savepoint {}").format( + sql.Identifier(self._savepoint_name) + ) + ) + if self._outer_transaction: + if self._conn._savepoints is None or self._conn._savepoints: + raise self._out_of_order_err + self._conn._exec_command(b"commit") + self._conn._savepoints = None + + return False # discarded + + def _rollback(self, exc_val: Optional[BaseException]) -> bool: + # Rollback changes made in the transaction context + if isinstance(exc_val, Rollback): + _log.debug( + f"{self._conn}: Explicit rollback from: ", exc_info=True + ) + + if self._savepoint_name: + self._pop_savepoint() + self._conn._exec_command( + sql.SQL( + "rollback to savepoint {n}; release savepoint {n}" + ).format(n=sql.Identifier(self._savepoint_name)) + ) + if self._outer_transaction: + if self._conn._savepoints is None or self._conn._savepoints: + raise self._out_of_order_err + self._conn._exec_command(b"rollback") + self._conn._savepoints = None + + if isinstance(exc_val, Rollback): + if exc_val.transaction in (self, None): + return True # Swallow the exception + + return False + + +class AsyncTransaction(BaseTransaction["AsyncConnection"]): + async def __aenter__(self) -> "AsyncTransaction": + async with self._conn.lock: + if self._conn.pgconn.transaction_status == TransactionStatus.IDLE: + assert self._conn._savepoints is None, self._conn._savepoints + self._conn._savepoints = [] + self._outer_transaction = True + await self._conn._exec_command(b"begin") else: - # Rollback changes made in the transaction context - if isinstance(exc_val, Rollback): - _log.debug( - f"{self._conn}: Explicit rollback from: ", - exc_info=True, + if self._conn._savepoints is None: + self._conn._savepoints = [] + self._outer_transaction = False + if not self._savepoint_name: + self._savepoint_name = ( + f"s{len(self._conn._savepoints) + 1}" ) - if self._savepoint_name: - if self._conn._savepoints is None: - raise out_of_order_err - actual = self._conn._savepoints.pop() - if actual != self._savepoint_name: - raise out_of_order_err - self._conn._exec_command( - sql.SQL( - "rollback to savepoint {n}; release savepoint {n}" - ).format(n=sql.Identifier(self._savepoint_name)) + if self._savepoint_name: + await self._conn._exec_command( + sql.SQL("savepoint {}").format( + sql.Identifier(self._savepoint_name) ) - if self._outer_transaction: - if self._conn._savepoints is None: - raise out_of_order_err - if len(self._conn._savepoints) != 0: - raise out_of_order_err - self._conn._exec_command(b"rollback") - self._conn._savepoints = None - - if isinstance(exc_val, Rollback): - if exc_val.transaction in (self, None): - return True # Swallow the exception - return False + ) + self._conn._savepoints.append(self._savepoint_name) + return self - def __repr__(self) -> str: - args = [f"connection={self.connection}"] - if not self.savepoint_name: - args.append(f"savepoint_name={self.savepoint_name!r}") - if self.force_rollback: - args.append("force_rollback=True") - return f"{self.__class__.__qualname__}({', '.join(args)})" + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> bool: + if self._outer_transaction is None: + raise self._out_of_order_err + async with self._conn.lock: + if not exc_val and not self.force_rollback: + return await self._commit() + else: + return await self._rollback(exc_val) + + async def _commit(self) -> bool: + """Commit changes made in the transaction context.""" + if self._savepoint_name: + self._pop_savepoint() + await self._conn._exec_command( + sql.SQL("release savepoint {}").format( + sql.Identifier(self._savepoint_name) + ) + ) + if self._outer_transaction: + if self._conn._savepoints is None or self._conn._savepoints: + raise self._out_of_order_err + await self._conn._exec_command(b"commit") + self._conn._savepoints = None + + return False # discarded + + async def _rollback(self, exc_val: Optional[BaseException]) -> bool: + # Rollback changes made in the transaction context + if isinstance(exc_val, Rollback): + _log.debug( + f"{self._conn}: Explicit rollback from: ", exc_info=True + ) + + if self._savepoint_name: + self._pop_savepoint() + await self._conn._exec_command( + sql.SQL( + "rollback to savepoint {n}; release savepoint {n}" + ).format(n=sql.Identifier(self._savepoint_name)) + ) + if self._outer_transaction: + if self._conn._savepoints is None or self._conn._savepoints: + raise self._out_of_order_err + await self._conn._exec_command(b"rollback") + self._conn._savepoints = None + + if isinstance(exc_val, Rollback): + if exc_val.transaction in (self, None): + return True # Swallow the exception + + return False diff --git a/tests/test_transaction.py b/tests/test_transaction.py index 026a7fa6e..04e6f8143 100644 --- a/tests/test_transaction.py +++ b/tests/test_transaction.py @@ -1,15 +1,14 @@ import sys -from contextlib import contextmanager import pytest -from psycopg3 import ProgrammingError, Rollback +from psycopg3 import Connection, ProgrammingError, Rollback from psycopg3.sql import Composable from psycopg3.transaction import Transaction @pytest.fixture(autouse=True) -def test_table(svcconn): +def create_test_table(svcconn): """ Creates a table called 'test_table' for use in tests. """ @@ -24,40 +23,52 @@ def insert_row(conn, value): conn.cursor().execute("INSERT INTO test_table VALUES (%s)", (value,)) -def assert_rows(conn, expected): - rows = conn.cursor().execute("SELECT * FROM test_table").fetchall() - assert set(v for (v,) in rows) == expected +def inserted(conn): + sql = "SELECT * FROM test_table" + if isinstance(conn, Connection): + rows = conn.cursor().execute(sql).fetchall() + return set(v for (v,) in rows) + else: + async def f(): + cur = await conn.cursor() + await cur.execute(sql) + rows = await cur.fetchall() + return set(v for (v,) in rows) -def assert_not_in_transaction(conn): - assert conn.pgconn.transaction_status == conn.TransactionStatus.IDLE + return f() -def assert_in_transaction(conn): - assert conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS +@pytest.fixture +def commands(monkeypatch): + """The queue of commands issued internally by connections. + Not concurrency safe as it mokkeypatches a class method, but good enough + for tests. + """ + _orig_exec_command = Connection._exec_command + L = [] -@contextmanager -def assert_commands_issued(conn, *commands): - commands_actual = [] - real_exec_command = conn._exec_command - - def _exec_command(command): + def _exec_command(self, command): if isinstance(command, bytes): - command = command.decode(conn.client_encoding) + command = command.decode(self.client_encoding) elif isinstance(command, Composable): - command = command.as_string(conn) + command = command.as_string(self) - commands_actual.append(command) - real_exec_command(command) + L.insert(0, command) + _orig_exec_command(self, command) + + monkeypatch.setattr(Connection, "_exec_command", _exec_command) + yield L - try: - conn._exec_command = _exec_command - yield - finally: - conn._exec_command = real_exec_command - assert commands_actual == list(commands) +def in_transaction(conn): + if conn.pgconn.transaction_status == conn.TransactionStatus.IDLE: + return False + elif conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS: + return True + else: + assert False, conn.pgconn.transaction_status class ExpectedException(Exception): @@ -73,10 +84,10 @@ def some_exc_info(): def test_basic(conn): """Basic use of transaction() to BEGIN and COMMIT a transaction.""" - assert_not_in_transaction(conn) + assert not in_transaction(conn) with conn.transaction(): - assert_in_transaction(conn) - assert_not_in_transaction(conn) + assert in_transaction(conn) + assert not in_transaction(conn) def test_exposes_associated_connection(conn): @@ -98,10 +109,10 @@ def test_exposes_savepoint_name(conn): def test_begins_on_enter(conn): """Transaction does not begin until __enter__() is called.""" tx = conn.transaction() - assert_not_in_transaction(conn) + assert not in_transaction(conn) with tx: - assert_in_transaction(conn) - assert_not_in_transaction(conn) + assert in_transaction(conn) + assert not in_transaction(conn) def test_commit_on_successful_exit(conn): @@ -109,8 +120,8 @@ def test_commit_on_successful_exit(conn): with conn.transaction(): insert_row(conn, "foo") - assert_not_in_transaction(conn) - assert_rows(conn, {"foo"}) + assert not in_transaction(conn) + assert inserted(conn) == {"foo"} def test_rollback_on_exception_exit(conn): @@ -120,8 +131,8 @@ def test_rollback_on_exception_exit(conn): insert_row(conn, "foo") raise ExpectedException("This discards the insert") - assert_not_in_transaction(conn) - assert_rows(conn, set()) + assert not in_transaction(conn) + assert not inserted(conn) def test_prohibits_use_of_commit_rollback_autocommit(conn): @@ -169,14 +180,14 @@ def test_autocommit_off_but_no_tx_started_successful_exit(conn, svcconn): * Changes made within Transaction context are committed """ conn.autocommit = False - assert_not_in_transaction(conn) + assert not in_transaction(conn) with conn.transaction(): insert_row(conn, "new") - assert_not_in_transaction(conn) + assert not in_transaction(conn) # Changes committed - assert_rows(conn, {"new"}) - assert_rows(svcconn, {"new"}) + assert inserted(conn) == {"new"} + assert inserted(svcconn) == {"new"} def test_autocommit_off_but_no_tx_started_exception_exit(conn, svcconn): @@ -190,16 +201,16 @@ def test_autocommit_off_but_no_tx_started_exception_exit(conn, svcconn): * Changes made within Transaction context are discarded """ conn.autocommit = False - assert_not_in_transaction(conn) + assert not in_transaction(conn) with pytest.raises(ExpectedException): with conn.transaction(): insert_row(conn, "new") raise ExpectedException() - assert_not_in_transaction(conn) + assert not in_transaction(conn) # Changes discarded - assert_rows(conn, set()) - assert_rows(svcconn, set()) + assert not inserted(conn) + assert not inserted(svcconn) def test_autocommit_off_and_tx_in_progress_successful_exit(conn, svcconn): @@ -216,13 +227,13 @@ def test_autocommit_off_and_tx_in_progress_successful_exit(conn, svcconn): """ conn.autocommit = False insert_row(conn, "prior") - assert_in_transaction(conn) + assert in_transaction(conn) with conn.transaction(): insert_row(conn, "new") - assert_in_transaction(conn) - assert_rows(conn, {"prior", "new"}) + assert in_transaction(conn) + assert inserted(conn) == {"prior", "new"} # Nothing committed yet; changes not visible on another connection - assert_rows(svcconn, set()) + assert not inserted(svcconn) def test_autocommit_off_and_tx_in_progress_exception_exit(conn, svcconn): @@ -240,15 +251,15 @@ def test_autocommit_off_and_tx_in_progress_exception_exit(conn, svcconn): """ conn.autocommit = False insert_row(conn, "prior") - assert_in_transaction(conn) + assert in_transaction(conn) with pytest.raises(ExpectedException): with conn.transaction(): insert_row(conn, "new") raise ExpectedException() - assert_in_transaction(conn) - assert_rows(conn, {"prior"}) + assert in_transaction(conn) + assert inserted(conn) == {"prior"} # Nothing committed yet; changes not visible on another connection - assert_rows(svcconn, set()) + assert not inserted(svcconn) def test_nested_all_changes_persisted_on_successful_exit(conn, svcconn): @@ -258,9 +269,9 @@ def test_nested_all_changes_persisted_on_successful_exit(conn, svcconn): with conn.transaction(): insert_row(conn, "inner") insert_row(conn, "outer-after") - assert_not_in_transaction(conn) - assert_rows(conn, {"outer-before", "inner", "outer-after"}) - assert_rows(svcconn, {"outer-before", "inner", "outer-after"}) + assert not in_transaction(conn) + assert inserted(conn) == {"outer-before", "inner", "outer-after"} + assert inserted(svcconn) == {"outer-before", "inner", "outer-after"} def test_nested_all_changes_discarded_on_outer_exception(conn, svcconn): @@ -274,9 +285,9 @@ def test_nested_all_changes_discarded_on_outer_exception(conn, svcconn): with conn.transaction(): insert_row(conn, "inner") raise ExpectedException() - assert_not_in_transaction(conn) - assert_rows(conn, set()) - assert_rows(svcconn, set()) + assert not in_transaction(conn) + assert not inserted(conn) + assert not inserted(svcconn) def test_nested_all_changes_discarded_on_inner_exception(conn, svcconn): @@ -290,9 +301,9 @@ def test_nested_all_changes_discarded_on_inner_exception(conn, svcconn): with conn.transaction(): insert_row(conn, "inner") raise ExpectedException() - assert_not_in_transaction(conn) - assert_rows(conn, set()) - assert_rows(svcconn, set()) + assert not in_transaction(conn) + assert not inserted(conn) + assert not inserted(svcconn) def test_nested_inner_scope_exception_handled_in_outer_scope(conn, svcconn): @@ -309,9 +320,9 @@ def test_nested_inner_scope_exception_handled_in_outer_scope(conn, svcconn): insert_row(conn, "inner") raise ExpectedException() insert_row(conn, "outer-after") - assert_not_in_transaction(conn) - assert_rows(conn, {"outer-before", "outer-after"}) - assert_rows(svcconn, {"outer-before", "outer-after"}) + assert not in_transaction(conn) + assert inserted(conn) == {"outer-before", "outer-after"} + assert inserted(svcconn) == {"outer-before", "outer-after"} def test_nested_three_levels_successful_exit(conn, svcconn): @@ -322,9 +333,9 @@ def test_nested_three_levels_successful_exit(conn, svcconn): insert_row(conn, "two") with conn.transaction(): # SAVEPOINT s2 insert_row(conn, "three") - assert_not_in_transaction(conn) - assert_rows(conn, {"one", "two", "three"}) - assert_rows(svcconn, {"one", "two", "three"}) + assert not in_transaction(conn) + assert inserted(conn) == {"one", "two", "three"} + assert inserted(svcconn) == {"one", "two", "three"} def test_named_savepoint_escapes_savepoint_name(conn): @@ -334,7 +345,7 @@ def test_named_savepoint_escapes_savepoint_name(conn): pass -def test_named_savepoints_successful_exit(conn): +def test_named_savepoints_successful_exit(conn, commands): """ Entering a transaction context will do one of these these things: 1. Begin an outer transaction (if one isn't already in progress) @@ -347,40 +358,49 @@ def test_named_savepoints_successful_exit(conn): # Case 1 # Using Transaction explicitly becase conn.transaction() enters the contetx tx = Transaction(conn) - with assert_commands_issued(conn, "begin"): - tx.__enter__() + assert not commands + tx.__enter__() + assert commands.pop() == "begin" assert not tx.savepoint_name - with assert_commands_issued(conn, "commit"): - tx.__exit__(None, None, None) + tx.__exit__(None, None, None) + assert commands.pop() == "commit" # Case 2 tx = Transaction(conn, savepoint_name="foo") - with assert_commands_issued(conn, "begin", 'savepoint "foo"'): - tx.__enter__() + tx.__enter__() + assert commands.pop() == "begin" + assert commands.pop() == 'savepoint "foo"' assert tx.savepoint_name == "foo" - with assert_commands_issued(conn, 'release savepoint "foo"', "commit"): - tx.__exit__(None, None, None) + tx.__exit__(None, None, None) + assert commands.pop() == 'release savepoint "foo"' + assert commands.pop() == "commit" # Case 3 (with savepoint name provided) - with Transaction(conn): + with conn.transaction(): + assert commands.pop() == "begin" tx = Transaction(conn, savepoint_name="bar") - with assert_commands_issued(conn, 'savepoint "bar"'): - tx.__enter__() + tx.__enter__() + assert commands.pop() == 'savepoint "bar"' assert tx.savepoint_name == "bar" - with assert_commands_issued(conn, 'release savepoint "bar"'): - tx.__exit__(None, None, None) + tx.__exit__(None, None, None) + assert commands.pop() == 'release savepoint "bar"' + assert commands.pop() == "commit" # Case 3 (with savepoint name auto-generated) with conn.transaction(): + assert commands.pop() == "begin" tx = Transaction(conn) - with assert_commands_issued(conn, 'savepoint "s1"'): - tx.__enter__() + tx.__enter__() + assert commands.pop() == 'savepoint "s1"' assert tx.savepoint_name == "s1" - with assert_commands_issued(conn, 'release savepoint "s1"'): - tx.__exit__(None, None, None) + tx.__exit__(None, None, None) + assert commands.pop() == 'release savepoint "s1"' + assert commands.pop() == "commit" + + assert not commands -def test_named_savepoints_exception_exit(conn): +def test_named_savepoints_exception_exit(conn, commands): """ Same as the previous test but checks that when exiting the context with an exception, whatever transaction and/or savepoint was started on enter will @@ -388,45 +408,54 @@ def test_named_savepoints_exception_exit(conn): """ # Case 1 tx = Transaction(conn) - with assert_commands_issued(conn, "begin"): - tx.__enter__() + tx.__enter__() + assert commands.pop() == "begin" assert not tx.savepoint_name - with assert_commands_issued(conn, "rollback"): - tx.__exit__(*some_exc_info()) + tx.__exit__(*some_exc_info()) + assert commands.pop() == "rollback" # Case 2 tx = Transaction(conn, savepoint_name="foo") - with assert_commands_issued(conn, "begin", 'savepoint "foo"'): - tx.__enter__() + tx.__enter__() + assert commands.pop() == "begin" + assert commands.pop() == 'savepoint "foo"' assert tx.savepoint_name == "foo" - with assert_commands_issued( - conn, - 'rollback to savepoint "foo"; release savepoint "foo"', - "rollback", - ): - tx.__exit__(*some_exc_info()) + tx.__exit__(*some_exc_info()) + assert ( + commands.pop() + == 'rollback to savepoint "foo"; release savepoint "foo"' + ) + assert commands.pop() == "rollback" # Case 3 (with savepoint name provided) with conn.transaction(): + assert commands.pop() == "begin" tx = Transaction(conn, savepoint_name="bar") - with assert_commands_issued(conn, 'savepoint "bar"'): - tx.__enter__() + tx.__enter__() + assert commands.pop() == 'savepoint "bar"' assert tx.savepoint_name == "bar" - with assert_commands_issued( - conn, 'rollback to savepoint "bar"; release savepoint "bar"' - ): - tx.__exit__(*some_exc_info()) + tx.__exit__(*some_exc_info()) + assert ( + commands.pop() + == 'rollback to savepoint "bar"; release savepoint "bar"' + ) + assert commands.pop() == "commit" # Case 3 (with savepoint name auto-generated) with conn.transaction(): + assert commands.pop() == "begin" tx = Transaction(conn) - with assert_commands_issued(conn, 'savepoint "s1"'): - tx.__enter__() + tx.__enter__() + assert commands.pop() == 'savepoint "s1"' assert tx.savepoint_name == "s1" - with assert_commands_issued( - conn, 'rollback to savepoint "s1"; release savepoint "s1"' - ): - tx.__exit__(*some_exc_info()) + tx.__exit__(*some_exc_info()) + assert ( + commands.pop() + == 'rollback to savepoint "s1"; release savepoint "s1"' + ) + assert commands.pop() == "commit" + + assert not commands def test_named_savepoints_with_repeated_names_works(conn): @@ -442,7 +471,7 @@ def test_named_savepoints_with_repeated_names_works(conn): insert_row(conn, "tx2") with conn.transaction("sp"): insert_row(conn, "tx3") - assert_rows(conn, {"tx1", "tx2", "tx3"}) + assert inserted(conn) == {"tx1", "tx2", "tx3"} # Works correctly if one level of inner transaction is rolled back with conn.transaction(force_rollback=True): @@ -452,8 +481,8 @@ def test_named_savepoints_with_repeated_names_works(conn): insert_row(conn, "tx2") with conn.transaction("s1"): insert_row(conn, "tx3") - assert_rows(conn, {"tx1"}) - assert_rows(conn, {"tx1"}) + assert inserted(conn) == {"tx1"} + assert inserted(conn) == {"tx1"} # Works correctly if multiple inner transactions are rolled back # (This scenario mandates releasing savepoints after rolling back to them.) @@ -465,8 +494,8 @@ def test_named_savepoints_with_repeated_names_works(conn): with conn.transaction("s1"): insert_row(conn, "tx3") raise Rollback(tx2) - assert_rows(conn, {"tx1"}) - assert_rows(conn, {"tx1"}) + assert inserted(conn) == {"tx1"} + assert inserted(conn) == {"tx1"} # Will not (always) catch out-of-order exits with conn.transaction(force_rollback=True): @@ -485,8 +514,8 @@ def test_force_rollback_successful_exit(conn, svcconn): """ with conn.transaction(force_rollback=True): insert_row(conn, "foo") - assert_rows(conn, set()) - assert_rows(svcconn, set()) + assert not inserted(conn) + assert not inserted(svcconn) def test_force_rollback_exception_exit(conn, svcconn): @@ -498,8 +527,8 @@ def test_force_rollback_exception_exit(conn, svcconn): with conn.transaction(force_rollback=True): insert_row(conn, "foo") raise ExpectedException() - assert_rows(conn, set()) - assert_rows(svcconn, set()) + assert not inserted(conn) + assert not inserted(svcconn) def test_explicit_rollback_discards_changes(conn, svcconn): @@ -515,8 +544,8 @@ def test_explicit_rollback_discards_changes(conn, svcconn): """ def assert_no_rows(): - assert_rows(conn, set()) - assert_rows(svcconn, set()) + assert not inserted(conn) + assert not inserted(svcconn) with conn.transaction(): insert_row(conn, "foo") @@ -544,11 +573,11 @@ def test_explicit_rollback_outer_tx_unaffected(conn, svcconn): with conn.transaction(): insert_row(conn, "during") raise Rollback - assert_in_transaction(conn) - assert_rows(svcconn, set()) + assert in_transaction(conn) + assert not inserted(svcconn) insert_row(conn, "after") - assert_rows(conn, {"before", "after"}) - assert_rows(svcconn, {"before", "after"}) + assert inserted(conn) == {"before", "after"} + assert inserted(svcconn) == {"before", "after"} def test_explicit_rollback_of_outer_transaction(conn): @@ -562,7 +591,7 @@ def test_explicit_rollback_of_outer_transaction(conn): insert_row(conn, "inner") raise Rollback(outer_tx) assert False, "This line of code should be unreachable." - assert_rows(conn, set()) + assert not inserted(conn) def test_explicit_rollback_of_enclosing_tx_outer_tx_unaffected(conn, svcconn): @@ -578,9 +607,10 @@ def test_explicit_rollback_of_enclosing_tx_outer_tx_unaffected(conn, svcconn): raise Rollback(tx_enclosing) insert_row(conn, "outer-after") - assert_rows(conn, {"outer-before", "outer-after"}) - assert_rows(svcconn, set()) # Not yet committed - assert_rows(svcconn, {"outer-before", "outer-after"}) # Changes committed + assert inserted(conn) == {"outer-before", "outer-after"} + assert not inserted(svcconn) # Not yet committed + # Changes committed + assert inserted(svcconn) == {"outer-before", "outer-after"} @pytest.mark.parametrize("exc_info", [(None, None, None), some_exc_info()]) diff --git a/tests/test_transaction_async.py b/tests/test_transaction_async.py new file mode 100644 index 000000000..ceac50eef --- /dev/null +++ b/tests/test_transaction_async.py @@ -0,0 +1,628 @@ +import pytest + +from psycopg3 import AsyncConnection, ProgrammingError, Rollback +from psycopg3.sql import Composable +from psycopg3.transaction import AsyncTransaction + +from .test_transaction import ( + in_transaction, + ExpectedException, + some_exc_info, + inserted, +) +from .test_transaction import create_test_table # noqa # autouse fixture + +pytestmark = pytest.mark.asyncio + + +@pytest.fixture +async def commands(monkeypatch): + """The queue of commands issued internally by connections. + + Not concurrency safe as it mokkeypatches a class method, but good enough + for tests. + """ + _orig_exec_command = AsyncConnection._exec_command + L = [] + + async def _exec_command(self, command): + if isinstance(command, bytes): + command = command.decode(self.client_encoding) + elif isinstance(command, Composable): + command = command.as_string(self) + + L.insert(0, command) + await _orig_exec_command(self, command) + + monkeypatch.setattr(AsyncConnection, "_exec_command", _exec_command) + yield L + + +async def insert_row(aconn, value): + await (await aconn.cursor()).execute( + "INSERT INTO test_table VALUES (%s)", (value,) + ) + + +async def test_basic(aconn): + """Basic use of transaction() to BEGIN and COMMIT a transaction.""" + assert not in_transaction(aconn) + async with aconn.transaction(): + assert in_transaction(aconn) + assert not in_transaction(aconn) + + +async def test_exposes_associated_connection(aconn): + """Transaction exposes its connection as a read-only property.""" + async with aconn.transaction() as tx: + assert tx.connection is aconn + with pytest.raises(AttributeError): + tx.connection = aconn + + +async def test_exposes_savepoint_name(aconn): + """Transaction exposes its savepoint name as a read-only property.""" + async with aconn.transaction(savepoint_name="foo") as tx: + assert tx.savepoint_name == "foo" + with pytest.raises(AttributeError): + tx.savepoint_name = "bar" + + +async def test_begins_on_enter(aconn): + """Transaction does not begin until __enter__() is called.""" + tx = aconn.transaction() + assert not in_transaction(aconn) + async with tx: + assert in_transaction(aconn) + assert not in_transaction(aconn) + + +async def test_commit_on_successful_exit(aconn): + """Changes are committed on successful exit from the `with` block.""" + async with aconn.transaction(): + await insert_row(aconn, "foo") + + assert not in_transaction(aconn) + assert await inserted(aconn) == {"foo"} + + +async def test_rollback_on_exception_exit(aconn): + """Changes are rolled back if an exception escapes the `with` block.""" + with pytest.raises(ExpectedException): + async with aconn.transaction(): + await insert_row(aconn, "foo") + raise ExpectedException("This discards the insert") + + assert not in_transaction(aconn) + assert not await inserted(aconn) + + +async def test_prohibits_use_of_commit_rollback_autocommit(aconn): + """ + Within a Transaction block, it is forbidden to touch commit, rollback, + or the autocommit setting on the connection, as this would interfere + with the transaction scope being managed by the Transaction block. + """ + await aconn.set_autocommit(False) + await aconn.commit() + await aconn.rollback() + + async with aconn.transaction(): + with pytest.raises(ProgrammingError): + await aconn.set_autocommit(False) + with pytest.raises(ProgrammingError): + await aconn.commit() + with pytest.raises(ProgrammingError): + await aconn.rollback() + + await aconn.set_autocommit(False) + await aconn.commit() + await aconn.rollback() + + +@pytest.mark.parametrize("autocommit", [False, True]) +async def test_preserves_autocommit(aconn, autocommit): + """ + Connection.autocommit is unchanged both during and after Transaction block. + """ + await aconn.set_autocommit(autocommit) + async with aconn.transaction(): + assert aconn.autocommit is autocommit + assert aconn.autocommit is autocommit + + +async def test_autocommit_off_but_no_tx_started_successful_exit( + aconn, svcconn +): + """ + Scenario: + * Connection has autocommit off but no transaction has been initiated + before entering the Transaction context + * Code exits Transaction context successfully + + Outcome: + * Changes made within Transaction context are committed + """ + await aconn.set_autocommit(False) + assert not in_transaction(aconn) + async with aconn.transaction(): + await insert_row(aconn, "new") + assert not in_transaction(aconn) + + # Changes committed + assert await inserted(aconn) == {"new"} + assert inserted(svcconn) == {"new"} + + +async def test_autocommit_off_but_no_tx_started_exception_exit(aconn, svcconn): + """ + Scenario: + * Connection has autocommit off but no transaction has been initiated + before entering the Transaction context + * Code exits Transaction context with an exception + + Outcome: + * Changes made within Transaction context are discarded + """ + await aconn.set_autocommit(False) + assert not in_transaction(aconn) + with pytest.raises(ExpectedException): + async with aconn.transaction(): + await insert_row(aconn, "new") + raise ExpectedException() + assert not in_transaction(aconn) + + # Changes discarded + assert not await inserted(aconn) + assert not inserted(svcconn) + + +async def test_autocommit_off_and_tx_in_progress_successful_exit( + aconn, svcconn +): + """ + Scenario: + * Connection has autocommit off but and a transaction is already in + progress before entering the Transaction context + * Code exits Transaction context successfully + + Outcome: + * Changes made within Transaction context are left intact + * Outer transaction is left running, and no changes are visible to an + outside observer from another connection. + """ + await aconn.set_autocommit(False) + await insert_row(aconn, "prior") + assert in_transaction(aconn) + async with aconn.transaction(): + await insert_row(aconn, "new") + assert in_transaction(aconn) + assert await inserted(aconn) == {"prior", "new"} + # Nothing committed yet; changes not visible on another connection + assert not inserted(svcconn) + + +async def test_autocommit_off_and_tx_in_progress_exception_exit( + aconn, svcconn +): + """ + Scenario: + * Connection has autocommit off but and a transaction is already in + progress before entering the Transaction context + * Code exits Transaction context with an exception + + Outcome: + * Changes made before the Transaction context are left intact + * Changes made within Transaction context are discarded + * Outer transaction is left running, and no changes are visible to an + outside observer from another connection. + """ + await aconn.set_autocommit(False) + await insert_row(aconn, "prior") + assert in_transaction(aconn) + with pytest.raises(ExpectedException): + async with aconn.transaction(): + await insert_row(aconn, "new") + raise ExpectedException() + assert in_transaction(aconn) + assert await inserted(aconn) == {"prior"} + # Nothing committed yet; changes not visible on another connection + assert not inserted(svcconn) + + +async def test_nested_all_changes_persisted_on_successful_exit(aconn, svcconn): + """Changes from nested transaction contexts are all persisted on exit.""" + async with aconn.transaction(): + await insert_row(aconn, "outer-before") + async with aconn.transaction(): + await insert_row(aconn, "inner") + await insert_row(aconn, "outer-after") + assert not in_transaction(aconn) + assert await inserted(aconn) == {"outer-before", "inner", "outer-after"} + assert inserted(svcconn) == {"outer-before", "inner", "outer-after"} + + +async def test_nested_all_changes_discarded_on_outer_exception(aconn, svcconn): + """ + Changes from nested transaction contexts are discarded when an exception + raised in outer context escapes. + """ + with pytest.raises(ExpectedException): + async with aconn.transaction(): + await insert_row(aconn, "outer") + async with aconn.transaction(): + await insert_row(aconn, "inner") + raise ExpectedException() + assert not in_transaction(aconn) + assert not await inserted(aconn) + assert not inserted(svcconn) + + +async def test_nested_all_changes_discarded_on_inner_exception(aconn, svcconn): + """ + Changes from nested transaction contexts are discarded when an exception + raised in inner context escapes the outer context. + """ + with pytest.raises(ExpectedException): + async with aconn.transaction(): + await insert_row(aconn, "outer") + async with aconn.transaction(): + await insert_row(aconn, "inner") + raise ExpectedException() + assert not in_transaction(aconn) + assert not await inserted(aconn) + assert not inserted(svcconn) + + +async def test_nested_inner_scope_exception_handled_in_outer_scope( + aconn, svcconn +): + """ + An exception escaping the inner transaction context causes changes made + within that inner context to be discarded, but the error can then be + handled in the outer context, allowing changes made in the outer context + (both before, and after, the inner context) to be successfully committed. + """ + async with aconn.transaction(): + await insert_row(aconn, "outer-before") + with pytest.raises(ExpectedException): + async with aconn.transaction(): + await insert_row(aconn, "inner") + raise ExpectedException() + await insert_row(aconn, "outer-after") + assert not in_transaction(aconn) + assert await inserted(aconn) == {"outer-before", "outer-after"} + assert inserted(svcconn) == {"outer-before", "outer-after"} + + +async def test_nested_three_levels_successful_exit(aconn, svcconn): + """Exercise management of more than one savepoint.""" + async with aconn.transaction(): # BEGIN + await insert_row(aconn, "one") + async with aconn.transaction(): # SAVEPOINT s1 + await insert_row(aconn, "two") + async with aconn.transaction(): # SAVEPOINT s2 + await insert_row(aconn, "three") + assert not in_transaction(aconn) + assert await inserted(aconn) == {"one", "two", "three"} + assert inserted(svcconn) == {"one", "two", "three"} + + +async def test_named_savepoint_escapes_savepoint_name(aconn): + async with aconn.transaction("s-1"): + pass + async with aconn.transaction("s1; drop table students"): + pass + + +async def test_named_savepoints_successful_exit(aconn, commands): + """ + Entering a transaction context will do one of these these things: + 1. Begin an outer transaction (if one isn't already in progress) + 2. Begin an outer transaction and create a savepoint (if one is named) + 3. Create a savepoint (if a transaction is already in progress) + either using the name provided, or auto-generating a savepoint name. + + ...and exiting the context successfully will "commit" the same. + """ + # Case 1 + # Using Transaction explicitly becase conn.transaction() enters the contetx + tx = AsyncTransaction(aconn) + await tx.__aenter__() + assert commands.pop() == "begin" + assert not tx.savepoint_name + await tx.__aexit__(None, None, None) + assert commands.pop() == "commit" + + # Case 2 + tx = AsyncTransaction(aconn, savepoint_name="foo") + await tx.__aenter__() + assert commands.pop() == "begin" + assert commands.pop() == 'savepoint "foo"' + assert tx.savepoint_name == "foo" + await tx.__aexit__(None, None, None) + assert commands.pop() == 'release savepoint "foo"' + assert commands.pop() == "commit" + + # Case 3 (with savepoint name provided) + async with aconn.transaction(): + assert commands.pop() == "begin" + tx = AsyncTransaction(aconn, savepoint_name="bar") + await tx.__aenter__() + assert commands.pop() == 'savepoint "bar"' + assert tx.savepoint_name == "bar" + await tx.__aexit__(None, None, None) + assert commands.pop() == 'release savepoint "bar"' + assert commands.pop() == "commit" + + # Case 3 (with savepoint name auto-generated) + async with aconn.transaction(): + assert commands.pop() == "begin" + tx = AsyncTransaction(aconn) + await tx.__aenter__() + assert commands.pop() == 'savepoint "s1"' + assert tx.savepoint_name == "s1" + await tx.__aexit__(None, None, None) + assert commands.pop() == 'release savepoint "s1"' + assert commands.pop() == "commit" + + assert not commands + + +async def test_named_savepoints_exception_exit(aconn, commands): + """ + Same as the previous test but checks that when exiting the context with an + exception, whatever transaction and/or savepoint was started on enter will + be rolled-back as appropriate. + """ + # Case 1 + tx = AsyncTransaction(aconn) + await tx.__aenter__() + assert commands.pop() == "begin" + assert not tx.savepoint_name + await tx.__aexit__(*some_exc_info()) + assert commands.pop() == "rollback" + + # Case 2 + tx = AsyncTransaction(aconn, savepoint_name="foo") + await tx.__aenter__() + assert commands.pop() == "begin" + assert commands.pop() == 'savepoint "foo"' + assert tx.savepoint_name == "foo" + await tx.__aexit__(*some_exc_info()) + assert ( + commands.pop() + == 'rollback to savepoint "foo"; release savepoint "foo"' + ) + assert commands.pop() == "rollback" + + # Case 3 (with savepoint name provided) + async with aconn.transaction(): + assert commands.pop() == "begin" + tx = AsyncTransaction(aconn, savepoint_name="bar") + await tx.__aenter__() + assert commands.pop() == 'savepoint "bar"' + assert tx.savepoint_name == "bar" + await tx.__aexit__(*some_exc_info()) + assert ( + commands.pop() + == 'rollback to savepoint "bar"; release savepoint "bar"' + ) + assert commands.pop() == "commit" + + # Case 3 (with savepoint name auto-generated) + async with aconn.transaction(): + assert commands.pop() == "begin" + tx = AsyncTransaction(aconn) + await tx.__aenter__() + assert commands.pop() == 'savepoint "s1"' + assert tx.savepoint_name == "s1" + await tx.__aexit__(*some_exc_info()) + assert ( + commands.pop() + == 'rollback to savepoint "s1"; release savepoint "s1"' + ) + assert commands.pop() == "commit" + + assert not commands + + +async def test_named_savepoints_with_repeated_names_works(aconn): + """ + Using the same savepoint name repeatedly works correctly, but bypasses + some sanity checks. + """ + # Works correctly if no inner transactions are rolled back + async with aconn.transaction(force_rollback=True): + async with aconn.transaction("sp"): + await insert_row(aconn, "tx1") + async with aconn.transaction("sp"): + await insert_row(aconn, "tx2") + async with aconn.transaction("sp"): + await insert_row(aconn, "tx3") + assert await inserted(aconn) == {"tx1", "tx2", "tx3"} + + # Works correctly if one level of inner transaction is rolled back + async with aconn.transaction(force_rollback=True): + async with aconn.transaction("s1"): + await insert_row(aconn, "tx1") + async with aconn.transaction("s1", force_rollback=True): + await insert_row(aconn, "tx2") + async with aconn.transaction("s1"): + await insert_row(aconn, "tx3") + assert await inserted(aconn) == {"tx1"} + assert await inserted(aconn) == {"tx1"} + + # Works correctly if multiple inner transactions are rolled back + # (This scenario mandates releasing savepoints after rolling back to them.) + async with aconn.transaction(force_rollback=True): + async with aconn.transaction("s1"): + await insert_row(aconn, "tx1") + async with aconn.transaction("s1") as tx2: + await insert_row(aconn, "tx2") + async with aconn.transaction("s1"): + await insert_row(aconn, "tx3") + raise Rollback(tx2) + assert await inserted(aconn) == {"tx1"} + assert await inserted(aconn) == {"tx1"} + + # Will not (always) catch out-of-order exits + async with aconn.transaction(force_rollback=True): + tx1 = aconn.transaction("s1") + tx2 = aconn.transaction("s1") + await tx1.__aenter__() + await tx2.__aenter__() + await tx1.__aexit__(None, None, None) + await tx2.__aexit__(None, None, None) + + +async def test_force_rollback_successful_exit(aconn, svcconn): + """ + Transaction started with the force_rollback option enabled discards all + changes at the end of the context. + """ + async with aconn.transaction(force_rollback=True): + await insert_row(aconn, "foo") + assert not await inserted(aconn) + assert not inserted(svcconn) + + +async def test_force_rollback_exception_exit(aconn, svcconn): + """ + Transaction started with the force_rollback option enabled discards all + changes at the end of the context. + """ + with pytest.raises(ExpectedException): + async with aconn.transaction(force_rollback=True): + await insert_row(aconn, "foo") + raise ExpectedException() + assert not await inserted(aconn) + assert not inserted(svcconn) + + +async def test_explicit_rollback_discards_changes(aconn, svcconn): + """ + Raising a Rollback exception in the middle of a block exits the block and + discards all changes made within that block. + + You can raise any of the following: + - Rollback (type) + - Rollback() (instance) + - Rollback(tx) (instance initialised with reference to the transaction) + All of these are equivalent. + """ + + async def assert_no_rows(): + assert not await inserted(aconn) + assert not inserted(svcconn) + + async with aconn.transaction(): + await insert_row(aconn, "foo") + raise Rollback + await assert_no_rows() + + async with aconn.transaction(): + await insert_row(aconn, "foo") + raise Rollback() + await assert_no_rows() + + async with aconn.transaction() as tx: + await insert_row(aconn, "foo") + raise Rollback(tx) + await assert_no_rows() + + +async def test_explicit_rollback_outer_tx_unaffected(aconn, svcconn): + """ + Raising a Rollback exception in the middle of a block does not impact an + enclosing transaction block. + """ + async with aconn.transaction(): + await insert_row(aconn, "before") + async with aconn.transaction(): + await insert_row(aconn, "during") + raise Rollback + assert in_transaction(aconn) + assert not inserted(svcconn) + await insert_row(aconn, "after") + assert await inserted(aconn) == {"before", "after"} + assert inserted(svcconn) == {"before", "after"} + + +async def test_explicit_rollback_of_outer_transaction(aconn): + """ + Raising a Rollback exception that references an outer transaction will + discard all changes from both inner and outer transaction blocks. + """ + async with aconn.transaction() as outer_tx: + await insert_row(aconn, "outer") + async with aconn.transaction(): + await insert_row(aconn, "inner") + raise Rollback(outer_tx) + assert False, "This line of code should be unreachable." + assert not await inserted(aconn) + + +async def test_explicit_rollback_of_enclosing_tx_outer_tx_unaffected( + aconn, svcconn +): + """ + Rolling-back an enclosing transaction does not impact an outer transaction. + """ + async with aconn.transaction(): + await insert_row(aconn, "outer-before") + async with aconn.transaction() as tx_enclosing: + await insert_row(aconn, "enclosing") + async with aconn.transaction(): + await insert_row(aconn, "inner") + raise Rollback(tx_enclosing) + await insert_row(aconn, "outer-after") + + assert await inserted(aconn) == {"outer-before", "outer-after"} + assert not inserted(svcconn) # Not yet committed + # Changes committed + assert inserted(svcconn) == {"outer-before", "outer-after"} + + +@pytest.mark.parametrize("exc_info", [(None, None, None), some_exc_info()]) +@pytest.mark.parametrize("name", [None, "s1"]) +async def test_manual_enter_and_exit_out_of_order_exit_asserts( + aconn, name, exc_info +): + """ + When user is calling __enter__() and __exit__() manually for some reason, + provide a helpful error message if they call __exit__() in the wrong order + for nested transactions. + """ + tx1, tx2 = AsyncTransaction(aconn, name), AsyncTransaction(aconn) + await tx1.__aenter__() + await tx2.__aenter__() + with pytest.raises(ProgrammingError, match="Out-of-order"): + await tx1.__aexit__(*exc_info) + + +@pytest.mark.parametrize("exc_info", [(None, None, None), some_exc_info()]) +@pytest.mark.parametrize("name", [None, "s1"]) +async def test_manual_exit_without_enter_asserts(aconn, name, exc_info): + """ + When user is calling __enter__() and __exit__() manually for some reason, + provide a helpful error message if they call __exit__() without first + having called __enter__() + """ + tx = AsyncTransaction(aconn, name) + with pytest.raises(ProgrammingError, match="Out-of-order"): + await tx.__aexit__(*exc_info) + + +@pytest.mark.parametrize("exc_info", [(None, None, None), some_exc_info()]) +@pytest.mark.parametrize("name", [None, "s1"]) +async def test_manual_exit_twice_asserts(aconn, name, exc_info): + """ + When user is calling __enter__() and __exit__() manually for some reason, + provide a helpful error message if they accidentally call __exit__() twice. + """ + tx = AsyncTransaction(aconn, name) + await tx.__aenter__() + await tx.__aexit__(*exc_info) + with pytest.raises(ProgrammingError, match="Out-of-order"): + await tx.__aexit__(*exc_info) -- 2.47.3