From: Daniele Varrazzo Date: Mon, 16 Nov 2020 03:38:26 +0000 (+0000) Subject: Define at init rather than at enter if the transaction is top-level X-Git-Tag: 3.0.dev0~351^2~6 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=73c8c9f01f4f95730725b53f91024aa9a76aa9a9;p=thirdparty%2Fpsycopg.git Define at init rather than at enter if the transaction is top-level --- diff --git a/psycopg3/psycopg3/transaction.py b/psycopg3/psycopg3/transaction.py index 8879d1582..b159b65b0 100644 --- a/psycopg3/psycopg3/transaction.py +++ b/psycopg3/psycopg3/transaction.py @@ -49,8 +49,9 @@ class BaseTransaction(Generic[ConnectionType]): self._conn = connection self._savepoint_name = savepoint_name or "" self.force_rollback = force_rollback - - self._outer_transaction: Optional[bool] = None + self._outer_transaction = ( + connection.pgconn.transaction_status == TransactionStatus.IDLE + ) @property def connection(self) -> ConnectionType: @@ -89,15 +90,13 @@ class BaseTransaction(Generic[ConnectionType]): class Transaction(BaseTransaction["Connection"]): def __enter__(self) -> "Transaction": with self._conn.lock: - if self._conn.pgconn.transaction_status == TransactionStatus.IDLE: + if self._outer_transaction: assert self._conn._savepoints is None, self._conn._savepoints self._conn._savepoints = [] - self._outer_transaction = True self._conn._exec_command(b"begin") else: if self._conn._savepoints is None: self._conn._savepoints = [] - self._outer_transaction = False if not self._savepoint_name: self._savepoint_name = ( f"s{len(self._conn._savepoints) + 1}" @@ -118,8 +117,6 @@ class Transaction(BaseTransaction["Connection"]): exc_val: Optional[BaseException], exc_tb: Optional[TracebackType], ) -> bool: - if self._outer_transaction is None: - raise self._out_of_order_err with self._conn.lock: if not exc_val and not self.force_rollback: return self._commit() @@ -167,15 +164,13 @@ class Transaction(BaseTransaction["Connection"]): class AsyncTransaction(BaseTransaction["AsyncConnection"]): async def __aenter__(self) -> "AsyncTransaction": async with self._conn.lock: - if self._conn.pgconn.transaction_status == TransactionStatus.IDLE: + if self._outer_transaction: assert self._conn._savepoints is None, self._conn._savepoints self._conn._savepoints = [] - self._outer_transaction = True await self._conn._exec_command(b"begin") else: if self._conn._savepoints is None: self._conn._savepoints = [] - self._outer_transaction = False if not self._savepoint_name: self._savepoint_name = ( f"s{len(self._conn._savepoints) + 1}" @@ -196,8 +191,6 @@ class AsyncTransaction(BaseTransaction["AsyncConnection"]): exc_val: Optional[BaseException], exc_tb: Optional[TracebackType], ) -> bool: - if self._outer_transaction is None: - raise self._out_of_order_err async with self._conn.lock: if not exc_val and not self.force_rollback: return await self._commit() diff --git a/tests/test_transaction.py b/tests/test_transaction.py index 04e6f8143..28385ded2 100644 --- a/tests/test_transaction.py +++ b/tests/test_transaction.py @@ -613,21 +613,6 @@ def test_explicit_rollback_of_enclosing_tx_outer_tx_unaffected(conn, svcconn): assert inserted(svcconn) == {"outer-before", "outer-after"} -@pytest.mark.parametrize("exc_info", [(None, None, None), some_exc_info()]) -@pytest.mark.parametrize("name", [None, "s1"]) -def test_manual_enter_and_exit_out_of_order_exit_asserts(conn, name, exc_info): - """ - When user is calling __enter__() and __exit__() manually for some reason, - provide a helpful error message if they call __exit__() in the wrong order - for nested transactions. - """ - tx1, tx2 = Transaction(conn, name), Transaction(conn) - tx1.__enter__() - tx2.__enter__() - with pytest.raises(ProgrammingError, match="Out-of-order"): - tx1.__exit__(*exc_info) - - @pytest.mark.parametrize("exc_info", [(None, None, None), some_exc_info()]) @pytest.mark.parametrize("name", [None, "s1"]) def test_manual_exit_without_enter_asserts(conn, name, exc_info): diff --git a/tests/test_transaction_async.py b/tests/test_transaction_async.py index ceac50eef..0958a812b 100644 --- a/tests/test_transaction_async.py +++ b/tests/test_transaction_async.py @@ -584,23 +584,6 @@ async def test_explicit_rollback_of_enclosing_tx_outer_tx_unaffected( assert inserted(svcconn) == {"outer-before", "outer-after"} -@pytest.mark.parametrize("exc_info", [(None, None, None), some_exc_info()]) -@pytest.mark.parametrize("name", [None, "s1"]) -async def test_manual_enter_and_exit_out_of_order_exit_asserts( - aconn, name, exc_info -): - """ - When user is calling __enter__() and __exit__() manually for some reason, - provide a helpful error message if they call __exit__() in the wrong order - for nested transactions. - """ - tx1, tx2 = AsyncTransaction(aconn, name), AsyncTransaction(aconn) - await tx1.__aenter__() - await tx2.__aenter__() - with pytest.raises(ProgrammingError, match="Out-of-order"): - await tx1.__aexit__(*exc_info) - - @pytest.mark.parametrize("exc_info", [(None, None, None), some_exc_info()]) @pytest.mark.parametrize("name", [None, "s1"]) async def test_manual_exit_without_enter_asserts(aconn, name, exc_info):