from . import errors as e
from . import encodings
from .pq import TransactionStatus, ExecStatus
-from .proto import DumpersMap, LoadersMap, PQGen, RV
+from .sql import Composable
+from .proto import DumpersMap, LoadersMap, PQGen, RV, Query
from .waiting import wait, wait_async
from .conninfo import make_conninfo
from .generators import notifies
self._notify_handlers: List[NotifyHandler] = []
# stack of savepoint names managed by active Transaction() blocks
- self._savepoints: Optional[List[bytes]] = None
+ self._savepoints: Optional[List[str]] = None
# (None when there no active Transaction blocks; [] when there is only
# one Transaction block, with a top-level transaction and no savepoint)
return
self._exec_command(b"rollback")
- def _exec_command(self, command: bytes) -> None:
+ def _exec_command(self, command: Query) -> None:
# Caller must hold self.lock
+
+ if isinstance(command, str):
+ command = command.encode(self.client_encoding)
+ elif isinstance(command, Composable):
+ command = command.as_string(self).encode(self.client_encoding)
+
logger.debug(f"{self}: {command!r}")
self.pgconn.send_query(command)
results = self.wait(execute(self.pgconn))
import logging
-from psycopg3.errors import ProgrammingError
from types import TracebackType
from typing import Optional, Type, TYPE_CHECKING
+from . import sql
from .pq import TransactionStatus
+from psycopg3.errors import ProgrammingError
if TYPE_CHECKING:
from .connection import Connection
force_rollback: bool,
):
self._conn = connection
- self._savepoint_name: Optional[bytes] = None
+ self._savepoint_name: Optional[str] = None
if savepoint_name is not None:
- if len(savepoint_name) == 0:
+ if not savepoint_name:
raise ValueError("savepoint_name must be a non-empty string")
- self._savepoint_name = savepoint_name.encode(
- connection.client_encoding
- )
+ self._savepoint_name = savepoint_name
self.force_rollback = force_rollback
self._outer_transaction: Optional[bool] = None
@property
def savepoint_name(self) -> Optional[str]:
- if self._savepoint_name is None:
- return None
- return self._savepoint_name.decode(self._conn.client_encoding)
+ return self._savepoint_name
def __enter__(self) -> "Transaction":
with self._conn.lock:
self._conn._savepoints = []
self._outer_transaction = False
if self._savepoint_name is None:
- self._savepoint_name = b"s%i" % (
- len(self._conn._savepoints) + 1
+ self._savepoint_name = (
+ f"s{len(self._conn._savepoints) + 1}"
)
if self._savepoint_name is not None:
- self._conn._exec_command(b"savepoint " + self._savepoint_name)
+ self._conn._exec_command(
+ sql.SQL("savepoint {}").format(
+ sql.Identifier(self._savepoint_name)
+ )
+ )
self._conn._savepoints.append(self._savepoint_name)
return self
if actual != self._savepoint_name:
raise out_of_order_err
self._conn._exec_command(
- b"release savepoint " + self._savepoint_name
+ sql.SQL("release savepoint {}").format(
+ sql.Identifier(self._savepoint_name)
+ )
)
if self._outer_transaction:
if self._conn._savepoints is None:
if actual != self._savepoint_name:
raise out_of_order_err
self._conn._exec_command(
- b"rollback to savepoint " + self._savepoint_name + b";"
- b"release savepoint " + self._savepoint_name
+ sql.SQL(
+ "rollback to savepoint {n}; release savepoint {n}"
+ ).format(n=sql.Identifier(self._savepoint_name))
)
if self._outer_transaction:
if self._conn._savepoints is None:
import pytest
-from psycopg3 import OperationalError, ProgrammingError, Rollback
+from psycopg3 import ProgrammingError, Rollback
+from psycopg3.sql import Composable
@pytest.fixture(autouse=True)
real_exec_command = conn._exec_command
def _exec_command(command):
+ if isinstance(command, bytes):
+ command = command.decode(conn.client_encoding)
+ elif isinstance(command, Composable):
+ command = command.as_string(conn)
+
commands_actual.append(command)
real_exec_command(command)
yield
finally:
conn._exec_command = real_exec_command
- commands_expected = [cmd.encode("ascii") for cmd in commands]
- assert commands_actual == commands_expected
+
+ assert commands_actual == list(commands)
class ExpectedException(Exception):
conn.transaction(savepoint_name="")
-@pytest.mark.xfail(raises=OperationalError, reason="TODO: Escape sp names")
def test_named_savepoint_escapes_savepoint_name(conn):
with conn.transaction("s-1"):
pass
# Case 2
tx = conn.transaction(savepoint_name="foo")
- with assert_commands_issued(conn, "begin", "savepoint foo"):
+ with assert_commands_issued(conn, "begin", 'savepoint "foo"'):
tx.__enter__()
assert tx.savepoint_name == "foo"
- with assert_commands_issued(conn, "release savepoint foo", "commit"):
+ with assert_commands_issued(conn, 'release savepoint "foo"', "commit"):
tx.__exit__(None, None, None)
# Case 3 (with savepoint name provided)
with conn.transaction():
tx = conn.transaction(savepoint_name="bar")
- with assert_commands_issued(conn, "savepoint bar"):
+ with assert_commands_issued(conn, 'savepoint "bar"'):
tx.__enter__()
assert tx.savepoint_name == "bar"
- with assert_commands_issued(conn, "release savepoint bar"):
+ with assert_commands_issued(conn, 'release savepoint "bar"'):
tx.__exit__(None, None, None)
# Case 3 (with savepoint name auto-generated)
with conn.transaction():
tx = conn.transaction()
- with assert_commands_issued(conn, "savepoint s1"):
+ with assert_commands_issued(conn, 'savepoint "s1"'):
tx.__enter__()
assert tx.savepoint_name == "s1"
- with assert_commands_issued(conn, "release savepoint s1"):
+ with assert_commands_issued(conn, 'release savepoint "s1"'):
tx.__exit__(None, None, None)
# Case 2
tx = conn.transaction(savepoint_name="foo")
- with assert_commands_issued(conn, "begin", "savepoint foo"):
+ with assert_commands_issued(conn, "begin", 'savepoint "foo"'):
tx.__enter__()
assert tx.savepoint_name == "foo"
with assert_commands_issued(
- conn, "rollback to savepoint foo;release savepoint foo", "rollback"
+ conn,
+ 'rollback to savepoint "foo"; release savepoint "foo"',
+ "rollback",
):
tx.__exit__(*some_exc_info())
# Case 3 (with savepoint name provided)
with conn.transaction():
tx = conn.transaction(savepoint_name="bar")
- with assert_commands_issued(conn, "savepoint bar"):
+ with assert_commands_issued(conn, 'savepoint "bar"'):
tx.__enter__()
assert tx.savepoint_name == "bar"
with assert_commands_issued(
- conn, "rollback to savepoint bar;release savepoint bar"
+ conn, 'rollback to savepoint "bar"; release savepoint "bar"'
):
tx.__exit__(*some_exc_info())
# Case 3 (with savepoint name auto-generated)
with conn.transaction():
tx = conn.transaction()
- with assert_commands_issued(conn, "savepoint s1"):
+ with assert_commands_issued(conn, 'savepoint "s1"'):
tx.__enter__()
assert tx.savepoint_name == "s1"
with assert_commands_issued(
- conn, "rollback to savepoint s1;release savepoint s1"
+ conn, 'rollback to savepoint "s1"; release savepoint "s1"'
):
tx.__exit__(*some_exc_info())