]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Escape savepoint names
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 14 Nov 2020 23:03:39 +0000 (23:03 +0000)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 14 Nov 2020 23:03:39 +0000 (23:03 +0000)
Added more internal support to the connection to generate internal
commands dynamically, thanks to the `sql` module now implemented.

psycopg3/psycopg3/connection.py
psycopg3/psycopg3/transaction.py
tests/test_transaction.py

index f186d46cb0eaa0d8f9d4d2d1ae4b5971b65b9b1e..ddebe879f4f43f564ea8ce0001ae1df9cfef98af 100644 (file)
@@ -18,7 +18,8 @@ from . import cursor
 from . import errors as e
 from . import encodings
 from .pq import TransactionStatus, ExecStatus
-from .proto import DumpersMap, LoadersMap, PQGen, RV
+from .sql import Composable
+from .proto import DumpersMap, LoadersMap, PQGen, RV, Query
 from .waiting import wait, wait_async
 from .conninfo import make_conninfo
 from .generators import notifies
@@ -99,7 +100,7 @@ class BaseConnection:
         self._notify_handlers: List[NotifyHandler] = []
 
         # stack of savepoint names managed by active Transaction() blocks
-        self._savepoints: Optional[List[bytes]] = None
+        self._savepoints: Optional[List[str]] = None
         # (None when there no active Transaction blocks; [] when there is only
         # one Transaction block, with a top-level transaction and no savepoint)
 
@@ -310,8 +311,14 @@ class Connection(BaseConnection):
                 return
             self._exec_command(b"rollback")
 
-    def _exec_command(self, command: bytes) -> None:
+    def _exec_command(self, command: Query) -> None:
         # Caller must hold self.lock
+
+        if isinstance(command, str):
+            command = command.encode(self.client_encoding)
+        elif isinstance(command, Composable):
+            command = command.as_string(self).encode(self.client_encoding)
+
         logger.debug(f"{self}: {command!r}")
         self.pgconn.send_query(command)
         results = self.wait(execute(self.pgconn))
index 4af553340ff5e198a8f4ed9441c1ce39a18b9a34..783ebce57ee9b1edd0d90af4e70055f90907344c 100644 (file)
@@ -6,11 +6,12 @@ Transaction context managers returned by Connection.transaction()
 
 import logging
 
-from psycopg3.errors import ProgrammingError
 from types import TracebackType
 from typing import Optional, Type, TYPE_CHECKING
 
+from . import sql
 from .pq import TransactionStatus
+from psycopg3.errors import ProgrammingError
 
 if TYPE_CHECKING:
     from .connection import Connection
@@ -42,13 +43,11 @@ class Transaction:
         force_rollback: bool,
     ):
         self._conn = connection
-        self._savepoint_name: Optional[bytes] = None
+        self._savepoint_name: Optional[str] = None
         if savepoint_name is not None:
-            if len(savepoint_name) == 0:
+            if not savepoint_name:
                 raise ValueError("savepoint_name must be a non-empty string")
-            self._savepoint_name = savepoint_name.encode(
-                connection.client_encoding
-            )
+            self._savepoint_name = savepoint_name
         self.force_rollback = force_rollback
 
         self._outer_transaction: Optional[bool] = None
@@ -59,9 +58,7 @@ class Transaction:
 
     @property
     def savepoint_name(self) -> Optional[str]:
-        if self._savepoint_name is None:
-            return None
-        return self._savepoint_name.decode(self._conn.client_encoding)
+        return self._savepoint_name
 
     def __enter__(self) -> "Transaction":
         with self._conn.lock:
@@ -75,12 +72,16 @@ class Transaction:
                     self._conn._savepoints = []
                 self._outer_transaction = False
                 if self._savepoint_name is None:
-                    self._savepoint_name = b"s%i" % (
-                        len(self._conn._savepoints) + 1
+                    self._savepoint_name = (
+                        f"s{len(self._conn._savepoints) + 1}"
                     )
 
             if self._savepoint_name is not None:
-                self._conn._exec_command(b"savepoint " + self._savepoint_name)
+                self._conn._exec_command(
+                    sql.SQL("savepoint {}").format(
+                        sql.Identifier(self._savepoint_name)
+                    )
+                )
                 self._conn._savepoints.append(self._savepoint_name)
         return self
 
@@ -106,7 +107,9 @@ class Transaction:
                     if actual != self._savepoint_name:
                         raise out_of_order_err
                     self._conn._exec_command(
-                        b"release savepoint " + self._savepoint_name
+                        sql.SQL("release savepoint {}").format(
+                            sql.Identifier(self._savepoint_name)
+                        )
                     )
                 if self._outer_transaction:
                     if self._conn._savepoints is None:
@@ -130,8 +133,9 @@ class Transaction:
                     if actual != self._savepoint_name:
                         raise out_of_order_err
                     self._conn._exec_command(
-                        b"rollback to savepoint " + self._savepoint_name + b";"
-                        b"release savepoint " + self._savepoint_name
+                        sql.SQL(
+                            "rollback to savepoint {n}; release savepoint {n}"
+                        ).format(n=sql.Identifier(self._savepoint_name))
                     )
                 if self._outer_transaction:
                     if self._conn._savepoints is None:
index 2afa831e6184745b4c2f212ff20f56aa4c823342..95554321057cac47d1f374ff6e1315a9cc9d7582 100644 (file)
@@ -3,7 +3,8 @@ from contextlib import contextmanager
 
 import pytest
 
-from psycopg3 import OperationalError, ProgrammingError, Rollback
+from psycopg3 import ProgrammingError, Rollback
+from psycopg3.sql import Composable
 
 
 @pytest.fixture(autouse=True)
@@ -41,6 +42,11 @@ def assert_commands_issued(conn, *commands):
     real_exec_command = conn._exec_command
 
     def _exec_command(command):
+        if isinstance(command, bytes):
+            command = command.decode(conn.client_encoding)
+        elif isinstance(command, Composable):
+            command = command.as_string(conn)
+
         commands_actual.append(command)
         real_exec_command(command)
 
@@ -49,8 +55,8 @@ def assert_commands_issued(conn, *commands):
         yield
     finally:
         conn._exec_command = real_exec_command
-    commands_expected = [cmd.encode("ascii") for cmd in commands]
-    assert commands_actual == commands_expected
+
+    assert commands_actual == list(commands)
 
 
 class ExpectedException(Exception):
@@ -329,7 +335,6 @@ def test_named_savepoint_empty_string_invalid(conn):
         conn.transaction(savepoint_name="")
 
 
-@pytest.mark.xfail(raises=OperationalError, reason="TODO: Escape sp names")
 def test_named_savepoint_escapes_savepoint_name(conn):
     with conn.transaction("s-1"):
         pass
@@ -357,28 +362,28 @@ def test_named_savepoints_successful_exit(conn):
 
     # Case 2
     tx = conn.transaction(savepoint_name="foo")
-    with assert_commands_issued(conn, "begin", "savepoint foo"):
+    with assert_commands_issued(conn, "begin", 'savepoint "foo"'):
         tx.__enter__()
     assert tx.savepoint_name == "foo"
-    with assert_commands_issued(conn, "release savepoint foo", "commit"):
+    with assert_commands_issued(conn, 'release savepoint "foo"', "commit"):
         tx.__exit__(None, None, None)
 
     # Case 3 (with savepoint name provided)
     with conn.transaction():
         tx = conn.transaction(savepoint_name="bar")
-        with assert_commands_issued(conn, "savepoint bar"):
+        with assert_commands_issued(conn, 'savepoint "bar"'):
             tx.__enter__()
         assert tx.savepoint_name == "bar"
-        with assert_commands_issued(conn, "release savepoint bar"):
+        with assert_commands_issued(conn, 'release savepoint "bar"'):
             tx.__exit__(None, None, None)
 
     # Case 3 (with savepoint name auto-generated)
     with conn.transaction():
         tx = conn.transaction()
-        with assert_commands_issued(conn, "savepoint s1"):
+        with assert_commands_issued(conn, 'savepoint "s1"'):
             tx.__enter__()
         assert tx.savepoint_name == "s1"
-        with assert_commands_issued(conn, "release savepoint s1"):
+        with assert_commands_issued(conn, 'release savepoint "s1"'):
             tx.__exit__(None, None, None)
 
 
@@ -398,33 +403,35 @@ def test_named_savepoints_exception_exit(conn):
 
     # Case 2
     tx = conn.transaction(savepoint_name="foo")
-    with assert_commands_issued(conn, "begin", "savepoint foo"):
+    with assert_commands_issued(conn, "begin", 'savepoint "foo"'):
         tx.__enter__()
     assert tx.savepoint_name == "foo"
     with assert_commands_issued(
-        conn, "rollback to savepoint foo;release savepoint foo", "rollback"
+        conn,
+        'rollback to savepoint "foo"; release savepoint "foo"',
+        "rollback",
     ):
         tx.__exit__(*some_exc_info())
 
     # Case 3 (with savepoint name provided)
     with conn.transaction():
         tx = conn.transaction(savepoint_name="bar")
-        with assert_commands_issued(conn, "savepoint bar"):
+        with assert_commands_issued(conn, 'savepoint "bar"'):
             tx.__enter__()
         assert tx.savepoint_name == "bar"
         with assert_commands_issued(
-            conn, "rollback to savepoint bar;release savepoint bar"
+            conn, 'rollback to savepoint "bar"; release savepoint "bar"'
         ):
             tx.__exit__(*some_exc_info())
 
     # Case 3 (with savepoint name auto-generated)
     with conn.transaction():
         tx = conn.transaction()
-        with assert_commands_issued(conn, "savepoint s1"):
+        with assert_commands_issued(conn, 'savepoint "s1"'):
             tx.__enter__()
         assert tx.savepoint_name == "s1"
         with assert_commands_issued(
-            conn, "rollback to savepoint s1;release savepoint s1"
+            conn, 'rollback to savepoint "s1"; release savepoint "s1"'
         ):
             tx.__exit__(*some_exc_info())