]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added AsyncConnection.transaction()
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 15 Nov 2020 21:23:43 +0000 (21:23 +0000)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 16 Nov 2020 04:02:26 +0000 (04:02 +0000)
psycopg3/psycopg3/connection.py
psycopg3/psycopg3/transaction.py
tests/test_transaction.py
tests/test_transaction_async.py [new file with mode: 0644]

index 45d786d639c7fdf8db53b86e697674417cf3bd67..016af7f0114aab5d3e15a4d074b8a997a4ec58f6 100644 (file)
@@ -4,8 +4,9 @@ psycopg3 connection objects
 
 # Copyright (C) 2020 The Psycopg Team
 
-import logging
+import sys
 import asyncio
+import logging
 import threading
 from types import TracebackType
 from typing import Any, AsyncIterator, Callable, Iterator, List, NamedTuple
@@ -14,6 +15,11 @@ from weakref import ref, ReferenceType
 from functools import partial
 from contextlib import contextmanager
 
+if sys.version_info >= (3, 7):
+    from contextlib import asynccontextmanager
+else:
+    from .utils.context import asynccontextmanager
+
 from . import pq
 from . import cursor
 from . import errors as e
@@ -24,7 +30,7 @@ from .proto import DumpersMap, LoadersMap, PQGen, RV, Query
 from .waiting import wait, wait_async
 from .conninfo import make_conninfo
 from .generators import notifies
-from .transaction import Transaction
+from .transaction import Transaction, AsyncTransaction
 
 logger = logging.getLogger(__name__)
 package_logger = logging.getLogger("psycopg3")
@@ -335,7 +341,7 @@ class Connection(BaseConnection):
         savepoint_name: Optional[str] = None,
         force_rollback: bool = False,
     ) -> Iterator[Transaction]:
-        with Transaction(self, savepoint_name or "", force_rollback) as tx:
+        with Transaction(self, savepoint_name, force_rollback) as tx:
             yield tx
 
     @classmethod
@@ -435,36 +441,58 @@ class AsyncConnection(BaseConnection):
         if self.pgconn.transaction_status != TransactionStatus.IDLE:
             return
 
-        self.pgconn.send_query(b"begin")
-        (pgres,) = await self.wait(execute(self.pgconn))
-        if pgres.status != ExecStatus.COMMAND_OK:
-            raise e.OperationalError(
-                "error on begin:"
-                f" {pq.error_message(pgres, encoding=self.client_encoding)}"
-            )
+        await self._exec_command(b"begin")
 
     async def commit(self) -> None:
         async with self.lock:
             if self.pgconn.transaction_status == TransactionStatus.IDLE:
                 return
-            await self._exec(b"commit")
+            if self._savepoints is not None:
+                raise e.ProgrammingError(
+                    "Explicit commit() forbidden within a Transaction "
+                    "context. (Transaction will be automatically committed "
+                    "on successful exit from context.)"
+                )
+            await self._exec_command(b"commit")
 
     async def rollback(self) -> None:
         async with self.lock:
             if self.pgconn.transaction_status == TransactionStatus.IDLE:
                 return
-            await self._exec(b"rollback")
+            if self._savepoints is not None:
+                raise e.ProgrammingError(
+                    "Explicit rollback() forbidden within a Transaction "
+                    "context. (Either raise Transaction.Rollback() or allow "
+                    "an exception to propagate out of the context.)"
+                )
+            await self._exec_command(b"rollback")
 
-    async def _exec(self, command: bytes) -> None:
+    async def _exec_command(self, command: Query) -> None:
         # Caller must hold self.lock
+
+        if isinstance(command, str):
+            command = command.encode(self.client_encoding)
+        elif isinstance(command, Composable):
+            command = command.as_string(self).encode(self.client_encoding)
+
         self.pgconn.send_query(command)
-        (pgres,) = await self.wait(execute(self.pgconn))
-        if pgres.status != ExecStatus.COMMAND_OK:
+        results = await self.wait(execute(self.pgconn))
+        if results[-1].status != ExecStatus.COMMAND_OK:
             raise e.OperationalError(
                 f"error on {command.decode('utf8')}:"
-                f" {pq.error_message(pgres, encoding=self.client_encoding)}"
+                f" {pq.error_message(results[-1], encoding=self.client_encoding)}"
             )
 
+    @asynccontextmanager
+    async def transaction(
+        self,
+        savepoint_name: Optional[str] = None,
+        force_rollback: bool = False,
+    ) -> AsyncIterator[AsyncTransaction]:
+        tx = AsyncTransaction(self, savepoint_name, force_rollback)
+        async with tx:
+            yield tx
+
     @classmethod
     async def wait(cls, gen: PQGen[RV]) -> RV:
         return await wait_async(gen)
index 0f0bbfaa7472db1968a5aad713781e5c1f87b7fa..22364d7a0a8bd498c1f7e28bd4d1c22b94b094b0 100644 (file)
@@ -7,14 +7,15 @@ Transaction context managers returned by Connection.transaction()
 import logging
 
 from types import TracebackType
-from typing import Optional, Type, TYPE_CHECKING
+from typing import Generic, Optional, Type, Union, TYPE_CHECKING
 
 from . import sql
 from .pq import TransactionStatus
-from psycopg3.errors import ProgrammingError
+from .proto import ConnectionType
+from .errors import ProgrammingError
 
 if TYPE_CHECKING:
-    from .connection import Connection
+    from .connection import Connection, AsyncConnection  # noqa: F401
 
 _log = logging.getLogger(__name__)
 
@@ -28,34 +29,59 @@ class Rollback(Exception):
     enclosing transactions contexts up to and including the one specified.
     """
 
-    def __init__(self, transaction: Optional["Transaction"] = None):
+    def __init__(
+        self,
+        transaction: Union["Transaction", "AsyncTransaction", None] = None,
+    ):
         self.transaction = transaction
 
     def __repr__(self) -> str:
         return f"{self.__class__.__qualname__}({self.transaction!r})"
 
 
-class Transaction:
+class BaseTransaction(Generic[ConnectionType]):
     def __init__(
         self,
-        connection: "Connection",
-        savepoint_name: str = "",
+        connection: ConnectionType,
+        savepoint_name: Optional[str] = None,
         force_rollback: bool = False,
     ):
         self._conn = connection
-        self._savepoint_name = savepoint_name
+        self._savepoint_name = savepoint_name or ""
         self.force_rollback = force_rollback
 
         self._outer_transaction: Optional[bool] = None
 
     @property
-    def connection(self) -> "Connection":
+    def connection(self) -> ConnectionType:
         return self._conn
 
     @property
     def savepoint_name(self) -> str:
         return self._savepoint_name
 
+    def __repr__(self) -> str:
+        args = [f"connection={self.connection}"]
+        if not self.savepoint_name:
+            args.append(f"savepoint_name={self.savepoint_name!r}")
+        if self.force_rollback:
+            args.append("force_rollback=True")
+        return f"{self.__class__.__qualname__}({', '.join(args)})"
+
+    _out_of_order_err = ProgrammingError(
+        "Out-of-order Transaction context exits. Are you "
+        "calling __exit__() manually and getting it wrong?"
+    )
+
+    def _pop_savepoint(self) -> None:
+        if self._conn._savepoints is None:
+            raise self._out_of_order_err
+        actual = self._conn._savepoints.pop()
+        if actual != self._savepoint_name:
+            raise self._out_of_order_err
+
+
+class Transaction(BaseTransaction["Connection"]):
     def __enter__(self) -> "Transaction":
         with self._conn.lock:
             if self._conn.pgconn.transaction_status == TransactionStatus.IDLE:
@@ -87,69 +113,137 @@ class Transaction:
         exc_val: Optional[BaseException],
         exc_tb: Optional[TracebackType],
     ) -> bool:
-        out_of_order_err = ProgrammingError(
-            "Out-of-order Transaction context exits. Are you "
-            "calling __exit__() manually and getting it wrong?"
-        )
         if self._outer_transaction is None:
-            raise out_of_order_err
+            raise self._out_of_order_err
         with self._conn.lock:
-            if exc_type is None and not self.force_rollback:
-                # Commit changes made in the transaction context
-                if self._savepoint_name:
-                    if self._conn._savepoints is None:
-                        raise out_of_order_err
-                    actual = self._conn._savepoints.pop()
-                    if actual != self._savepoint_name:
-                        raise out_of_order_err
-                    self._conn._exec_command(
-                        sql.SQL("release savepoint {}").format(
-                            sql.Identifier(self._savepoint_name)
-                        )
-                    )
-                if self._outer_transaction:
-                    if self._conn._savepoints is None:
-                        raise out_of_order_err
-                    if len(self._conn._savepoints) != 0:
-                        raise out_of_order_err
-                    self._conn._exec_command(b"commit")
-                    self._conn._savepoints = None
+            if not exc_val and not self.force_rollback:
+                return self._commit()
+            else:
+                return self._rollback(exc_val)
+
+    def _commit(self) -> bool:
+        """Commit changes made in the transaction context."""
+        if self._savepoint_name:
+            self._pop_savepoint()
+            self._conn._exec_command(
+                sql.SQL("release savepoint {}").format(
+                    sql.Identifier(self._savepoint_name)
+                )
+            )
+        if self._outer_transaction:
+            if self._conn._savepoints is None or self._conn._savepoints:
+                raise self._out_of_order_err
+            self._conn._exec_command(b"commit")
+            self._conn._savepoints = None
+
+        return False  # discarded
+
+    def _rollback(self, exc_val: Optional[BaseException]) -> bool:
+        # Rollback changes made in the transaction context
+        if isinstance(exc_val, Rollback):
+            _log.debug(
+                f"{self._conn}: Explicit rollback from: ", exc_info=True
+            )
+
+        if self._savepoint_name:
+            self._pop_savepoint()
+            self._conn._exec_command(
+                sql.SQL(
+                    "rollback to savepoint {n}; release savepoint {n}"
+                ).format(n=sql.Identifier(self._savepoint_name))
+            )
+        if self._outer_transaction:
+            if self._conn._savepoints is None or self._conn._savepoints:
+                raise self._out_of_order_err
+            self._conn._exec_command(b"rollback")
+            self._conn._savepoints = None
+
+        if isinstance(exc_val, Rollback):
+            if exc_val.transaction in (self, None):
+                return True  # Swallow the exception
+
+        return False
+
+
+class AsyncTransaction(BaseTransaction["AsyncConnection"]):
+    async def __aenter__(self) -> "AsyncTransaction":
+        async with self._conn.lock:
+            if self._conn.pgconn.transaction_status == TransactionStatus.IDLE:
+                assert self._conn._savepoints is None, self._conn._savepoints
+                self._conn._savepoints = []
+                self._outer_transaction = True
+                await self._conn._exec_command(b"begin")
             else:
-                # Rollback changes made in the transaction context
-                if isinstance(exc_val, Rollback):
-                    _log.debug(
-                        f"{self._conn}: Explicit rollback from: ",
-                        exc_info=True,
+                if self._conn._savepoints is None:
+                    self._conn._savepoints = []
+                self._outer_transaction = False
+                if not self._savepoint_name:
+                    self._savepoint_name = (
+                        f"s{len(self._conn._savepoints) + 1}"
                     )
 
-                if self._savepoint_name:
-                    if self._conn._savepoints is None:
-                        raise out_of_order_err
-                    actual = self._conn._savepoints.pop()
-                    if actual != self._savepoint_name:
-                        raise out_of_order_err
-                    self._conn._exec_command(
-                        sql.SQL(
-                            "rollback to savepoint {n}; release savepoint {n}"
-                        ).format(n=sql.Identifier(self._savepoint_name))
+            if self._savepoint_name:
+                await self._conn._exec_command(
+                    sql.SQL("savepoint {}").format(
+                        sql.Identifier(self._savepoint_name)
                     )
-                if self._outer_transaction:
-                    if self._conn._savepoints is None:
-                        raise out_of_order_err
-                    if len(self._conn._savepoints) != 0:
-                        raise out_of_order_err
-                    self._conn._exec_command(b"rollback")
-                    self._conn._savepoints = None
-
-                if isinstance(exc_val, Rollback):
-                    if exc_val.transaction in (self, None):
-                        return True  # Swallow the exception
-        return False
+                )
+                self._conn._savepoints.append(self._savepoint_name)
+        return self
 
-    def __repr__(self) -> str:
-        args = [f"connection={self.connection}"]
-        if not self.savepoint_name:
-            args.append(f"savepoint_name={self.savepoint_name!r}")
-        if self.force_rollback:
-            args.append("force_rollback=True")
-        return f"{self.__class__.__qualname__}({', '.join(args)})"
+    async def __aexit__(
+        self,
+        exc_type: Optional[Type[BaseException]],
+        exc_val: Optional[BaseException],
+        exc_tb: Optional[TracebackType],
+    ) -> bool:
+        if self._outer_transaction is None:
+            raise self._out_of_order_err
+        async with self._conn.lock:
+            if not exc_val and not self.force_rollback:
+                return await self._commit()
+            else:
+                return await self._rollback(exc_val)
+
+    async def _commit(self) -> bool:
+        """Commit changes made in the transaction context."""
+        if self._savepoint_name:
+            self._pop_savepoint()
+            await self._conn._exec_command(
+                sql.SQL("release savepoint {}").format(
+                    sql.Identifier(self._savepoint_name)
+                )
+            )
+        if self._outer_transaction:
+            if self._conn._savepoints is None or self._conn._savepoints:
+                raise self._out_of_order_err
+            await self._conn._exec_command(b"commit")
+            self._conn._savepoints = None
+
+        return False  # discarded
+
+    async def _rollback(self, exc_val: Optional[BaseException]) -> bool:
+        # Rollback changes made in the transaction context
+        if isinstance(exc_val, Rollback):
+            _log.debug(
+                f"{self._conn}: Explicit rollback from: ", exc_info=True
+            )
+
+        if self._savepoint_name:
+            self._pop_savepoint()
+            await self._conn._exec_command(
+                sql.SQL(
+                    "rollback to savepoint {n}; release savepoint {n}"
+                ).format(n=sql.Identifier(self._savepoint_name))
+            )
+        if self._outer_transaction:
+            if self._conn._savepoints is None or self._conn._savepoints:
+                raise self._out_of_order_err
+            await self._conn._exec_command(b"rollback")
+            self._conn._savepoints = None
+
+        if isinstance(exc_val, Rollback):
+            if exc_val.transaction in (self, None):
+                return True  # Swallow the exception
+
+        return False
index 026a7fa6e7ebdc0f8419bb08ac3795bb0382d817..04e6f81436bc62dab5b5478647c1cc28a59d6c9b 100644 (file)
@@ -1,15 +1,14 @@
 import sys
-from contextlib import contextmanager
 
 import pytest
 
-from psycopg3 import ProgrammingError, Rollback
+from psycopg3 import Connection, ProgrammingError, Rollback
 from psycopg3.sql import Composable
 from psycopg3.transaction import Transaction
 
 
 @pytest.fixture(autouse=True)
-def test_table(svcconn):
+def create_test_table(svcconn):
     """
     Creates a table called 'test_table' for use in tests.
     """
@@ -24,40 +23,52 @@ def insert_row(conn, value):
     conn.cursor().execute("INSERT INTO test_table VALUES (%s)", (value,))
 
 
-def assert_rows(conn, expected):
-    rows = conn.cursor().execute("SELECT * FROM test_table").fetchall()
-    assert set(v for (v,) in rows) == expected
+def inserted(conn):
+    sql = "SELECT * FROM test_table"
+    if isinstance(conn, Connection):
+        rows = conn.cursor().execute(sql).fetchall()
+        return set(v for (v,) in rows)
+    else:
 
+        async def f():
+            cur = await conn.cursor()
+            await cur.execute(sql)
+            rows = await cur.fetchall()
+            return set(v for (v,) in rows)
 
-def assert_not_in_transaction(conn):
-    assert conn.pgconn.transaction_status == conn.TransactionStatus.IDLE
+        return f()
 
 
-def assert_in_transaction(conn):
-    assert conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS
+@pytest.fixture
+def commands(monkeypatch):
+    """The queue of commands issued internally by connections.
 
+    Not concurrency safe as it mokkeypatches a class method, but good enough
+    for tests.
+    """
+    _orig_exec_command = Connection._exec_command
+    L = []
 
-@contextmanager
-def assert_commands_issued(conn, *commands):
-    commands_actual = []
-    real_exec_command = conn._exec_command
-
-    def _exec_command(command):
+    def _exec_command(self, command):
         if isinstance(command, bytes):
-            command = command.decode(conn.client_encoding)
+            command = command.decode(self.client_encoding)
         elif isinstance(command, Composable):
-            command = command.as_string(conn)
+            command = command.as_string(self)
 
-        commands_actual.append(command)
-        real_exec_command(command)
+        L.insert(0, command)
+        _orig_exec_command(self, command)
+
+    monkeypatch.setattr(Connection, "_exec_command", _exec_command)
+    yield L
 
-    try:
-        conn._exec_command = _exec_command
-        yield
-    finally:
-        conn._exec_command = real_exec_command
 
-    assert commands_actual == list(commands)
+def in_transaction(conn):
+    if conn.pgconn.transaction_status == conn.TransactionStatus.IDLE:
+        return False
+    elif conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS:
+        return True
+    else:
+        assert False, conn.pgconn.transaction_status
 
 
 class ExpectedException(Exception):
@@ -73,10 +84,10 @@ def some_exc_info():
 
 def test_basic(conn):
     """Basic use of transaction() to BEGIN and COMMIT a transaction."""
-    assert_not_in_transaction(conn)
+    assert not in_transaction(conn)
     with conn.transaction():
-        assert_in_transaction(conn)
-    assert_not_in_transaction(conn)
+        assert in_transaction(conn)
+    assert not in_transaction(conn)
 
 
 def test_exposes_associated_connection(conn):
@@ -98,10 +109,10 @@ def test_exposes_savepoint_name(conn):
 def test_begins_on_enter(conn):
     """Transaction does not begin until __enter__() is called."""
     tx = conn.transaction()
-    assert_not_in_transaction(conn)
+    assert not in_transaction(conn)
     with tx:
-        assert_in_transaction(conn)
-    assert_not_in_transaction(conn)
+        assert in_transaction(conn)
+    assert not in_transaction(conn)
 
 
 def test_commit_on_successful_exit(conn):
@@ -109,8 +120,8 @@ def test_commit_on_successful_exit(conn):
     with conn.transaction():
         insert_row(conn, "foo")
 
-    assert_not_in_transaction(conn)
-    assert_rows(conn, {"foo"})
+    assert not in_transaction(conn)
+    assert inserted(conn) == {"foo"}
 
 
 def test_rollback_on_exception_exit(conn):
@@ -120,8 +131,8 @@ def test_rollback_on_exception_exit(conn):
             insert_row(conn, "foo")
             raise ExpectedException("This discards the insert")
 
-    assert_not_in_transaction(conn)
-    assert_rows(conn, set())
+    assert not in_transaction(conn)
+    assert not inserted(conn)
 
 
 def test_prohibits_use_of_commit_rollback_autocommit(conn):
@@ -169,14 +180,14 @@ def test_autocommit_off_but_no_tx_started_successful_exit(conn, svcconn):
     * Changes made within Transaction context are committed
     """
     conn.autocommit = False
-    assert_not_in_transaction(conn)
+    assert not in_transaction(conn)
     with conn.transaction():
         insert_row(conn, "new")
-    assert_not_in_transaction(conn)
+    assert not in_transaction(conn)
 
     # Changes committed
-    assert_rows(conn, {"new"})
-    assert_rows(svcconn, {"new"})
+    assert inserted(conn) == {"new"}
+    assert inserted(svcconn) == {"new"}
 
 
 def test_autocommit_off_but_no_tx_started_exception_exit(conn, svcconn):
@@ -190,16 +201,16 @@ def test_autocommit_off_but_no_tx_started_exception_exit(conn, svcconn):
     * Changes made within Transaction context are discarded
     """
     conn.autocommit = False
-    assert_not_in_transaction(conn)
+    assert not in_transaction(conn)
     with pytest.raises(ExpectedException):
         with conn.transaction():
             insert_row(conn, "new")
             raise ExpectedException()
-    assert_not_in_transaction(conn)
+    assert not in_transaction(conn)
 
     # Changes discarded
-    assert_rows(conn, set())
-    assert_rows(svcconn, set())
+    assert not inserted(conn)
+    assert not inserted(svcconn)
 
 
 def test_autocommit_off_and_tx_in_progress_successful_exit(conn, svcconn):
@@ -216,13 +227,13 @@ def test_autocommit_off_and_tx_in_progress_successful_exit(conn, svcconn):
     """
     conn.autocommit = False
     insert_row(conn, "prior")
-    assert_in_transaction(conn)
+    assert in_transaction(conn)
     with conn.transaction():
         insert_row(conn, "new")
-    assert_in_transaction(conn)
-    assert_rows(conn, {"prior", "new"})
+    assert in_transaction(conn)
+    assert inserted(conn) == {"prior", "new"}
     # Nothing committed yet; changes not visible on another connection
-    assert_rows(svcconn, set())
+    assert not inserted(svcconn)
 
 
 def test_autocommit_off_and_tx_in_progress_exception_exit(conn, svcconn):
@@ -240,15 +251,15 @@ def test_autocommit_off_and_tx_in_progress_exception_exit(conn, svcconn):
     """
     conn.autocommit = False
     insert_row(conn, "prior")
-    assert_in_transaction(conn)
+    assert in_transaction(conn)
     with pytest.raises(ExpectedException):
         with conn.transaction():
             insert_row(conn, "new")
             raise ExpectedException()
-    assert_in_transaction(conn)
-    assert_rows(conn, {"prior"})
+    assert in_transaction(conn)
+    assert inserted(conn) == {"prior"}
     # Nothing committed yet; changes not visible on another connection
-    assert_rows(svcconn, set())
+    assert not inserted(svcconn)
 
 
 def test_nested_all_changes_persisted_on_successful_exit(conn, svcconn):
@@ -258,9 +269,9 @@ def test_nested_all_changes_persisted_on_successful_exit(conn, svcconn):
         with conn.transaction():
             insert_row(conn, "inner")
         insert_row(conn, "outer-after")
-    assert_not_in_transaction(conn)
-    assert_rows(conn, {"outer-before", "inner", "outer-after"})
-    assert_rows(svcconn, {"outer-before", "inner", "outer-after"})
+    assert not in_transaction(conn)
+    assert inserted(conn) == {"outer-before", "inner", "outer-after"}
+    assert inserted(svcconn) == {"outer-before", "inner", "outer-after"}
 
 
 def test_nested_all_changes_discarded_on_outer_exception(conn, svcconn):
@@ -274,9 +285,9 @@ def test_nested_all_changes_discarded_on_outer_exception(conn, svcconn):
             with conn.transaction():
                 insert_row(conn, "inner")
             raise ExpectedException()
-    assert_not_in_transaction(conn)
-    assert_rows(conn, set())
-    assert_rows(svcconn, set())
+    assert not in_transaction(conn)
+    assert not inserted(conn)
+    assert not inserted(svcconn)
 
 
 def test_nested_all_changes_discarded_on_inner_exception(conn, svcconn):
@@ -290,9 +301,9 @@ def test_nested_all_changes_discarded_on_inner_exception(conn, svcconn):
             with conn.transaction():
                 insert_row(conn, "inner")
                 raise ExpectedException()
-    assert_not_in_transaction(conn)
-    assert_rows(conn, set())
-    assert_rows(svcconn, set())
+    assert not in_transaction(conn)
+    assert not inserted(conn)
+    assert not inserted(svcconn)
 
 
 def test_nested_inner_scope_exception_handled_in_outer_scope(conn, svcconn):
@@ -309,9 +320,9 @@ def test_nested_inner_scope_exception_handled_in_outer_scope(conn, svcconn):
                 insert_row(conn, "inner")
                 raise ExpectedException()
         insert_row(conn, "outer-after")
-    assert_not_in_transaction(conn)
-    assert_rows(conn, {"outer-before", "outer-after"})
-    assert_rows(svcconn, {"outer-before", "outer-after"})
+    assert not in_transaction(conn)
+    assert inserted(conn) == {"outer-before", "outer-after"}
+    assert inserted(svcconn) == {"outer-before", "outer-after"}
 
 
 def test_nested_three_levels_successful_exit(conn, svcconn):
@@ -322,9 +333,9 @@ def test_nested_three_levels_successful_exit(conn, svcconn):
             insert_row(conn, "two")
             with conn.transaction():  # SAVEPOINT s2
                 insert_row(conn, "three")
-    assert_not_in_transaction(conn)
-    assert_rows(conn, {"one", "two", "three"})
-    assert_rows(svcconn, {"one", "two", "three"})
+    assert not in_transaction(conn)
+    assert inserted(conn) == {"one", "two", "three"}
+    assert inserted(svcconn) == {"one", "two", "three"}
 
 
 def test_named_savepoint_escapes_savepoint_name(conn):
@@ -334,7 +345,7 @@ def test_named_savepoint_escapes_savepoint_name(conn):
         pass
 
 
-def test_named_savepoints_successful_exit(conn):
+def test_named_savepoints_successful_exit(conn, commands):
     """
     Entering a transaction context will do one of these these things:
     1. Begin an outer transaction (if one isn't already in progress)
@@ -347,40 +358,49 @@ def test_named_savepoints_successful_exit(conn):
     # Case 1
     # Using Transaction explicitly becase conn.transaction() enters the contetx
     tx = Transaction(conn)
-    with assert_commands_issued(conn, "begin"):
-        tx.__enter__()
+    assert not commands
+    tx.__enter__()
+    assert commands.pop() == "begin"
     assert not tx.savepoint_name
-    with assert_commands_issued(conn, "commit"):
-        tx.__exit__(None, None, None)
+    tx.__exit__(None, None, None)
+    assert commands.pop() == "commit"
 
     # Case 2
     tx = Transaction(conn, savepoint_name="foo")
-    with assert_commands_issued(conn, "begin", 'savepoint "foo"'):
-        tx.__enter__()
+    tx.__enter__()
+    assert commands.pop() == "begin"
+    assert commands.pop() == 'savepoint "foo"'
     assert tx.savepoint_name == "foo"
-    with assert_commands_issued(conn, 'release savepoint "foo"', "commit"):
-        tx.__exit__(None, None, None)
+    tx.__exit__(None, None, None)
+    assert commands.pop() == 'release savepoint "foo"'
+    assert commands.pop() == "commit"
 
     # Case 3 (with savepoint name provided)
-    with Transaction(conn):
+    with conn.transaction():
+        assert commands.pop() == "begin"
         tx = Transaction(conn, savepoint_name="bar")
-        with assert_commands_issued(conn, 'savepoint "bar"'):
-            tx.__enter__()
+        tx.__enter__()
+        assert commands.pop() == 'savepoint "bar"'
         assert tx.savepoint_name == "bar"
-        with assert_commands_issued(conn, 'release savepoint "bar"'):
-            tx.__exit__(None, None, None)
+        tx.__exit__(None, None, None)
+        assert commands.pop() == 'release savepoint "bar"'
+    assert commands.pop() == "commit"
 
     # Case 3 (with savepoint name auto-generated)
     with conn.transaction():
+        assert commands.pop() == "begin"
         tx = Transaction(conn)
-        with assert_commands_issued(conn, 'savepoint "s1"'):
-            tx.__enter__()
+        tx.__enter__()
+        assert commands.pop() == 'savepoint "s1"'
         assert tx.savepoint_name == "s1"
-        with assert_commands_issued(conn, 'release savepoint "s1"'):
-            tx.__exit__(None, None, None)
+        tx.__exit__(None, None, None)
+        assert commands.pop() == 'release savepoint "s1"'
+    assert commands.pop() == "commit"
+
+    assert not commands
 
 
-def test_named_savepoints_exception_exit(conn):
+def test_named_savepoints_exception_exit(conn, commands):
     """
     Same as the previous test but checks that when exiting the context with an
     exception, whatever transaction and/or savepoint was started on enter will
@@ -388,45 +408,54 @@ def test_named_savepoints_exception_exit(conn):
     """
     # Case 1
     tx = Transaction(conn)
-    with assert_commands_issued(conn, "begin"):
-        tx.__enter__()
+    tx.__enter__()
+    assert commands.pop() == "begin"
     assert not tx.savepoint_name
-    with assert_commands_issued(conn, "rollback"):
-        tx.__exit__(*some_exc_info())
+    tx.__exit__(*some_exc_info())
+    assert commands.pop() == "rollback"
 
     # Case 2
     tx = Transaction(conn, savepoint_name="foo")
-    with assert_commands_issued(conn, "begin", 'savepoint "foo"'):
-        tx.__enter__()
+    tx.__enter__()
+    assert commands.pop() == "begin"
+    assert commands.pop() == 'savepoint "foo"'
     assert tx.savepoint_name == "foo"
-    with assert_commands_issued(
-        conn,
-        'rollback to savepoint "foo"; release savepoint "foo"',
-        "rollback",
-    ):
-        tx.__exit__(*some_exc_info())
+    tx.__exit__(*some_exc_info())
+    assert (
+        commands.pop()
+        == 'rollback to savepoint "foo"; release savepoint "foo"'
+    )
+    assert commands.pop() == "rollback"
 
     # Case 3 (with savepoint name provided)
     with conn.transaction():
+        assert commands.pop() == "begin"
         tx = Transaction(conn, savepoint_name="bar")
-        with assert_commands_issued(conn, 'savepoint "bar"'):
-            tx.__enter__()
+        tx.__enter__()
+        assert commands.pop() == 'savepoint "bar"'
         assert tx.savepoint_name == "bar"
-        with assert_commands_issued(
-            conn, 'rollback to savepoint "bar"; release savepoint "bar"'
-        ):
-            tx.__exit__(*some_exc_info())
+        tx.__exit__(*some_exc_info())
+        assert (
+            commands.pop()
+            == 'rollback to savepoint "bar"; release savepoint "bar"'
+        )
+    assert commands.pop() == "commit"
 
     # Case 3 (with savepoint name auto-generated)
     with conn.transaction():
+        assert commands.pop() == "begin"
         tx = Transaction(conn)
-        with assert_commands_issued(conn, 'savepoint "s1"'):
-            tx.__enter__()
+        tx.__enter__()
+        assert commands.pop() == 'savepoint "s1"'
         assert tx.savepoint_name == "s1"
-        with assert_commands_issued(
-            conn, 'rollback to savepoint "s1"; release savepoint "s1"'
-        ):
-            tx.__exit__(*some_exc_info())
+        tx.__exit__(*some_exc_info())
+        assert (
+            commands.pop()
+            == 'rollback to savepoint "s1"; release savepoint "s1"'
+        )
+    assert commands.pop() == "commit"
+
+    assert not commands
 
 
 def test_named_savepoints_with_repeated_names_works(conn):
@@ -442,7 +471,7 @@ def test_named_savepoints_with_repeated_names_works(conn):
                 insert_row(conn, "tx2")
                 with conn.transaction("sp"):
                     insert_row(conn, "tx3")
-        assert_rows(conn, {"tx1", "tx2", "tx3"})
+        assert inserted(conn) == {"tx1", "tx2", "tx3"}
 
     # Works correctly if one level of inner transaction is rolled back
     with conn.transaction(force_rollback=True):
@@ -452,8 +481,8 @@ def test_named_savepoints_with_repeated_names_works(conn):
                 insert_row(conn, "tx2")
                 with conn.transaction("s1"):
                     insert_row(conn, "tx3")
-            assert_rows(conn, {"tx1"})
-        assert_rows(conn, {"tx1"})
+            assert inserted(conn) == {"tx1"}
+        assert inserted(conn) == {"tx1"}
 
     # Works correctly if multiple inner transactions are rolled back
     # (This scenario mandates releasing savepoints after rolling back to them.)
@@ -465,8 +494,8 @@ def test_named_savepoints_with_repeated_names_works(conn):
                 with conn.transaction("s1"):
                     insert_row(conn, "tx3")
                     raise Rollback(tx2)
-            assert_rows(conn, {"tx1"})
-        assert_rows(conn, {"tx1"})
+            assert inserted(conn) == {"tx1"}
+        assert inserted(conn) == {"tx1"}
 
     # Will not (always) catch out-of-order exits
     with conn.transaction(force_rollback=True):
@@ -485,8 +514,8 @@ def test_force_rollback_successful_exit(conn, svcconn):
     """
     with conn.transaction(force_rollback=True):
         insert_row(conn, "foo")
-    assert_rows(conn, set())
-    assert_rows(svcconn, set())
+    assert not inserted(conn)
+    assert not inserted(svcconn)
 
 
 def test_force_rollback_exception_exit(conn, svcconn):
@@ -498,8 +527,8 @@ def test_force_rollback_exception_exit(conn, svcconn):
         with conn.transaction(force_rollback=True):
             insert_row(conn, "foo")
             raise ExpectedException()
-    assert_rows(conn, set())
-    assert_rows(svcconn, set())
+    assert not inserted(conn)
+    assert not inserted(svcconn)
 
 
 def test_explicit_rollback_discards_changes(conn, svcconn):
@@ -515,8 +544,8 @@ def test_explicit_rollback_discards_changes(conn, svcconn):
     """
 
     def assert_no_rows():
-        assert_rows(conn, set())
-        assert_rows(svcconn, set())
+        assert not inserted(conn)
+        assert not inserted(svcconn)
 
     with conn.transaction():
         insert_row(conn, "foo")
@@ -544,11 +573,11 @@ def test_explicit_rollback_outer_tx_unaffected(conn, svcconn):
         with conn.transaction():
             insert_row(conn, "during")
             raise Rollback
-        assert_in_transaction(conn)
-        assert_rows(svcconn, set())
+        assert in_transaction(conn)
+        assert not inserted(svcconn)
         insert_row(conn, "after")
-    assert_rows(conn, {"before", "after"})
-    assert_rows(svcconn, {"before", "after"})
+    assert inserted(conn) == {"before", "after"}
+    assert inserted(svcconn) == {"before", "after"}
 
 
 def test_explicit_rollback_of_outer_transaction(conn):
@@ -562,7 +591,7 @@ def test_explicit_rollback_of_outer_transaction(conn):
             insert_row(conn, "inner")
             raise Rollback(outer_tx)
         assert False, "This line of code should be unreachable."
-    assert_rows(conn, set())
+    assert not inserted(conn)
 
 
 def test_explicit_rollback_of_enclosing_tx_outer_tx_unaffected(conn, svcconn):
@@ -578,9 +607,10 @@ def test_explicit_rollback_of_enclosing_tx_outer_tx_unaffected(conn, svcconn):
                 raise Rollback(tx_enclosing)
         insert_row(conn, "outer-after")
 
-        assert_rows(conn, {"outer-before", "outer-after"})
-        assert_rows(svcconn, set())  # Not yet committed
-    assert_rows(svcconn, {"outer-before", "outer-after"})  # Changes committed
+        assert inserted(conn) == {"outer-before", "outer-after"}
+        assert not inserted(svcconn)  # Not yet committed
+    # Changes committed
+    assert inserted(svcconn) == {"outer-before", "outer-after"}
 
 
 @pytest.mark.parametrize("exc_info", [(None, None, None), some_exc_info()])
diff --git a/tests/test_transaction_async.py b/tests/test_transaction_async.py
new file mode 100644 (file)
index 0000000..ceac50e
--- /dev/null
@@ -0,0 +1,628 @@
+import pytest
+
+from psycopg3 import AsyncConnection, ProgrammingError, Rollback
+from psycopg3.sql import Composable
+from psycopg3.transaction import AsyncTransaction
+
+from .test_transaction import (
+    in_transaction,
+    ExpectedException,
+    some_exc_info,
+    inserted,
+)
+from .test_transaction import create_test_table  # noqa  # autouse fixture
+
+pytestmark = pytest.mark.asyncio
+
+
+@pytest.fixture
+async def commands(monkeypatch):
+    """The queue of commands issued internally by connections.
+
+    Not concurrency safe as it mokkeypatches a class method, but good enough
+    for tests.
+    """
+    _orig_exec_command = AsyncConnection._exec_command
+    L = []
+
+    async def _exec_command(self, command):
+        if isinstance(command, bytes):
+            command = command.decode(self.client_encoding)
+        elif isinstance(command, Composable):
+            command = command.as_string(self)
+
+        L.insert(0, command)
+        await _orig_exec_command(self, command)
+
+    monkeypatch.setattr(AsyncConnection, "_exec_command", _exec_command)
+    yield L
+
+
+async def insert_row(aconn, value):
+    await (await aconn.cursor()).execute(
+        "INSERT INTO test_table VALUES (%s)", (value,)
+    )
+
+
+async def test_basic(aconn):
+    """Basic use of transaction() to BEGIN and COMMIT a transaction."""
+    assert not in_transaction(aconn)
+    async with aconn.transaction():
+        assert in_transaction(aconn)
+    assert not in_transaction(aconn)
+
+
+async def test_exposes_associated_connection(aconn):
+    """Transaction exposes its connection as a read-only property."""
+    async with aconn.transaction() as tx:
+        assert tx.connection is aconn
+        with pytest.raises(AttributeError):
+            tx.connection = aconn
+
+
+async def test_exposes_savepoint_name(aconn):
+    """Transaction exposes its savepoint name as a read-only property."""
+    async with aconn.transaction(savepoint_name="foo") as tx:
+        assert tx.savepoint_name == "foo"
+        with pytest.raises(AttributeError):
+            tx.savepoint_name = "bar"
+
+
+async def test_begins_on_enter(aconn):
+    """Transaction does not begin until __enter__() is called."""
+    tx = aconn.transaction()
+    assert not in_transaction(aconn)
+    async with tx:
+        assert in_transaction(aconn)
+    assert not in_transaction(aconn)
+
+
+async def test_commit_on_successful_exit(aconn):
+    """Changes are committed on successful exit from the `with` block."""
+    async with aconn.transaction():
+        await insert_row(aconn, "foo")
+
+    assert not in_transaction(aconn)
+    assert await inserted(aconn) == {"foo"}
+
+
+async def test_rollback_on_exception_exit(aconn):
+    """Changes are rolled back if an exception escapes the `with` block."""
+    with pytest.raises(ExpectedException):
+        async with aconn.transaction():
+            await insert_row(aconn, "foo")
+            raise ExpectedException("This discards the insert")
+
+    assert not in_transaction(aconn)
+    assert not await inserted(aconn)
+
+
+async def test_prohibits_use_of_commit_rollback_autocommit(aconn):
+    """
+    Within a Transaction block, it is forbidden to touch commit, rollback,
+    or the autocommit setting on the connection, as this would interfere
+    with the transaction scope being managed by the Transaction block.
+    """
+    await aconn.set_autocommit(False)
+    await aconn.commit()
+    await aconn.rollback()
+
+    async with aconn.transaction():
+        with pytest.raises(ProgrammingError):
+            await aconn.set_autocommit(False)
+        with pytest.raises(ProgrammingError):
+            await aconn.commit()
+        with pytest.raises(ProgrammingError):
+            await aconn.rollback()
+
+    await aconn.set_autocommit(False)
+    await aconn.commit()
+    await aconn.rollback()
+
+
+@pytest.mark.parametrize("autocommit", [False, True])
+async def test_preserves_autocommit(aconn, autocommit):
+    """
+    Connection.autocommit is unchanged both during and after Transaction block.
+    """
+    await aconn.set_autocommit(autocommit)
+    async with aconn.transaction():
+        assert aconn.autocommit is autocommit
+    assert aconn.autocommit is autocommit
+
+
+async def test_autocommit_off_but_no_tx_started_successful_exit(
+    aconn, svcconn
+):
+    """
+    Scenario:
+    * Connection has autocommit off but no transaction has been initiated
+      before entering the Transaction context
+    * Code exits Transaction context successfully
+
+    Outcome:
+    * Changes made within Transaction context are committed
+    """
+    await aconn.set_autocommit(False)
+    assert not in_transaction(aconn)
+    async with aconn.transaction():
+        await insert_row(aconn, "new")
+    assert not in_transaction(aconn)
+
+    # Changes committed
+    assert await inserted(aconn) == {"new"}
+    assert inserted(svcconn) == {"new"}
+
+
+async def test_autocommit_off_but_no_tx_started_exception_exit(aconn, svcconn):
+    """
+    Scenario:
+    * Connection has autocommit off but no transaction has been initiated
+      before entering the Transaction context
+    * Code exits Transaction context with an exception
+
+    Outcome:
+    * Changes made within Transaction context are discarded
+    """
+    await aconn.set_autocommit(False)
+    assert not in_transaction(aconn)
+    with pytest.raises(ExpectedException):
+        async with aconn.transaction():
+            await insert_row(aconn, "new")
+            raise ExpectedException()
+    assert not in_transaction(aconn)
+
+    # Changes discarded
+    assert not await inserted(aconn)
+    assert not inserted(svcconn)
+
+
+async def test_autocommit_off_and_tx_in_progress_successful_exit(
+    aconn, svcconn
+):
+    """
+    Scenario:
+    * Connection has autocommit off but and a transaction is already in
+      progress before entering the Transaction context
+    * Code exits Transaction context successfully
+
+    Outcome:
+    * Changes made within Transaction context are left intact
+    * Outer transaction is left running, and no changes are visible to an
+      outside observer from another connection.
+    """
+    await aconn.set_autocommit(False)
+    await insert_row(aconn, "prior")
+    assert in_transaction(aconn)
+    async with aconn.transaction():
+        await insert_row(aconn, "new")
+    assert in_transaction(aconn)
+    assert await inserted(aconn) == {"prior", "new"}
+    # Nothing committed yet; changes not visible on another connection
+    assert not inserted(svcconn)
+
+
+async def test_autocommit_off_and_tx_in_progress_exception_exit(
+    aconn, svcconn
+):
+    """
+    Scenario:
+    * Connection has autocommit off but and a transaction is already in
+      progress before entering the Transaction context
+    * Code exits Transaction context with an exception
+
+    Outcome:
+    * Changes made before the Transaction context are left intact
+    * Changes made within Transaction context are discarded
+    * Outer transaction is left running, and no changes are visible to an
+      outside observer from another connection.
+    """
+    await aconn.set_autocommit(False)
+    await insert_row(aconn, "prior")
+    assert in_transaction(aconn)
+    with pytest.raises(ExpectedException):
+        async with aconn.transaction():
+            await insert_row(aconn, "new")
+            raise ExpectedException()
+    assert in_transaction(aconn)
+    assert await inserted(aconn) == {"prior"}
+    # Nothing committed yet; changes not visible on another connection
+    assert not inserted(svcconn)
+
+
+async def test_nested_all_changes_persisted_on_successful_exit(aconn, svcconn):
+    """Changes from nested transaction contexts are all persisted on exit."""
+    async with aconn.transaction():
+        await insert_row(aconn, "outer-before")
+        async with aconn.transaction():
+            await insert_row(aconn, "inner")
+        await insert_row(aconn, "outer-after")
+    assert not in_transaction(aconn)
+    assert await inserted(aconn) == {"outer-before", "inner", "outer-after"}
+    assert inserted(svcconn) == {"outer-before", "inner", "outer-after"}
+
+
+async def test_nested_all_changes_discarded_on_outer_exception(aconn, svcconn):
+    """
+    Changes from nested transaction contexts are discarded when an exception
+    raised in outer context escapes.
+    """
+    with pytest.raises(ExpectedException):
+        async with aconn.transaction():
+            await insert_row(aconn, "outer")
+            async with aconn.transaction():
+                await insert_row(aconn, "inner")
+            raise ExpectedException()
+    assert not in_transaction(aconn)
+    assert not await inserted(aconn)
+    assert not inserted(svcconn)
+
+
+async def test_nested_all_changes_discarded_on_inner_exception(aconn, svcconn):
+    """
+    Changes from nested transaction contexts are discarded when an exception
+    raised in inner context escapes the outer context.
+    """
+    with pytest.raises(ExpectedException):
+        async with aconn.transaction():
+            await insert_row(aconn, "outer")
+            async with aconn.transaction():
+                await insert_row(aconn, "inner")
+                raise ExpectedException()
+    assert not in_transaction(aconn)
+    assert not await inserted(aconn)
+    assert not inserted(svcconn)
+
+
+async def test_nested_inner_scope_exception_handled_in_outer_scope(
+    aconn, svcconn
+):
+    """
+    An exception escaping the inner transaction context causes changes made
+    within that inner context to be discarded, but the error can then be
+    handled in the outer context, allowing changes made in the outer context
+    (both before, and after, the inner context) to be successfully committed.
+    """
+    async with aconn.transaction():
+        await insert_row(aconn, "outer-before")
+        with pytest.raises(ExpectedException):
+            async with aconn.transaction():
+                await insert_row(aconn, "inner")
+                raise ExpectedException()
+        await insert_row(aconn, "outer-after")
+    assert not in_transaction(aconn)
+    assert await inserted(aconn) == {"outer-before", "outer-after"}
+    assert inserted(svcconn) == {"outer-before", "outer-after"}
+
+
+async def test_nested_three_levels_successful_exit(aconn, svcconn):
+    """Exercise management of more than one savepoint."""
+    async with aconn.transaction():  # BEGIN
+        await insert_row(aconn, "one")
+        async with aconn.transaction():  # SAVEPOINT s1
+            await insert_row(aconn, "two")
+            async with aconn.transaction():  # SAVEPOINT s2
+                await insert_row(aconn, "three")
+    assert not in_transaction(aconn)
+    assert await inserted(aconn) == {"one", "two", "three"}
+    assert inserted(svcconn) == {"one", "two", "three"}
+
+
+async def test_named_savepoint_escapes_savepoint_name(aconn):
+    async with aconn.transaction("s-1"):
+        pass
+    async with aconn.transaction("s1; drop table students"):
+        pass
+
+
+async def test_named_savepoints_successful_exit(aconn, commands):
+    """
+    Entering a transaction context will do one of these these things:
+    1. Begin an outer transaction (if one isn't already in progress)
+    2. Begin an outer transaction and create a savepoint (if one is named)
+    3. Create a savepoint (if a transaction is already in progress)
+       either using the name provided, or auto-generating a savepoint name.
+
+    ...and exiting the context successfully will "commit" the same.
+    """
+    # Case 1
+    # Using Transaction explicitly becase conn.transaction() enters the contetx
+    tx = AsyncTransaction(aconn)
+    await tx.__aenter__()
+    assert commands.pop() == "begin"
+    assert not tx.savepoint_name
+    await tx.__aexit__(None, None, None)
+    assert commands.pop() == "commit"
+
+    # Case 2
+    tx = AsyncTransaction(aconn, savepoint_name="foo")
+    await tx.__aenter__()
+    assert commands.pop() == "begin"
+    assert commands.pop() == 'savepoint "foo"'
+    assert tx.savepoint_name == "foo"
+    await tx.__aexit__(None, None, None)
+    assert commands.pop() == 'release savepoint "foo"'
+    assert commands.pop() == "commit"
+
+    # Case 3 (with savepoint name provided)
+    async with aconn.transaction():
+        assert commands.pop() == "begin"
+        tx = AsyncTransaction(aconn, savepoint_name="bar")
+        await tx.__aenter__()
+        assert commands.pop() == 'savepoint "bar"'
+        assert tx.savepoint_name == "bar"
+        await tx.__aexit__(None, None, None)
+        assert commands.pop() == 'release savepoint "bar"'
+    assert commands.pop() == "commit"
+
+    # Case 3 (with savepoint name auto-generated)
+    async with aconn.transaction():
+        assert commands.pop() == "begin"
+        tx = AsyncTransaction(aconn)
+        await tx.__aenter__()
+        assert commands.pop() == 'savepoint "s1"'
+        assert tx.savepoint_name == "s1"
+        await tx.__aexit__(None, None, None)
+        assert commands.pop() == 'release savepoint "s1"'
+    assert commands.pop() == "commit"
+
+    assert not commands
+
+
+async def test_named_savepoints_exception_exit(aconn, commands):
+    """
+    Same as the previous test but checks that when exiting the context with an
+    exception, whatever transaction and/or savepoint was started on enter will
+    be rolled-back as appropriate.
+    """
+    # Case 1
+    tx = AsyncTransaction(aconn)
+    await tx.__aenter__()
+    assert commands.pop() == "begin"
+    assert not tx.savepoint_name
+    await tx.__aexit__(*some_exc_info())
+    assert commands.pop() == "rollback"
+
+    # Case 2
+    tx = AsyncTransaction(aconn, savepoint_name="foo")
+    await tx.__aenter__()
+    assert commands.pop() == "begin"
+    assert commands.pop() == 'savepoint "foo"'
+    assert tx.savepoint_name == "foo"
+    await tx.__aexit__(*some_exc_info())
+    assert (
+        commands.pop()
+        == 'rollback to savepoint "foo"; release savepoint "foo"'
+    )
+    assert commands.pop() == "rollback"
+
+    # Case 3 (with savepoint name provided)
+    async with aconn.transaction():
+        assert commands.pop() == "begin"
+        tx = AsyncTransaction(aconn, savepoint_name="bar")
+        await tx.__aenter__()
+        assert commands.pop() == 'savepoint "bar"'
+        assert tx.savepoint_name == "bar"
+        await tx.__aexit__(*some_exc_info())
+        assert (
+            commands.pop()
+            == 'rollback to savepoint "bar"; release savepoint "bar"'
+        )
+    assert commands.pop() == "commit"
+
+    # Case 3 (with savepoint name auto-generated)
+    async with aconn.transaction():
+        assert commands.pop() == "begin"
+        tx = AsyncTransaction(aconn)
+        await tx.__aenter__()
+        assert commands.pop() == 'savepoint "s1"'
+        assert tx.savepoint_name == "s1"
+        await tx.__aexit__(*some_exc_info())
+        assert (
+            commands.pop()
+            == 'rollback to savepoint "s1"; release savepoint "s1"'
+        )
+    assert commands.pop() == "commit"
+
+    assert not commands
+
+
+async def test_named_savepoints_with_repeated_names_works(aconn):
+    """
+    Using the same savepoint name repeatedly works correctly, but bypasses
+    some sanity checks.
+    """
+    # Works correctly if no inner transactions are rolled back
+    async with aconn.transaction(force_rollback=True):
+        async with aconn.transaction("sp"):
+            await insert_row(aconn, "tx1")
+            async with aconn.transaction("sp"):
+                await insert_row(aconn, "tx2")
+                async with aconn.transaction("sp"):
+                    await insert_row(aconn, "tx3")
+        assert await inserted(aconn) == {"tx1", "tx2", "tx3"}
+
+    # Works correctly if one level of inner transaction is rolled back
+    async with aconn.transaction(force_rollback=True):
+        async with aconn.transaction("s1"):
+            await insert_row(aconn, "tx1")
+            async with aconn.transaction("s1", force_rollback=True):
+                await insert_row(aconn, "tx2")
+                async with aconn.transaction("s1"):
+                    await insert_row(aconn, "tx3")
+            assert await inserted(aconn) == {"tx1"}
+        assert await inserted(aconn) == {"tx1"}
+
+    # Works correctly if multiple inner transactions are rolled back
+    # (This scenario mandates releasing savepoints after rolling back to them.)
+    async with aconn.transaction(force_rollback=True):
+        async with aconn.transaction("s1"):
+            await insert_row(aconn, "tx1")
+            async with aconn.transaction("s1") as tx2:
+                await insert_row(aconn, "tx2")
+                async with aconn.transaction("s1"):
+                    await insert_row(aconn, "tx3")
+                    raise Rollback(tx2)
+            assert await inserted(aconn) == {"tx1"}
+        assert await inserted(aconn) == {"tx1"}
+
+    # Will not (always) catch out-of-order exits
+    async with aconn.transaction(force_rollback=True):
+        tx1 = aconn.transaction("s1")
+        tx2 = aconn.transaction("s1")
+        await tx1.__aenter__()
+        await tx2.__aenter__()
+        await tx1.__aexit__(None, None, None)
+        await tx2.__aexit__(None, None, None)
+
+
+async def test_force_rollback_successful_exit(aconn, svcconn):
+    """
+    Transaction started with the force_rollback option enabled discards all
+    changes at the end of the context.
+    """
+    async with aconn.transaction(force_rollback=True):
+        await insert_row(aconn, "foo")
+    assert not await inserted(aconn)
+    assert not inserted(svcconn)
+
+
+async def test_force_rollback_exception_exit(aconn, svcconn):
+    """
+    Transaction started with the force_rollback option enabled discards all
+    changes at the end of the context.
+    """
+    with pytest.raises(ExpectedException):
+        async with aconn.transaction(force_rollback=True):
+            await insert_row(aconn, "foo")
+            raise ExpectedException()
+    assert not await inserted(aconn)
+    assert not inserted(svcconn)
+
+
+async def test_explicit_rollback_discards_changes(aconn, svcconn):
+    """
+    Raising a Rollback exception in the middle of a block exits the block and
+    discards all changes made within that block.
+
+    You can raise any of the following:
+     - Rollback (type)
+     - Rollback() (instance)
+     - Rollback(tx) (instance initialised with reference to the transaction)
+    All of these are equivalent.
+    """
+
+    async def assert_no_rows():
+        assert not await inserted(aconn)
+        assert not inserted(svcconn)
+
+    async with aconn.transaction():
+        await insert_row(aconn, "foo")
+        raise Rollback
+    await assert_no_rows()
+
+    async with aconn.transaction():
+        await insert_row(aconn, "foo")
+        raise Rollback()
+    await assert_no_rows()
+
+    async with aconn.transaction() as tx:
+        await insert_row(aconn, "foo")
+        raise Rollback(tx)
+    await assert_no_rows()
+
+
+async def test_explicit_rollback_outer_tx_unaffected(aconn, svcconn):
+    """
+    Raising a Rollback exception in the middle of a block does not impact an
+    enclosing transaction block.
+    """
+    async with aconn.transaction():
+        await insert_row(aconn, "before")
+        async with aconn.transaction():
+            await insert_row(aconn, "during")
+            raise Rollback
+        assert in_transaction(aconn)
+        assert not inserted(svcconn)
+        await insert_row(aconn, "after")
+    assert await inserted(aconn) == {"before", "after"}
+    assert inserted(svcconn) == {"before", "after"}
+
+
+async def test_explicit_rollback_of_outer_transaction(aconn):
+    """
+    Raising a Rollback exception that references an outer transaction will
+    discard all changes from both inner and outer transaction blocks.
+    """
+    async with aconn.transaction() as outer_tx:
+        await insert_row(aconn, "outer")
+        async with aconn.transaction():
+            await insert_row(aconn, "inner")
+            raise Rollback(outer_tx)
+        assert False, "This line of code should be unreachable."
+    assert not await inserted(aconn)
+
+
+async def test_explicit_rollback_of_enclosing_tx_outer_tx_unaffected(
+    aconn, svcconn
+):
+    """
+    Rolling-back an enclosing transaction does not impact an outer transaction.
+    """
+    async with aconn.transaction():
+        await insert_row(aconn, "outer-before")
+        async with aconn.transaction() as tx_enclosing:
+            await insert_row(aconn, "enclosing")
+            async with aconn.transaction():
+                await insert_row(aconn, "inner")
+                raise Rollback(tx_enclosing)
+        await insert_row(aconn, "outer-after")
+
+        assert await inserted(aconn) == {"outer-before", "outer-after"}
+        assert not inserted(svcconn)  # Not yet committed
+    # Changes committed
+    assert inserted(svcconn) == {"outer-before", "outer-after"}
+
+
+@pytest.mark.parametrize("exc_info", [(None, None, None), some_exc_info()])
+@pytest.mark.parametrize("name", [None, "s1"])
+async def test_manual_enter_and_exit_out_of_order_exit_asserts(
+    aconn, name, exc_info
+):
+    """
+    When user is calling __enter__() and __exit__() manually for some reason,
+    provide a helpful error message if they call __exit__() in the wrong order
+    for nested transactions.
+    """
+    tx1, tx2 = AsyncTransaction(aconn, name), AsyncTransaction(aconn)
+    await tx1.__aenter__()
+    await tx2.__aenter__()
+    with pytest.raises(ProgrammingError, match="Out-of-order"):
+        await tx1.__aexit__(*exc_info)
+
+
+@pytest.mark.parametrize("exc_info", [(None, None, None), some_exc_info()])
+@pytest.mark.parametrize("name", [None, "s1"])
+async def test_manual_exit_without_enter_asserts(aconn, name, exc_info):
+    """
+    When user is calling __enter__() and __exit__() manually for some reason,
+    provide a helpful error message if they call __exit__() without first
+    having called __enter__()
+    """
+    tx = AsyncTransaction(aconn, name)
+    with pytest.raises(ProgrammingError, match="Out-of-order"):
+        await tx.__aexit__(*exc_info)
+
+
+@pytest.mark.parametrize("exc_info", [(None, None, None), some_exc_info()])
+@pytest.mark.parametrize("name", [None, "s1"])
+async def test_manual_exit_twice_asserts(aconn, name, exc_info):
+    """
+    When user is calling __enter__() and __exit__() manually for some reason,
+    provide a helpful error message if they accidentally call __exit__() twice.
+    """
+    tx = AsyncTransaction(aconn, name)
+    await tx.__aenter__()
+    await tx.__aexit__(*exc_info)
+    with pytest.raises(ProgrammingError, match="Out-of-order"):
+        await tx.__aexit__(*exc_info)