From: Daniele Varrazzo Date: Mon, 16 Nov 2020 16:26:38 +0000 (+0000) Subject: Fixed transaction behaviour when there is a transaction already started X-Git-Tag: 3.0.dev0~351^2~3 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=57399571c76e1f146204618df3c484a94d1ac7a6;p=thirdparty%2Fpsycopg.git Fixed transaction behaviour when there is a transaction already started State management simplified too. --- diff --git a/psycopg3/psycopg3/connection.py b/psycopg3/psycopg3/connection.py index 016af7f01..2dd262de4 100644 --- a/psycopg3/psycopg3/connection.py +++ b/psycopg3/psycopg3/connection.py @@ -106,10 +106,10 @@ class BaseConnection: self._notice_handlers: List[NoticeHandler] = [] self._notify_handlers: List[NotifyHandler] = [] - # stack of savepoint names managed by active Transaction() blocks - self._savepoints: Optional[List[str]] = None - # (None when there no active Transaction blocks; [] when there is only - # one Transaction block, with a top-level transaction and no savepoint) + # Stack of savepoint names managed by current transaction blocks. + # the first item is "" in case the outermost Transaction must manage + # only a begin/commit and not a savepoint. + self._savepoints: List[str] = [] wself = ref(self) @@ -135,7 +135,7 @@ class BaseConnection: # subclasses must call it holding a lock status = self.pgconn.transaction_status if status != TransactionStatus.IDLE: - if self._savepoints is not None: + if self._savepoints: raise e.ProgrammingError( "couldn't change autocommit state: " "connection.transaction() context in progress" @@ -295,7 +295,7 @@ class Connection(BaseConnection): def commit(self) -> None: """Commit any pending transaction to the database.""" with self.lock: - if self._savepoints is not None: + if self._savepoints: raise e.ProgrammingError( "Explicit commit() forbidden within a Transaction " "context. (Transaction will be automatically committed " @@ -308,7 +308,7 @@ class Connection(BaseConnection): def rollback(self) -> None: """Roll back to the start of any pending transaction.""" with self.lock: - if self._savepoints is not None: + if self._savepoints: raise e.ProgrammingError( "Explicit rollback() forbidden within a Transaction " "context. (Either raise Transaction.Rollback() or allow " @@ -447,7 +447,7 @@ class AsyncConnection(BaseConnection): async with self.lock: if self.pgconn.transaction_status == TransactionStatus.IDLE: return - if self._savepoints is not None: + if self._savepoints: raise e.ProgrammingError( "Explicit commit() forbidden within a Transaction " "context. (Transaction will be automatically committed " @@ -459,7 +459,7 @@ class AsyncConnection(BaseConnection): async with self.lock: if self.pgconn.transaction_status == TransactionStatus.IDLE: return - if self._savepoints is not None: + if self._savepoints: raise e.ProgrammingError( "Explicit rollback() forbidden within a Transaction " "context. (Either raise Transaction.Rollback() or allow " diff --git a/psycopg3/psycopg3/transaction.py b/psycopg3/psycopg3/transaction.py index b02d412bc..8e11ab04b 100644 --- a/psycopg3/psycopg3/transaction.py +++ b/psycopg3/psycopg3/transaction.py @@ -12,7 +12,6 @@ from typing import Generic, List, Optional, Type, Union, TYPE_CHECKING from . import sql from .pq import TransactionStatus from .proto import ConnectionType -from .errors import ProgrammingError if TYPE_CHECKING: from .connection import Connection, AsyncConnection # noqa: F401 @@ -47,18 +46,28 @@ class BaseTransaction(Generic[ConnectionType]): force_rollback: bool = False, ): self._conn = connection - self._savepoint_name = savepoint_name or "" self.force_rollback = force_rollback - self._outer_transaction = ( - connection.pgconn.transaction_status == TransactionStatus.IDLE - ) + self._yolo = True + + if connection.pgconn.transaction_status == TransactionStatus.IDLE: + # outer transaction: if no name it's only a begin, else + # there will be an additional savepoint + self._outer_transaction = True + assert not connection._savepoints + self._savepoint_name = savepoint_name or "" + else: + # inner transaction: it always has a name + self._outer_transaction = False + self._savepoint_name = ( + savepoint_name or f"s{len(self._conn._savepoints) + 1}" + ) @property def connection(self) -> ConnectionType: return self._conn @property - def savepoint_name(self) -> str: + def savepoint_name(self) -> Optional[str]: return self._savepoint_name def __repr__(self) -> str: @@ -69,23 +78,14 @@ class BaseTransaction(Generic[ConnectionType]): args.append("force_rollback=True") return f"{self.__class__.__qualname__}({', '.join(args)})" - _out_of_order_err = ProgrammingError( - "Out-of-order Transaction context exits. Are you " - "calling __exit__() manually and getting it wrong?" - ) - def _enter_commands(self) -> List[str]: - commands = [] + assert self._yolo + self._yolo = False + commands = [] if self._outer_transaction: - assert self._conn._savepoints is None, self._conn._savepoints - self._conn._savepoints = [] + assert not self._conn._savepoints, self._conn._savepoints commands.append("begin") - else: - if self._conn._savepoints is None: - self._conn._savepoints = [] - if not self._savepoint_name: - self._savepoint_name = f"s{len(self._conn._savepoints) + 1}" if self._savepoint_name: commands.append( @@ -93,52 +93,46 @@ class BaseTransaction(Generic[ConnectionType]): .format(sql.Identifier(self._savepoint_name)) .as_string(self._conn) ) - self._conn._savepoints.append(self._savepoint_name) + self._conn._savepoints.append(self._savepoint_name) return commands def _commit_commands(self) -> List[str]: - commands = [] + assert self._conn._savepoints[-1] == self._savepoint_name + self._conn._savepoints.pop() - self._pop_savepoint() + commands = [] if self._savepoint_name and not self._outer_transaction: commands.append( sql.SQL("release savepoint {}") .format(sql.Identifier(self._savepoint_name)) .as_string(self._conn) ) + if self._outer_transaction: + assert not self._conn._savepoints commands.append("commit") return commands def _rollback_commands(self) -> List[str]: - commands = [] + assert self._conn._savepoints[-1] == self._savepoint_name + self._conn._savepoints.pop() - self._pop_savepoint() + commands = [] if self._savepoint_name and not self._outer_transaction: commands.append( sql.SQL("rollback to savepoint {n}; release savepoint {n}") .format(n=sql.Identifier(self._savepoint_name)) .as_string(self._conn) ) + if self._outer_transaction: + assert not self._conn._savepoints commands.append("rollback") return commands - def _pop_savepoint(self) -> None: - if self._savepoint_name: - if self._conn._savepoints is None: - raise self._out_of_order_err - actual = self._conn._savepoints.pop() - if actual != self._savepoint_name: - raise self._out_of_order_err - if self._outer_transaction: - if self._conn._savepoints is None or self._conn._savepoints: - raise self._out_of_order_err - self._conn._savepoints = None - class Transaction(BaseTransaction["Connection"]): def __enter__(self) -> "Transaction": diff --git a/tests/test_transaction.py b/tests/test_transaction.py index dfa66ab25..c5d8890bd 100644 --- a/tests/test_transaction.py +++ b/tests/test_transaction.py @@ -135,6 +135,21 @@ def test_rollback_on_exception_exit(conn): assert not inserted(conn) +def test_interaction_dbapi_transaction(conn): + insert_row(conn, "foo") + + with conn.transaction(): + insert_row(conn, "bar") + raise Rollback + + with conn.transaction(): + insert_row(conn, "baz") + + assert in_transaction(conn) + conn.commit() + assert inserted(conn) == {"foo", "baz"} + + def test_prohibits_use_of_commit_rollback_autocommit(conn): """ Within a Transaction block, it is forbidden to touch commit, rollback, @@ -365,6 +380,20 @@ def test_named_savepoints_successful_exit(conn, commands): tx.__exit__(None, None, None) assert commands.pop() == "commit" + # Case 1 (with a transaction already started) + conn.cursor().execute("select 1") + assert commands.pop() == "begin" + tx = Transaction(conn) + tx.__enter__() + assert commands.pop() == 'savepoint "s1"' + assert tx.savepoint_name == "s1" + tx.__exit__(None, None, None) + assert commands.pop() == 'release savepoint "s1"' + assert not commands + conn.rollback() + assert commands.pop() == "rollback" + assert not commands + # Case 2 tx = Transaction(conn, savepoint_name="foo") tx.__enter__() @@ -390,10 +419,10 @@ def test_named_savepoints_successful_exit(conn, commands): assert commands.pop() == "begin" tx = Transaction(conn) tx.__enter__() - assert commands.pop() == 'savepoint "s1"' - assert tx.savepoint_name == "s1" + assert commands.pop() == 'savepoint "s2"' + assert tx.savepoint_name == "s2" tx.__exit__(None, None, None) - assert commands.pop() == 'release savepoint "s1"' + assert commands.pop() == 'release savepoint "s2"' assert not commands assert commands.pop() == "commit" @@ -442,12 +471,12 @@ def test_named_savepoints_exception_exit(conn, commands): assert commands.pop() == "begin" tx = Transaction(conn) tx.__enter__() - assert commands.pop() == 'savepoint "s1"' - assert tx.savepoint_name == "s1" + assert commands.pop() == 'savepoint "s2"' + assert tx.savepoint_name == "s2" tx.__exit__(*some_exc_info()) assert ( commands.pop() - == 'rollback to savepoint "s1"; release savepoint "s1"' + == 'rollback to savepoint "s2"; release savepoint "s2"' ) assert not commands assert commands.pop() == "commit" @@ -608,30 +637,3 @@ def test_explicit_rollback_of_enclosing_tx_outer_tx_unaffected(conn, svcconn): assert not inserted(svcconn) # Not yet committed # Changes committed assert inserted(svcconn) == {"outer-before", "outer-after"} - - -@pytest.mark.parametrize("exc_info", [(None, None, None), some_exc_info()]) -@pytest.mark.parametrize("name", [None, "s1"]) -def test_manual_exit_without_enter_asserts(conn, name, exc_info): - """ - When user is calling __enter__() and __exit__() manually for some reason, - provide a helpful error message if they call __exit__() without first - having called __enter__() - """ - tx = Transaction(conn, name) - with pytest.raises(ProgrammingError, match="Out-of-order"): - tx.__exit__(*exc_info) - - -@pytest.mark.parametrize("exc_info", [(None, None, None), some_exc_info()]) -@pytest.mark.parametrize("name", [None, "s1"]) -def test_manual_exit_twice_asserts(conn, name, exc_info): - """ - When user is calling __enter__() and __exit__() manually for some reason, - provide a helpful error message if they accidentally call __exit__() twice. - """ - tx = Transaction(conn, name) - tx.__enter__() - tx.__exit__(*exc_info) - with pytest.raises(ProgrammingError, match="Out-of-order"): - tx.__exit__(*exc_info) diff --git a/tests/test_transaction_async.py b/tests/test_transaction_async.py index c434fa6a0..d4a3889d6 100644 --- a/tests/test_transaction_async.py +++ b/tests/test_transaction_async.py @@ -97,6 +97,21 @@ async def test_rollback_on_exception_exit(aconn): assert not await inserted(aconn) +async def test_interaction_dbapi_transaction(aconn): + await insert_row(aconn, "foo") + + async with aconn.transaction(): + await insert_row(aconn, "bar") + raise Rollback + + async with aconn.transaction(): + await insert_row(aconn, "baz") + + assert in_transaction(aconn) + await aconn.commit() + assert await inserted(aconn) == {"foo", "baz"} + + async def test_prohibits_use_of_commit_rollback_autocommit(aconn): """ Within a Transaction block, it is forbidden to touch commit, rollback, @@ -334,6 +349,20 @@ async def test_named_savepoints_successful_exit(aconn, commands): await tx.__aexit__(None, None, None) assert commands.pop() == "commit" + # Case 1 (with a transaction already started) + await (await aconn.cursor()).execute("select 1") + assert commands.pop() == "begin" + tx = AsyncTransaction(aconn) + await tx.__aenter__() + assert commands.pop() == 'savepoint "s1"' + assert tx.savepoint_name == "s1" + await tx.__aexit__(None, None, None) + assert commands.pop() == 'release savepoint "s1"' + assert not commands + await aconn.rollback() + assert commands.pop() == "rollback" + assert not commands + # Case 2 tx = AsyncTransaction(aconn, savepoint_name="foo") await tx.__aenter__() @@ -359,10 +388,10 @@ async def test_named_savepoints_successful_exit(aconn, commands): assert commands.pop() == "begin" tx = AsyncTransaction(aconn) await tx.__aenter__() - assert commands.pop() == 'savepoint "s1"' - assert tx.savepoint_name == "s1" + assert commands.pop() == 'savepoint "s2"' + assert tx.savepoint_name == "s2" await tx.__aexit__(None, None, None) - assert commands.pop() == 'release savepoint "s1"' + assert commands.pop() == 'release savepoint "s2"' assert not commands assert commands.pop() == "commit" @@ -411,12 +440,12 @@ async def test_named_savepoints_exception_exit(aconn, commands): assert commands.pop() == "begin" tx = AsyncTransaction(aconn) await tx.__aenter__() - assert commands.pop() == 'savepoint "s1"' - assert tx.savepoint_name == "s1" + assert commands.pop() == 'savepoint "s2"' + assert tx.savepoint_name == "s2" await tx.__aexit__(*some_exc_info()) assert ( commands.pop() - == 'rollback to savepoint "s1"; release savepoint "s1"' + == 'rollback to savepoint "s2"; release savepoint "s2"' ) assert not commands assert commands.pop() == "commit" @@ -579,30 +608,3 @@ async def test_explicit_rollback_of_enclosing_tx_outer_tx_unaffected( assert not inserted(svcconn) # Not yet committed # Changes committed assert inserted(svcconn) == {"outer-before", "outer-after"} - - -@pytest.mark.parametrize("exc_info", [(None, None, None), some_exc_info()]) -@pytest.mark.parametrize("name", [None, "s1"]) -async def test_manual_exit_without_enter_asserts(aconn, name, exc_info): - """ - When user is calling __enter__() and __exit__() manually for some reason, - provide a helpful error message if they call __exit__() without first - having called __enter__() - """ - tx = AsyncTransaction(aconn, name) - with pytest.raises(ProgrammingError, match="Out-of-order"): - await tx.__aexit__(*exc_info) - - -@pytest.mark.parametrize("exc_info", [(None, None, None), some_exc_info()]) -@pytest.mark.parametrize("name", [None, "s1"]) -async def test_manual_exit_twice_asserts(aconn, name, exc_info): - """ - When user is calling __enter__() and __exit__() manually for some reason, - provide a helpful error message if they accidentally call __exit__() twice. - """ - tx = AsyncTransaction(aconn, name) - await tx.__aenter__() - await tx.__aexit__(*exc_info) - with pytest.raises(ProgrammingError, match="Out-of-order"): - await tx.__aexit__(*exc_info)