]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fix: always validate PrepareManager cache in pipeline mode
authorDenis Laxalde <denis.laxalde@dalibo.com>
Thu, 8 Jun 2023 09:21:42 +0000 (11:21 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 12 Jun 2023 20:29:46 +0000 (22:29 +0200)
Previously, when processing results in pipeline mode
(BasePipeline._process_results()), we'd run
'cursor._check_results(results)' early before calling
_prepared.validate() with prepared statement information. However, if
this check step fails, for example if the pipeline got aborted due to a
previous error, the latter step (PrepareManager cache validation) was
not run.

We fix this by reversing the logic, and checking results last.

However, this is not enough, because the results processing logic in
BasePipeline._fetch_gen() or _communicate_gen(), which sequentially
walked through fetched results, would typically stop at the first
exception and thus possibly never go through the step of validating
PrepareManager cache if a previous error happened.

We fix that by making sure that *all* results are processed, possibly
capturing the first exception and then re-raising it. In both
_communicate_gen() and _fetch_gen(), we no longer store results in the
'to_process' like, but process then upon reception as this logic is no
longer needed.

Fix #585.

docs/news.rst
psycopg/psycopg/_pipeline.py
tests/test_pipeline.py
tests/test_pipeline_async.py

index 38f5e037c800e144d198fd4fbf4bb1367b53d11c..1f2377717c72e78f4798861845c3960eed56f961 100644 (file)
@@ -7,6 +7,16 @@
 ``psycopg`` release notes
 =========================
 
+Future releases
+---------------
+
+Psycopg 3.1.10
+^^^^^^^^^^^^^^
+
+- Fix prepared statement cache validation when exiting pipeline mode (or
+  `~Cursor.executemany()`) in case an error occurred within the pipeline
+  (:ticket:`#585`).
+
 Current release
 ---------------
 
index ecd6f0628d815b9d79ce6974c000f4a8830c29cd..e0c564a21d59d722fde0b0a7e02054226e1bc773 100644 (file)
@@ -139,9 +139,16 @@ class BasePipeline:
         results, which are then processed.
         """
         fetched = yield from pipeline_communicate(self.pgconn, self.command_queue)
-        to_process = [(self.result_queue.popleft(), results) for results in fetched]
-        for queued, results in to_process:
-            self._process_results(queued, results)
+        exception = None
+        for results in fetched:
+            queued = self.result_queue.popleft()
+            try:
+                self._process_results(queued, results)
+            except e.Error as exc:
+                if exception is None:
+                    exception = exc
+        if exception is not None:
+            raise exception
 
     def _fetch_gen(self, *, flush: bool) -> PQGen[None]:
         """Fetch available results from the connection and process them with
@@ -159,7 +166,7 @@ class BasePipeline:
             self.pgconn.send_flush_request()
             yield from send(self.pgconn)
 
-        to_process = []
+        exception = None
         while self.result_queue:
             results = yield from fetch_many(self.pgconn)
             if not results:
@@ -167,10 +174,13 @@ class BasePipeline:
                 # commands.
                 break
             queued = self.result_queue.popleft()
-            to_process.append((queued, results))
-
-        for queued, results in to_process:
-            self._process_results(queued, results)
+            try:
+                self._process_results(queued, results)
+            except e.Error as exc:
+                if exception is None:
+                    exception = exc
+        if exception is not None:
+            raise exception
 
     def _process_results(
         self, queued: PendingResult, results: List["PGresult"]
@@ -190,11 +200,11 @@ class BasePipeline:
                 raise e.PipelineAborted("pipeline aborted")
         else:
             cursor, prepinfo = queued
-            cursor._set_results_from_pipeline(results)
             if prepinfo:
                 key, prep, name = prepinfo
                 # Update the prepare state of the query.
                 cursor._conn._prepared.validate(key, prep, name, results)
+            cursor._set_results_from_pipeline(results)
 
     def _enqueue_sync(self) -> None:
         """Enqueue a PQpipelineSync() command."""
index 9d5301d7fa9d020763288d88adaebdb712edbf12..cfe39e07a9fe6a82a80083b1bc05a3c937ccbecc 100644 (file)
@@ -427,6 +427,22 @@ def test_auto_prepare(conn):
     assert res == [0] * 5 + [1] * 5
 
 
+def test_prepare_error(conn):
+    """Regression test for GH issue #585.
+
+    An invalid prepared statement, in a pipeline, should be discarded at exit
+    and not reused.
+    """
+    conn.autocommit = True
+    stmt = "INSERT INTO nosuchtable(data) VALUES (%s)"
+    with pytest.raises(psycopg.errors.UndefinedTable):
+        with conn.pipeline():
+            conn.execute(stmt, ["foo"], prepare=True)
+    assert not conn._prepared._names
+    with pytest.raises(psycopg.errors.UndefinedTable):
+        conn.execute(stmt, ["bar"])
+
+
 def test_transaction(conn):
     notices = []
     conn.add_notice_handler(lambda diag: notices.append(diag.message_primary))
index 3fcf29393fbdbae563f6a8cb2b3e9df4af7d5ea9..88c8d30bd3888bf8f26b328be6b974df92038626 100644 (file)
@@ -428,6 +428,22 @@ async def test_auto_prepare(aconn):
     assert res == [0] * 5 + [1] * 5
 
 
+async def test_prepare_error(aconn):
+    """Regression test for GH issue #585.
+
+    An invalid prepared statement, in a pipeline, should be discarded at exit
+    and not reused.
+    """
+    await aconn.set_autocommit(True)
+    stmt = "INSERT INTO nosuchtable(data) VALUES (%s)"
+    with pytest.raises(psycopg.errors.UndefinedTable):
+        async with aconn.pipeline():
+            await aconn.execute(stmt, ["foo"], prepare=True)
+    assert not aconn._prepared._names
+    with pytest.raises(psycopg.errors.UndefinedTable):
+        await aconn.execute(stmt, ["bar"])
+
+
 async def test_transaction(aconn):
     notices = []
     aconn.add_notice_handler(lambda diag: notices.append(diag.message_primary))