From: Daniel Fortunov Date: Sat, 25 Jul 2020 11:27:31 +0000 (+0100) Subject: First-cut implementation of connection.transaction() X-Git-Tag: 3.0.dev0~351^2~18 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=ff04f7d9db4ffc74762a3942ebe02d248ba8b0f4;p=thirdparty%2Fpsycopg.git First-cut implementation of connection.transaction() --- diff --git a/psycopg3/psycopg3/__init__.py b/psycopg3/psycopg3/__init__.py index b411dabb0..6f6a907e2 100644 --- a/psycopg3/psycopg3/__init__.py +++ b/psycopg3/psycopg3/__init__.py @@ -5,13 +5,13 @@ psycopg3 -- PostgreSQL database adapter for Python # Copyright (C) 2020 The Psycopg Team from . import pq -from .connection import AsyncConnection, Connection, Notify -from .cursor import AsyncCursor, Cursor, Column from .copy import Copy, AsyncCopy - +from .cursor import AsyncCursor, Cursor, Column from .errors import Warning, Error, InterfaceError, DatabaseError from .errors import DataError, OperationalError, IntegrityError from .errors import InternalError, ProgrammingError, NotSupportedError +from .connection import AsyncConnection, Connection, Notify +from .transaction import Rollback from .dbapi20 import BINARY, DATETIME, NUMBER, ROWID, STRING from .dbapi20 import Binary, Date, DateFromTicks, Time, TimeFromTicks diff --git a/psycopg3/psycopg3/connection.py b/psycopg3/psycopg3/connection.py index 0ebbf00c4..f186d46cb 100644 --- a/psycopg3/psycopg3/connection.py +++ b/psycopg3/psycopg3/connection.py @@ -22,6 +22,7 @@ from .proto import DumpersMap, LoadersMap, PQGen, RV from .waiting import wait, wait_async from .conninfo import make_conninfo from .generators import notifies +from .transaction import Transaction logger = logging.getLogger(__name__) package_logger = logging.getLogger("psycopg3") @@ -97,6 +98,11 @@ class BaseConnection: self._notice_handlers: List[NoticeHandler] = [] self._notify_handlers: List[NotifyHandler] = [] + # stack of savepoint names managed by active Transaction() blocks + self._savepoints: Optional[List[bytes]] = None + # (None when there no active Transaction blocks; [] when there is only + # one Transaction block, with a top-level transaction and no savepoint) + wself = ref(self) pgconn.notice_handler = partial(BaseConnection._notice_handler, wself) @@ -121,10 +127,18 @@ class BaseConnection: # subclasses must call it holding a lock status = self.pgconn.transaction_status if status != TransactionStatus.IDLE: - raise e.ProgrammingError( - "couldn't change autocommit state: connection in" - f" transaction status {TransactionStatus(status).name}" - ) + if self._savepoints is not None: + raise e.ProgrammingError( + "couldn't change autocommit state: " + "connection.transaction() context in progress" + ) + else: + raise e.ProgrammingError( + "couldn't change autocommit state: " + "connection in transaction status " + f"{TransactionStatus(status).name}" + ) + self._autocommit = value @property @@ -268,30 +282,37 @@ class Connection(BaseConnection): if self.pgconn.transaction_status != TransactionStatus.IDLE: return - self.pgconn.send_query(b"begin") - (pgres,) = self.wait(execute(self.pgconn)) - if pgres.status != ExecStatus.COMMAND_OK: - raise e.OperationalError( - "error on begin:" - f" {pq.error_message(pgres, encoding=self.client_encoding)}" - ) + self._exec_command(b"begin") def commit(self) -> None: """Commit any pending transaction to the database.""" with self.lock: - self._exec_commit_rollback(b"commit") + if self._savepoints is not None: + raise e.ProgrammingError( + "Explicit commit() forbidden within a Transaction " + "context. (Transaction will be automatically committed " + "on successful exit from context.)" + ) + if self.pgconn.transaction_status == TransactionStatus.IDLE: + return + self._exec_command(b"commit") def rollback(self) -> None: """Roll back to the start of any pending transaction.""" with self.lock: - self._exec_commit_rollback(b"rollback") + if self._savepoints is not None: + raise e.ProgrammingError( + "Explicit rollback() forbidden within a Transaction " + "context. (Either raise Transaction.Rollback() or allow " + "an exception to propagate out of the context.)" + ) + if self.pgconn.transaction_status == TransactionStatus.IDLE: + return + self._exec_command(b"rollback") - def _exec_commit_rollback(self, command: bytes) -> None: + def _exec_command(self, command: bytes) -> None: # Caller must hold self.lock - status = self.pgconn.transaction_status - if status == TransactionStatus.IDLE: - return - + logger.debug(f"{self}: {command!r}") self.pgconn.send_query(command) results = self.wait(execute(self.pgconn)) if results[-1].status != ExecStatus.COMMAND_OK: @@ -300,6 +321,13 @@ class Connection(BaseConnection): f" {pq.error_message(results[-1], encoding=self.client_encoding)}" ) + def transaction( + self, + savepoint_name: Optional[str] = None, + force_rollback: bool = False, + ) -> Transaction: + return Transaction(self, savepoint_name, force_rollback) + @classmethod def wait(cls, gen: PQGen[RV], timeout: Optional[float] = 0.1) -> RV: return wait(gen, timeout=timeout) @@ -407,18 +435,18 @@ class AsyncConnection(BaseConnection): async def commit(self) -> None: async with self.lock: - await self._exec_commit_rollback(b"commit") + if self.pgconn.transaction_status == TransactionStatus.IDLE: + return + await self._exec(b"commit") async def rollback(self) -> None: async with self.lock: - await self._exec_commit_rollback(b"rollback") + if self.pgconn.transaction_status == TransactionStatus.IDLE: + return + await self._exec(b"rollback") - async def _exec_commit_rollback(self, command: bytes) -> None: + async def _exec(self, command: bytes) -> None: # Caller must hold self.lock - status = self.pgconn.transaction_status - if status == TransactionStatus.IDLE: - return - self.pgconn.send_query(command) (pgres,) = await self.wait(execute(self.pgconn)) if pgres.status != ExecStatus.COMMAND_OK: diff --git a/psycopg3/psycopg3/transaction.py b/psycopg3/psycopg3/transaction.py new file mode 100644 index 000000000..ee9921f96 --- /dev/null +++ b/psycopg3/psycopg3/transaction.py @@ -0,0 +1,153 @@ +""" +Transaction context managers returned by Connection.transaction() +""" + +# Copyright (C) 2020 The Psycopg Team + +import logging + +from psycopg3.errors import ProgrammingError +from types import TracebackType +from typing import Optional, Type, TYPE_CHECKING + +from .pq import TransactionStatus + +if TYPE_CHECKING: + from .connection import Connection + +_log = logging.getLogger(__name__) + + +class Rollback(Exception): + """ + Exit the current Transaction context immediately and rollback any changes + made within this context. + + If a transaction context is specified in the constructor, rollback + enclosing transactions contexts up to and including the one specified. + """ + + def __init__(self, transaction: Optional["Transaction"] = None): + self.transaction = transaction + + def __repr__(self) -> str: + return f"{self.__class__.__qualname__}({self.transaction!r})" + + +class Transaction: + def __init__( + self, + connection: "Connection", + savepoint_name: Optional[str], + force_rollback: bool, + ): + self._conn = connection + self._savepoint_name: Optional[bytes] = None + if savepoint_name is not None: + if len(savepoint_name) == 0: + raise ValueError("savepoint_name must be a non-empty string") + self._savepoint_name = connection.codec.encode(savepoint_name)[0] + self.force_rollback = force_rollback + + self._outer_transaction: Optional[bool] = None + + @property + def connection(self) -> "Connection": + return self._conn + + @property + def savepoint_name(self) -> Optional[str]: + if self._savepoint_name is None: + return None + return self._conn.codec.decode(self._savepoint_name)[0] + + def __enter__(self) -> "Transaction": + with self._conn.lock: + if self._conn.pgconn.transaction_status == TransactionStatus.IDLE: + assert self._conn._savepoints is None, self._conn._savepoints + self._conn._savepoints = [] + self._outer_transaction = True + self._conn._exec_command(b"begin") + else: + if self._conn._savepoints is None: + self._conn._savepoints = [] + self._outer_transaction = False + if self._savepoint_name is None: + self._savepoint_name = b"s%i" % ( + len(self._conn._savepoints) + 1 + ) + + if self._savepoint_name is not None: + self._conn._exec_command(b"savepoint " + self._savepoint_name) + self._conn._savepoints.append(self._savepoint_name) + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> bool: + out_of_order_err = ProgrammingError( + "Out-of-order Transaction context exits. Are you " + "calling __exit__() manually and getting it wrong?" + ) + if self._outer_transaction is None: + raise out_of_order_err + with self._conn.lock: + if exc_type is None and not self.force_rollback: + # Commit changes made in the transaction context + if self._savepoint_name: + if self._conn._savepoints is None: + raise out_of_order_err + actual = self._conn._savepoints.pop() + if actual != self._savepoint_name: + raise out_of_order_err + self._conn._exec_command( + b"release savepoint " + self._savepoint_name + ) + if self._outer_transaction: + if self._conn._savepoints is None: + raise out_of_order_err + if len(self._conn._savepoints) != 0: + raise out_of_order_err + self._conn._exec_command(b"commit") + self._conn._savepoints = None + else: + # Rollback changes made in the transaction context + if isinstance(exc_val, Rollback): + _log.debug( + f"{self._conn}: Explicit rollback from: ", + exc_info=True, + ) + + if self._savepoint_name: + if self._conn._savepoints is None: + raise out_of_order_err + actual = self._conn._savepoints.pop() + if actual != self._savepoint_name: + raise out_of_order_err + self._conn._exec_command( + b"rollback to savepoint " + self._savepoint_name + b";" + b"release savepoint " + self._savepoint_name + ) + if self._outer_transaction: + if self._conn._savepoints is None: + raise out_of_order_err + if len(self._conn._savepoints) != 0: + raise out_of_order_err + self._conn._exec_command(b"rollback") + self._conn._savepoints = None + + if isinstance(exc_val, Rollback): + if exc_val.transaction in (self, None): + return True # Swallow the exception + return False + + def __repr__(self) -> str: + args = [f"connection={self.connection}"] + if self.savepoint_name is not None: + args.append(f"savepoint_name={self.savepoint_name!r}") + if self.force_rollback: + args.append("force_rollback=True") + return f"{self.__class__.__qualname__}({', '.join(args)})" diff --git a/tests/test_transaction.py b/tests/test_transaction.py new file mode 100644 index 000000000..2afa831e6 --- /dev/null +++ b/tests/test_transaction.py @@ -0,0 +1,618 @@ +import sys +from contextlib import contextmanager + +import pytest + +from psycopg3 import OperationalError, ProgrammingError, Rollback + + +@pytest.fixture(autouse=True) +def test_table(svcconn): + """ + Creates a table called 'test_table' for use in tests. + """ + cur = svcconn.cursor() + cur.execute("drop table if exists test_table") + cur.execute("create table test_table (id text primary key)") + yield + cur.execute("drop table test_table") + + +def insert_row(conn, value): + conn.cursor().execute("INSERT INTO test_table VALUES (%s)", (value,)) + + +def assert_rows(conn, expected): + rows = conn.cursor().execute("SELECT * FROM test_table").fetchall() + assert set(v for (v,) in rows) == expected + + +def assert_not_in_transaction(conn): + assert conn.pgconn.transaction_status == conn.TransactionStatus.IDLE + + +def assert_in_transaction(conn): + assert conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS + + +@contextmanager +def assert_commands_issued(conn, *commands): + commands_actual = [] + real_exec_command = conn._exec_command + + def _exec_command(command): + commands_actual.append(command) + real_exec_command(command) + + try: + conn._exec_command = _exec_command + yield + finally: + conn._exec_command = real_exec_command + commands_expected = [cmd.encode("ascii") for cmd in commands] + assert commands_actual == commands_expected + + +class ExpectedException(Exception): + pass + + +def some_exc_info(): + try: + raise ExpectedException() + except ExpectedException: + return sys.exc_info() + + +def test_basic(conn): + """Basic use of transaction() to BEGIN and COMMIT a transaction.""" + assert_not_in_transaction(conn) + with conn.transaction(): + assert_in_transaction(conn) + assert_not_in_transaction(conn) + + +def test_exposes_associated_connection(conn): + """Transaction exposes its connection as a read-only property.""" + with conn.transaction() as tx: + assert tx.connection is conn + with pytest.raises(AttributeError): + tx.connection = conn + + +def test_exposes_savepoint_name(conn): + """Transaction exposes its savepoint name as a read-only property.""" + with conn.transaction(savepoint_name="foo") as tx: + assert tx.savepoint_name == "foo" + with pytest.raises(AttributeError): + tx.savepoint_name = "bar" + + +def test_begins_on_enter(conn): + """Transaction does not begin until __enter__() is called.""" + tx = conn.transaction() + assert_not_in_transaction(conn) + with tx: + assert_in_transaction(conn) + assert_not_in_transaction(conn) + + +def test_commit_on_successful_exit(conn): + """Changes are committed on successful exit from the `with` block.""" + with conn.transaction(): + insert_row(conn, "foo") + + assert_not_in_transaction(conn) + assert_rows(conn, {"foo"}) + + +def test_rollback_on_exception_exit(conn): + """Changes are rolled back if an exception escapes the `with` block.""" + with pytest.raises(ExpectedException): + with conn.transaction(): + insert_row(conn, "foo") + raise ExpectedException("This discards the insert") + + assert_not_in_transaction(conn) + assert_rows(conn, set()) + + +def test_prohibits_use_of_commit_rollback_autocommit(conn): + """ + Within a Transaction block, it is forbidden to touch commit, rollback, + or the autocommit setting on the connection, as this would interfere + with the transaction scope being managed by the Transaction block. + """ + conn.autocommit = False + conn.commit() + conn.rollback() + + with conn.transaction(): + with pytest.raises(ProgrammingError): + conn.autocommit = False + with pytest.raises(ProgrammingError): + conn.commit() + with pytest.raises(ProgrammingError): + conn.rollback() + + conn.autocommit = False + conn.commit() + conn.rollback() + + +@pytest.mark.parametrize("autocommit", [False, True]) +def test_preserves_autocommit(conn, autocommit): + """ + Connection.autocommit is unchanged both during and after Transaction block. + """ + conn.autocommit = autocommit + with conn.transaction(): + assert conn.autocommit is autocommit + assert conn.autocommit is autocommit + + +def test_autocommit_off_but_no_tx_started_successful_exit(conn, svcconn): + """ + Scenario: + * Connection has autocommit off but no transaction has been initiated + before entering the Transaction context + * Code exits Transaction context successfully + + Outcome: + * Changes made within Transaction context are committed + """ + conn.autocommit = False + assert_not_in_transaction(conn) + with conn.transaction(): + insert_row(conn, "new") + assert_not_in_transaction(conn) + + # Changes committed + assert_rows(conn, {"new"}) + assert_rows(svcconn, {"new"}) + + +def test_autocommit_off_but_no_tx_started_exception_exit(conn, svcconn): + """ + Scenario: + * Connection has autocommit off but no transaction has been initiated + before entering the Transaction context + * Code exits Transaction context with an exception + + Outcome: + * Changes made within Transaction context are discarded + """ + conn.autocommit = False + assert_not_in_transaction(conn) + with pytest.raises(ExpectedException): + with conn.transaction(): + insert_row(conn, "new") + raise ExpectedException() + assert_not_in_transaction(conn) + + # Changes discarded + assert_rows(conn, set()) + assert_rows(svcconn, set()) + + +def test_autocommit_off_and_tx_in_progress_successful_exit(conn, svcconn): + """ + Scenario: + * Connection has autocommit off but and a transaction is already in + progress before entering the Transaction context + * Code exits Transaction context successfully + + Outcome: + * Changes made within Transaction context are left intact + * Outer transaction is left running, and no changes are visible to an + outside observer from another connection. + """ + conn.autocommit = False + insert_row(conn, "prior") + assert_in_transaction(conn) + with conn.transaction(): + insert_row(conn, "new") + assert_in_transaction(conn) + assert_rows(conn, {"prior", "new"}) + # Nothing committed yet; changes not visible on another connection + assert_rows(svcconn, set()) + + +def test_autocommit_off_and_tx_in_progress_exception_exit(conn, svcconn): + """ + Scenario: + * Connection has autocommit off but and a transaction is already in + progress before entering the Transaction context + * Code exits Transaction context with an exception + + Outcome: + * Changes made before the Transaction context are left intact + * Changes made within Transaction context are discarded + * Outer transaction is left running, and no changes are visible to an + outside observer from another connection. + """ + conn.autocommit = False + insert_row(conn, "prior") + assert_in_transaction(conn) + with pytest.raises(ExpectedException): + with conn.transaction(): + insert_row(conn, "new") + raise ExpectedException() + assert_in_transaction(conn) + assert_rows(conn, {"prior"}) + # Nothing committed yet; changes not visible on another connection + assert_rows(svcconn, set()) + + +def test_nested_all_changes_persisted_on_successful_exit(conn, svcconn): + """Changes from nested transaction contexts are all persisted on exit.""" + with conn.transaction(): + insert_row(conn, "outer-before") + with conn.transaction(): + insert_row(conn, "inner") + insert_row(conn, "outer-after") + assert_not_in_transaction(conn) + assert_rows(conn, {"outer-before", "inner", "outer-after"}) + assert_rows(svcconn, {"outer-before", "inner", "outer-after"}) + + +def test_nested_all_changes_discarded_on_outer_exception(conn, svcconn): + """ + Changes from nested transaction contexts are discarded when an exception + raised in outer context escapes. + """ + with pytest.raises(ExpectedException): + with conn.transaction(): + insert_row(conn, "outer") + with conn.transaction(): + insert_row(conn, "inner") + raise ExpectedException() + assert_not_in_transaction(conn) + assert_rows(conn, set()) + assert_rows(svcconn, set()) + + +def test_nested_all_changes_discarded_on_inner_exception(conn, svcconn): + """ + Changes from nested transaction contexts are discarded when an exception + raised in inner context escapes the outer context. + """ + with pytest.raises(ExpectedException): + with conn.transaction(): + insert_row(conn, "outer") + with conn.transaction(): + insert_row(conn, "inner") + raise ExpectedException() + assert_not_in_transaction(conn) + assert_rows(conn, set()) + assert_rows(svcconn, set()) + + +def test_nested_inner_scope_exception_handled_in_outer_scope(conn, svcconn): + """ + An exception escaping the inner transaction context causes changes made + within that inner context to be discarded, but the error can then be + handled in the outer context, allowing changes made in the outer context + (both before, and after, the inner context) to be successfully committed. + """ + with conn.transaction(): + insert_row(conn, "outer-before") + with pytest.raises(ExpectedException): + with conn.transaction(): + insert_row(conn, "inner") + raise ExpectedException() + insert_row(conn, "outer-after") + assert_not_in_transaction(conn) + assert_rows(conn, {"outer-before", "outer-after"}) + assert_rows(svcconn, {"outer-before", "outer-after"}) + + +def test_nested_three_levels_successful_exit(conn, svcconn): + """Exercise management of more than one savepoint.""" + with conn.transaction(): # BEGIN + insert_row(conn, "one") + with conn.transaction(): # SAVEPOINT s1 + insert_row(conn, "two") + with conn.transaction(): # SAVEPOINT s2 + insert_row(conn, "three") + assert_not_in_transaction(conn) + assert_rows(conn, {"one", "two", "three"}) + assert_rows(svcconn, {"one", "two", "three"}) + + +def test_named_savepoint_empty_string_invalid(conn): + """ + Raise validate savepoint_name up-front (rather than later constructing an + invalid SQL command and having that fail with an OperationalError). + """ + with pytest.raises(ValueError): + conn.transaction(savepoint_name="") + + +@pytest.mark.xfail(raises=OperationalError, reason="TODO: Escape sp names") +def test_named_savepoint_escapes_savepoint_name(conn): + with conn.transaction("s-1"): + pass + with conn.transaction("s1; drop table students"): + pass + + +def test_named_savepoints_successful_exit(conn): + """ + Entering a transaction context will do one of these these things: + 1. Begin an outer transaction (if one isn't already in progress) + 2. Begin an outer transaction and create a savepoint (if one is named) + 3. Create a savepoint (if a transaction is already in progress) + either using the name provided, or auto-generating a savepoint name. + + ...and exiting the context successfully will "commit" the same. + """ + # Case 1 + tx = conn.transaction() + with assert_commands_issued(conn, "begin"): + tx.__enter__() + assert tx.savepoint_name is None + with assert_commands_issued(conn, "commit"): + tx.__exit__(None, None, None) + + # Case 2 + tx = conn.transaction(savepoint_name="foo") + with assert_commands_issued(conn, "begin", "savepoint foo"): + tx.__enter__() + assert tx.savepoint_name == "foo" + with assert_commands_issued(conn, "release savepoint foo", "commit"): + tx.__exit__(None, None, None) + + # Case 3 (with savepoint name provided) + with conn.transaction(): + tx = conn.transaction(savepoint_name="bar") + with assert_commands_issued(conn, "savepoint bar"): + tx.__enter__() + assert tx.savepoint_name == "bar" + with assert_commands_issued(conn, "release savepoint bar"): + tx.__exit__(None, None, None) + + # Case 3 (with savepoint name auto-generated) + with conn.transaction(): + tx = conn.transaction() + with assert_commands_issued(conn, "savepoint s1"): + tx.__enter__() + assert tx.savepoint_name == "s1" + with assert_commands_issued(conn, "release savepoint s1"): + tx.__exit__(None, None, None) + + +def test_named_savepoints_exception_exit(conn): + """ + Same as the previous test but checks that when exiting the context with an + exception, whatever transaction and/or savepoint was started on enter will + be rolled-back as appropriate. + """ + # Case 1 + tx = conn.transaction() + with assert_commands_issued(conn, "begin"): + tx.__enter__() + assert tx.savepoint_name is None + with assert_commands_issued(conn, "rollback"): + tx.__exit__(*some_exc_info()) + + # Case 2 + tx = conn.transaction(savepoint_name="foo") + with assert_commands_issued(conn, "begin", "savepoint foo"): + tx.__enter__() + assert tx.savepoint_name == "foo" + with assert_commands_issued( + conn, "rollback to savepoint foo;release savepoint foo", "rollback" + ): + tx.__exit__(*some_exc_info()) + + # Case 3 (with savepoint name provided) + with conn.transaction(): + tx = conn.transaction(savepoint_name="bar") + with assert_commands_issued(conn, "savepoint bar"): + tx.__enter__() + assert tx.savepoint_name == "bar" + with assert_commands_issued( + conn, "rollback to savepoint bar;release savepoint bar" + ): + tx.__exit__(*some_exc_info()) + + # Case 3 (with savepoint name auto-generated) + with conn.transaction(): + tx = conn.transaction() + with assert_commands_issued(conn, "savepoint s1"): + tx.__enter__() + assert tx.savepoint_name == "s1" + with assert_commands_issued( + conn, "rollback to savepoint s1;release savepoint s1" + ): + tx.__exit__(*some_exc_info()) + + +def test_named_savepoints_with_repeated_names_works(conn): + """ + Using the same savepoint name repeatedly works correctly, but bypasses + some sanity checks. + """ + # Works correctly if no inner transactions are rolled back + with conn.transaction(force_rollback=True): + with conn.transaction("sp"): + insert_row(conn, "tx1") + with conn.transaction("sp"): + insert_row(conn, "tx2") + with conn.transaction("sp"): + insert_row(conn, "tx3") + assert_rows(conn, {"tx1", "tx2", "tx3"}) + + # Works correctly if one level of inner transaction is rolled back + with conn.transaction(force_rollback=True): + with conn.transaction("s1"): + insert_row(conn, "tx1") + with conn.transaction("s1", force_rollback=True): + insert_row(conn, "tx2") + with conn.transaction("s1"): + insert_row(conn, "tx3") + assert_rows(conn, {"tx1"}) + assert_rows(conn, {"tx1"}) + + # Works correctly if multiple inner transactions are rolled back + # (This scenario mandates releasing savepoints after rolling back to them.) + with conn.transaction(force_rollback=True): + with conn.transaction("s1"): + insert_row(conn, "tx1") + with conn.transaction("s1") as tx2: + insert_row(conn, "tx2") + with conn.transaction("s1"): + insert_row(conn, "tx3") + raise Rollback(tx2) + assert_rows(conn, {"tx1"}) + assert_rows(conn, {"tx1"}) + + # Will not (always) catch out-of-order exits + with conn.transaction(force_rollback=True): + tx1 = conn.transaction("s1") + tx2 = conn.transaction("s1") + tx1.__enter__() + tx2.__enter__() + tx1.__exit__(None, None, None) + tx2.__exit__(None, None, None) + + +def test_force_rollback_successful_exit(conn, svcconn): + """ + Transaction started with the force_rollback option enabled discards all + changes at the end of the context. + """ + with conn.transaction(force_rollback=True): + insert_row(conn, "foo") + assert_rows(conn, set()) + assert_rows(svcconn, set()) + + +def test_force_rollback_exception_exit(conn, svcconn): + """ + Transaction started with the force_rollback option enabled discards all + changes at the end of the context. + """ + with pytest.raises(ExpectedException): + with conn.transaction(force_rollback=True): + insert_row(conn, "foo") + raise ExpectedException() + assert_rows(conn, set()) + assert_rows(svcconn, set()) + + +def test_explicit_rollback_discards_changes(conn, svcconn): + """ + Raising a Rollback exception in the middle of a block exits the block and + discards all changes made within that block. + + You can raise any of the following: + - Rollback (type) + - Rollback() (instance) + - Rollback(tx) (instance initialised with reference to the transaction) + All of these are equivalent. + """ + tx = conn.transaction() + for to_raise in ( + Rollback, + Rollback(), + Rollback(tx), + ): + with tx: + insert_row(conn, "foo") + raise to_raise + assert_rows(conn, set("")) + assert_rows(svcconn, set()) + + +def test_explicit_rollback_outer_tx_unaffected(conn, svcconn): + """ + Raising a Rollback exception in the middle of a block does not impact an + enclosing transaction block. + """ + with conn.transaction(): + insert_row(conn, "before") + with conn.transaction(): + insert_row(conn, "during") + raise Rollback + assert_in_transaction(conn) + assert_rows(svcconn, set()) + insert_row(conn, "after") + assert_rows(conn, {"before", "after"}) + assert_rows(svcconn, {"before", "after"}) + + +def test_explicit_rollback_of_outer_transaction(conn): + """ + Raising a Rollback exception that references an outer transaction will + discard all changes from both inner and outer transaction blocks. + """ + outer_tx = conn.transaction() + with outer_tx: + insert_row(conn, "outer") + with conn.transaction(): + insert_row(conn, "inner") + raise Rollback(outer_tx) + assert False, "This line of code should be unreachable." + assert_rows(conn, set()) + + +def test_explicit_rollback_of_enclosing_tx_outer_tx_unaffected(conn, svcconn): + """ + Rolling-back an enclosing transaction does not impact an outer transaction. + """ + with conn.transaction(): + insert_row(conn, "outer-before") + with conn.transaction() as tx_enclosing: + insert_row(conn, "enclosing") + with conn.transaction(): + insert_row(conn, "inner") + raise Rollback(tx_enclosing) + insert_row(conn, "outer-after") + + assert_rows(conn, {"outer-before", "outer-after"}) + assert_rows(svcconn, set()) # Not yet committed + assert_rows(svcconn, {"outer-before", "outer-after"}) # Changes committed + + +@pytest.mark.parametrize("exc_info", [(None, None, None), some_exc_info()]) +@pytest.mark.parametrize("name", [None, "s1"]) +def test_manual_enter_and_exit_out_of_order_exit_asserts(conn, name, exc_info): + """ + When user is calling __enter__() and __exit__() manually for some reason, + provide a helpful error message if they call __exit__() in the wrong order + for nested transactions. + """ + tx1, tx2 = conn.transaction(name), conn.transaction() + tx1.__enter__() + tx2.__enter__() + with pytest.raises(ProgrammingError, match="Out-of-order"): + tx1.__exit__(*exc_info) + + +@pytest.mark.parametrize("exc_info", [(None, None, None), some_exc_info()]) +@pytest.mark.parametrize("name", [None, "s1"]) +def test_manual_exit_without_enter_asserts(conn, name, exc_info): + """ + When user is calling __enter__() and __exit__() manually for some reason, + provide a helpful error message if they call __exit__() without first + having called __enter__() + """ + tx = conn.transaction(name) + with pytest.raises(ProgrammingError, match="Out-of-order"): + tx.__exit__(*exc_info) + + +@pytest.mark.parametrize("exc_info", [(None, None, None), some_exc_info()]) +@pytest.mark.parametrize("name", [None, "s1"]) +def test_manual_exit_twice_asserts(conn, name, exc_info): + """ + When user is calling __enter__() and __exit__() manually for some reason, + provide a helpful error message if they accidentally call __exit__() twice. + """ + tx = conn.transaction(name) + tx.__enter__() + tx.__exit__(*exc_info) + with pytest.raises(ProgrammingError, match="Out-of-order"): + tx.__exit__(*exc_info)