From: Daniele Varrazzo Date: Mon, 16 Nov 2020 13:56:44 +0000 (+0000) Subject: Commands to enter/exit a transaction with named savepoint tweaked X-Git-Tag: 3.0.dev0~351^2~4 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=2054a812b7f9e69c82af8116b685f94cc130bd49;p=thirdparty%2Fpsycopg.git Commands to enter/exit a transaction with named savepoint tweaked - Batch together begin + savepoint - Do without releasing the savepoint immediately before commit/rollback --- diff --git a/psycopg3/psycopg3/transaction.py b/psycopg3/psycopg3/transaction.py index a24d7874f..b02d412bc 100644 --- a/psycopg3/psycopg3/transaction.py +++ b/psycopg3/psycopg3/transaction.py @@ -11,7 +11,7 @@ from typing import Generic, List, Optional, Type, Union, TYPE_CHECKING from . import sql from .pq import TransactionStatus -from .proto import ConnectionType, Query +from .proto import ConnectionType from .errors import ProgrammingError if TYPE_CHECKING: @@ -74,13 +74,13 @@ class BaseTransaction(Generic[ConnectionType]): "calling __exit__() manually and getting it wrong?" ) - def _enter_commands(self) -> List[Query]: - commands: List[Query] = [] + def _enter_commands(self) -> List[str]: + commands = [] if self._outer_transaction: assert self._conn._savepoints is None, self._conn._savepoints self._conn._savepoints = [] - commands.append(b"begin") + commands.append("begin") else: if self._conn._savepoints is None: self._conn._savepoints = [] @@ -89,41 +89,41 @@ class BaseTransaction(Generic[ConnectionType]): if self._savepoint_name: commands.append( - sql.SQL("savepoint {}").format( - sql.Identifier(self._savepoint_name) - ) + sql.SQL("savepoint {}") + .format(sql.Identifier(self._savepoint_name)) + .as_string(self._conn) ) self._conn._savepoints.append(self._savepoint_name) return commands - def _commit_commands(self) -> List[Query]: - commands: List[Query] = [] + def _commit_commands(self) -> List[str]: + commands = [] self._pop_savepoint() - if self._savepoint_name: + if self._savepoint_name and not self._outer_transaction: commands.append( - sql.SQL("release savepoint {}").format( - sql.Identifier(self._savepoint_name) - ) + sql.SQL("release savepoint {}") + .format(sql.Identifier(self._savepoint_name)) + .as_string(self._conn) ) if self._outer_transaction: - commands.append(b"commit") + commands.append("commit") return commands - def _rollback_commands(self) -> List[Query]: - commands: List[Query] = [] + def _rollback_commands(self) -> List[str]: + commands = [] self._pop_savepoint() - if self._savepoint_name: + if self._savepoint_name and not self._outer_transaction: commands.append( - sql.SQL( - "rollback to savepoint {n}; release savepoint {n}" - ).format(n=sql.Identifier(self._savepoint_name)) + sql.SQL("rollback to savepoint {n}; release savepoint {n}") + .format(n=sql.Identifier(self._savepoint_name)) + .as_string(self._conn) ) if self._outer_transaction: - commands.append(b"rollback") + commands.append("rollback") return commands @@ -178,9 +178,8 @@ class Transaction(BaseTransaction["Connection"]): return False - def _execute(self, commands: List[Query]) -> None: - for command in commands: - self._conn._exec_command(command) + def _execute(self, commands: List[str]) -> None: + self._conn._exec_command("; ".join(commands)) class AsyncTransaction(BaseTransaction["AsyncConnection"]): @@ -222,6 +221,5 @@ class AsyncTransaction(BaseTransaction["AsyncConnection"]): return False - async def _execute(self, commands: List[Query]) -> None: - for command in commands: - await self._conn._exec_command(command) + async def _execute(self, commands: List[str]) -> None: + await self._conn._exec_command("; ".join(commands)) diff --git a/tests/test_transaction.py b/tests/test_transaction.py index 28385ded2..dfa66ab25 100644 --- a/tests/test_transaction.py +++ b/tests/test_transaction.py @@ -368,11 +368,9 @@ def test_named_savepoints_successful_exit(conn, commands): # Case 2 tx = Transaction(conn, savepoint_name="foo") tx.__enter__() - assert commands.pop() == "begin" - assert commands.pop() == 'savepoint "foo"' + assert commands.pop() == 'begin; savepoint "foo"' assert tx.savepoint_name == "foo" tx.__exit__(None, None, None) - assert commands.pop() == 'release savepoint "foo"' assert commands.pop() == "commit" # Case 3 (with savepoint name provided) @@ -384,6 +382,7 @@ def test_named_savepoints_successful_exit(conn, commands): assert tx.savepoint_name == "bar" tx.__exit__(None, None, None) assert commands.pop() == 'release savepoint "bar"' + assert not commands assert commands.pop() == "commit" # Case 3 (with savepoint name auto-generated) @@ -395,6 +394,7 @@ def test_named_savepoints_successful_exit(conn, commands): assert tx.savepoint_name == "s1" tx.__exit__(None, None, None) assert commands.pop() == 'release savepoint "s1"' + assert not commands assert commands.pop() == "commit" assert not commands @@ -417,14 +417,9 @@ def test_named_savepoints_exception_exit(conn, commands): # Case 2 tx = Transaction(conn, savepoint_name="foo") tx.__enter__() - assert commands.pop() == "begin" - assert commands.pop() == 'savepoint "foo"' + assert commands.pop() == 'begin; savepoint "foo"' assert tx.savepoint_name == "foo" tx.__exit__(*some_exc_info()) - assert ( - commands.pop() - == 'rollback to savepoint "foo"; release savepoint "foo"' - ) assert commands.pop() == "rollback" # Case 3 (with savepoint name provided) @@ -439,6 +434,7 @@ def test_named_savepoints_exception_exit(conn, commands): commands.pop() == 'rollback to savepoint "bar"; release savepoint "bar"' ) + assert not commands assert commands.pop() == "commit" # Case 3 (with savepoint name auto-generated) @@ -453,6 +449,7 @@ def test_named_savepoints_exception_exit(conn, commands): commands.pop() == 'rollback to savepoint "s1"; release savepoint "s1"' ) + assert not commands assert commands.pop() == "commit" assert not commands diff --git a/tests/test_transaction_async.py b/tests/test_transaction_async.py index 0958a812b..c434fa6a0 100644 --- a/tests/test_transaction_async.py +++ b/tests/test_transaction_async.py @@ -337,11 +337,9 @@ async def test_named_savepoints_successful_exit(aconn, commands): # Case 2 tx = AsyncTransaction(aconn, savepoint_name="foo") await tx.__aenter__() - assert commands.pop() == "begin" - assert commands.pop() == 'savepoint "foo"' + assert commands.pop() == 'begin; savepoint "foo"' assert tx.savepoint_name == "foo" await tx.__aexit__(None, None, None) - assert commands.pop() == 'release savepoint "foo"' assert commands.pop() == "commit" # Case 3 (with savepoint name provided) @@ -353,6 +351,7 @@ async def test_named_savepoints_successful_exit(aconn, commands): assert tx.savepoint_name == "bar" await tx.__aexit__(None, None, None) assert commands.pop() == 'release savepoint "bar"' + assert not commands assert commands.pop() == "commit" # Case 3 (with savepoint name auto-generated) @@ -364,6 +363,7 @@ async def test_named_savepoints_successful_exit(aconn, commands): assert tx.savepoint_name == "s1" await tx.__aexit__(None, None, None) assert commands.pop() == 'release savepoint "s1"' + assert not commands assert commands.pop() == "commit" assert not commands @@ -386,14 +386,9 @@ async def test_named_savepoints_exception_exit(aconn, commands): # Case 2 tx = AsyncTransaction(aconn, savepoint_name="foo") await tx.__aenter__() - assert commands.pop() == "begin" - assert commands.pop() == 'savepoint "foo"' + assert commands.pop() == 'begin; savepoint "foo"' assert tx.savepoint_name == "foo" await tx.__aexit__(*some_exc_info()) - assert ( - commands.pop() - == 'rollback to savepoint "foo"; release savepoint "foo"' - ) assert commands.pop() == "rollback" # Case 3 (with savepoint name provided) @@ -408,6 +403,7 @@ async def test_named_savepoints_exception_exit(aconn, commands): commands.pop() == 'rollback to savepoint "bar"; release savepoint "bar"' ) + assert not commands assert commands.pop() == "commit" # Case 3 (with savepoint name auto-generated) @@ -422,6 +418,7 @@ async def test_named_savepoints_exception_exit(aconn, commands): commands.pop() == 'rollback to savepoint "s1"; release savepoint "s1"' ) + assert not commands assert commands.pop() == "commit" assert not commands