]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fix: sync nested pipeline with pending commands upon enter
authorDenis Laxalde <denis.laxalde@dalibo.com>
Wed, 25 May 2022 06:51:05 +0000 (08:51 +0200)
committerDenis Laxalde <denis.laxalde@dalibo.com>
Wed, 25 May 2022 07:28:44 +0000 (09:28 +0200)
When entering a nested pipeline and the outer one has pending commands,
we now sync the pipeline. This is probably less surprising at it makes
the implicit transaction from a nested pipeline isolated from the outer
one.

With this, the explicit Sync when entering a transaction is no longer
needed.

Fix #309.

docs/advanced/pipeline.rst
psycopg/psycopg/_pipeline.py
psycopg/psycopg/connection.py
psycopg/psycopg/connection_async.py
tests/test_pipeline.py
tests/test_pipeline_async.py

index 12a7eb187a65c94ff57ffc8a4c505b316e2a1fdf..21f27b5b63dd30ec30f2625c543b5f315b44f6f3 100644 (file)
@@ -239,6 +239,7 @@ point is established by Psycopg:
 - using the `Pipeline.sync()` method;
 - on `Connection.commit()` or `~Connection.rollback()`;
 - at the end of a `!Pipeline` block;
+- possibly when opening a nested `!Pipeline` block;
 - using a fetch method such as `Cursor.fetchone()` (which only flushes the
   query but doesn't issue a Sync and doesn't reset a pipeline state error).
 
index e11c7b423709ccf4f936e2d262bc73cab7f24036..787184860c05f14b5c1965b57026f5fe412b0166 100644 (file)
@@ -77,9 +77,11 @@ class BasePipeline:
             BasePipeline._is_supported = pq_version >= 140000
         return BasePipeline._is_supported
 
-    def _enter(self) -> None:
+    def _enter_gen(self) -> PQGen[None]:
         if self.level == 0:
             self.pgconn.enter_pipeline_mode()
+        elif self.command_queue:
+            yield from self._sync_gen()
         self.level += 1
 
     def _exit(self) -> None:
@@ -191,7 +193,8 @@ class Pipeline(BasePipeline):
             raise ex.with_traceback(None)
 
     def __enter__(self) -> "Pipeline":
-        self._enter()
+        with self._conn.lock:
+            self._conn.wait(self._enter_gen())
         return self
 
     def __exit__(
@@ -239,7 +242,8 @@ class AsyncPipeline(BasePipeline):
             raise ex.with_traceback(None)
 
     async def __aenter__(self) -> "AsyncPipeline":
-        self._enter()
+        async with self._conn.lock:
+            await self._conn.wait(self._enter_gen())
         return self
 
     async def __aexit__(
index 4c610bf3f23692f3718f4ba402ce8f064b565d8a..58cb63c6a7958a70a9518dd566e849ce05a8d6ec 100644 (file)
@@ -903,7 +903,6 @@ class Connection(BaseConnection[Row]):
         """
         tx = Transaction(self, savepoint_name, force_rollback)
         if self._pipeline:
-            self._pipeline.sync()
             with self.pipeline(), tx, self.pipeline():
                 yield tx
         else:
index ae8715da1408f0f3ae2531dd37d9d55f38d05180..3a2bc91afac624c1c383c364c18361ab9b3843ae 100644 (file)
@@ -293,7 +293,6 @@ class AsyncConnection(BaseConnection[Row]):
         """
         tx = AsyncTransaction(self, savepoint_name, force_rollback)
         if self._pipeline:
-            await self._pipeline.sync()
             async with self.pipeline(), tx, self.pipeline():
                 yield tx
         else:
index 03a299d3615e4e6e1f8a62b920ea0073f5eff4f7..841379bb9830b30422df1586185e523c7f36de6c 100644 (file)
@@ -402,6 +402,9 @@ def test_auto_prepare(conn):
 
 
 def test_transaction(conn):
+    notices = []
+    conn.add_notice_handler(lambda diag: notices.append(diag.message_primary))
+
     with conn.pipeline():
         with conn.transaction():
             cur = conn.execute("select 'tx'")
@@ -416,6 +419,8 @@ def test_transaction(conn):
         (r,) = cur.fetchone()
         assert r == "rb"
 
+    assert not notices
+
 
 def test_transaction_nested(conn):
     with conn.pipeline():
index 70468688c85c6f76d4d16293a1b2029faaad27cc..c170dfa23528e02c14c852863ce0f848c61a2720 100644 (file)
@@ -402,6 +402,9 @@ async def test_auto_prepare(aconn):
 
 
 async def test_transaction(aconn):
+    notices = []
+    aconn.add_notice_handler(lambda diag: notices.append(diag.message_primary))
+
     async with aconn.pipeline():
         async with aconn.transaction():
             cur = await aconn.execute("select 'tx'")
@@ -416,6 +419,8 @@ async def test_transaction(aconn):
         (r,) = await cur.fetchone()
         assert r == "rb"
 
+    assert not notices
+
 
 async def test_transaction_nested(aconn):
     async with aconn.pipeline():