From: Daniele Varrazzo Date: Sat, 14 Nov 2020 23:03:39 +0000 (+0000) Subject: Escape savepoint names X-Git-Tag: 3.0.dev0~351^2~16 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=1384ff1683d66b9c207c68a224e2498618b862f3;p=thirdparty%2Fpsycopg.git Escape savepoint names Added more internal support to the connection to generate internal commands dynamically, thanks to the `sql` module now implemented. --- diff --git a/psycopg3/psycopg3/connection.py b/psycopg3/psycopg3/connection.py index f186d46cb..ddebe879f 100644 --- a/psycopg3/psycopg3/connection.py +++ b/psycopg3/psycopg3/connection.py @@ -18,7 +18,8 @@ from . import cursor from . import errors as e from . import encodings from .pq import TransactionStatus, ExecStatus -from .proto import DumpersMap, LoadersMap, PQGen, RV +from .sql import Composable +from .proto import DumpersMap, LoadersMap, PQGen, RV, Query from .waiting import wait, wait_async from .conninfo import make_conninfo from .generators import notifies @@ -99,7 +100,7 @@ class BaseConnection: self._notify_handlers: List[NotifyHandler] = [] # stack of savepoint names managed by active Transaction() blocks - self._savepoints: Optional[List[bytes]] = None + self._savepoints: Optional[List[str]] = None # (None when there no active Transaction blocks; [] when there is only # one Transaction block, with a top-level transaction and no savepoint) @@ -310,8 +311,14 @@ class Connection(BaseConnection): return self._exec_command(b"rollback") - def _exec_command(self, command: bytes) -> None: + def _exec_command(self, command: Query) -> None: # Caller must hold self.lock + + if isinstance(command, str): + command = command.encode(self.client_encoding) + elif isinstance(command, Composable): + command = command.as_string(self).encode(self.client_encoding) + logger.debug(f"{self}: {command!r}") self.pgconn.send_query(command) results = self.wait(execute(self.pgconn)) diff --git a/psycopg3/psycopg3/transaction.py b/psycopg3/psycopg3/transaction.py index 4af553340..783ebce57 100644 --- a/psycopg3/psycopg3/transaction.py +++ b/psycopg3/psycopg3/transaction.py @@ -6,11 +6,12 @@ Transaction context managers returned by Connection.transaction() import logging -from psycopg3.errors import ProgrammingError from types import TracebackType from typing import Optional, Type, TYPE_CHECKING +from . import sql from .pq import TransactionStatus +from psycopg3.errors import ProgrammingError if TYPE_CHECKING: from .connection import Connection @@ -42,13 +43,11 @@ class Transaction: force_rollback: bool, ): self._conn = connection - self._savepoint_name: Optional[bytes] = None + self._savepoint_name: Optional[str] = None if savepoint_name is not None: - if len(savepoint_name) == 0: + if not savepoint_name: raise ValueError("savepoint_name must be a non-empty string") - self._savepoint_name = savepoint_name.encode( - connection.client_encoding - ) + self._savepoint_name = savepoint_name self.force_rollback = force_rollback self._outer_transaction: Optional[bool] = None @@ -59,9 +58,7 @@ class Transaction: @property def savepoint_name(self) -> Optional[str]: - if self._savepoint_name is None: - return None - return self._savepoint_name.decode(self._conn.client_encoding) + return self._savepoint_name def __enter__(self) -> "Transaction": with self._conn.lock: @@ -75,12 +72,16 @@ class Transaction: self._conn._savepoints = [] self._outer_transaction = False if self._savepoint_name is None: - self._savepoint_name = b"s%i" % ( - len(self._conn._savepoints) + 1 + self._savepoint_name = ( + f"s{len(self._conn._savepoints) + 1}" ) if self._savepoint_name is not None: - self._conn._exec_command(b"savepoint " + self._savepoint_name) + self._conn._exec_command( + sql.SQL("savepoint {}").format( + sql.Identifier(self._savepoint_name) + ) + ) self._conn._savepoints.append(self._savepoint_name) return self @@ -106,7 +107,9 @@ class Transaction: if actual != self._savepoint_name: raise out_of_order_err self._conn._exec_command( - b"release savepoint " + self._savepoint_name + sql.SQL("release savepoint {}").format( + sql.Identifier(self._savepoint_name) + ) ) if self._outer_transaction: if self._conn._savepoints is None: @@ -130,8 +133,9 @@ class Transaction: if actual != self._savepoint_name: raise out_of_order_err self._conn._exec_command( - b"rollback to savepoint " + self._savepoint_name + b";" - b"release savepoint " + self._savepoint_name + sql.SQL( + "rollback to savepoint {n}; release savepoint {n}" + ).format(n=sql.Identifier(self._savepoint_name)) ) if self._outer_transaction: if self._conn._savepoints is None: diff --git a/tests/test_transaction.py b/tests/test_transaction.py index 2afa831e6..955543210 100644 --- a/tests/test_transaction.py +++ b/tests/test_transaction.py @@ -3,7 +3,8 @@ from contextlib import contextmanager import pytest -from psycopg3 import OperationalError, ProgrammingError, Rollback +from psycopg3 import ProgrammingError, Rollback +from psycopg3.sql import Composable @pytest.fixture(autouse=True) @@ -41,6 +42,11 @@ def assert_commands_issued(conn, *commands): real_exec_command = conn._exec_command def _exec_command(command): + if isinstance(command, bytes): + command = command.decode(conn.client_encoding) + elif isinstance(command, Composable): + command = command.as_string(conn) + commands_actual.append(command) real_exec_command(command) @@ -49,8 +55,8 @@ def assert_commands_issued(conn, *commands): yield finally: conn._exec_command = real_exec_command - commands_expected = [cmd.encode("ascii") for cmd in commands] - assert commands_actual == commands_expected + + assert commands_actual == list(commands) class ExpectedException(Exception): @@ -329,7 +335,6 @@ def test_named_savepoint_empty_string_invalid(conn): conn.transaction(savepoint_name="") -@pytest.mark.xfail(raises=OperationalError, reason="TODO: Escape sp names") def test_named_savepoint_escapes_savepoint_name(conn): with conn.transaction("s-1"): pass @@ -357,28 +362,28 @@ def test_named_savepoints_successful_exit(conn): # Case 2 tx = conn.transaction(savepoint_name="foo") - with assert_commands_issued(conn, "begin", "savepoint foo"): + with assert_commands_issued(conn, "begin", 'savepoint "foo"'): tx.__enter__() assert tx.savepoint_name == "foo" - with assert_commands_issued(conn, "release savepoint foo", "commit"): + with assert_commands_issued(conn, 'release savepoint "foo"', "commit"): tx.__exit__(None, None, None) # Case 3 (with savepoint name provided) with conn.transaction(): tx = conn.transaction(savepoint_name="bar") - with assert_commands_issued(conn, "savepoint bar"): + with assert_commands_issued(conn, 'savepoint "bar"'): tx.__enter__() assert tx.savepoint_name == "bar" - with assert_commands_issued(conn, "release savepoint bar"): + with assert_commands_issued(conn, 'release savepoint "bar"'): tx.__exit__(None, None, None) # Case 3 (with savepoint name auto-generated) with conn.transaction(): tx = conn.transaction() - with assert_commands_issued(conn, "savepoint s1"): + with assert_commands_issued(conn, 'savepoint "s1"'): tx.__enter__() assert tx.savepoint_name == "s1" - with assert_commands_issued(conn, "release savepoint s1"): + with assert_commands_issued(conn, 'release savepoint "s1"'): tx.__exit__(None, None, None) @@ -398,33 +403,35 @@ def test_named_savepoints_exception_exit(conn): # Case 2 tx = conn.transaction(savepoint_name="foo") - with assert_commands_issued(conn, "begin", "savepoint foo"): + with assert_commands_issued(conn, "begin", 'savepoint "foo"'): tx.__enter__() assert tx.savepoint_name == "foo" with assert_commands_issued( - conn, "rollback to savepoint foo;release savepoint foo", "rollback" + conn, + 'rollback to savepoint "foo"; release savepoint "foo"', + "rollback", ): tx.__exit__(*some_exc_info()) # Case 3 (with savepoint name provided) with conn.transaction(): tx = conn.transaction(savepoint_name="bar") - with assert_commands_issued(conn, "savepoint bar"): + with assert_commands_issued(conn, 'savepoint "bar"'): tx.__enter__() assert tx.savepoint_name == "bar" with assert_commands_issued( - conn, "rollback to savepoint bar;release savepoint bar" + conn, 'rollback to savepoint "bar"; release savepoint "bar"' ): tx.__exit__(*some_exc_info()) # Case 3 (with savepoint name auto-generated) with conn.transaction(): tx = conn.transaction() - with assert_commands_issued(conn, "savepoint s1"): + with assert_commands_issued(conn, 'savepoint "s1"'): tx.__enter__() assert tx.savepoint_name == "s1" with assert_commands_issued( - conn, "rollback to savepoint s1;release savepoint s1" + conn, 'rollback to savepoint "s1"; release savepoint "s1"' ): tx.__exit__(*some_exc_info())