]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fix: wrap transaction in pipelines if the connection has one 298/head
authorDenis Laxalde <denis@laxalde.org>
Mon, 9 May 2022 19:14:59 +0000 (21:14 +0200)
committerDenis Laxalde <denis@laxalde.org>
Mon, 9 May 2022 19:24:13 +0000 (21:24 +0200)
When entering a transaction on a connection in pipeline mode, we open an
inner pipeline to ensure that a Sync is emitted at the end of
transaction thus restoring the connection in its expected state (i.e.
the same as in non-pipeline mode).

psycopg/psycopg/connection.py
psycopg/psycopg/connection_async.py
tests/test_pipeline.py
tests/test_pipeline_async.py

index 4440c2285c1391366c05ebdcf07def18e4830b79..f8a43bb9241b7fd999d6a21868061f74a7f6585b 100644 (file)
@@ -869,8 +869,13 @@ class Connection(BaseConnection[Row]):
             block even if there were no error (e.g. to try a no-op process).
         :rtype: Transaction
         """
-        with Transaction(self, savepoint_name, force_rollback) as tx:
-            yield tx
+        tx = Transaction(self, savepoint_name, force_rollback)
+        if self._pipeline:
+            with tx, self.pipeline():
+                yield tx
+        else:
+            with tx:
+                yield tx
 
     def notifies(self) -> Generator[Notify, None, None]:
         """
index af37e1618c4027066e7e9578a73b1366afc720be..606c68386e10ab34825aac26bd83308265c07a89 100644 (file)
@@ -282,8 +282,12 @@ class AsyncConnection(BaseConnection[Row]):
         :rtype: AsyncTransaction
         """
         tx = AsyncTransaction(self, savepoint_name, force_rollback)
-        async with tx:
-            yield tx
+        if self._pipeline:
+            async with tx, self.pipeline():
+                yield tx
+        else:
+            async with tx:
+                yield tx
 
     async def notifies(self) -> AsyncGenerator[Notify, None]:
         while 1:
index d8783f35ba5c368dea7f486b412a8549ab633832..29816cf23f4c667eaa79a66483b68a118e686c5e 100644 (file)
@@ -237,7 +237,6 @@ def test_errors_raised_on_transaction_exit(conn):
             with conn.transaction():
                 conn.execute("select 1 from nosuchtable")
                 here = True
-        conn.rollback()  # TODO: inconsistent with non-pipeline.
         cur1 = conn.execute("select 1")
     assert here
     cur2 = conn.execute("select 2")
index c69fef788a03df7e109a73d555d7967057c42ba0..b3dfef755ae16712dccf7824dbb3ecffd9db7f18 100644 (file)
@@ -240,7 +240,6 @@ async def test_errors_raised_on_transaction_exit(aconn):
             async with aconn.transaction():
                 await aconn.execute("select 1 from nosuchtable")
                 here = True
-        await aconn.rollback()  # TODO: inconsistent with non-pipeline.
         cur1 = await aconn.execute("select 1")
     assert here
     cur2 = await aconn.execute("select 2")