]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Commands to enter/exit a transaction with named savepoint tweaked
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 16 Nov 2020 13:56:44 +0000 (13:56 +0000)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 16 Nov 2020 13:59:51 +0000 (13:59 +0000)
- Batch together begin + savepoint
- Do without releasing the savepoint immediately before commit/rollback

psycopg3/psycopg3/transaction.py
tests/test_transaction.py
tests/test_transaction_async.py

index a24d7874fdfb3e7ca24ba57dc8b058e6207bf9e0..b02d412bc1a41175551315649e1a655c068ce6d5 100644 (file)
@@ -11,7 +11,7 @@ from typing import Generic, List, Optional, Type, Union, TYPE_CHECKING
 
 from . import sql
 from .pq import TransactionStatus
-from .proto import ConnectionType, Query
+from .proto import ConnectionType
 from .errors import ProgrammingError
 
 if TYPE_CHECKING:
@@ -74,13 +74,13 @@ class BaseTransaction(Generic[ConnectionType]):
         "calling __exit__() manually and getting it wrong?"
     )
 
-    def _enter_commands(self) -> List[Query]:
-        commands: List[Query] = []
+    def _enter_commands(self) -> List[str]:
+        commands = []
 
         if self._outer_transaction:
             assert self._conn._savepoints is None, self._conn._savepoints
             self._conn._savepoints = []
-            commands.append(b"begin")
+            commands.append("begin")
         else:
             if self._conn._savepoints is None:
                 self._conn._savepoints = []
@@ -89,41 +89,41 @@ class BaseTransaction(Generic[ConnectionType]):
 
         if self._savepoint_name:
             commands.append(
-                sql.SQL("savepoint {}").format(
-                    sql.Identifier(self._savepoint_name)
-                )
+                sql.SQL("savepoint {}")
+                .format(sql.Identifier(self._savepoint_name))
+                .as_string(self._conn)
             )
             self._conn._savepoints.append(self._savepoint_name)
 
         return commands
 
-    def _commit_commands(self) -> List[Query]:
-        commands: List[Query] = []
+    def _commit_commands(self) -> List[str]:
+        commands = []
 
         self._pop_savepoint()
-        if self._savepoint_name:
+        if self._savepoint_name and not self._outer_transaction:
             commands.append(
-                sql.SQL("release savepoint {}").format(
-                    sql.Identifier(self._savepoint_name)
-                )
+                sql.SQL("release savepoint {}")
+                .format(sql.Identifier(self._savepoint_name))
+                .as_string(self._conn)
             )
         if self._outer_transaction:
-            commands.append(b"commit")
+            commands.append("commit")
 
         return commands
 
-    def _rollback_commands(self) -> List[Query]:
-        commands: List[Query] = []
+    def _rollback_commands(self) -> List[str]:
+        commands = []
 
         self._pop_savepoint()
-        if self._savepoint_name:
+        if self._savepoint_name and not self._outer_transaction:
             commands.append(
-                sql.SQL(
-                    "rollback to savepoint {n}; release savepoint {n}"
-                ).format(n=sql.Identifier(self._savepoint_name))
+                sql.SQL("rollback to savepoint {n}; release savepoint {n}")
+                .format(n=sql.Identifier(self._savepoint_name))
+                .as_string(self._conn)
             )
         if self._outer_transaction:
-            commands.append(b"rollback")
+            commands.append("rollback")
 
         return commands
 
@@ -178,9 +178,8 @@ class Transaction(BaseTransaction["Connection"]):
 
         return False
 
-    def _execute(self, commands: List[Query]) -> None:
-        for command in commands:
-            self._conn._exec_command(command)
+    def _execute(self, commands: List[str]) -> None:
+        self._conn._exec_command("; ".join(commands))
 
 
 class AsyncTransaction(BaseTransaction["AsyncConnection"]):
@@ -222,6 +221,5 @@ class AsyncTransaction(BaseTransaction["AsyncConnection"]):
 
         return False
 
-    async def _execute(self, commands: List[Query]) -> None:
-        for command in commands:
-            await self._conn._exec_command(command)
+    async def _execute(self, commands: List[str]) -> None:
+        await self._conn._exec_command("; ".join(commands))
index 28385ded2f98ca55a10226aaeda3db82786c33b8..dfa66ab25f2239a8df7100bb371fa58ed4f1651e 100644 (file)
@@ -368,11 +368,9 @@ def test_named_savepoints_successful_exit(conn, commands):
     # Case 2
     tx = Transaction(conn, savepoint_name="foo")
     tx.__enter__()
-    assert commands.pop() == "begin"
-    assert commands.pop() == 'savepoint "foo"'
+    assert commands.pop() == 'begin; savepoint "foo"'
     assert tx.savepoint_name == "foo"
     tx.__exit__(None, None, None)
-    assert commands.pop() == 'release savepoint "foo"'
     assert commands.pop() == "commit"
 
     # Case 3 (with savepoint name provided)
@@ -384,6 +382,7 @@ def test_named_savepoints_successful_exit(conn, commands):
         assert tx.savepoint_name == "bar"
         tx.__exit__(None, None, None)
         assert commands.pop() == 'release savepoint "bar"'
+        assert not commands
     assert commands.pop() == "commit"
 
     # Case 3 (with savepoint name auto-generated)
@@ -395,6 +394,7 @@ def test_named_savepoints_successful_exit(conn, commands):
         assert tx.savepoint_name == "s1"
         tx.__exit__(None, None, None)
         assert commands.pop() == 'release savepoint "s1"'
+        assert not commands
     assert commands.pop() == "commit"
 
     assert not commands
@@ -417,14 +417,9 @@ def test_named_savepoints_exception_exit(conn, commands):
     # Case 2
     tx = Transaction(conn, savepoint_name="foo")
     tx.__enter__()
-    assert commands.pop() == "begin"
-    assert commands.pop() == 'savepoint "foo"'
+    assert commands.pop() == 'begin; savepoint "foo"'
     assert tx.savepoint_name == "foo"
     tx.__exit__(*some_exc_info())
-    assert (
-        commands.pop()
-        == 'rollback to savepoint "foo"; release savepoint "foo"'
-    )
     assert commands.pop() == "rollback"
 
     # Case 3 (with savepoint name provided)
@@ -439,6 +434,7 @@ def test_named_savepoints_exception_exit(conn, commands):
             commands.pop()
             == 'rollback to savepoint "bar"; release savepoint "bar"'
         )
+        assert not commands
     assert commands.pop() == "commit"
 
     # Case 3 (with savepoint name auto-generated)
@@ -453,6 +449,7 @@ def test_named_savepoints_exception_exit(conn, commands):
             commands.pop()
             == 'rollback to savepoint "s1"; release savepoint "s1"'
         )
+        assert not commands
     assert commands.pop() == "commit"
 
     assert not commands
index 0958a812be6f102ba03a6d884a2d8549fcc0e9c1..c434fa6a01c2be80c1dad3fd7496461245c32b75 100644 (file)
@@ -337,11 +337,9 @@ async def test_named_savepoints_successful_exit(aconn, commands):
     # Case 2
     tx = AsyncTransaction(aconn, savepoint_name="foo")
     await tx.__aenter__()
-    assert commands.pop() == "begin"
-    assert commands.pop() == 'savepoint "foo"'
+    assert commands.pop() == 'begin; savepoint "foo"'
     assert tx.savepoint_name == "foo"
     await tx.__aexit__(None, None, None)
-    assert commands.pop() == 'release savepoint "foo"'
     assert commands.pop() == "commit"
 
     # Case 3 (with savepoint name provided)
@@ -353,6 +351,7 @@ async def test_named_savepoints_successful_exit(aconn, commands):
         assert tx.savepoint_name == "bar"
         await tx.__aexit__(None, None, None)
         assert commands.pop() == 'release savepoint "bar"'
+        assert not commands
     assert commands.pop() == "commit"
 
     # Case 3 (with savepoint name auto-generated)
@@ -364,6 +363,7 @@ async def test_named_savepoints_successful_exit(aconn, commands):
         assert tx.savepoint_name == "s1"
         await tx.__aexit__(None, None, None)
         assert commands.pop() == 'release savepoint "s1"'
+        assert not commands
     assert commands.pop() == "commit"
 
     assert not commands
@@ -386,14 +386,9 @@ async def test_named_savepoints_exception_exit(aconn, commands):
     # Case 2
     tx = AsyncTransaction(aconn, savepoint_name="foo")
     await tx.__aenter__()
-    assert commands.pop() == "begin"
-    assert commands.pop() == 'savepoint "foo"'
+    assert commands.pop() == 'begin; savepoint "foo"'
     assert tx.savepoint_name == "foo"
     await tx.__aexit__(*some_exc_info())
-    assert (
-        commands.pop()
-        == 'rollback to savepoint "foo"; release savepoint "foo"'
-    )
     assert commands.pop() == "rollback"
 
     # Case 3 (with savepoint name provided)
@@ -408,6 +403,7 @@ async def test_named_savepoints_exception_exit(aconn, commands):
             commands.pop()
             == 'rollback to savepoint "bar"; release savepoint "bar"'
         )
+        assert not commands
     assert commands.pop() == "commit"
 
     # Case 3 (with savepoint name auto-generated)
@@ -422,6 +418,7 @@ async def test_named_savepoints_exception_exit(aconn, commands):
             commands.pop()
             == 'rollback to savepoint "s1"; release savepoint "s1"'
         )
+        assert not commands
     assert commands.pop() == "commit"
 
     assert not commands