]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Define at init rather than at enter if the transaction is top-level
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 16 Nov 2020 03:38:26 +0000 (03:38 +0000)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 16 Nov 2020 04:02:30 +0000 (04:02 +0000)
psycopg3/psycopg3/transaction.py
tests/test_transaction.py
tests/test_transaction_async.py

index 8879d158250714b1095433ac8da2b660a8d15604..b159b65b04e2005b76521e49770e6220e6cda458 100644 (file)
@@ -49,8 +49,9 @@ class BaseTransaction(Generic[ConnectionType]):
         self._conn = connection
         self._savepoint_name = savepoint_name or ""
         self.force_rollback = force_rollback
-
-        self._outer_transaction: Optional[bool] = None
+        self._outer_transaction = (
+            connection.pgconn.transaction_status == TransactionStatus.IDLE
+        )
 
     @property
     def connection(self) -> ConnectionType:
@@ -89,15 +90,13 @@ class BaseTransaction(Generic[ConnectionType]):
 class Transaction(BaseTransaction["Connection"]):
     def __enter__(self) -> "Transaction":
         with self._conn.lock:
-            if self._conn.pgconn.transaction_status == TransactionStatus.IDLE:
+            if self._outer_transaction:
                 assert self._conn._savepoints is None, self._conn._savepoints
                 self._conn._savepoints = []
-                self._outer_transaction = True
                 self._conn._exec_command(b"begin")
             else:
                 if self._conn._savepoints is None:
                     self._conn._savepoints = []
-                self._outer_transaction = False
                 if not self._savepoint_name:
                     self._savepoint_name = (
                         f"s{len(self._conn._savepoints) + 1}"
@@ -118,8 +117,6 @@ class Transaction(BaseTransaction["Connection"]):
         exc_val: Optional[BaseException],
         exc_tb: Optional[TracebackType],
     ) -> bool:
-        if self._outer_transaction is None:
-            raise self._out_of_order_err
         with self._conn.lock:
             if not exc_val and not self.force_rollback:
                 return self._commit()
@@ -167,15 +164,13 @@ class Transaction(BaseTransaction["Connection"]):
 class AsyncTransaction(BaseTransaction["AsyncConnection"]):
     async def __aenter__(self) -> "AsyncTransaction":
         async with self._conn.lock:
-            if self._conn.pgconn.transaction_status == TransactionStatus.IDLE:
+            if self._outer_transaction:
                 assert self._conn._savepoints is None, self._conn._savepoints
                 self._conn._savepoints = []
-                self._outer_transaction = True
                 await self._conn._exec_command(b"begin")
             else:
                 if self._conn._savepoints is None:
                     self._conn._savepoints = []
-                self._outer_transaction = False
                 if not self._savepoint_name:
                     self._savepoint_name = (
                         f"s{len(self._conn._savepoints) + 1}"
@@ -196,8 +191,6 @@ class AsyncTransaction(BaseTransaction["AsyncConnection"]):
         exc_val: Optional[BaseException],
         exc_tb: Optional[TracebackType],
     ) -> bool:
-        if self._outer_transaction is None:
-            raise self._out_of_order_err
         async with self._conn.lock:
             if not exc_val and not self.force_rollback:
                 return await self._commit()
index 04e6f81436bc62dab5b5478647c1cc28a59d6c9b..28385ded2f98ca55a10226aaeda3db82786c33b8 100644 (file)
@@ -613,21 +613,6 @@ def test_explicit_rollback_of_enclosing_tx_outer_tx_unaffected(conn, svcconn):
     assert inserted(svcconn) == {"outer-before", "outer-after"}
 
 
-@pytest.mark.parametrize("exc_info", [(None, None, None), some_exc_info()])
-@pytest.mark.parametrize("name", [None, "s1"])
-def test_manual_enter_and_exit_out_of_order_exit_asserts(conn, name, exc_info):
-    """
-    When user is calling __enter__() and __exit__() manually for some reason,
-    provide a helpful error message if they call __exit__() in the wrong order
-    for nested transactions.
-    """
-    tx1, tx2 = Transaction(conn, name), Transaction(conn)
-    tx1.__enter__()
-    tx2.__enter__()
-    with pytest.raises(ProgrammingError, match="Out-of-order"):
-        tx1.__exit__(*exc_info)
-
-
 @pytest.mark.parametrize("exc_info", [(None, None, None), some_exc_info()])
 @pytest.mark.parametrize("name", [None, "s1"])
 def test_manual_exit_without_enter_asserts(conn, name, exc_info):
index ceac50eef0d81f658203d5b90b83a8b203eca5d0..0958a812be6f102ba03a6d884a2d8549fcc0e9c1 100644 (file)
@@ -584,23 +584,6 @@ async def test_explicit_rollback_of_enclosing_tx_outer_tx_unaffected(
     assert inserted(svcconn) == {"outer-before", "outer-after"}
 
 
-@pytest.mark.parametrize("exc_info", [(None, None, None), some_exc_info()])
-@pytest.mark.parametrize("name", [None, "s1"])
-async def test_manual_enter_and_exit_out_of_order_exit_asserts(
-    aconn, name, exc_info
-):
-    """
-    When user is calling __enter__() and __exit__() manually for some reason,
-    provide a helpful error message if they call __exit__() in the wrong order
-    for nested transactions.
-    """
-    tx1, tx2 = AsyncTransaction(aconn, name), AsyncTransaction(aconn)
-    await tx1.__aenter__()
-    await tx2.__aenter__()
-    with pytest.raises(ProgrammingError, match="Out-of-order"):
-        await tx1.__aexit__(*exc_info)
-
-
 @pytest.mark.parametrize("exc_info", [(None, None, None), some_exc_info()])
 @pytest.mark.parametrize("name", [None, "s1"])
 async def test_manual_exit_without_enter_asserts(aconn, name, exc_info):