# Copyright (C) 2020 The Psycopg Team
-import logging
+import sys
import asyncio
+import logging
import threading
from types import TracebackType
from typing import Any, AsyncIterator, Callable, Iterator, List, NamedTuple
from functools import partial
from contextlib import contextmanager
+if sys.version_info >= (3, 7):
+ from contextlib import asynccontextmanager
+else:
+ from .utils.context import asynccontextmanager
+
from . import pq
from . import cursor
from . import errors as e
from .waiting import wait, wait_async
from .conninfo import make_conninfo
from .generators import notifies
-from .transaction import Transaction
+from .transaction import Transaction, AsyncTransaction
logger = logging.getLogger(__name__)
package_logger = logging.getLogger("psycopg3")
savepoint_name: Optional[str] = None,
force_rollback: bool = False,
) -> Iterator[Transaction]:
- with Transaction(self, savepoint_name or "", force_rollback) as tx:
+ with Transaction(self, savepoint_name, force_rollback) as tx:
yield tx
@classmethod
if self.pgconn.transaction_status != TransactionStatus.IDLE:
return
- self.pgconn.send_query(b"begin")
- (pgres,) = await self.wait(execute(self.pgconn))
- if pgres.status != ExecStatus.COMMAND_OK:
- raise e.OperationalError(
- "error on begin:"
- f" {pq.error_message(pgres, encoding=self.client_encoding)}"
- )
+ await self._exec_command(b"begin")
async def commit(self) -> None:
async with self.lock:
if self.pgconn.transaction_status == TransactionStatus.IDLE:
return
- await self._exec(b"commit")
+ if self._savepoints is not None:
+ raise e.ProgrammingError(
+ "Explicit commit() forbidden within a Transaction "
+ "context. (Transaction will be automatically committed "
+ "on successful exit from context.)"
+ )
+ await self._exec_command(b"commit")
async def rollback(self) -> None:
async with self.lock:
if self.pgconn.transaction_status == TransactionStatus.IDLE:
return
- await self._exec(b"rollback")
+ if self._savepoints is not None:
+ raise e.ProgrammingError(
+ "Explicit rollback() forbidden within a Transaction "
+ "context. (Either raise Transaction.Rollback() or allow "
+ "an exception to propagate out of the context.)"
+ )
+ await self._exec_command(b"rollback")
- async def _exec(self, command: bytes) -> None:
+ async def _exec_command(self, command: Query) -> None:
# Caller must hold self.lock
+
+ if isinstance(command, str):
+ command = command.encode(self.client_encoding)
+ elif isinstance(command, Composable):
+ command = command.as_string(self).encode(self.client_encoding)
+
self.pgconn.send_query(command)
- (pgres,) = await self.wait(execute(self.pgconn))
- if pgres.status != ExecStatus.COMMAND_OK:
+ results = await self.wait(execute(self.pgconn))
+ if results[-1].status != ExecStatus.COMMAND_OK:
raise e.OperationalError(
f"error on {command.decode('utf8')}:"
- f" {pq.error_message(pgres, encoding=self.client_encoding)}"
+ f" {pq.error_message(results[-1], encoding=self.client_encoding)}"
)
+ @asynccontextmanager
+ async def transaction(
+ self,
+ savepoint_name: Optional[str] = None,
+ force_rollback: bool = False,
+ ) -> AsyncIterator[AsyncTransaction]:
+ tx = AsyncTransaction(self, savepoint_name, force_rollback)
+ async with tx:
+ yield tx
+
@classmethod
async def wait(cls, gen: PQGen[RV]) -> RV:
return await wait_async(gen)
import logging
from types import TracebackType
-from typing import Optional, Type, TYPE_CHECKING
+from typing import Generic, Optional, Type, Union, TYPE_CHECKING
from . import sql
from .pq import TransactionStatus
-from psycopg3.errors import ProgrammingError
+from .proto import ConnectionType
+from .errors import ProgrammingError
if TYPE_CHECKING:
- from .connection import Connection
+ from .connection import Connection, AsyncConnection # noqa: F401
_log = logging.getLogger(__name__)
enclosing transactions contexts up to and including the one specified.
"""
- def __init__(self, transaction: Optional["Transaction"] = None):
+ def __init__(
+ self,
+ transaction: Union["Transaction", "AsyncTransaction", None] = None,
+ ):
self.transaction = transaction
def __repr__(self) -> str:
return f"{self.__class__.__qualname__}({self.transaction!r})"
-class Transaction:
+class BaseTransaction(Generic[ConnectionType]):
def __init__(
self,
- connection: "Connection",
- savepoint_name: str = "",
+ connection: ConnectionType,
+ savepoint_name: Optional[str] = None,
force_rollback: bool = False,
):
self._conn = connection
- self._savepoint_name = savepoint_name
+ self._savepoint_name = savepoint_name or ""
self.force_rollback = force_rollback
self._outer_transaction: Optional[bool] = None
@property
- def connection(self) -> "Connection":
+ def connection(self) -> ConnectionType:
return self._conn
@property
def savepoint_name(self) -> str:
return self._savepoint_name
+ def __repr__(self) -> str:
+ args = [f"connection={self.connection}"]
+ if not self.savepoint_name:
+ args.append(f"savepoint_name={self.savepoint_name!r}")
+ if self.force_rollback:
+ args.append("force_rollback=True")
+ return f"{self.__class__.__qualname__}({', '.join(args)})"
+
+ _out_of_order_err = ProgrammingError(
+ "Out-of-order Transaction context exits. Are you "
+ "calling __exit__() manually and getting it wrong?"
+ )
+
+ def _pop_savepoint(self) -> None:
+ if self._conn._savepoints is None:
+ raise self._out_of_order_err
+ actual = self._conn._savepoints.pop()
+ if actual != self._savepoint_name:
+ raise self._out_of_order_err
+
+
+class Transaction(BaseTransaction["Connection"]):
def __enter__(self) -> "Transaction":
with self._conn.lock:
if self._conn.pgconn.transaction_status == TransactionStatus.IDLE:
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> bool:
- out_of_order_err = ProgrammingError(
- "Out-of-order Transaction context exits. Are you "
- "calling __exit__() manually and getting it wrong?"
- )
if self._outer_transaction is None:
- raise out_of_order_err
+ raise self._out_of_order_err
with self._conn.lock:
- if exc_type is None and not self.force_rollback:
- # Commit changes made in the transaction context
- if self._savepoint_name:
- if self._conn._savepoints is None:
- raise out_of_order_err
- actual = self._conn._savepoints.pop()
- if actual != self._savepoint_name:
- raise out_of_order_err
- self._conn._exec_command(
- sql.SQL("release savepoint {}").format(
- sql.Identifier(self._savepoint_name)
- )
- )
- if self._outer_transaction:
- if self._conn._savepoints is None:
- raise out_of_order_err
- if len(self._conn._savepoints) != 0:
- raise out_of_order_err
- self._conn._exec_command(b"commit")
- self._conn._savepoints = None
+ if not exc_val and not self.force_rollback:
+ return self._commit()
+ else:
+ return self._rollback(exc_val)
+
+ def _commit(self) -> bool:
+ """Commit changes made in the transaction context."""
+ if self._savepoint_name:
+ self._pop_savepoint()
+ self._conn._exec_command(
+ sql.SQL("release savepoint {}").format(
+ sql.Identifier(self._savepoint_name)
+ )
+ )
+ if self._outer_transaction:
+ if self._conn._savepoints is None or self._conn._savepoints:
+ raise self._out_of_order_err
+ self._conn._exec_command(b"commit")
+ self._conn._savepoints = None
+
+ return False # discarded
+
+ def _rollback(self, exc_val: Optional[BaseException]) -> bool:
+ # Rollback changes made in the transaction context
+ if isinstance(exc_val, Rollback):
+ _log.debug(
+ f"{self._conn}: Explicit rollback from: ", exc_info=True
+ )
+
+ if self._savepoint_name:
+ self._pop_savepoint()
+ self._conn._exec_command(
+ sql.SQL(
+ "rollback to savepoint {n}; release savepoint {n}"
+ ).format(n=sql.Identifier(self._savepoint_name))
+ )
+ if self._outer_transaction:
+ if self._conn._savepoints is None or self._conn._savepoints:
+ raise self._out_of_order_err
+ self._conn._exec_command(b"rollback")
+ self._conn._savepoints = None
+
+ if isinstance(exc_val, Rollback):
+ if exc_val.transaction in (self, None):
+ return True # Swallow the exception
+
+ return False
+
+
+class AsyncTransaction(BaseTransaction["AsyncConnection"]):
+ async def __aenter__(self) -> "AsyncTransaction":
+ async with self._conn.lock:
+ if self._conn.pgconn.transaction_status == TransactionStatus.IDLE:
+ assert self._conn._savepoints is None, self._conn._savepoints
+ self._conn._savepoints = []
+ self._outer_transaction = True
+ await self._conn._exec_command(b"begin")
else:
- # Rollback changes made in the transaction context
- if isinstance(exc_val, Rollback):
- _log.debug(
- f"{self._conn}: Explicit rollback from: ",
- exc_info=True,
+ if self._conn._savepoints is None:
+ self._conn._savepoints = []
+ self._outer_transaction = False
+ if not self._savepoint_name:
+ self._savepoint_name = (
+ f"s{len(self._conn._savepoints) + 1}"
)
- if self._savepoint_name:
- if self._conn._savepoints is None:
- raise out_of_order_err
- actual = self._conn._savepoints.pop()
- if actual != self._savepoint_name:
- raise out_of_order_err
- self._conn._exec_command(
- sql.SQL(
- "rollback to savepoint {n}; release savepoint {n}"
- ).format(n=sql.Identifier(self._savepoint_name))
+ if self._savepoint_name:
+ await self._conn._exec_command(
+ sql.SQL("savepoint {}").format(
+ sql.Identifier(self._savepoint_name)
)
- if self._outer_transaction:
- if self._conn._savepoints is None:
- raise out_of_order_err
- if len(self._conn._savepoints) != 0:
- raise out_of_order_err
- self._conn._exec_command(b"rollback")
- self._conn._savepoints = None
-
- if isinstance(exc_val, Rollback):
- if exc_val.transaction in (self, None):
- return True # Swallow the exception
- return False
+ )
+ self._conn._savepoints.append(self._savepoint_name)
+ return self
- def __repr__(self) -> str:
- args = [f"connection={self.connection}"]
- if not self.savepoint_name:
- args.append(f"savepoint_name={self.savepoint_name!r}")
- if self.force_rollback:
- args.append("force_rollback=True")
- return f"{self.__class__.__qualname__}({', '.join(args)})"
+ async def __aexit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> bool:
+ if self._outer_transaction is None:
+ raise self._out_of_order_err
+ async with self._conn.lock:
+ if not exc_val and not self.force_rollback:
+ return await self._commit()
+ else:
+ return await self._rollback(exc_val)
+
+ async def _commit(self) -> bool:
+ """Commit changes made in the transaction context."""
+ if self._savepoint_name:
+ self._pop_savepoint()
+ await self._conn._exec_command(
+ sql.SQL("release savepoint {}").format(
+ sql.Identifier(self._savepoint_name)
+ )
+ )
+ if self._outer_transaction:
+ if self._conn._savepoints is None or self._conn._savepoints:
+ raise self._out_of_order_err
+ await self._conn._exec_command(b"commit")
+ self._conn._savepoints = None
+
+ return False # discarded
+
+ async def _rollback(self, exc_val: Optional[BaseException]) -> bool:
+ # Rollback changes made in the transaction context
+ if isinstance(exc_val, Rollback):
+ _log.debug(
+ f"{self._conn}: Explicit rollback from: ", exc_info=True
+ )
+
+ if self._savepoint_name:
+ self._pop_savepoint()
+ await self._conn._exec_command(
+ sql.SQL(
+ "rollback to savepoint {n}; release savepoint {n}"
+ ).format(n=sql.Identifier(self._savepoint_name))
+ )
+ if self._outer_transaction:
+ if self._conn._savepoints is None or self._conn._savepoints:
+ raise self._out_of_order_err
+ await self._conn._exec_command(b"rollback")
+ self._conn._savepoints = None
+
+ if isinstance(exc_val, Rollback):
+ if exc_val.transaction in (self, None):
+ return True # Swallow the exception
+
+ return False
import sys
-from contextlib import contextmanager
import pytest
-from psycopg3 import ProgrammingError, Rollback
+from psycopg3 import Connection, ProgrammingError, Rollback
from psycopg3.sql import Composable
from psycopg3.transaction import Transaction
@pytest.fixture(autouse=True)
-def test_table(svcconn):
+def create_test_table(svcconn):
"""
Creates a table called 'test_table' for use in tests.
"""
conn.cursor().execute("INSERT INTO test_table VALUES (%s)", (value,))
-def assert_rows(conn, expected):
- rows = conn.cursor().execute("SELECT * FROM test_table").fetchall()
- assert set(v for (v,) in rows) == expected
+def inserted(conn):
+ sql = "SELECT * FROM test_table"
+ if isinstance(conn, Connection):
+ rows = conn.cursor().execute(sql).fetchall()
+ return set(v for (v,) in rows)
+ else:
+ async def f():
+ cur = await conn.cursor()
+ await cur.execute(sql)
+ rows = await cur.fetchall()
+ return set(v for (v,) in rows)
-def assert_not_in_transaction(conn):
- assert conn.pgconn.transaction_status == conn.TransactionStatus.IDLE
+ return f()
-def assert_in_transaction(conn):
- assert conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS
+@pytest.fixture
+def commands(monkeypatch):
+ """The queue of commands issued internally by connections.
+ Not concurrency safe as it mokkeypatches a class method, but good enough
+ for tests.
+ """
+ _orig_exec_command = Connection._exec_command
+ L = []
-@contextmanager
-def assert_commands_issued(conn, *commands):
- commands_actual = []
- real_exec_command = conn._exec_command
-
- def _exec_command(command):
+ def _exec_command(self, command):
if isinstance(command, bytes):
- command = command.decode(conn.client_encoding)
+ command = command.decode(self.client_encoding)
elif isinstance(command, Composable):
- command = command.as_string(conn)
+ command = command.as_string(self)
- commands_actual.append(command)
- real_exec_command(command)
+ L.insert(0, command)
+ _orig_exec_command(self, command)
+
+ monkeypatch.setattr(Connection, "_exec_command", _exec_command)
+ yield L
- try:
- conn._exec_command = _exec_command
- yield
- finally:
- conn._exec_command = real_exec_command
- assert commands_actual == list(commands)
+def in_transaction(conn):
+ if conn.pgconn.transaction_status == conn.TransactionStatus.IDLE:
+ return False
+ elif conn.pgconn.transaction_status == conn.TransactionStatus.INTRANS:
+ return True
+ else:
+ assert False, conn.pgconn.transaction_status
class ExpectedException(Exception):
def test_basic(conn):
"""Basic use of transaction() to BEGIN and COMMIT a transaction."""
- assert_not_in_transaction(conn)
+ assert not in_transaction(conn)
with conn.transaction():
- assert_in_transaction(conn)
- assert_not_in_transaction(conn)
+ assert in_transaction(conn)
+ assert not in_transaction(conn)
def test_exposes_associated_connection(conn):
def test_begins_on_enter(conn):
"""Transaction does not begin until __enter__() is called."""
tx = conn.transaction()
- assert_not_in_transaction(conn)
+ assert not in_transaction(conn)
with tx:
- assert_in_transaction(conn)
- assert_not_in_transaction(conn)
+ assert in_transaction(conn)
+ assert not in_transaction(conn)
def test_commit_on_successful_exit(conn):
with conn.transaction():
insert_row(conn, "foo")
- assert_not_in_transaction(conn)
- assert_rows(conn, {"foo"})
+ assert not in_transaction(conn)
+ assert inserted(conn) == {"foo"}
def test_rollback_on_exception_exit(conn):
insert_row(conn, "foo")
raise ExpectedException("This discards the insert")
- assert_not_in_transaction(conn)
- assert_rows(conn, set())
+ assert not in_transaction(conn)
+ assert not inserted(conn)
def test_prohibits_use_of_commit_rollback_autocommit(conn):
* Changes made within Transaction context are committed
"""
conn.autocommit = False
- assert_not_in_transaction(conn)
+ assert not in_transaction(conn)
with conn.transaction():
insert_row(conn, "new")
- assert_not_in_transaction(conn)
+ assert not in_transaction(conn)
# Changes committed
- assert_rows(conn, {"new"})
- assert_rows(svcconn, {"new"})
+ assert inserted(conn) == {"new"}
+ assert inserted(svcconn) == {"new"}
def test_autocommit_off_but_no_tx_started_exception_exit(conn, svcconn):
* Changes made within Transaction context are discarded
"""
conn.autocommit = False
- assert_not_in_transaction(conn)
+ assert not in_transaction(conn)
with pytest.raises(ExpectedException):
with conn.transaction():
insert_row(conn, "new")
raise ExpectedException()
- assert_not_in_transaction(conn)
+ assert not in_transaction(conn)
# Changes discarded
- assert_rows(conn, set())
- assert_rows(svcconn, set())
+ assert not inserted(conn)
+ assert not inserted(svcconn)
def test_autocommit_off_and_tx_in_progress_successful_exit(conn, svcconn):
"""
conn.autocommit = False
insert_row(conn, "prior")
- assert_in_transaction(conn)
+ assert in_transaction(conn)
with conn.transaction():
insert_row(conn, "new")
- assert_in_transaction(conn)
- assert_rows(conn, {"prior", "new"})
+ assert in_transaction(conn)
+ assert inserted(conn) == {"prior", "new"}
# Nothing committed yet; changes not visible on another connection
- assert_rows(svcconn, set())
+ assert not inserted(svcconn)
def test_autocommit_off_and_tx_in_progress_exception_exit(conn, svcconn):
"""
conn.autocommit = False
insert_row(conn, "prior")
- assert_in_transaction(conn)
+ assert in_transaction(conn)
with pytest.raises(ExpectedException):
with conn.transaction():
insert_row(conn, "new")
raise ExpectedException()
- assert_in_transaction(conn)
- assert_rows(conn, {"prior"})
+ assert in_transaction(conn)
+ assert inserted(conn) == {"prior"}
# Nothing committed yet; changes not visible on another connection
- assert_rows(svcconn, set())
+ assert not inserted(svcconn)
def test_nested_all_changes_persisted_on_successful_exit(conn, svcconn):
with conn.transaction():
insert_row(conn, "inner")
insert_row(conn, "outer-after")
- assert_not_in_transaction(conn)
- assert_rows(conn, {"outer-before", "inner", "outer-after"})
- assert_rows(svcconn, {"outer-before", "inner", "outer-after"})
+ assert not in_transaction(conn)
+ assert inserted(conn) == {"outer-before", "inner", "outer-after"}
+ assert inserted(svcconn) == {"outer-before", "inner", "outer-after"}
def test_nested_all_changes_discarded_on_outer_exception(conn, svcconn):
with conn.transaction():
insert_row(conn, "inner")
raise ExpectedException()
- assert_not_in_transaction(conn)
- assert_rows(conn, set())
- assert_rows(svcconn, set())
+ assert not in_transaction(conn)
+ assert not inserted(conn)
+ assert not inserted(svcconn)
def test_nested_all_changes_discarded_on_inner_exception(conn, svcconn):
with conn.transaction():
insert_row(conn, "inner")
raise ExpectedException()
- assert_not_in_transaction(conn)
- assert_rows(conn, set())
- assert_rows(svcconn, set())
+ assert not in_transaction(conn)
+ assert not inserted(conn)
+ assert not inserted(svcconn)
def test_nested_inner_scope_exception_handled_in_outer_scope(conn, svcconn):
insert_row(conn, "inner")
raise ExpectedException()
insert_row(conn, "outer-after")
- assert_not_in_transaction(conn)
- assert_rows(conn, {"outer-before", "outer-after"})
- assert_rows(svcconn, {"outer-before", "outer-after"})
+ assert not in_transaction(conn)
+ assert inserted(conn) == {"outer-before", "outer-after"}
+ assert inserted(svcconn) == {"outer-before", "outer-after"}
def test_nested_three_levels_successful_exit(conn, svcconn):
insert_row(conn, "two")
with conn.transaction(): # SAVEPOINT s2
insert_row(conn, "three")
- assert_not_in_transaction(conn)
- assert_rows(conn, {"one", "two", "three"})
- assert_rows(svcconn, {"one", "two", "three"})
+ assert not in_transaction(conn)
+ assert inserted(conn) == {"one", "two", "three"}
+ assert inserted(svcconn) == {"one", "two", "three"}
def test_named_savepoint_escapes_savepoint_name(conn):
pass
-def test_named_savepoints_successful_exit(conn):
+def test_named_savepoints_successful_exit(conn, commands):
"""
Entering a transaction context will do one of these these things:
1. Begin an outer transaction (if one isn't already in progress)
# Case 1
# Using Transaction explicitly becase conn.transaction() enters the contetx
tx = Transaction(conn)
- with assert_commands_issued(conn, "begin"):
- tx.__enter__()
+ assert not commands
+ tx.__enter__()
+ assert commands.pop() == "begin"
assert not tx.savepoint_name
- with assert_commands_issued(conn, "commit"):
- tx.__exit__(None, None, None)
+ tx.__exit__(None, None, None)
+ assert commands.pop() == "commit"
# Case 2
tx = Transaction(conn, savepoint_name="foo")
- with assert_commands_issued(conn, "begin", 'savepoint "foo"'):
- tx.__enter__()
+ tx.__enter__()
+ assert commands.pop() == "begin"
+ assert commands.pop() == 'savepoint "foo"'
assert tx.savepoint_name == "foo"
- with assert_commands_issued(conn, 'release savepoint "foo"', "commit"):
- tx.__exit__(None, None, None)
+ tx.__exit__(None, None, None)
+ assert commands.pop() == 'release savepoint "foo"'
+ assert commands.pop() == "commit"
# Case 3 (with savepoint name provided)
- with Transaction(conn):
+ with conn.transaction():
+ assert commands.pop() == "begin"
tx = Transaction(conn, savepoint_name="bar")
- with assert_commands_issued(conn, 'savepoint "bar"'):
- tx.__enter__()
+ tx.__enter__()
+ assert commands.pop() == 'savepoint "bar"'
assert tx.savepoint_name == "bar"
- with assert_commands_issued(conn, 'release savepoint "bar"'):
- tx.__exit__(None, None, None)
+ tx.__exit__(None, None, None)
+ assert commands.pop() == 'release savepoint "bar"'
+ assert commands.pop() == "commit"
# Case 3 (with savepoint name auto-generated)
with conn.transaction():
+ assert commands.pop() == "begin"
tx = Transaction(conn)
- with assert_commands_issued(conn, 'savepoint "s1"'):
- tx.__enter__()
+ tx.__enter__()
+ assert commands.pop() == 'savepoint "s1"'
assert tx.savepoint_name == "s1"
- with assert_commands_issued(conn, 'release savepoint "s1"'):
- tx.__exit__(None, None, None)
+ tx.__exit__(None, None, None)
+ assert commands.pop() == 'release savepoint "s1"'
+ assert commands.pop() == "commit"
+
+ assert not commands
-def test_named_savepoints_exception_exit(conn):
+def test_named_savepoints_exception_exit(conn, commands):
"""
Same as the previous test but checks that when exiting the context with an
exception, whatever transaction and/or savepoint was started on enter will
"""
# Case 1
tx = Transaction(conn)
- with assert_commands_issued(conn, "begin"):
- tx.__enter__()
+ tx.__enter__()
+ assert commands.pop() == "begin"
assert not tx.savepoint_name
- with assert_commands_issued(conn, "rollback"):
- tx.__exit__(*some_exc_info())
+ tx.__exit__(*some_exc_info())
+ assert commands.pop() == "rollback"
# Case 2
tx = Transaction(conn, savepoint_name="foo")
- with assert_commands_issued(conn, "begin", 'savepoint "foo"'):
- tx.__enter__()
+ tx.__enter__()
+ assert commands.pop() == "begin"
+ assert commands.pop() == 'savepoint "foo"'
assert tx.savepoint_name == "foo"
- with assert_commands_issued(
- conn,
- 'rollback to savepoint "foo"; release savepoint "foo"',
- "rollback",
- ):
- tx.__exit__(*some_exc_info())
+ tx.__exit__(*some_exc_info())
+ assert (
+ commands.pop()
+ == 'rollback to savepoint "foo"; release savepoint "foo"'
+ )
+ assert commands.pop() == "rollback"
# Case 3 (with savepoint name provided)
with conn.transaction():
+ assert commands.pop() == "begin"
tx = Transaction(conn, savepoint_name="bar")
- with assert_commands_issued(conn, 'savepoint "bar"'):
- tx.__enter__()
+ tx.__enter__()
+ assert commands.pop() == 'savepoint "bar"'
assert tx.savepoint_name == "bar"
- with assert_commands_issued(
- conn, 'rollback to savepoint "bar"; release savepoint "bar"'
- ):
- tx.__exit__(*some_exc_info())
+ tx.__exit__(*some_exc_info())
+ assert (
+ commands.pop()
+ == 'rollback to savepoint "bar"; release savepoint "bar"'
+ )
+ assert commands.pop() == "commit"
# Case 3 (with savepoint name auto-generated)
with conn.transaction():
+ assert commands.pop() == "begin"
tx = Transaction(conn)
- with assert_commands_issued(conn, 'savepoint "s1"'):
- tx.__enter__()
+ tx.__enter__()
+ assert commands.pop() == 'savepoint "s1"'
assert tx.savepoint_name == "s1"
- with assert_commands_issued(
- conn, 'rollback to savepoint "s1"; release savepoint "s1"'
- ):
- tx.__exit__(*some_exc_info())
+ tx.__exit__(*some_exc_info())
+ assert (
+ commands.pop()
+ == 'rollback to savepoint "s1"; release savepoint "s1"'
+ )
+ assert commands.pop() == "commit"
+
+ assert not commands
def test_named_savepoints_with_repeated_names_works(conn):
insert_row(conn, "tx2")
with conn.transaction("sp"):
insert_row(conn, "tx3")
- assert_rows(conn, {"tx1", "tx2", "tx3"})
+ assert inserted(conn) == {"tx1", "tx2", "tx3"}
# Works correctly if one level of inner transaction is rolled back
with conn.transaction(force_rollback=True):
insert_row(conn, "tx2")
with conn.transaction("s1"):
insert_row(conn, "tx3")
- assert_rows(conn, {"tx1"})
- assert_rows(conn, {"tx1"})
+ assert inserted(conn) == {"tx1"}
+ assert inserted(conn) == {"tx1"}
# Works correctly if multiple inner transactions are rolled back
# (This scenario mandates releasing savepoints after rolling back to them.)
with conn.transaction("s1"):
insert_row(conn, "tx3")
raise Rollback(tx2)
- assert_rows(conn, {"tx1"})
- assert_rows(conn, {"tx1"})
+ assert inserted(conn) == {"tx1"}
+ assert inserted(conn) == {"tx1"}
# Will not (always) catch out-of-order exits
with conn.transaction(force_rollback=True):
"""
with conn.transaction(force_rollback=True):
insert_row(conn, "foo")
- assert_rows(conn, set())
- assert_rows(svcconn, set())
+ assert not inserted(conn)
+ assert not inserted(svcconn)
def test_force_rollback_exception_exit(conn, svcconn):
with conn.transaction(force_rollback=True):
insert_row(conn, "foo")
raise ExpectedException()
- assert_rows(conn, set())
- assert_rows(svcconn, set())
+ assert not inserted(conn)
+ assert not inserted(svcconn)
def test_explicit_rollback_discards_changes(conn, svcconn):
"""
def assert_no_rows():
- assert_rows(conn, set())
- assert_rows(svcconn, set())
+ assert not inserted(conn)
+ assert not inserted(svcconn)
with conn.transaction():
insert_row(conn, "foo")
with conn.transaction():
insert_row(conn, "during")
raise Rollback
- assert_in_transaction(conn)
- assert_rows(svcconn, set())
+ assert in_transaction(conn)
+ assert not inserted(svcconn)
insert_row(conn, "after")
- assert_rows(conn, {"before", "after"})
- assert_rows(svcconn, {"before", "after"})
+ assert inserted(conn) == {"before", "after"}
+ assert inserted(svcconn) == {"before", "after"}
def test_explicit_rollback_of_outer_transaction(conn):
insert_row(conn, "inner")
raise Rollback(outer_tx)
assert False, "This line of code should be unreachable."
- assert_rows(conn, set())
+ assert not inserted(conn)
def test_explicit_rollback_of_enclosing_tx_outer_tx_unaffected(conn, svcconn):
raise Rollback(tx_enclosing)
insert_row(conn, "outer-after")
- assert_rows(conn, {"outer-before", "outer-after"})
- assert_rows(svcconn, set()) # Not yet committed
- assert_rows(svcconn, {"outer-before", "outer-after"}) # Changes committed
+ assert inserted(conn) == {"outer-before", "outer-after"}
+ assert not inserted(svcconn) # Not yet committed
+ # Changes committed
+ assert inserted(svcconn) == {"outer-before", "outer-after"}
@pytest.mark.parametrize("exc_info", [(None, None, None), some_exc_info()])
--- /dev/null
+import pytest
+
+from psycopg3 import AsyncConnection, ProgrammingError, Rollback
+from psycopg3.sql import Composable
+from psycopg3.transaction import AsyncTransaction
+
+from .test_transaction import (
+ in_transaction,
+ ExpectedException,
+ some_exc_info,
+ inserted,
+)
+from .test_transaction import create_test_table # noqa # autouse fixture
+
+pytestmark = pytest.mark.asyncio
+
+
+@pytest.fixture
+async def commands(monkeypatch):
+ """The queue of commands issued internally by connections.
+
+ Not concurrency safe as it mokkeypatches a class method, but good enough
+ for tests.
+ """
+ _orig_exec_command = AsyncConnection._exec_command
+ L = []
+
+ async def _exec_command(self, command):
+ if isinstance(command, bytes):
+ command = command.decode(self.client_encoding)
+ elif isinstance(command, Composable):
+ command = command.as_string(self)
+
+ L.insert(0, command)
+ await _orig_exec_command(self, command)
+
+ monkeypatch.setattr(AsyncConnection, "_exec_command", _exec_command)
+ yield L
+
+
+async def insert_row(aconn, value):
+ await (await aconn.cursor()).execute(
+ "INSERT INTO test_table VALUES (%s)", (value,)
+ )
+
+
+async def test_basic(aconn):
+ """Basic use of transaction() to BEGIN and COMMIT a transaction."""
+ assert not in_transaction(aconn)
+ async with aconn.transaction():
+ assert in_transaction(aconn)
+ assert not in_transaction(aconn)
+
+
+async def test_exposes_associated_connection(aconn):
+ """Transaction exposes its connection as a read-only property."""
+ async with aconn.transaction() as tx:
+ assert tx.connection is aconn
+ with pytest.raises(AttributeError):
+ tx.connection = aconn
+
+
+async def test_exposes_savepoint_name(aconn):
+ """Transaction exposes its savepoint name as a read-only property."""
+ async with aconn.transaction(savepoint_name="foo") as tx:
+ assert tx.savepoint_name == "foo"
+ with pytest.raises(AttributeError):
+ tx.savepoint_name = "bar"
+
+
+async def test_begins_on_enter(aconn):
+ """Transaction does not begin until __enter__() is called."""
+ tx = aconn.transaction()
+ assert not in_transaction(aconn)
+ async with tx:
+ assert in_transaction(aconn)
+ assert not in_transaction(aconn)
+
+
+async def test_commit_on_successful_exit(aconn):
+ """Changes are committed on successful exit from the `with` block."""
+ async with aconn.transaction():
+ await insert_row(aconn, "foo")
+
+ assert not in_transaction(aconn)
+ assert await inserted(aconn) == {"foo"}
+
+
+async def test_rollback_on_exception_exit(aconn):
+ """Changes are rolled back if an exception escapes the `with` block."""
+ with pytest.raises(ExpectedException):
+ async with aconn.transaction():
+ await insert_row(aconn, "foo")
+ raise ExpectedException("This discards the insert")
+
+ assert not in_transaction(aconn)
+ assert not await inserted(aconn)
+
+
+async def test_prohibits_use_of_commit_rollback_autocommit(aconn):
+ """
+ Within a Transaction block, it is forbidden to touch commit, rollback,
+ or the autocommit setting on the connection, as this would interfere
+ with the transaction scope being managed by the Transaction block.
+ """
+ await aconn.set_autocommit(False)
+ await aconn.commit()
+ await aconn.rollback()
+
+ async with aconn.transaction():
+ with pytest.raises(ProgrammingError):
+ await aconn.set_autocommit(False)
+ with pytest.raises(ProgrammingError):
+ await aconn.commit()
+ with pytest.raises(ProgrammingError):
+ await aconn.rollback()
+
+ await aconn.set_autocommit(False)
+ await aconn.commit()
+ await aconn.rollback()
+
+
+@pytest.mark.parametrize("autocommit", [False, True])
+async def test_preserves_autocommit(aconn, autocommit):
+ """
+ Connection.autocommit is unchanged both during and after Transaction block.
+ """
+ await aconn.set_autocommit(autocommit)
+ async with aconn.transaction():
+ assert aconn.autocommit is autocommit
+ assert aconn.autocommit is autocommit
+
+
+async def test_autocommit_off_but_no_tx_started_successful_exit(
+ aconn, svcconn
+):
+ """
+ Scenario:
+ * Connection has autocommit off but no transaction has been initiated
+ before entering the Transaction context
+ * Code exits Transaction context successfully
+
+ Outcome:
+ * Changes made within Transaction context are committed
+ """
+ await aconn.set_autocommit(False)
+ assert not in_transaction(aconn)
+ async with aconn.transaction():
+ await insert_row(aconn, "new")
+ assert not in_transaction(aconn)
+
+ # Changes committed
+ assert await inserted(aconn) == {"new"}
+ assert inserted(svcconn) == {"new"}
+
+
+async def test_autocommit_off_but_no_tx_started_exception_exit(aconn, svcconn):
+ """
+ Scenario:
+ * Connection has autocommit off but no transaction has been initiated
+ before entering the Transaction context
+ * Code exits Transaction context with an exception
+
+ Outcome:
+ * Changes made within Transaction context are discarded
+ """
+ await aconn.set_autocommit(False)
+ assert not in_transaction(aconn)
+ with pytest.raises(ExpectedException):
+ async with aconn.transaction():
+ await insert_row(aconn, "new")
+ raise ExpectedException()
+ assert not in_transaction(aconn)
+
+ # Changes discarded
+ assert not await inserted(aconn)
+ assert not inserted(svcconn)
+
+
+async def test_autocommit_off_and_tx_in_progress_successful_exit(
+ aconn, svcconn
+):
+ """
+ Scenario:
+ * Connection has autocommit off but and a transaction is already in
+ progress before entering the Transaction context
+ * Code exits Transaction context successfully
+
+ Outcome:
+ * Changes made within Transaction context are left intact
+ * Outer transaction is left running, and no changes are visible to an
+ outside observer from another connection.
+ """
+ await aconn.set_autocommit(False)
+ await insert_row(aconn, "prior")
+ assert in_transaction(aconn)
+ async with aconn.transaction():
+ await insert_row(aconn, "new")
+ assert in_transaction(aconn)
+ assert await inserted(aconn) == {"prior", "new"}
+ # Nothing committed yet; changes not visible on another connection
+ assert not inserted(svcconn)
+
+
+async def test_autocommit_off_and_tx_in_progress_exception_exit(
+ aconn, svcconn
+):
+ """
+ Scenario:
+ * Connection has autocommit off but and a transaction is already in
+ progress before entering the Transaction context
+ * Code exits Transaction context with an exception
+
+ Outcome:
+ * Changes made before the Transaction context are left intact
+ * Changes made within Transaction context are discarded
+ * Outer transaction is left running, and no changes are visible to an
+ outside observer from another connection.
+ """
+ await aconn.set_autocommit(False)
+ await insert_row(aconn, "prior")
+ assert in_transaction(aconn)
+ with pytest.raises(ExpectedException):
+ async with aconn.transaction():
+ await insert_row(aconn, "new")
+ raise ExpectedException()
+ assert in_transaction(aconn)
+ assert await inserted(aconn) == {"prior"}
+ # Nothing committed yet; changes not visible on another connection
+ assert not inserted(svcconn)
+
+
+async def test_nested_all_changes_persisted_on_successful_exit(aconn, svcconn):
+ """Changes from nested transaction contexts are all persisted on exit."""
+ async with aconn.transaction():
+ await insert_row(aconn, "outer-before")
+ async with aconn.transaction():
+ await insert_row(aconn, "inner")
+ await insert_row(aconn, "outer-after")
+ assert not in_transaction(aconn)
+ assert await inserted(aconn) == {"outer-before", "inner", "outer-after"}
+ assert inserted(svcconn) == {"outer-before", "inner", "outer-after"}
+
+
+async def test_nested_all_changes_discarded_on_outer_exception(aconn, svcconn):
+ """
+ Changes from nested transaction contexts are discarded when an exception
+ raised in outer context escapes.
+ """
+ with pytest.raises(ExpectedException):
+ async with aconn.transaction():
+ await insert_row(aconn, "outer")
+ async with aconn.transaction():
+ await insert_row(aconn, "inner")
+ raise ExpectedException()
+ assert not in_transaction(aconn)
+ assert not await inserted(aconn)
+ assert not inserted(svcconn)
+
+
+async def test_nested_all_changes_discarded_on_inner_exception(aconn, svcconn):
+ """
+ Changes from nested transaction contexts are discarded when an exception
+ raised in inner context escapes the outer context.
+ """
+ with pytest.raises(ExpectedException):
+ async with aconn.transaction():
+ await insert_row(aconn, "outer")
+ async with aconn.transaction():
+ await insert_row(aconn, "inner")
+ raise ExpectedException()
+ assert not in_transaction(aconn)
+ assert not await inserted(aconn)
+ assert not inserted(svcconn)
+
+
+async def test_nested_inner_scope_exception_handled_in_outer_scope(
+ aconn, svcconn
+):
+ """
+ An exception escaping the inner transaction context causes changes made
+ within that inner context to be discarded, but the error can then be
+ handled in the outer context, allowing changes made in the outer context
+ (both before, and after, the inner context) to be successfully committed.
+ """
+ async with aconn.transaction():
+ await insert_row(aconn, "outer-before")
+ with pytest.raises(ExpectedException):
+ async with aconn.transaction():
+ await insert_row(aconn, "inner")
+ raise ExpectedException()
+ await insert_row(aconn, "outer-after")
+ assert not in_transaction(aconn)
+ assert await inserted(aconn) == {"outer-before", "outer-after"}
+ assert inserted(svcconn) == {"outer-before", "outer-after"}
+
+
+async def test_nested_three_levels_successful_exit(aconn, svcconn):
+ """Exercise management of more than one savepoint."""
+ async with aconn.transaction(): # BEGIN
+ await insert_row(aconn, "one")
+ async with aconn.transaction(): # SAVEPOINT s1
+ await insert_row(aconn, "two")
+ async with aconn.transaction(): # SAVEPOINT s2
+ await insert_row(aconn, "three")
+ assert not in_transaction(aconn)
+ assert await inserted(aconn) == {"one", "two", "three"}
+ assert inserted(svcconn) == {"one", "two", "three"}
+
+
+async def test_named_savepoint_escapes_savepoint_name(aconn):
+ async with aconn.transaction("s-1"):
+ pass
+ async with aconn.transaction("s1; drop table students"):
+ pass
+
+
+async def test_named_savepoints_successful_exit(aconn, commands):
+ """
+ Entering a transaction context will do one of these these things:
+ 1. Begin an outer transaction (if one isn't already in progress)
+ 2. Begin an outer transaction and create a savepoint (if one is named)
+ 3. Create a savepoint (if a transaction is already in progress)
+ either using the name provided, or auto-generating a savepoint name.
+
+ ...and exiting the context successfully will "commit" the same.
+ """
+ # Case 1
+ # Using Transaction explicitly becase conn.transaction() enters the contetx
+ tx = AsyncTransaction(aconn)
+ await tx.__aenter__()
+ assert commands.pop() == "begin"
+ assert not tx.savepoint_name
+ await tx.__aexit__(None, None, None)
+ assert commands.pop() == "commit"
+
+ # Case 2
+ tx = AsyncTransaction(aconn, savepoint_name="foo")
+ await tx.__aenter__()
+ assert commands.pop() == "begin"
+ assert commands.pop() == 'savepoint "foo"'
+ assert tx.savepoint_name == "foo"
+ await tx.__aexit__(None, None, None)
+ assert commands.pop() == 'release savepoint "foo"'
+ assert commands.pop() == "commit"
+
+ # Case 3 (with savepoint name provided)
+ async with aconn.transaction():
+ assert commands.pop() == "begin"
+ tx = AsyncTransaction(aconn, savepoint_name="bar")
+ await tx.__aenter__()
+ assert commands.pop() == 'savepoint "bar"'
+ assert tx.savepoint_name == "bar"
+ await tx.__aexit__(None, None, None)
+ assert commands.pop() == 'release savepoint "bar"'
+ assert commands.pop() == "commit"
+
+ # Case 3 (with savepoint name auto-generated)
+ async with aconn.transaction():
+ assert commands.pop() == "begin"
+ tx = AsyncTransaction(aconn)
+ await tx.__aenter__()
+ assert commands.pop() == 'savepoint "s1"'
+ assert tx.savepoint_name == "s1"
+ await tx.__aexit__(None, None, None)
+ assert commands.pop() == 'release savepoint "s1"'
+ assert commands.pop() == "commit"
+
+ assert not commands
+
+
+async def test_named_savepoints_exception_exit(aconn, commands):
+ """
+ Same as the previous test but checks that when exiting the context with an
+ exception, whatever transaction and/or savepoint was started on enter will
+ be rolled-back as appropriate.
+ """
+ # Case 1
+ tx = AsyncTransaction(aconn)
+ await tx.__aenter__()
+ assert commands.pop() == "begin"
+ assert not tx.savepoint_name
+ await tx.__aexit__(*some_exc_info())
+ assert commands.pop() == "rollback"
+
+ # Case 2
+ tx = AsyncTransaction(aconn, savepoint_name="foo")
+ await tx.__aenter__()
+ assert commands.pop() == "begin"
+ assert commands.pop() == 'savepoint "foo"'
+ assert tx.savepoint_name == "foo"
+ await tx.__aexit__(*some_exc_info())
+ assert (
+ commands.pop()
+ == 'rollback to savepoint "foo"; release savepoint "foo"'
+ )
+ assert commands.pop() == "rollback"
+
+ # Case 3 (with savepoint name provided)
+ async with aconn.transaction():
+ assert commands.pop() == "begin"
+ tx = AsyncTransaction(aconn, savepoint_name="bar")
+ await tx.__aenter__()
+ assert commands.pop() == 'savepoint "bar"'
+ assert tx.savepoint_name == "bar"
+ await tx.__aexit__(*some_exc_info())
+ assert (
+ commands.pop()
+ == 'rollback to savepoint "bar"; release savepoint "bar"'
+ )
+ assert commands.pop() == "commit"
+
+ # Case 3 (with savepoint name auto-generated)
+ async with aconn.transaction():
+ assert commands.pop() == "begin"
+ tx = AsyncTransaction(aconn)
+ await tx.__aenter__()
+ assert commands.pop() == 'savepoint "s1"'
+ assert tx.savepoint_name == "s1"
+ await tx.__aexit__(*some_exc_info())
+ assert (
+ commands.pop()
+ == 'rollback to savepoint "s1"; release savepoint "s1"'
+ )
+ assert commands.pop() == "commit"
+
+ assert not commands
+
+
+async def test_named_savepoints_with_repeated_names_works(aconn):
+ """
+ Using the same savepoint name repeatedly works correctly, but bypasses
+ some sanity checks.
+ """
+ # Works correctly if no inner transactions are rolled back
+ async with aconn.transaction(force_rollback=True):
+ async with aconn.transaction("sp"):
+ await insert_row(aconn, "tx1")
+ async with aconn.transaction("sp"):
+ await insert_row(aconn, "tx2")
+ async with aconn.transaction("sp"):
+ await insert_row(aconn, "tx3")
+ assert await inserted(aconn) == {"tx1", "tx2", "tx3"}
+
+ # Works correctly if one level of inner transaction is rolled back
+ async with aconn.transaction(force_rollback=True):
+ async with aconn.transaction("s1"):
+ await insert_row(aconn, "tx1")
+ async with aconn.transaction("s1", force_rollback=True):
+ await insert_row(aconn, "tx2")
+ async with aconn.transaction("s1"):
+ await insert_row(aconn, "tx3")
+ assert await inserted(aconn) == {"tx1"}
+ assert await inserted(aconn) == {"tx1"}
+
+ # Works correctly if multiple inner transactions are rolled back
+ # (This scenario mandates releasing savepoints after rolling back to them.)
+ async with aconn.transaction(force_rollback=True):
+ async with aconn.transaction("s1"):
+ await insert_row(aconn, "tx1")
+ async with aconn.transaction("s1") as tx2:
+ await insert_row(aconn, "tx2")
+ async with aconn.transaction("s1"):
+ await insert_row(aconn, "tx3")
+ raise Rollback(tx2)
+ assert await inserted(aconn) == {"tx1"}
+ assert await inserted(aconn) == {"tx1"}
+
+ # Will not (always) catch out-of-order exits
+ async with aconn.transaction(force_rollback=True):
+ tx1 = aconn.transaction("s1")
+ tx2 = aconn.transaction("s1")
+ await tx1.__aenter__()
+ await tx2.__aenter__()
+ await tx1.__aexit__(None, None, None)
+ await tx2.__aexit__(None, None, None)
+
+
+async def test_force_rollback_successful_exit(aconn, svcconn):
+ """
+ Transaction started with the force_rollback option enabled discards all
+ changes at the end of the context.
+ """
+ async with aconn.transaction(force_rollback=True):
+ await insert_row(aconn, "foo")
+ assert not await inserted(aconn)
+ assert not inserted(svcconn)
+
+
+async def test_force_rollback_exception_exit(aconn, svcconn):
+ """
+ Transaction started with the force_rollback option enabled discards all
+ changes at the end of the context.
+ """
+ with pytest.raises(ExpectedException):
+ async with aconn.transaction(force_rollback=True):
+ await insert_row(aconn, "foo")
+ raise ExpectedException()
+ assert not await inserted(aconn)
+ assert not inserted(svcconn)
+
+
+async def test_explicit_rollback_discards_changes(aconn, svcconn):
+ """
+ Raising a Rollback exception in the middle of a block exits the block and
+ discards all changes made within that block.
+
+ You can raise any of the following:
+ - Rollback (type)
+ - Rollback() (instance)
+ - Rollback(tx) (instance initialised with reference to the transaction)
+ All of these are equivalent.
+ """
+
+ async def assert_no_rows():
+ assert not await inserted(aconn)
+ assert not inserted(svcconn)
+
+ async with aconn.transaction():
+ await insert_row(aconn, "foo")
+ raise Rollback
+ await assert_no_rows()
+
+ async with aconn.transaction():
+ await insert_row(aconn, "foo")
+ raise Rollback()
+ await assert_no_rows()
+
+ async with aconn.transaction() as tx:
+ await insert_row(aconn, "foo")
+ raise Rollback(tx)
+ await assert_no_rows()
+
+
+async def test_explicit_rollback_outer_tx_unaffected(aconn, svcconn):
+ """
+ Raising a Rollback exception in the middle of a block does not impact an
+ enclosing transaction block.
+ """
+ async with aconn.transaction():
+ await insert_row(aconn, "before")
+ async with aconn.transaction():
+ await insert_row(aconn, "during")
+ raise Rollback
+ assert in_transaction(aconn)
+ assert not inserted(svcconn)
+ await insert_row(aconn, "after")
+ assert await inserted(aconn) == {"before", "after"}
+ assert inserted(svcconn) == {"before", "after"}
+
+
+async def test_explicit_rollback_of_outer_transaction(aconn):
+ """
+ Raising a Rollback exception that references an outer transaction will
+ discard all changes from both inner and outer transaction blocks.
+ """
+ async with aconn.transaction() as outer_tx:
+ await insert_row(aconn, "outer")
+ async with aconn.transaction():
+ await insert_row(aconn, "inner")
+ raise Rollback(outer_tx)
+ assert False, "This line of code should be unreachable."
+ assert not await inserted(aconn)
+
+
+async def test_explicit_rollback_of_enclosing_tx_outer_tx_unaffected(
+ aconn, svcconn
+):
+ """
+ Rolling-back an enclosing transaction does not impact an outer transaction.
+ """
+ async with aconn.transaction():
+ await insert_row(aconn, "outer-before")
+ async with aconn.transaction() as tx_enclosing:
+ await insert_row(aconn, "enclosing")
+ async with aconn.transaction():
+ await insert_row(aconn, "inner")
+ raise Rollback(tx_enclosing)
+ await insert_row(aconn, "outer-after")
+
+ assert await inserted(aconn) == {"outer-before", "outer-after"}
+ assert not inserted(svcconn) # Not yet committed
+ # Changes committed
+ assert inserted(svcconn) == {"outer-before", "outer-after"}
+
+
+@pytest.mark.parametrize("exc_info", [(None, None, None), some_exc_info()])
+@pytest.mark.parametrize("name", [None, "s1"])
+async def test_manual_enter_and_exit_out_of_order_exit_asserts(
+ aconn, name, exc_info
+):
+ """
+ When user is calling __enter__() and __exit__() manually for some reason,
+ provide a helpful error message if they call __exit__() in the wrong order
+ for nested transactions.
+ """
+ tx1, tx2 = AsyncTransaction(aconn, name), AsyncTransaction(aconn)
+ await tx1.__aenter__()
+ await tx2.__aenter__()
+ with pytest.raises(ProgrammingError, match="Out-of-order"):
+ await tx1.__aexit__(*exc_info)
+
+
+@pytest.mark.parametrize("exc_info", [(None, None, None), some_exc_info()])
+@pytest.mark.parametrize("name", [None, "s1"])
+async def test_manual_exit_without_enter_asserts(aconn, name, exc_info):
+ """
+ When user is calling __enter__() and __exit__() manually for some reason,
+ provide a helpful error message if they call __exit__() without first
+ having called __enter__()
+ """
+ tx = AsyncTransaction(aconn, name)
+ with pytest.raises(ProgrammingError, match="Out-of-order"):
+ await tx.__aexit__(*exc_info)
+
+
+@pytest.mark.parametrize("exc_info", [(None, None, None), some_exc_info()])
+@pytest.mark.parametrize("name", [None, "s1"])
+async def test_manual_exit_twice_asserts(aconn, name, exc_info):
+ """
+ When user is calling __enter__() and __exit__() manually for some reason,
+ provide a helpful error message if they accidentally call __exit__() twice.
+ """
+ tx = AsyncTransaction(aconn, name)
+ await tx.__aenter__()
+ await tx.__aexit__(*exc_info)
+ with pytest.raises(ProgrammingError, match="Out-of-order"):
+ await tx.__aexit__(*exc_info)