]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fix: make transaction status check account for pipeline mode 299/head
authorDenis Laxalde <denis.laxalde@dalibo.com>
Tue, 10 May 2022 10:11:21 +0000 (12:11 +0200)
committerDenis Laxalde <denis.laxalde@dalibo.com>
Tue, 10 May 2022 13:47:07 +0000 (15:47 +0200)
We turn _check_intrans() into a generator _check_intrans_gen() in order
to call _pipeline._sync_gen() if the connection is in pipeline mode so
as to retrieve an accurate connection status.

This makes the safety guard about 'autocommit' when inside a transaction
work in pipeline mode, thus removing the xfail in transaction tests.

In test_autocommit_unknown, we now catch OperationalError which is
raised by conn.wait() rather than ProgrammingError previously which is
no longer reached.

psycopg/psycopg/connection.py
psycopg/psycopg/connection_async.py
tests/test_connection.py
tests/test_connection_async.py
tests/test_transaction.py
tests/test_transaction_async.py

index 7ecae04c80553ba5a16cb9955ce6f6870de1e687..7a169207d95ff95590508b620bf4d2684dad48d7 100644 (file)
@@ -183,9 +183,10 @@ class BaseConnection(Generic[Row]):
         self._set_autocommit(value)
 
     def _set_autocommit(self, value: bool) -> None:
-        # Base implementation, not thread safe.
-        # Subclasses must call it holding a lock
-        self._check_intrans("autocommit")
+        raise NotImplementedError
+
+    def _set_autocommit_gen(self, value: bool) -> PQGen[None]:
+        yield from self._check_intrans_gen("autocommit")
         self._autocommit = bool(value)
 
     @property
@@ -200,9 +201,10 @@ class BaseConnection(Generic[Row]):
         self._set_isolation_level(value)
 
     def _set_isolation_level(self, value: Optional[IsolationLevel]) -> None:
-        # Base implementation, not thread safe.
-        # Subclasses must call it holding a lock
-        self._check_intrans("isolation_level")
+        raise NotImplementedError
+
+    def _set_isolation_level_gen(self, value: Optional[IsolationLevel]) -> PQGen[None]:
+        yield from self._check_intrans_gen("isolation_level")
         self._isolation_level = IsolationLevel(value) if value is not None else None
         self._begin_statement = b""
 
@@ -218,9 +220,10 @@ class BaseConnection(Generic[Row]):
         self._set_read_only(value)
 
     def _set_read_only(self, value: Optional[bool]) -> None:
-        # Base implementation, not thread safe.
-        # Subclasses must call it holding a lock
-        self._check_intrans("read_only")
+        raise NotImplementedError
+
+    def _set_read_only_gen(self, value: Optional[bool]) -> PQGen[None]:
+        yield from self._check_intrans_gen("read_only")
         self._read_only = bool(value)
         self._begin_statement = b""
 
@@ -236,15 +239,19 @@ class BaseConnection(Generic[Row]):
         self._set_deferrable(value)
 
     def _set_deferrable(self, value: Optional[bool]) -> None:
-        # Base implementation, not thread safe.
-        # Subclasses must call it holding a lock
-        self._check_intrans("deferrable")
+        raise NotImplementedError
+
+    def _set_deferrable_gen(self, value: Optional[bool]) -> PQGen[None]:
+        yield from self._check_intrans_gen("deferrable")
         self._deferrable = bool(value)
         self._begin_statement = b""
 
-    def _check_intrans(self, attribute: str) -> None:
+    def _check_intrans_gen(self, attribute: str) -> PQGen[None]:
         # Raise an exception if we are in a transaction
         status = self.pgconn.transaction_status
+        if status == TransactionStatus.IDLE and self._pipeline:
+            yield from self._pipeline._sync_gen()
+            status = self.pgconn.transaction_status
         if status != TransactionStatus.IDLE:
             if self._num_transactions:
                 raise e.ProgrammingError(
@@ -937,19 +944,19 @@ class Connection(BaseConnection[Row]):
 
     def _set_autocommit(self, value: bool) -> None:
         with self.lock:
-            super()._set_autocommit(value)
+            self.wait(self._set_autocommit_gen(value))
 
     def _set_isolation_level(self, value: Optional[IsolationLevel]) -> None:
         with self.lock:
-            super()._set_isolation_level(value)
+            self.wait(self._set_isolation_level_gen(value))
 
     def _set_read_only(self, value: Optional[bool]) -> None:
         with self.lock:
-            super()._set_read_only(value)
+            self.wait(self._set_read_only_gen(value))
 
     def _set_deferrable(self, value: Optional[bool]) -> None:
         with self.lock:
-            super()._set_deferrable(value)
+            self.wait(self._set_deferrable_gen(value))
 
     def tpc_begin(self, xid: Union[Xid, str]) -> None:
         """
index a4eee8f5c02394a80db9361c838e42a7c2a8835a..ef6c2bbd75a5ea4b79720e6ebaeb5a564182dfde 100644 (file)
@@ -347,7 +347,7 @@ class AsyncConnection(BaseConnection[Row]):
     async def set_autocommit(self, value: bool) -> None:
         """Async version of the `~Connection.autocommit` setter."""
         async with self.lock:
-            super()._set_autocommit(value)
+            await self.wait(self._set_autocommit_gen(value))
 
     def _set_isolation_level(self, value: Optional[IsolationLevel]) -> None:
         self._no_set_async("isolation_level")
@@ -355,7 +355,7 @@ class AsyncConnection(BaseConnection[Row]):
     async def set_isolation_level(self, value: Optional[IsolationLevel]) -> None:
         """Async version of the `~Connection.isolation_level` setter."""
         async with self.lock:
-            super()._set_isolation_level(value)
+            await self.wait(self._set_isolation_level_gen(value))
 
     def _set_read_only(self, value: Optional[bool]) -> None:
         self._no_set_async("read_only")
@@ -363,7 +363,7 @@ class AsyncConnection(BaseConnection[Row]):
     async def set_read_only(self, value: Optional[bool]) -> None:
         """Async version of the `~Connection.read_only` setter."""
         async with self.lock:
-            super()._set_read_only(value)
+            await self.wait(self._set_read_only_gen(value))
 
     def _set_deferrable(self, value: Optional[bool]) -> None:
         self._no_set_async("deferrable")
@@ -371,7 +371,7 @@ class AsyncConnection(BaseConnection[Row]):
     async def set_deferrable(self, value: Optional[bool]) -> None:
         """Async version of the `~Connection.deferrable` setter."""
         async with self.lock:
-            super()._set_deferrable(value)
+            await self.wait(self._set_deferrable_gen(value))
 
     def _no_set_async(self, attribute: str) -> None:
         raise AttributeError(
index c9eb3ba660139bd1993b775a264d8c89fb1a27aa..663bd1c153b0818aaf1d480a9acb13d50a919352 100644 (file)
@@ -336,7 +336,7 @@ def test_autocommit_inerror(conn):
 def test_autocommit_unknown(conn):
     conn.close()
     assert conn.pgconn.transaction_status == conn.TransactionStatus.UNKNOWN
-    with pytest.raises(psycopg.ProgrammingError):
+    with pytest.raises(psycopg.OperationalError):
         conn.autocommit = True
     assert not conn.autocommit
 
index 16502a51cfe18087f6da7f27a590f64596e1fb9a..912b4c90935b44135de96b0e068eca838646b028 100644 (file)
@@ -340,7 +340,7 @@ async def test_autocommit_inerror(aconn):
 async def test_autocommit_unknown(aconn):
     await aconn.close()
     assert aconn.pgconn.transaction_status == aconn.TransactionStatus.UNKNOWN
-    with pytest.raises(psycopg.ProgrammingError):
+    with pytest.raises(psycopg.OperationalError):
         await aconn.set_autocommit(True)
     assert not aconn.autocommit
 
index 094746e27cf1a3a4528379b9a4b5ab8ac95f14b4..ef5a8971d77407106c9da64105cf2ed38af6581d 100644 (file)
@@ -199,17 +199,12 @@ def test_interaction_dbapi_transaction(conn):
     assert inserted(conn) == {"foo", "baz"}
 
 
-def test_prohibits_use_of_commit_rollback_autocommit(conn, pipeline):
+def test_prohibits_use_of_commit_rollback_autocommit(conn):
     """
     Within a Transaction block, it is forbidden to touch commit, rollback,
     or the autocommit setting on the connection, as this would interfere
     with the transaction scope being managed by the Transaction block.
     """
-    if pipeline:
-        # TODO: Fixing Connection._check_intrans() would require calling
-        # conn._pipeline.sync(), which implies turning _check_intrans() into a
-        # generator method.
-        pytest.xfail("Connection._check_intrans() does not account for pipeline mode")
     conn.autocommit = False
     conn.commit()
     conn.rollback()
index 1af82484e4308bd56098280a2b3e51500cdf6203..4533afa5ec576c700ce42fc8f1e4a01fa13d4ec7 100644 (file)
@@ -142,17 +142,12 @@ async def test_interaction_dbapi_transaction(aconn):
     assert await inserted(aconn) == {"foo", "baz"}
 
 
-async def test_prohibits_use_of_commit_rollback_autocommit(aconn, apipeline):
+async def test_prohibits_use_of_commit_rollback_autocommit(aconn):
     """
     Within a Transaction block, it is forbidden to touch commit, rollback,
     or the autocommit setting on the connection, as this would interfere
     with the transaction scope being managed by the Transaction block.
     """
-    if apipeline:
-        # TODO: Fixing Connection._check_intrans() would require calling
-        # conn._pipeline.sync(), which implies turning _check_intrans() into a
-        # generator method.
-        pytest.xfail("Connection._check_intrans() does not account for pipeline mode")
     await aconn.set_autocommit(False)
     await aconn.commit()
     await aconn.rollback()