From: Daniele Varrazzo Date: Sun, 15 Nov 2020 20:40:34 +0000 (+0000) Subject: Anonymous savepoints represented by empty string X-Git-Tag: 3.0.dev0~351^2~9 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=da3f13c15545bb75bdd8c6a6051fc333a510ad22;p=thirdparty%2Fpsycopg.git Anonymous savepoints represented by empty string It makes signatures and types easier. Empty string is not valid anyway, and bonkers types are not checked. --- diff --git a/psycopg3/psycopg3/connection.py b/psycopg3/psycopg3/connection.py index 7badb6e78..45d786d63 100644 --- a/psycopg3/psycopg3/connection.py +++ b/psycopg3/psycopg3/connection.py @@ -335,8 +335,7 @@ class Connection(BaseConnection): savepoint_name: Optional[str] = None, force_rollback: bool = False, ) -> Iterator[Transaction]: - tx = Transaction(self, savepoint_name, force_rollback) - with tx: + with Transaction(self, savepoint_name or "", force_rollback) as tx: yield tx @classmethod diff --git a/psycopg3/psycopg3/transaction.py b/psycopg3/psycopg3/transaction.py index 0c3721107..0f0bbfaa7 100644 --- a/psycopg3/psycopg3/transaction.py +++ b/psycopg3/psycopg3/transaction.py @@ -39,15 +39,11 @@ class Transaction: def __init__( self, connection: "Connection", - savepoint_name: Optional[str] = None, + savepoint_name: str = "", force_rollback: bool = False, ): self._conn = connection - self._savepoint_name: Optional[str] = None - if savepoint_name is not None: - if not savepoint_name: - raise ValueError("savepoint_name must be a non-empty string") - self._savepoint_name = savepoint_name + self._savepoint_name = savepoint_name self.force_rollback = force_rollback self._outer_transaction: Optional[bool] = None @@ -57,7 +53,7 @@ class Transaction: return self._conn @property - def savepoint_name(self) -> Optional[str]: + def savepoint_name(self) -> str: return self._savepoint_name def __enter__(self) -> "Transaction": @@ -71,12 +67,12 @@ class Transaction: if self._conn._savepoints is None: self._conn._savepoints = [] self._outer_transaction = False - if self._savepoint_name is None: + if not self._savepoint_name: self._savepoint_name = ( f"s{len(self._conn._savepoints) + 1}" ) - if self._savepoint_name is not None: + if self._savepoint_name: self._conn._exec_command( sql.SQL("savepoint {}").format( sql.Identifier(self._savepoint_name) @@ -152,7 +148,7 @@ class Transaction: def __repr__(self) -> str: args = [f"connection={self.connection}"] - if self.savepoint_name is not None: + if not self.savepoint_name: args.append(f"savepoint_name={self.savepoint_name!r}") if self.force_rollback: args.append("force_rollback=True") diff --git a/tests/test_transaction.py b/tests/test_transaction.py index 84a092db3..026a7fa6e 100644 --- a/tests/test_transaction.py +++ b/tests/test_transaction.py @@ -327,16 +327,6 @@ def test_nested_three_levels_successful_exit(conn, svcconn): assert_rows(svcconn, {"one", "two", "three"}) -def test_named_savepoint_empty_string_invalid(conn): - """ - Raise validate savepoint_name up-front (rather than later constructing an - invalid SQL command and having that fail with an OperationalError). - """ - with pytest.raises(ValueError): - with conn.transaction(savepoint_name=""): - pass - - def test_named_savepoint_escapes_savepoint_name(conn): with conn.transaction("s-1"): pass @@ -359,7 +349,7 @@ def test_named_savepoints_successful_exit(conn): tx = Transaction(conn) with assert_commands_issued(conn, "begin"): tx.__enter__() - assert tx.savepoint_name is None + assert not tx.savepoint_name with assert_commands_issued(conn, "commit"): tx.__exit__(None, None, None) @@ -400,7 +390,7 @@ def test_named_savepoints_exception_exit(conn): tx = Transaction(conn) with assert_commands_issued(conn, "begin"): tx.__enter__() - assert tx.savepoint_name is None + assert not tx.savepoint_name with assert_commands_issued(conn, "rollback"): tx.__exit__(*some_exc_info())