From: Denis Laxalde Date: Thu, 2 Dec 2021 14:13:08 +0000 (+0100) Subject: Avoid multiple commands in transaction code X-Git-Tag: 3.1~146^2~14 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=cadcdc0cb0742d48f283513d73516579a767e8d5;p=thirdparty%2Fpsycopg.git Avoid multiple commands in transaction code In pipeline mode, command strings containing multiple SQL commands are disallowed so we remove all such usages from transaction code. Accordingly, all generator methods in transaction classes now do not return anything (the result was not used previously anyways). In tests, the 'commands' list defined in patch_exec() is now filled by appending instead of inserting so that we keep the natural order of commands in assertions. --- diff --git a/psycopg/psycopg/transaction.py b/psycopg/psycopg/transaction.py index 0c3b82c9a..b8e36867e 100644 --- a/psycopg/psycopg/transaction.py +++ b/psycopg/psycopg/transaction.py @@ -7,14 +7,13 @@ Transaction context managers returned by Connection.transaction() import logging from types import TracebackType -from typing import Generic, List, Optional, Type, Union, TYPE_CHECKING +from typing import Generic, Iterator, Optional, Type, Union, TYPE_CHECKING from . import pq from . import sql from . import errors as e from .pq import TransactionStatus, ConnStatus from .abc import ConnectionType, PQGen -from .pq.abc import PGresult if TYPE_CHECKING: from typing import Any @@ -85,14 +84,14 @@ class BaseTransaction(Generic[ConnectionType]): sp = f"{self.savepoint_name!r} " if self.savepoint_name else "" return f"<{cls} {sp}({status}) {info} at 0x{id(self):x}>" - def _enter_gen(self) -> PQGen[PGresult]: + def _enter_gen(self) -> PQGen[None]: if self._entered: raise TypeError("transaction blocks can be used only once") self._entered = True self._push_savepoint() - commands = self._get_enter_commands() - return self._conn._exec_command(b"; ".join(commands)) + for command in self._get_enter_commands(): + yield from self._conn._exec_command(command) def _exit_gen( self, @@ -119,14 +118,14 @@ class BaseTransaction(Generic[ConnectionType]): logger.warning("error ignored in rollback of %s: %s", self, exc2) return False - def _commit_gen(self) -> PQGen[PGresult]: + def _commit_gen(self) -> PQGen[None]: ex = self._pop_savepoint("commit") self._exited = True if ex: raise ex - commands = self._get_commit_commands() - return self._conn._exec_command(b"; ".join(commands)) + for command in self._get_commit_commands(): + yield from self._conn._exec_command(command) def _rollback_gen(self, exc_val: Optional[BaseException]) -> PQGen[bool]: if isinstance(exc_val, Rollback): @@ -137,8 +136,8 @@ class BaseTransaction(Generic[ConnectionType]): if ex: raise ex - commands = self._get_rollback_commands() - yield from self._conn._exec_command(b"; ".join(commands)) + for command in self._get_rollback_commands(): + yield from self._conn._exec_command(command) if isinstance(exc_val, Rollback): if not exc_val.transaction or exc_val.transaction is self: @@ -146,23 +145,20 @@ class BaseTransaction(Generic[ConnectionType]): return False - def _get_enter_commands(self) -> List[bytes]: - commands = [] + def _get_enter_commands(self) -> Iterator[bytes]: if self._outer_transaction: - commands.append(self._conn._get_tx_start_command()) + yield self._conn._get_tx_start_command() if self._savepoint_name: - commands.append( + yield ( sql.SQL("SAVEPOINT {}") .format(sql.Identifier(self._savepoint_name)) .as_bytes(self._conn) ) - return commands - def _get_commit_commands(self) -> List[bytes]: - commands = [] + def _get_commit_commands(self) -> Iterator[bytes]: if self._savepoint_name and not self._outer_transaction: - commands.append( + yield ( sql.SQL("RELEASE {}") .format(sql.Identifier(self._savepoint_name)) .as_bytes(self._conn) @@ -170,29 +166,28 @@ class BaseTransaction(Generic[ConnectionType]): if self._outer_transaction: assert not self._conn._num_transactions - commands.append(b"COMMIT") + yield b"COMMIT" - return commands - - def _get_rollback_commands(self) -> List[bytes]: - commands = [] + def _get_rollback_commands(self) -> Iterator[bytes]: if self._savepoint_name and not self._outer_transaction: - commands.append( - sql.SQL("ROLLBACK TO {n}; RELEASE {n}") + yield ( + sql.SQL("ROLLBACK TO {n}") + .format(n=sql.Identifier(self._savepoint_name)) + .as_bytes(self._conn) + ) + yield ( + sql.SQL("RELEASE {n}") .format(n=sql.Identifier(self._savepoint_name)) .as_bytes(self._conn) ) if self._outer_transaction: assert not self._conn._num_transactions - commands.append(b"ROLLBACK") + yield b"ROLLBACK" # Also clear the prepared statements cache. if self._conn._prepared.clear(): - for cmd in self._conn._prepared.get_maintenance_commands(): - commands.append(cmd) - - return commands + yield from self._conn._prepared.get_maintenance_commands() def _push_savepoint(self) -> None: """ diff --git a/tests/fix_db.py b/tests/fix_db.py index ab4a6dd3b..1da888341 100644 --- a/tests/fix_db.py +++ b/tests/fix_db.py @@ -203,7 +203,7 @@ def patch_exec(conn, monkeypatch): elif isinstance(cmdcopy, sql.Composable): cmdcopy = cmdcopy.as_string(conn) - L.insert(0, cmdcopy) + L.append(cmdcopy) return _orig_exec_command(command, *args, **kwargs) monkeypatch.setattr(conn, "_exec_command", _exec_command) diff --git a/tests/test_transaction.py b/tests/test_transaction.py index 68d25df6b..802ea0577 100644 --- a/tests/test_transaction.py +++ b/tests/test_transaction.py @@ -426,7 +426,7 @@ def test_named_savepoints_successful_exit(conn, commands): # Case 2 with conn.transaction(savepoint_name="foo") as tx: - assert commands.popall() == ['BEGIN; SAVEPOINT "foo"'] + assert commands.popall() == ["BEGIN", 'SAVEPOINT "foo"'] assert tx.savepoint_name == "foo" assert commands.popall() == ["COMMIT"] @@ -466,7 +466,7 @@ def test_named_savepoints_exception_exit(conn, commands): # Case 2 with pytest.raises(ExpectedException): with conn.transaction(savepoint_name="foo") as tx: - assert commands.popall() == ['BEGIN; SAVEPOINT "foo"'] + assert commands.popall() == ["BEGIN", 'SAVEPOINT "foo"'] assert tx.savepoint_name == "foo" raise ExpectedException assert commands.popall() == ["ROLLBACK"] @@ -479,7 +479,7 @@ def test_named_savepoints_exception_exit(conn, commands): assert commands.popall() == ['SAVEPOINT "bar"'] assert tx.savepoint_name == "bar" raise ExpectedException - assert commands.popall() == ['ROLLBACK TO "bar"; RELEASE "bar"'] + assert commands.popall() == ['ROLLBACK TO "bar"', 'RELEASE "bar"'] assert commands.popall() == ["COMMIT"] # Case 3 (with savepoint name auto-generated) @@ -490,7 +490,10 @@ def test_named_savepoints_exception_exit(conn, commands): assert commands.popall() == ['SAVEPOINT "_pg3_2"'] assert tx.savepoint_name == "_pg3_2" raise ExpectedException - assert commands.popall() == ['ROLLBACK TO "_pg3_2"; RELEASE "_pg3_2"'] + assert commands.popall() == [ + 'ROLLBACK TO "_pg3_2"', + 'RELEASE "_pg3_2"', + ] assert commands.popall() == ["COMMIT"] diff --git a/tests/test_transaction_async.py b/tests/test_transaction_async.py index d6832b98f..3c01c4631 100644 --- a/tests/test_transaction_async.py +++ b/tests/test_transaction_async.py @@ -371,7 +371,7 @@ async def test_named_savepoints_successful_exit(aconn, acommands): # Case 2 async with aconn.transaction(savepoint_name="foo") as tx: - assert commands.popall() == ['BEGIN; SAVEPOINT "foo"'] + assert commands.popall() == ["BEGIN", 'SAVEPOINT "foo"'] assert tx.savepoint_name == "foo" assert commands.popall() == ["COMMIT"] @@ -413,7 +413,7 @@ async def test_named_savepoints_exception_exit(aconn, acommands): # Case 2 with pytest.raises(ExpectedException): async with aconn.transaction(savepoint_name="foo") as tx: - assert commands.popall() == ['BEGIN; SAVEPOINT "foo"'] + assert commands.popall() == ["BEGIN", 'SAVEPOINT "foo"'] assert tx.savepoint_name == "foo" raise ExpectedException assert commands.popall() == ["ROLLBACK"] @@ -426,7 +426,7 @@ async def test_named_savepoints_exception_exit(aconn, acommands): assert commands.popall() == ['SAVEPOINT "bar"'] assert tx.savepoint_name == "bar" raise ExpectedException - assert commands.popall() == ['ROLLBACK TO "bar"; RELEASE "bar"'] + assert commands.popall() == ['ROLLBACK TO "bar"', 'RELEASE "bar"'] assert commands.popall() == ["COMMIT"] # Case 3 (with savepoint name auto-generated) @@ -437,7 +437,10 @@ async def test_named_savepoints_exception_exit(aconn, acommands): assert commands.popall() == ['SAVEPOINT "_pg3_2"'] assert tx.savepoint_name == "_pg3_2" raise ExpectedException - assert commands.popall() == ['ROLLBACK TO "_pg3_2"; RELEASE "_pg3_2"'] + assert commands.popall() == [ + 'ROLLBACK TO "_pg3_2"', + 'RELEASE "_pg3_2"', + ] assert commands.popall() == ["COMMIT"]