From: Daniele Varrazzo Date: Sun, 15 Nov 2020 20:30:03 +0000 (+0000) Subject: Connection.transaction is a context manager X-Git-Tag: 3.0.dev0~351^2~10 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=964d5750b78893bb8ebe9b38790098781e7d9884;p=thirdparty%2Fpsycopg.git Connection.transaction is a context manager It will help to avoid an async with (await conn.transaction()) on async connections. --- diff --git a/psycopg3/psycopg3/connection.py b/psycopg3/psycopg3/connection.py index ddebe879f..7badb6e78 100644 --- a/psycopg3/psycopg3/connection.py +++ b/psycopg3/psycopg3/connection.py @@ -12,6 +12,7 @@ from typing import Any, AsyncIterator, Callable, Iterator, List, NamedTuple from typing import Optional, Type, TYPE_CHECKING, Union from weakref import ref, ReferenceType from functools import partial +from contextlib import contextmanager from . import pq from . import cursor @@ -328,12 +329,15 @@ class Connection(BaseConnection): f" {pq.error_message(results[-1], encoding=self.client_encoding)}" ) + @contextmanager def transaction( self, savepoint_name: Optional[str] = None, force_rollback: bool = False, - ) -> Transaction: - return Transaction(self, savepoint_name, force_rollback) + ) -> Iterator[Transaction]: + tx = Transaction(self, savepoint_name, force_rollback) + with tx: + yield tx @classmethod def wait(cls, gen: PQGen[RV], timeout: Optional[float] = 0.1) -> RV: diff --git a/psycopg3/psycopg3/transaction.py b/psycopg3/psycopg3/transaction.py index 783ebce57..0c3721107 100644 --- a/psycopg3/psycopg3/transaction.py +++ b/psycopg3/psycopg3/transaction.py @@ -39,8 +39,8 @@ class Transaction: def __init__( self, connection: "Connection", - savepoint_name: Optional[str], - force_rollback: bool, + savepoint_name: Optional[str] = None, + force_rollback: bool = False, ): self._conn = connection self._savepoint_name: Optional[str] = None diff --git a/tests/test_transaction.py b/tests/test_transaction.py index 955543210..84a092db3 100644 --- a/tests/test_transaction.py +++ b/tests/test_transaction.py @@ -5,6 +5,7 @@ import pytest from psycopg3 import ProgrammingError, Rollback from psycopg3.sql import Composable +from psycopg3.transaction import Transaction @pytest.fixture(autouse=True) @@ -332,7 +333,8 @@ def test_named_savepoint_empty_string_invalid(conn): invalid SQL command and having that fail with an OperationalError). """ with pytest.raises(ValueError): - conn.transaction(savepoint_name="") + with conn.transaction(savepoint_name=""): + pass def test_named_savepoint_escapes_savepoint_name(conn): @@ -353,7 +355,8 @@ def test_named_savepoints_successful_exit(conn): ...and exiting the context successfully will "commit" the same. """ # Case 1 - tx = conn.transaction() + # Using Transaction explicitly becase conn.transaction() enters the contetx + tx = Transaction(conn) with assert_commands_issued(conn, "begin"): tx.__enter__() assert tx.savepoint_name is None @@ -361,7 +364,7 @@ def test_named_savepoints_successful_exit(conn): tx.__exit__(None, None, None) # Case 2 - tx = conn.transaction(savepoint_name="foo") + tx = Transaction(conn, savepoint_name="foo") with assert_commands_issued(conn, "begin", 'savepoint "foo"'): tx.__enter__() assert tx.savepoint_name == "foo" @@ -369,8 +372,8 @@ def test_named_savepoints_successful_exit(conn): tx.__exit__(None, None, None) # Case 3 (with savepoint name provided) - with conn.transaction(): - tx = conn.transaction(savepoint_name="bar") + with Transaction(conn): + tx = Transaction(conn, savepoint_name="bar") with assert_commands_issued(conn, 'savepoint "bar"'): tx.__enter__() assert tx.savepoint_name == "bar" @@ -379,7 +382,7 @@ def test_named_savepoints_successful_exit(conn): # Case 3 (with savepoint name auto-generated) with conn.transaction(): - tx = conn.transaction() + tx = Transaction(conn) with assert_commands_issued(conn, 'savepoint "s1"'): tx.__enter__() assert tx.savepoint_name == "s1" @@ -394,7 +397,7 @@ def test_named_savepoints_exception_exit(conn): be rolled-back as appropriate. """ # Case 1 - tx = conn.transaction() + tx = Transaction(conn) with assert_commands_issued(conn, "begin"): tx.__enter__() assert tx.savepoint_name is None @@ -402,7 +405,7 @@ def test_named_savepoints_exception_exit(conn): tx.__exit__(*some_exc_info()) # Case 2 - tx = conn.transaction(savepoint_name="foo") + tx = Transaction(conn, savepoint_name="foo") with assert_commands_issued(conn, "begin", 'savepoint "foo"'): tx.__enter__() assert tx.savepoint_name == "foo" @@ -415,7 +418,7 @@ def test_named_savepoints_exception_exit(conn): # Case 3 (with savepoint name provided) with conn.transaction(): - tx = conn.transaction(savepoint_name="bar") + tx = Transaction(conn, savepoint_name="bar") with assert_commands_issued(conn, 'savepoint "bar"'): tx.__enter__() assert tx.savepoint_name == "bar" @@ -426,7 +429,7 @@ def test_named_savepoints_exception_exit(conn): # Case 3 (with savepoint name auto-generated) with conn.transaction(): - tx = conn.transaction() + tx = Transaction(conn) with assert_commands_issued(conn, 'savepoint "s1"'): tx.__enter__() assert tx.savepoint_name == "s1" @@ -520,18 +523,26 @@ def test_explicit_rollback_discards_changes(conn, svcconn): - Rollback(tx) (instance initialised with reference to the transaction) All of these are equivalent. """ - tx = conn.transaction() - for to_raise in ( - Rollback, - Rollback(), - Rollback(tx), - ): - with tx: - insert_row(conn, "foo") - raise to_raise - assert_rows(conn, set("")) + + def assert_no_rows(): + assert_rows(conn, set()) assert_rows(svcconn, set()) + with conn.transaction(): + insert_row(conn, "foo") + raise Rollback + assert_no_rows() + + with conn.transaction(): + insert_row(conn, "foo") + raise Rollback() + assert_no_rows() + + with conn.transaction() as tx: + insert_row(conn, "foo") + raise Rollback(tx) + assert_no_rows() + def test_explicit_rollback_outer_tx_unaffected(conn, svcconn): """ @@ -555,8 +566,7 @@ def test_explicit_rollback_of_outer_transaction(conn): Raising a Rollback exception that references an outer transaction will discard all changes from both inner and outer transaction blocks. """ - outer_tx = conn.transaction() - with outer_tx: + with conn.transaction() as outer_tx: insert_row(conn, "outer") with conn.transaction(): insert_row(conn, "inner") @@ -591,7 +601,7 @@ def test_manual_enter_and_exit_out_of_order_exit_asserts(conn, name, exc_info): provide a helpful error message if they call __exit__() in the wrong order for nested transactions. """ - tx1, tx2 = conn.transaction(name), conn.transaction() + tx1, tx2 = Transaction(conn, name), Transaction(conn) tx1.__enter__() tx2.__enter__() with pytest.raises(ProgrammingError, match="Out-of-order"): @@ -606,7 +616,7 @@ def test_manual_exit_without_enter_asserts(conn, name, exc_info): provide a helpful error message if they call __exit__() without first having called __enter__() """ - tx = conn.transaction(name) + tx = Transaction(conn, name) with pytest.raises(ProgrammingError, match="Out-of-order"): tx.__exit__(*exc_info) @@ -618,7 +628,7 @@ def test_manual_exit_twice_asserts(conn, name, exc_info): When user is calling __enter__() and __exit__() manually for some reason, provide a helpful error message if they accidentally call __exit__() twice. """ - tx = conn.transaction(name) + tx = Transaction(conn, name) tx.__enter__() tx.__exit__(*exc_info) with pytest.raises(ProgrammingError, match="Out-of-order"):