]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Anonymous savepoints represented by empty string
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 15 Nov 2020 20:40:34 +0000 (20:40 +0000)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 16 Nov 2020 04:02:06 +0000 (04:02 +0000)
It makes signatures and types easier. Empty string is not valid anyway,
and bonkers types are not checked.

psycopg3/psycopg3/connection.py
psycopg3/psycopg3/transaction.py
tests/test_transaction.py

index 7badb6e78848e8e63aed60254e8543f6268dd0c9..45d786d639c7fdf8db53b86e697674417cf3bd67 100644 (file)
@@ -335,8 +335,7 @@ class Connection(BaseConnection):
         savepoint_name: Optional[str] = None,
         force_rollback: bool = False,
     ) -> Iterator[Transaction]:
-        tx = Transaction(self, savepoint_name, force_rollback)
-        with tx:
+        with Transaction(self, savepoint_name or "", force_rollback) as tx:
             yield tx
 
     @classmethod
index 0c37211077fc3b4c1df52e4ffc9814e430260002..0f0bbfaa7472db1968a5aad713781e5c1f87b7fa 100644 (file)
@@ -39,15 +39,11 @@ class Transaction:
     def __init__(
         self,
         connection: "Connection",
-        savepoint_name: Optional[str] = None,
+        savepoint_name: str = "",
         force_rollback: bool = False,
     ):
         self._conn = connection
-        self._savepoint_name: Optional[str] = None
-        if savepoint_name is not None:
-            if not savepoint_name:
-                raise ValueError("savepoint_name must be a non-empty string")
-            self._savepoint_name = savepoint_name
+        self._savepoint_name = savepoint_name
         self.force_rollback = force_rollback
 
         self._outer_transaction: Optional[bool] = None
@@ -57,7 +53,7 @@ class Transaction:
         return self._conn
 
     @property
-    def savepoint_name(self) -> Optional[str]:
+    def savepoint_name(self) -> str:
         return self._savepoint_name
 
     def __enter__(self) -> "Transaction":
@@ -71,12 +67,12 @@ class Transaction:
                 if self._conn._savepoints is None:
                     self._conn._savepoints = []
                 self._outer_transaction = False
-                if self._savepoint_name is None:
+                if not self._savepoint_name:
                     self._savepoint_name = (
                         f"s{len(self._conn._savepoints) + 1}"
                     )
 
-            if self._savepoint_name is not None:
+            if self._savepoint_name:
                 self._conn._exec_command(
                     sql.SQL("savepoint {}").format(
                         sql.Identifier(self._savepoint_name)
@@ -152,7 +148,7 @@ class Transaction:
 
     def __repr__(self) -> str:
         args = [f"connection={self.connection}"]
-        if self.savepoint_name is not None:
+        if not self.savepoint_name:
             args.append(f"savepoint_name={self.savepoint_name!r}")
         if self.force_rollback:
             args.append("force_rollback=True")
index 84a092db3547f09983216e34c754cf89db5b0f5a..026a7fa6e7ebdc0f8419bb08ac3795bb0382d817 100644 (file)
@@ -327,16 +327,6 @@ def test_nested_three_levels_successful_exit(conn, svcconn):
     assert_rows(svcconn, {"one", "two", "three"})
 
 
-def test_named_savepoint_empty_string_invalid(conn):
-    """
-    Raise validate savepoint_name up-front (rather than later constructing an
-    invalid SQL command and having that fail with an OperationalError).
-    """
-    with pytest.raises(ValueError):
-        with conn.transaction(savepoint_name=""):
-            pass
-
-
 def test_named_savepoint_escapes_savepoint_name(conn):
     with conn.transaction("s-1"):
         pass
@@ -359,7 +349,7 @@ def test_named_savepoints_successful_exit(conn):
     tx = Transaction(conn)
     with assert_commands_issued(conn, "begin"):
         tx.__enter__()
-    assert tx.savepoint_name is None
+    assert not tx.savepoint_name
     with assert_commands_issued(conn, "commit"):
         tx.__exit__(None, None, None)
 
@@ -400,7 +390,7 @@ def test_named_savepoints_exception_exit(conn):
     tx = Transaction(conn)
     with assert_commands_issued(conn, "begin"):
         tx.__enter__()
-    assert tx.savepoint_name is None
+    assert not tx.savepoint_name
     with assert_commands_issued(conn, "rollback"):
         tx.__exit__(*some_exc_info())