]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
First-cut implementation of connection.transaction()
authorDaniel Fortunov <github@danielfortunov.com>
Sat, 25 Jul 2020 11:27:31 +0000 (12:27 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 14 Nov 2020 22:18:46 +0000 (22:18 +0000)
psycopg3/psycopg3/__init__.py
psycopg3/psycopg3/connection.py
psycopg3/psycopg3/transaction.py [new file with mode: 0644]
tests/test_transaction.py [new file with mode: 0644]

index b411dabb010467c4b879ca1e67618ea08570a486..6f6a907e279c8ee7671fe61ebafcf9605edcbc5a 100644 (file)
@@ -5,13 +5,13 @@ psycopg3 -- PostgreSQL database adapter for Python
 # Copyright (C) 2020 The Psycopg Team
 
 from . import pq
-from .connection import AsyncConnection, Connection, Notify
-from .cursor import AsyncCursor, Cursor, Column
 from .copy import Copy, AsyncCopy
-
+from .cursor import AsyncCursor, Cursor, Column
 from .errors import Warning, Error, InterfaceError, DatabaseError
 from .errors import DataError, OperationalError, IntegrityError
 from .errors import InternalError, ProgrammingError, NotSupportedError
+from .connection import AsyncConnection, Connection, Notify
+from .transaction import Rollback
 
 from .dbapi20 import BINARY, DATETIME, NUMBER, ROWID, STRING
 from .dbapi20 import Binary, Date, DateFromTicks, Time, TimeFromTicks
index 0ebbf00c404ef47522dd56fda72684dd74de5c7a..f186d46cb0eaa0d8f9d4d2d1ae4b5971b65b9b1e 100644 (file)
@@ -22,6 +22,7 @@ from .proto import DumpersMap, LoadersMap, PQGen, RV
 from .waiting import wait, wait_async
 from .conninfo import make_conninfo
 from .generators import notifies
+from .transaction import Transaction
 
 logger = logging.getLogger(__name__)
 package_logger = logging.getLogger("psycopg3")
@@ -97,6 +98,11 @@ class BaseConnection:
         self._notice_handlers: List[NoticeHandler] = []
         self._notify_handlers: List[NotifyHandler] = []
 
+        # stack of savepoint names managed by active Transaction() blocks
+        self._savepoints: Optional[List[bytes]] = None
+        # (None when there no active Transaction blocks; [] when there is only
+        # one Transaction block, with a top-level transaction and no savepoint)
+
         wself = ref(self)
 
         pgconn.notice_handler = partial(BaseConnection._notice_handler, wself)
@@ -121,10 +127,18 @@ class BaseConnection:
         # subclasses must call it holding a lock
         status = self.pgconn.transaction_status
         if status != TransactionStatus.IDLE:
-            raise e.ProgrammingError(
-                "couldn't change autocommit state: connection in"
-                f" transaction status {TransactionStatus(status).name}"
-            )
+            if self._savepoints is not None:
+                raise e.ProgrammingError(
+                    "couldn't change autocommit state: "
+                    "connection.transaction() context in progress"
+                )
+            else:
+                raise e.ProgrammingError(
+                    "couldn't change autocommit state: "
+                    "connection in transaction status "
+                    f"{TransactionStatus(status).name}"
+                )
+
         self._autocommit = value
 
     @property
@@ -268,30 +282,37 @@ class Connection(BaseConnection):
         if self.pgconn.transaction_status != TransactionStatus.IDLE:
             return
 
-        self.pgconn.send_query(b"begin")
-        (pgres,) = self.wait(execute(self.pgconn))
-        if pgres.status != ExecStatus.COMMAND_OK:
-            raise e.OperationalError(
-                "error on begin:"
-                f" {pq.error_message(pgres, encoding=self.client_encoding)}"
-            )
+        self._exec_command(b"begin")
 
     def commit(self) -> None:
         """Commit any pending transaction to the database."""
         with self.lock:
-            self._exec_commit_rollback(b"commit")
+            if self._savepoints is not None:
+                raise e.ProgrammingError(
+                    "Explicit commit() forbidden within a Transaction "
+                    "context. (Transaction will be automatically committed "
+                    "on successful exit from context.)"
+                )
+            if self.pgconn.transaction_status == TransactionStatus.IDLE:
+                return
+            self._exec_command(b"commit")
 
     def rollback(self) -> None:
         """Roll back to the start of any pending transaction."""
         with self.lock:
-            self._exec_commit_rollback(b"rollback")
+            if self._savepoints is not None:
+                raise e.ProgrammingError(
+                    "Explicit rollback() forbidden within a Transaction "
+                    "context. (Either raise Transaction.Rollback() or allow "
+                    "an exception to propagate out of the context.)"
+                )
+            if self.pgconn.transaction_status == TransactionStatus.IDLE:
+                return
+            self._exec_command(b"rollback")
 
-    def _exec_commit_rollback(self, command: bytes) -> None:
+    def _exec_command(self, command: bytes) -> None:
         # Caller must hold self.lock
-        status = self.pgconn.transaction_status
-        if status == TransactionStatus.IDLE:
-            return
-
+        logger.debug(f"{self}: {command!r}")
         self.pgconn.send_query(command)
         results = self.wait(execute(self.pgconn))
         if results[-1].status != ExecStatus.COMMAND_OK:
@@ -300,6 +321,13 @@ class Connection(BaseConnection):
                 f" {pq.error_message(results[-1], encoding=self.client_encoding)}"
             )
 
+    def transaction(
+        self,
+        savepoint_name: Optional[str] = None,
+        force_rollback: bool = False,
+    ) -> Transaction:
+        return Transaction(self, savepoint_name, force_rollback)
+
     @classmethod
     def wait(cls, gen: PQGen[RV], timeout: Optional[float] = 0.1) -> RV:
         return wait(gen, timeout=timeout)
@@ -407,18 +435,18 @@ class AsyncConnection(BaseConnection):
 
     async def commit(self) -> None:
         async with self.lock:
-            await self._exec_commit_rollback(b"commit")
+            if self.pgconn.transaction_status == TransactionStatus.IDLE:
+                return
+            await self._exec(b"commit")
 
     async def rollback(self) -> None:
         async with self.lock:
-            await self._exec_commit_rollback(b"rollback")
+            if self.pgconn.transaction_status == TransactionStatus.IDLE:
+                return
+            await self._exec(b"rollback")
 
-    async def _exec_commit_rollback(self, command: bytes) -> None:
+    async def _exec(self, command: bytes) -> None:
         # Caller must hold self.lock
-        status = self.pgconn.transaction_status
-        if status == TransactionStatus.IDLE:
-            return
-
         self.pgconn.send_query(command)
         (pgres,) = await self.wait(execute(self.pgconn))
         if pgres.status != ExecStatus.COMMAND_OK:
diff --git a/psycopg3/psycopg3/transaction.py b/psycopg3/psycopg3/transaction.py
new file mode 100644 (file)
index 0000000..ee9921f
--- /dev/null
@@ -0,0 +1,153 @@
+"""
+Transaction context managers returned by Connection.transaction()
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import logging
+
+from psycopg3.errors import ProgrammingError
+from types import TracebackType
+from typing import Optional, Type, TYPE_CHECKING
+
+from .pq import TransactionStatus
+
+if TYPE_CHECKING:
+    from .connection import Connection
+
+_log = logging.getLogger(__name__)
+
+
+class Rollback(Exception):
+    """
+    Exit the current Transaction context immediately and rollback any changes
+    made within this context.
+
+    If a transaction context is specified in the constructor, rollback
+    enclosing transactions contexts up to and including the one specified.
+    """
+
+    def __init__(self, transaction: Optional["Transaction"] = None):
+        self.transaction = transaction
+
+    def __repr__(self) -> str:
+        return f"{self.__class__.__qualname__}({self.transaction!r})"
+
+
+class Transaction:
+    def __init__(
+        self,
+        connection: "Connection",
+        savepoint_name: Optional[str],
+        force_rollback: bool,
+    ):
+        self._conn = connection
+        self._savepoint_name: Optional[bytes] = None
+        if savepoint_name is not None:
+            if len(savepoint_name) == 0:
+                raise ValueError("savepoint_name must be a non-empty string")
+            self._savepoint_name = connection.codec.encode(savepoint_name)[0]
+        self.force_rollback = force_rollback
+
+        self._outer_transaction: Optional[bool] = None
+
+    @property
+    def connection(self) -> "Connection":
+        return self._conn
+
+    @property
+    def savepoint_name(self) -> Optional[str]:
+        if self._savepoint_name is None:
+            return None
+        return self._conn.codec.decode(self._savepoint_name)[0]
+
+    def __enter__(self) -> "Transaction":
+        with self._conn.lock:
+            if self._conn.pgconn.transaction_status == TransactionStatus.IDLE:
+                assert self._conn._savepoints is None, self._conn._savepoints
+                self._conn._savepoints = []
+                self._outer_transaction = True
+                self._conn._exec_command(b"begin")
+            else:
+                if self._conn._savepoints is None:
+                    self._conn._savepoints = []
+                self._outer_transaction = False
+                if self._savepoint_name is None:
+                    self._savepoint_name = b"s%i" % (
+                        len(self._conn._savepoints) + 1
+                    )
+
+            if self._savepoint_name is not None:
+                self._conn._exec_command(b"savepoint " + self._savepoint_name)
+                self._conn._savepoints.append(self._savepoint_name)
+        return self
+
+    def __exit__(
+        self,
+        exc_type: Optional[Type[BaseException]],
+        exc_val: Optional[BaseException],
+        exc_tb: Optional[TracebackType],
+    ) -> bool:
+        out_of_order_err = ProgrammingError(
+            "Out-of-order Transaction context exits. Are you "
+            "calling __exit__() manually and getting it wrong?"
+        )
+        if self._outer_transaction is None:
+            raise out_of_order_err
+        with self._conn.lock:
+            if exc_type is None and not self.force_rollback:
+                # Commit changes made in the transaction context
+                if self._savepoint_name:
+                    if self._conn._savepoints is None:
+                        raise out_of_order_err
+                    actual = self._conn._savepoints.pop()
+                    if actual != self._savepoint_name:
+                        raise out_of_order_err
+                    self._conn._exec_command(
+                        b"release savepoint " + self._savepoint_name
+                    )
+                if self._outer_transaction:
+                    if self._conn._savepoints is None:
+                        raise out_of_order_err
+                    if len(self._conn._savepoints) != 0:
+                        raise out_of_order_err
+                    self._conn._exec_command(b"commit")
+                    self._conn._savepoints = None
+            else:
+                # Rollback changes made in the transaction context
+                if isinstance(exc_val, Rollback):
+                    _log.debug(
+                        f"{self._conn}: Explicit rollback from: ",
+                        exc_info=True,
+                    )
+
+                if self._savepoint_name:
+                    if self._conn._savepoints is None:
+                        raise out_of_order_err
+                    actual = self._conn._savepoints.pop()
+                    if actual != self._savepoint_name:
+                        raise out_of_order_err
+                    self._conn._exec_command(
+                        b"rollback to savepoint " + self._savepoint_name + b";"
+                        b"release savepoint " + self._savepoint_name
+                    )
+                if self._outer_transaction:
+                    if self._conn._savepoints is None:
+                        raise out_of_order_err
+                    if len(self._conn._savepoints) != 0:
+                        raise out_of_order_err
+                    self._conn._exec_command(b"rollback")
+                    self._conn._savepoints = None
+
+                if isinstance(exc_val, Rollback):
+                    if exc_val.transaction in (self, None):
+                        return True  # Swallow the exception
+        return False
+
+    def __repr__(self) -> str:
+        args = [f"connection={self.connection}"]
+        if self.savepoint_name is not None:
+            args.append(f"savepoint_name={self.savepoint_name!r}")
+        if self.force_rollback:
+            args.append("force_rollback=True")
+        return f"{self.__class__.__qualname__}({', '.join(args)})"
diff --git a/tests/test_transaction.py b/tests/test_transaction.py
new file mode 100644 (file)
index 0000000..2afa831
--- /dev/null
@@ -0,0 +1,618 @@
+import sys
+from contextlib import contextmanager
+
+import pytest
+
+from psycopg3 import OperationalError, ProgrammingError, Rollback
+
+
+@pytest.fixture(autouse=True)
+def test_table(svcconn):
+    """
+    Creates a table called 'test_table' for use in tests.
+    """
+    cur = svcconn.cursor()
+    cur.execute("drop table if exists test_table")
+    cur.execute("create table test_table (id text primary key)")
+    yield
+    cur.execute("drop table test_table")
+
+
+def insert_row(conn, value):
+    conn.cursor().execute("INSERT INTO test_table VALUES (%s)", (value,))
+
+
+def assert_rows(conn, expected):
+    rows = conn.cursor().execute("SELECT * FROM test_table").fetchall()
+    assert set(v for (v,) in rows) == expected
+
+
+def assert_not_in_transaction(conn):
+    assert conn.pgconn.transaction_status == conn.TransactionStatus.IDLE
+
+
+def assert_in_transaction(conn):
+    assert conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS
+
+
+@contextmanager
+def assert_commands_issued(conn, *commands):
+    commands_actual = []
+    real_exec_command = conn._exec_command
+
+    def _exec_command(command):
+        commands_actual.append(command)
+        real_exec_command(command)
+
+    try:
+        conn._exec_command = _exec_command
+        yield
+    finally:
+        conn._exec_command = real_exec_command
+    commands_expected = [cmd.encode("ascii") for cmd in commands]
+    assert commands_actual == commands_expected
+
+
+class ExpectedException(Exception):
+    pass
+
+
+def some_exc_info():
+    try:
+        raise ExpectedException()
+    except ExpectedException:
+        return sys.exc_info()
+
+
+def test_basic(conn):
+    """Basic use of transaction() to BEGIN and COMMIT a transaction."""
+    assert_not_in_transaction(conn)
+    with conn.transaction():
+        assert_in_transaction(conn)
+    assert_not_in_transaction(conn)
+
+
+def test_exposes_associated_connection(conn):
+    """Transaction exposes its connection as a read-only property."""
+    with conn.transaction() as tx:
+        assert tx.connection is conn
+        with pytest.raises(AttributeError):
+            tx.connection = conn
+
+
+def test_exposes_savepoint_name(conn):
+    """Transaction exposes its savepoint name as a read-only property."""
+    with conn.transaction(savepoint_name="foo") as tx:
+        assert tx.savepoint_name == "foo"
+        with pytest.raises(AttributeError):
+            tx.savepoint_name = "bar"
+
+
+def test_begins_on_enter(conn):
+    """Transaction does not begin until __enter__() is called."""
+    tx = conn.transaction()
+    assert_not_in_transaction(conn)
+    with tx:
+        assert_in_transaction(conn)
+    assert_not_in_transaction(conn)
+
+
+def test_commit_on_successful_exit(conn):
+    """Changes are committed on successful exit from the `with` block."""
+    with conn.transaction():
+        insert_row(conn, "foo")
+
+    assert_not_in_transaction(conn)
+    assert_rows(conn, {"foo"})
+
+
+def test_rollback_on_exception_exit(conn):
+    """Changes are rolled back if an exception escapes the `with` block."""
+    with pytest.raises(ExpectedException):
+        with conn.transaction():
+            insert_row(conn, "foo")
+            raise ExpectedException("This discards the insert")
+
+    assert_not_in_transaction(conn)
+    assert_rows(conn, set())
+
+
+def test_prohibits_use_of_commit_rollback_autocommit(conn):
+    """
+    Within a Transaction block, it is forbidden to touch commit, rollback,
+    or the autocommit setting on the connection, as this would interfere
+    with the transaction scope being managed by the Transaction block.
+    """
+    conn.autocommit = False
+    conn.commit()
+    conn.rollback()
+
+    with conn.transaction():
+        with pytest.raises(ProgrammingError):
+            conn.autocommit = False
+        with pytest.raises(ProgrammingError):
+            conn.commit()
+        with pytest.raises(ProgrammingError):
+            conn.rollback()
+
+    conn.autocommit = False
+    conn.commit()
+    conn.rollback()
+
+
+@pytest.mark.parametrize("autocommit", [False, True])
+def test_preserves_autocommit(conn, autocommit):
+    """
+    Connection.autocommit is unchanged both during and after Transaction block.
+    """
+    conn.autocommit = autocommit
+    with conn.transaction():
+        assert conn.autocommit is autocommit
+    assert conn.autocommit is autocommit
+
+
+def test_autocommit_off_but_no_tx_started_successful_exit(conn, svcconn):
+    """
+    Scenario:
+    * Connection has autocommit off but no transaction has been initiated
+      before entering the Transaction context
+    * Code exits Transaction context successfully
+
+    Outcome:
+    * Changes made within Transaction context are committed
+    """
+    conn.autocommit = False
+    assert_not_in_transaction(conn)
+    with conn.transaction():
+        insert_row(conn, "new")
+    assert_not_in_transaction(conn)
+
+    # Changes committed
+    assert_rows(conn, {"new"})
+    assert_rows(svcconn, {"new"})
+
+
+def test_autocommit_off_but_no_tx_started_exception_exit(conn, svcconn):
+    """
+    Scenario:
+    * Connection has autocommit off but no transaction has been initiated
+      before entering the Transaction context
+    * Code exits Transaction context with an exception
+
+    Outcome:
+    * Changes made within Transaction context are discarded
+    """
+    conn.autocommit = False
+    assert_not_in_transaction(conn)
+    with pytest.raises(ExpectedException):
+        with conn.transaction():
+            insert_row(conn, "new")
+            raise ExpectedException()
+    assert_not_in_transaction(conn)
+
+    # Changes discarded
+    assert_rows(conn, set())
+    assert_rows(svcconn, set())
+
+
+def test_autocommit_off_and_tx_in_progress_successful_exit(conn, svcconn):
+    """
+    Scenario:
+    * Connection has autocommit off but and a transaction is already in
+      progress before entering the Transaction context
+    * Code exits Transaction context successfully
+
+    Outcome:
+    * Changes made within Transaction context are left intact
+    * Outer transaction is left running, and no changes are visible to an
+      outside observer from another connection.
+    """
+    conn.autocommit = False
+    insert_row(conn, "prior")
+    assert_in_transaction(conn)
+    with conn.transaction():
+        insert_row(conn, "new")
+    assert_in_transaction(conn)
+    assert_rows(conn, {"prior", "new"})
+    # Nothing committed yet; changes not visible on another connection
+    assert_rows(svcconn, set())
+
+
+def test_autocommit_off_and_tx_in_progress_exception_exit(conn, svcconn):
+    """
+    Scenario:
+    * Connection has autocommit off but and a transaction is already in
+      progress before entering the Transaction context
+    * Code exits Transaction context with an exception
+
+    Outcome:
+    * Changes made before the Transaction context are left intact
+    * Changes made within Transaction context are discarded
+    * Outer transaction is left running, and no changes are visible to an
+      outside observer from another connection.
+    """
+    conn.autocommit = False
+    insert_row(conn, "prior")
+    assert_in_transaction(conn)
+    with pytest.raises(ExpectedException):
+        with conn.transaction():
+            insert_row(conn, "new")
+            raise ExpectedException()
+    assert_in_transaction(conn)
+    assert_rows(conn, {"prior"})
+    # Nothing committed yet; changes not visible on another connection
+    assert_rows(svcconn, set())
+
+
+def test_nested_all_changes_persisted_on_successful_exit(conn, svcconn):
+    """Changes from nested transaction contexts are all persisted on exit."""
+    with conn.transaction():
+        insert_row(conn, "outer-before")
+        with conn.transaction():
+            insert_row(conn, "inner")
+        insert_row(conn, "outer-after")
+    assert_not_in_transaction(conn)
+    assert_rows(conn, {"outer-before", "inner", "outer-after"})
+    assert_rows(svcconn, {"outer-before", "inner", "outer-after"})
+
+
+def test_nested_all_changes_discarded_on_outer_exception(conn, svcconn):
+    """
+    Changes from nested transaction contexts are discarded when an exception
+    raised in outer context escapes.
+    """
+    with pytest.raises(ExpectedException):
+        with conn.transaction():
+            insert_row(conn, "outer")
+            with conn.transaction():
+                insert_row(conn, "inner")
+            raise ExpectedException()
+    assert_not_in_transaction(conn)
+    assert_rows(conn, set())
+    assert_rows(svcconn, set())
+
+
+def test_nested_all_changes_discarded_on_inner_exception(conn, svcconn):
+    """
+    Changes from nested transaction contexts are discarded when an exception
+    raised in inner context escapes the outer context.
+    """
+    with pytest.raises(ExpectedException):
+        with conn.transaction():
+            insert_row(conn, "outer")
+            with conn.transaction():
+                insert_row(conn, "inner")
+                raise ExpectedException()
+    assert_not_in_transaction(conn)
+    assert_rows(conn, set())
+    assert_rows(svcconn, set())
+
+
+def test_nested_inner_scope_exception_handled_in_outer_scope(conn, svcconn):
+    """
+    An exception escaping the inner transaction context causes changes made
+    within that inner context to be discarded, but the error can then be
+    handled in the outer context, allowing changes made in the outer context
+    (both before, and after, the inner context) to be successfully committed.
+    """
+    with conn.transaction():
+        insert_row(conn, "outer-before")
+        with pytest.raises(ExpectedException):
+            with conn.transaction():
+                insert_row(conn, "inner")
+                raise ExpectedException()
+        insert_row(conn, "outer-after")
+    assert_not_in_transaction(conn)
+    assert_rows(conn, {"outer-before", "outer-after"})
+    assert_rows(svcconn, {"outer-before", "outer-after"})
+
+
+def test_nested_three_levels_successful_exit(conn, svcconn):
+    """Exercise management of more than one savepoint."""
+    with conn.transaction():  # BEGIN
+        insert_row(conn, "one")
+        with conn.transaction():  # SAVEPOINT s1
+            insert_row(conn, "two")
+            with conn.transaction():  # SAVEPOINT s2
+                insert_row(conn, "three")
+    assert_not_in_transaction(conn)
+    assert_rows(conn, {"one", "two", "three"})
+    assert_rows(svcconn, {"one", "two", "three"})
+
+
+def test_named_savepoint_empty_string_invalid(conn):
+    """
+    Raise validate savepoint_name up-front (rather than later constructing an
+    invalid SQL command and having that fail with an OperationalError).
+    """
+    with pytest.raises(ValueError):
+        conn.transaction(savepoint_name="")
+
+
+@pytest.mark.xfail(raises=OperationalError, reason="TODO: Escape sp names")
+def test_named_savepoint_escapes_savepoint_name(conn):
+    with conn.transaction("s-1"):
+        pass
+    with conn.transaction("s1; drop table students"):
+        pass
+
+
+def test_named_savepoints_successful_exit(conn):
+    """
+    Entering a transaction context will do one of these these things:
+    1. Begin an outer transaction (if one isn't already in progress)
+    2. Begin an outer transaction and create a savepoint (if one is named)
+    3. Create a savepoint (if a transaction is already in progress)
+       either using the name provided, or auto-generating a savepoint name.
+
+    ...and exiting the context successfully will "commit" the same.
+    """
+    # Case 1
+    tx = conn.transaction()
+    with assert_commands_issued(conn, "begin"):
+        tx.__enter__()
+    assert tx.savepoint_name is None
+    with assert_commands_issued(conn, "commit"):
+        tx.__exit__(None, None, None)
+
+    # Case 2
+    tx = conn.transaction(savepoint_name="foo")
+    with assert_commands_issued(conn, "begin", "savepoint foo"):
+        tx.__enter__()
+    assert tx.savepoint_name == "foo"
+    with assert_commands_issued(conn, "release savepoint foo", "commit"):
+        tx.__exit__(None, None, None)
+
+    # Case 3 (with savepoint name provided)
+    with conn.transaction():
+        tx = conn.transaction(savepoint_name="bar")
+        with assert_commands_issued(conn, "savepoint bar"):
+            tx.__enter__()
+        assert tx.savepoint_name == "bar"
+        with assert_commands_issued(conn, "release savepoint bar"):
+            tx.__exit__(None, None, None)
+
+    # Case 3 (with savepoint name auto-generated)
+    with conn.transaction():
+        tx = conn.transaction()
+        with assert_commands_issued(conn, "savepoint s1"):
+            tx.__enter__()
+        assert tx.savepoint_name == "s1"
+        with assert_commands_issued(conn, "release savepoint s1"):
+            tx.__exit__(None, None, None)
+
+
+def test_named_savepoints_exception_exit(conn):
+    """
+    Same as the previous test but checks that when exiting the context with an
+    exception, whatever transaction and/or savepoint was started on enter will
+    be rolled-back as appropriate.
+    """
+    # Case 1
+    tx = conn.transaction()
+    with assert_commands_issued(conn, "begin"):
+        tx.__enter__()
+    assert tx.savepoint_name is None
+    with assert_commands_issued(conn, "rollback"):
+        tx.__exit__(*some_exc_info())
+
+    # Case 2
+    tx = conn.transaction(savepoint_name="foo")
+    with assert_commands_issued(conn, "begin", "savepoint foo"):
+        tx.__enter__()
+    assert tx.savepoint_name == "foo"
+    with assert_commands_issued(
+        conn, "rollback to savepoint foo;release savepoint foo", "rollback"
+    ):
+        tx.__exit__(*some_exc_info())
+
+    # Case 3 (with savepoint name provided)
+    with conn.transaction():
+        tx = conn.transaction(savepoint_name="bar")
+        with assert_commands_issued(conn, "savepoint bar"):
+            tx.__enter__()
+        assert tx.savepoint_name == "bar"
+        with assert_commands_issued(
+            conn, "rollback to savepoint bar;release savepoint bar"
+        ):
+            tx.__exit__(*some_exc_info())
+
+    # Case 3 (with savepoint name auto-generated)
+    with conn.transaction():
+        tx = conn.transaction()
+        with assert_commands_issued(conn, "savepoint s1"):
+            tx.__enter__()
+        assert tx.savepoint_name == "s1"
+        with assert_commands_issued(
+            conn, "rollback to savepoint s1;release savepoint s1"
+        ):
+            tx.__exit__(*some_exc_info())
+
+
+def test_named_savepoints_with_repeated_names_works(conn):
+    """
+    Using the same savepoint name repeatedly works correctly, but bypasses
+    some sanity checks.
+    """
+    # Works correctly if no inner transactions are rolled back
+    with conn.transaction(force_rollback=True):
+        with conn.transaction("sp"):
+            insert_row(conn, "tx1")
+            with conn.transaction("sp"):
+                insert_row(conn, "tx2")
+                with conn.transaction("sp"):
+                    insert_row(conn, "tx3")
+        assert_rows(conn, {"tx1", "tx2", "tx3"})
+
+    # Works correctly if one level of inner transaction is rolled back
+    with conn.transaction(force_rollback=True):
+        with conn.transaction("s1"):
+            insert_row(conn, "tx1")
+            with conn.transaction("s1", force_rollback=True):
+                insert_row(conn, "tx2")
+                with conn.transaction("s1"):
+                    insert_row(conn, "tx3")
+            assert_rows(conn, {"tx1"})
+        assert_rows(conn, {"tx1"})
+
+    # Works correctly if multiple inner transactions are rolled back
+    # (This scenario mandates releasing savepoints after rolling back to them.)
+    with conn.transaction(force_rollback=True):
+        with conn.transaction("s1"):
+            insert_row(conn, "tx1")
+            with conn.transaction("s1") as tx2:
+                insert_row(conn, "tx2")
+                with conn.transaction("s1"):
+                    insert_row(conn, "tx3")
+                    raise Rollback(tx2)
+            assert_rows(conn, {"tx1"})
+        assert_rows(conn, {"tx1"})
+
+    # Will not (always) catch out-of-order exits
+    with conn.transaction(force_rollback=True):
+        tx1 = conn.transaction("s1")
+        tx2 = conn.transaction("s1")
+        tx1.__enter__()
+        tx2.__enter__()
+        tx1.__exit__(None, None, None)
+        tx2.__exit__(None, None, None)
+
+
+def test_force_rollback_successful_exit(conn, svcconn):
+    """
+    Transaction started with the force_rollback option enabled discards all
+    changes at the end of the context.
+    """
+    with conn.transaction(force_rollback=True):
+        insert_row(conn, "foo")
+    assert_rows(conn, set())
+    assert_rows(svcconn, set())
+
+
+def test_force_rollback_exception_exit(conn, svcconn):
+    """
+    Transaction started with the force_rollback option enabled discards all
+    changes at the end of the context.
+    """
+    with pytest.raises(ExpectedException):
+        with conn.transaction(force_rollback=True):
+            insert_row(conn, "foo")
+            raise ExpectedException()
+    assert_rows(conn, set())
+    assert_rows(svcconn, set())
+
+
+def test_explicit_rollback_discards_changes(conn, svcconn):
+    """
+    Raising a Rollback exception in the middle of a block exits the block and
+    discards all changes made within that block.
+
+    You can raise any of the following:
+     - Rollback (type)
+     - Rollback() (instance)
+     - Rollback(tx) (instance initialised with reference to the transaction)
+    All of these are equivalent.
+    """
+    tx = conn.transaction()
+    for to_raise in (
+        Rollback,
+        Rollback(),
+        Rollback(tx),
+    ):
+        with tx:
+            insert_row(conn, "foo")
+            raise to_raise
+        assert_rows(conn, set(""))
+        assert_rows(svcconn, set())
+
+
+def test_explicit_rollback_outer_tx_unaffected(conn, svcconn):
+    """
+    Raising a Rollback exception in the middle of a block does not impact an
+    enclosing transaction block.
+    """
+    with conn.transaction():
+        insert_row(conn, "before")
+        with conn.transaction():
+            insert_row(conn, "during")
+            raise Rollback
+        assert_in_transaction(conn)
+        assert_rows(svcconn, set())
+        insert_row(conn, "after")
+    assert_rows(conn, {"before", "after"})
+    assert_rows(svcconn, {"before", "after"})
+
+
+def test_explicit_rollback_of_outer_transaction(conn):
+    """
+    Raising a Rollback exception that references an outer transaction will
+    discard all changes from both inner and outer transaction blocks.
+    """
+    outer_tx = conn.transaction()
+    with outer_tx:
+        insert_row(conn, "outer")
+        with conn.transaction():
+            insert_row(conn, "inner")
+            raise Rollback(outer_tx)
+        assert False, "This line of code should be unreachable."
+    assert_rows(conn, set())
+
+
+def test_explicit_rollback_of_enclosing_tx_outer_tx_unaffected(conn, svcconn):
+    """
+    Rolling-back an enclosing transaction does not impact an outer transaction.
+    """
+    with conn.transaction():
+        insert_row(conn, "outer-before")
+        with conn.transaction() as tx_enclosing:
+            insert_row(conn, "enclosing")
+            with conn.transaction():
+                insert_row(conn, "inner")
+                raise Rollback(tx_enclosing)
+        insert_row(conn, "outer-after")
+
+        assert_rows(conn, {"outer-before", "outer-after"})
+        assert_rows(svcconn, set())  # Not yet committed
+    assert_rows(svcconn, {"outer-before", "outer-after"})  # Changes committed
+
+
+@pytest.mark.parametrize("exc_info", [(None, None, None), some_exc_info()])
+@pytest.mark.parametrize("name", [None, "s1"])
+def test_manual_enter_and_exit_out_of_order_exit_asserts(conn, name, exc_info):
+    """
+    When user is calling __enter__() and __exit__() manually for some reason,
+    provide a helpful error message if they call __exit__() in the wrong order
+    for nested transactions.
+    """
+    tx1, tx2 = conn.transaction(name), conn.transaction()
+    tx1.__enter__()
+    tx2.__enter__()
+    with pytest.raises(ProgrammingError, match="Out-of-order"):
+        tx1.__exit__(*exc_info)
+
+
+@pytest.mark.parametrize("exc_info", [(None, None, None), some_exc_info()])
+@pytest.mark.parametrize("name", [None, "s1"])
+def test_manual_exit_without_enter_asserts(conn, name, exc_info):
+    """
+    When user is calling __enter__() and __exit__() manually for some reason,
+    provide a helpful error message if they call __exit__() without first
+    having called __enter__()
+    """
+    tx = conn.transaction(name)
+    with pytest.raises(ProgrammingError, match="Out-of-order"):
+        tx.__exit__(*exc_info)
+
+
+@pytest.mark.parametrize("exc_info", [(None, None, None), some_exc_info()])
+@pytest.mark.parametrize("name", [None, "s1"])
+def test_manual_exit_twice_asserts(conn, name, exc_info):
+    """
+    When user is calling __enter__() and __exit__() manually for some reason,
+    provide a helpful error message if they accidentally call __exit__() twice.
+    """
+    tx = conn.transaction(name)
+    tx.__enter__()
+    tx.__exit__(*exc_info)
+    with pytest.raises(ProgrammingError, match="Out-of-order"):
+        tx.__exit__(*exc_info)