]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Connection.transaction is a context manager
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 15 Nov 2020 20:30:03 +0000 (20:30 +0000)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 16 Nov 2020 04:01:14 +0000 (04:01 +0000)
It will help to avoid an async with (await conn.transaction()) on async
connections.

psycopg3/psycopg3/connection.py
psycopg3/psycopg3/transaction.py
tests/test_transaction.py

index ddebe879f4f43f564ea8ce0001ae1df9cfef98af..7badb6e78848e8e63aed60254e8543f6268dd0c9 100644 (file)
@@ -12,6 +12,7 @@ from typing import Any, AsyncIterator, Callable, Iterator, List, NamedTuple
 from typing import Optional, Type, TYPE_CHECKING, Union
 from weakref import ref, ReferenceType
 from functools import partial
+from contextlib import contextmanager
 
 from . import pq
 from . import cursor
@@ -328,12 +329,15 @@ class Connection(BaseConnection):
                 f" {pq.error_message(results[-1], encoding=self.client_encoding)}"
             )
 
+    @contextmanager
     def transaction(
         self,
         savepoint_name: Optional[str] = None,
         force_rollback: bool = False,
-    ) -> Transaction:
-        return Transaction(self, savepoint_name, force_rollback)
+    ) -> Iterator[Transaction]:
+        tx = Transaction(self, savepoint_name, force_rollback)
+        with tx:
+            yield tx
 
     @classmethod
     def wait(cls, gen: PQGen[RV], timeout: Optional[float] = 0.1) -> RV:
index 783ebce57ee9b1edd0d90af4e70055f90907344c..0c37211077fc3b4c1df52e4ffc9814e430260002 100644 (file)
@@ -39,8 +39,8 @@ class Transaction:
     def __init__(
         self,
         connection: "Connection",
-        savepoint_name: Optional[str],
-        force_rollback: bool,
+        savepoint_name: Optional[str] = None,
+        force_rollback: bool = False,
     ):
         self._conn = connection
         self._savepoint_name: Optional[str] = None
index 95554321057cac47d1f374ff6e1315a9cc9d7582..84a092db3547f09983216e34c754cf89db5b0f5a 100644 (file)
@@ -5,6 +5,7 @@ import pytest
 
 from psycopg3 import ProgrammingError, Rollback
 from psycopg3.sql import Composable
+from psycopg3.transaction import Transaction
 
 
 @pytest.fixture(autouse=True)
@@ -332,7 +333,8 @@ def test_named_savepoint_empty_string_invalid(conn):
     invalid SQL command and having that fail with an OperationalError).
     """
     with pytest.raises(ValueError):
-        conn.transaction(savepoint_name="")
+        with conn.transaction(savepoint_name=""):
+            pass
 
 
 def test_named_savepoint_escapes_savepoint_name(conn):
@@ -353,7 +355,8 @@ def test_named_savepoints_successful_exit(conn):
     ...and exiting the context successfully will "commit" the same.
     """
     # Case 1
-    tx = conn.transaction()
+    # Using Transaction explicitly becase conn.transaction() enters the contetx
+    tx = Transaction(conn)
     with assert_commands_issued(conn, "begin"):
         tx.__enter__()
     assert tx.savepoint_name is None
@@ -361,7 +364,7 @@ def test_named_savepoints_successful_exit(conn):
         tx.__exit__(None, None, None)
 
     # Case 2
-    tx = conn.transaction(savepoint_name="foo")
+    tx = Transaction(conn, savepoint_name="foo")
     with assert_commands_issued(conn, "begin", 'savepoint "foo"'):
         tx.__enter__()
     assert tx.savepoint_name == "foo"
@@ -369,8 +372,8 @@ def test_named_savepoints_successful_exit(conn):
         tx.__exit__(None, None, None)
 
     # Case 3 (with savepoint name provided)
-    with conn.transaction():
-        tx = conn.transaction(savepoint_name="bar")
+    with Transaction(conn):
+        tx = Transaction(conn, savepoint_name="bar")
         with assert_commands_issued(conn, 'savepoint "bar"'):
             tx.__enter__()
         assert tx.savepoint_name == "bar"
@@ -379,7 +382,7 @@ def test_named_savepoints_successful_exit(conn):
 
     # Case 3 (with savepoint name auto-generated)
     with conn.transaction():
-        tx = conn.transaction()
+        tx = Transaction(conn)
         with assert_commands_issued(conn, 'savepoint "s1"'):
             tx.__enter__()
         assert tx.savepoint_name == "s1"
@@ -394,7 +397,7 @@ def test_named_savepoints_exception_exit(conn):
     be rolled-back as appropriate.
     """
     # Case 1
-    tx = conn.transaction()
+    tx = Transaction(conn)
     with assert_commands_issued(conn, "begin"):
         tx.__enter__()
     assert tx.savepoint_name is None
@@ -402,7 +405,7 @@ def test_named_savepoints_exception_exit(conn):
         tx.__exit__(*some_exc_info())
 
     # Case 2
-    tx = conn.transaction(savepoint_name="foo")
+    tx = Transaction(conn, savepoint_name="foo")
     with assert_commands_issued(conn, "begin", 'savepoint "foo"'):
         tx.__enter__()
     assert tx.savepoint_name == "foo"
@@ -415,7 +418,7 @@ def test_named_savepoints_exception_exit(conn):
 
     # Case 3 (with savepoint name provided)
     with conn.transaction():
-        tx = conn.transaction(savepoint_name="bar")
+        tx = Transaction(conn, savepoint_name="bar")
         with assert_commands_issued(conn, 'savepoint "bar"'):
             tx.__enter__()
         assert tx.savepoint_name == "bar"
@@ -426,7 +429,7 @@ def test_named_savepoints_exception_exit(conn):
 
     # Case 3 (with savepoint name auto-generated)
     with conn.transaction():
-        tx = conn.transaction()
+        tx = Transaction(conn)
         with assert_commands_issued(conn, 'savepoint "s1"'):
             tx.__enter__()
         assert tx.savepoint_name == "s1"
@@ -520,18 +523,26 @@ def test_explicit_rollback_discards_changes(conn, svcconn):
      - Rollback(tx) (instance initialised with reference to the transaction)
     All of these are equivalent.
     """
-    tx = conn.transaction()
-    for to_raise in (
-        Rollback,
-        Rollback(),
-        Rollback(tx),
-    ):
-        with tx:
-            insert_row(conn, "foo")
-            raise to_raise
-        assert_rows(conn, set(""))
+
+    def assert_no_rows():
+        assert_rows(conn, set())
         assert_rows(svcconn, set())
 
+    with conn.transaction():
+        insert_row(conn, "foo")
+        raise Rollback
+    assert_no_rows()
+
+    with conn.transaction():
+        insert_row(conn, "foo")
+        raise Rollback()
+    assert_no_rows()
+
+    with conn.transaction() as tx:
+        insert_row(conn, "foo")
+        raise Rollback(tx)
+    assert_no_rows()
+
 
 def test_explicit_rollback_outer_tx_unaffected(conn, svcconn):
     """
@@ -555,8 +566,7 @@ def test_explicit_rollback_of_outer_transaction(conn):
     Raising a Rollback exception that references an outer transaction will
     discard all changes from both inner and outer transaction blocks.
     """
-    outer_tx = conn.transaction()
-    with outer_tx:
+    with conn.transaction() as outer_tx:
         insert_row(conn, "outer")
         with conn.transaction():
             insert_row(conn, "inner")
@@ -591,7 +601,7 @@ def test_manual_enter_and_exit_out_of_order_exit_asserts(conn, name, exc_info):
     provide a helpful error message if they call __exit__() in the wrong order
     for nested transactions.
     """
-    tx1, tx2 = conn.transaction(name), conn.transaction()
+    tx1, tx2 = Transaction(conn, name), Transaction(conn)
     tx1.__enter__()
     tx2.__enter__()
     with pytest.raises(ProgrammingError, match="Out-of-order"):
@@ -606,7 +616,7 @@ def test_manual_exit_without_enter_asserts(conn, name, exc_info):
     provide a helpful error message if they call __exit__() without first
     having called __enter__()
     """
-    tx = conn.transaction(name)
+    tx = Transaction(conn, name)
     with pytest.raises(ProgrammingError, match="Out-of-order"):
         tx.__exit__(*exc_info)
 
@@ -618,7 +628,7 @@ def test_manual_exit_twice_asserts(conn, name, exc_info):
     When user is calling __enter__() and __exit__() manually for some reason,
     provide a helpful error message if they accidentally call __exit__() twice.
     """
-    tx = conn.transaction(name)
+    tx = Transaction(conn, name)
     tx.__enter__()
     tx.__exit__(*exc_info)
     with pytest.raises(ProgrammingError, match="Out-of-order"):