]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Avoid multiple commands in transaction code
authorDenis Laxalde <denis.laxalde@dalibo.com>
Thu, 2 Dec 2021 14:13:08 +0000 (15:13 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 2 Apr 2022 23:17:57 +0000 (01:17 +0200)
In pipeline mode, command strings containing multiple SQL commands are
disallowed so we remove all such usages from transaction code.

Accordingly, all generator methods in transaction classes now do not
return anything (the result was not used previously anyways).

In tests, the 'commands' list defined in patch_exec() is now filled by
appending instead of inserting so that we keep the natural order of
commands in assertions.

psycopg/psycopg/transaction.py
tests/fix_db.py
tests/test_transaction.py
tests/test_transaction_async.py

index 0c3b82c9a362f1d8730132f3e1338a1cf215ee6c..b8e36867e1cc0f9eca2541c150e7cb041eb67478 100644 (file)
@@ -7,14 +7,13 @@ Transaction context managers returned by Connection.transaction()
 import logging
 
 from types import TracebackType
-from typing import Generic, List, Optional, Type, Union, TYPE_CHECKING
+from typing import Generic, Iterator, Optional, Type, Union, TYPE_CHECKING
 
 from . import pq
 from . import sql
 from . import errors as e
 from .pq import TransactionStatus, ConnStatus
 from .abc import ConnectionType, PQGen
-from .pq.abc import PGresult
 
 if TYPE_CHECKING:
     from typing import Any
@@ -85,14 +84,14 @@ class BaseTransaction(Generic[ConnectionType]):
         sp = f"{self.savepoint_name!r} " if self.savepoint_name else ""
         return f"<{cls} {sp}({status}) {info} at 0x{id(self):x}>"
 
-    def _enter_gen(self) -> PQGen[PGresult]:
+    def _enter_gen(self) -> PQGen[None]:
         if self._entered:
             raise TypeError("transaction blocks can be used only once")
         self._entered = True
 
         self._push_savepoint()
-        commands = self._get_enter_commands()
-        return self._conn._exec_command(b"; ".join(commands))
+        for command in self._get_enter_commands():
+            yield from self._conn._exec_command(command)
 
     def _exit_gen(
         self,
@@ -119,14 +118,14 @@ class BaseTransaction(Generic[ConnectionType]):
                 logger.warning("error ignored in rollback of %s: %s", self, exc2)
                 return False
 
-    def _commit_gen(self) -> PQGen[PGresult]:
+    def _commit_gen(self) -> PQGen[None]:
         ex = self._pop_savepoint("commit")
         self._exited = True
         if ex:
             raise ex
 
-        commands = self._get_commit_commands()
-        return self._conn._exec_command(b"; ".join(commands))
+        for command in self._get_commit_commands():
+            yield from self._conn._exec_command(command)
 
     def _rollback_gen(self, exc_val: Optional[BaseException]) -> PQGen[bool]:
         if isinstance(exc_val, Rollback):
@@ -137,8 +136,8 @@ class BaseTransaction(Generic[ConnectionType]):
         if ex:
             raise ex
 
-        commands = self._get_rollback_commands()
-        yield from self._conn._exec_command(b"; ".join(commands))
+        for command in self._get_rollback_commands():
+            yield from self._conn._exec_command(command)
 
         if isinstance(exc_val, Rollback):
             if not exc_val.transaction or exc_val.transaction is self:
@@ -146,23 +145,20 @@ class BaseTransaction(Generic[ConnectionType]):
 
         return False
 
-    def _get_enter_commands(self) -> List[bytes]:
-        commands = []
+    def _get_enter_commands(self) -> Iterator[bytes]:
         if self._outer_transaction:
-            commands.append(self._conn._get_tx_start_command())
+            yield self._conn._get_tx_start_command()
 
         if self._savepoint_name:
-            commands.append(
+            yield (
                 sql.SQL("SAVEPOINT {}")
                 .format(sql.Identifier(self._savepoint_name))
                 .as_bytes(self._conn)
             )
-        return commands
 
-    def _get_commit_commands(self) -> List[bytes]:
-        commands = []
+    def _get_commit_commands(self) -> Iterator[bytes]:
         if self._savepoint_name and not self._outer_transaction:
-            commands.append(
+            yield (
                 sql.SQL("RELEASE {}")
                 .format(sql.Identifier(self._savepoint_name))
                 .as_bytes(self._conn)
@@ -170,29 +166,28 @@ class BaseTransaction(Generic[ConnectionType]):
 
         if self._outer_transaction:
             assert not self._conn._num_transactions
-            commands.append(b"COMMIT")
+            yield b"COMMIT"
 
-        return commands
-
-    def _get_rollback_commands(self) -> List[bytes]:
-        commands = []
+    def _get_rollback_commands(self) -> Iterator[bytes]:
         if self._savepoint_name and not self._outer_transaction:
-            commands.append(
-                sql.SQL("ROLLBACK TO {n}; RELEASE {n}")
+            yield (
+                sql.SQL("ROLLBACK TO {n}")
+                .format(n=sql.Identifier(self._savepoint_name))
+                .as_bytes(self._conn)
+            )
+            yield (
+                sql.SQL("RELEASE {n}")
                 .format(n=sql.Identifier(self._savepoint_name))
                 .as_bytes(self._conn)
             )
 
         if self._outer_transaction:
             assert not self._conn._num_transactions
-            commands.append(b"ROLLBACK")
+            yield b"ROLLBACK"
 
         # Also clear the prepared statements cache.
         if self._conn._prepared.clear():
-            for cmd in self._conn._prepared.get_maintenance_commands():
-                commands.append(cmd)
-
-        return commands
+            yield from self._conn._prepared.get_maintenance_commands()
 
     def _push_savepoint(self) -> None:
         """
index ab4a6dd3b9a0102bc1a01a546716809f41beaaec..1da888341e3014a094024596f81fc27c485be597 100644 (file)
@@ -203,7 +203,7 @@ def patch_exec(conn, monkeypatch):
         elif isinstance(cmdcopy, sql.Composable):
             cmdcopy = cmdcopy.as_string(conn)
 
-        L.insert(0, cmdcopy)
+        L.append(cmdcopy)
         return _orig_exec_command(command, *args, **kwargs)
 
     monkeypatch.setattr(conn, "_exec_command", _exec_command)
index 68d25df6b7a0bcba2af83b7de6218b0c73400f01..802ea0577c536561b0784040b9cc0b9678c26fb4 100644 (file)
@@ -426,7 +426,7 @@ def test_named_savepoints_successful_exit(conn, commands):
 
     # Case 2
     with conn.transaction(savepoint_name="foo") as tx:
-        assert commands.popall() == ['BEGIN; SAVEPOINT "foo"']
+        assert commands.popall() == ["BEGIN", 'SAVEPOINT "foo"']
         assert tx.savepoint_name == "foo"
     assert commands.popall() == ["COMMIT"]
 
@@ -466,7 +466,7 @@ def test_named_savepoints_exception_exit(conn, commands):
     # Case 2
     with pytest.raises(ExpectedException):
         with conn.transaction(savepoint_name="foo") as tx:
-            assert commands.popall() == ['BEGIN; SAVEPOINT "foo"']
+            assert commands.popall() == ["BEGIN", 'SAVEPOINT "foo"']
             assert tx.savepoint_name == "foo"
             raise ExpectedException
     assert commands.popall() == ["ROLLBACK"]
@@ -479,7 +479,7 @@ def test_named_savepoints_exception_exit(conn, commands):
                 assert commands.popall() == ['SAVEPOINT "bar"']
                 assert tx.savepoint_name == "bar"
                 raise ExpectedException
-        assert commands.popall() == ['ROLLBACK TO "bar"RELEASE "bar"']
+        assert commands.popall() == ['ROLLBACK TO "bar"', 'RELEASE "bar"']
     assert commands.popall() == ["COMMIT"]
 
     # Case 3 (with savepoint name auto-generated)
@@ -490,7 +490,10 @@ def test_named_savepoints_exception_exit(conn, commands):
                 assert commands.popall() == ['SAVEPOINT "_pg3_2"']
                 assert tx.savepoint_name == "_pg3_2"
                 raise ExpectedException
-        assert commands.popall() == ['ROLLBACK TO "_pg3_2"; RELEASE "_pg3_2"']
+        assert commands.popall() == [
+            'ROLLBACK TO "_pg3_2"',
+            'RELEASE "_pg3_2"',
+        ]
     assert commands.popall() == ["COMMIT"]
 
 
index d6832b98f9dffc9ed53da3fee0383c8bff76f9b5..3c01c463153b17c138ad7bb79b74493c809b61b8 100644 (file)
@@ -371,7 +371,7 @@ async def test_named_savepoints_successful_exit(aconn, acommands):
 
     # Case 2
     async with aconn.transaction(savepoint_name="foo") as tx:
-        assert commands.popall() == ['BEGIN; SAVEPOINT "foo"']
+        assert commands.popall() == ["BEGIN", 'SAVEPOINT "foo"']
         assert tx.savepoint_name == "foo"
     assert commands.popall() == ["COMMIT"]
 
@@ -413,7 +413,7 @@ async def test_named_savepoints_exception_exit(aconn, acommands):
     # Case 2
     with pytest.raises(ExpectedException):
         async with aconn.transaction(savepoint_name="foo") as tx:
-            assert commands.popall() == ['BEGIN; SAVEPOINT "foo"']
+            assert commands.popall() == ["BEGIN", 'SAVEPOINT "foo"']
             assert tx.savepoint_name == "foo"
             raise ExpectedException
     assert commands.popall() == ["ROLLBACK"]
@@ -426,7 +426,7 @@ async def test_named_savepoints_exception_exit(aconn, acommands):
                 assert commands.popall() == ['SAVEPOINT "bar"']
                 assert tx.savepoint_name == "bar"
                 raise ExpectedException
-        assert commands.popall() == ['ROLLBACK TO "bar"RELEASE "bar"']
+        assert commands.popall() == ['ROLLBACK TO "bar"', 'RELEASE "bar"']
     assert commands.popall() == ["COMMIT"]
 
     # Case 3 (with savepoint name auto-generated)
@@ -437,7 +437,10 @@ async def test_named_savepoints_exception_exit(aconn, acommands):
                 assert commands.popall() == ['SAVEPOINT "_pg3_2"']
                 assert tx.savepoint_name == "_pg3_2"
                 raise ExpectedException
-        assert commands.popall() == ['ROLLBACK TO "_pg3_2"; RELEASE "_pg3_2"']
+        assert commands.popall() == [
+            'ROLLBACK TO "_pg3_2"',
+            'RELEASE "_pg3_2"',
+        ]
     assert commands.popall() == ["COMMIT"]