]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fix: keep Cursor._execmany_returning set until reset
authorDenis Laxalde <denis.laxalde@dalibo.com>
Tue, 29 Mar 2022 07:44:35 +0000 (09:44 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 2 Apr 2022 23:23:22 +0000 (01:23 +0200)
Cursor's _execmany_returning attribute is now initialized at _reset()
and set during _executemany_gen_pipeline(). This way, the attribute is
kept for further results fetch that may occur outside executemany()
context: namely, this is needed because _execmany_returning is used by
_set_results_from_pipeline() which would be called by fetch*() methods.

As a consequence, in _fetch_pipeline(), actual fetch is skipped when
coming from executemany(..., returning=False) as the pgresult would
never be set. This ensures a consistent behavior by raising "no result
available" when calling fetch*() from a non-returning executemany().

psycopg/psycopg/cursor.py
psycopg/psycopg/cursor_async.py
tests/test_pipeline.py
tests/test_pipeline_async.py

index 98441c9ee31a021fedc3f45e1cf99c29c93f2700..7a982f62a3565e8783e87f7aff1902147f4837ae 100644 (file)
@@ -68,8 +68,6 @@ class BaseCursor(Generic[ConnectionType, Row]):
         self._closed = False
         self._last_query: Optional[Query] = None
         self._reset()
-        # None if executemany() not executing, True/False according to returning state
-        self._execmany_returning: Optional[bool] = None
 
     def _reset(self, reset_query: bool = True) -> None:
         self._results: List["PGresult"] = []
@@ -79,6 +77,8 @@ class BaseCursor(Generic[ConnectionType, Row]):
         self._rowcount = -1
         self._query: Optional[PostgresQuery]
         self._encoding = "utf-8"
+        # None if executemany() not executing, True/False according to returning state
+        self._execmany_returning: Optional[bool] = None
         if reset_query:
             self._query = None
 
@@ -212,7 +212,7 @@ class BaseCursor(Generic[ConnectionType, Row]):
             yield from self._conn._exec_command(cmd)
 
     def _executemany_gen_pipeline(
-        self, query: Query, params_seq: Iterable[Params]
+        self, query: Query, params_seq: Iterable[Params], returning: bool
     ) -> PQGen[None]:
         """
         Generator implementing `Cursor.executemany()` with pipelines available.
@@ -223,6 +223,9 @@ class BaseCursor(Generic[ConnectionType, Row]):
         yield from self._start_query(query)
         self._rowcount = 0
 
+        assert self._execmany_returning is None
+        self._execmany_returning = returning
+
         first = True
         for params in params_seq:
             if first:
@@ -693,9 +696,9 @@ class Cursor(BaseCursor["Connection[Any]", Row]):
         try:
             if Pipeline.is_supported():
                 with self._conn.pipeline(), self._conn.lock:
-                    assert self._execmany_returning is None
-                    self._execmany_returning = returning
-                    self._conn.wait(self._executemany_gen_pipeline(query, params_seq))
+                    self._conn.wait(
+                        self._executemany_gen_pipeline(query, params_seq, returning)
+                    )
             else:
                 with self._conn.lock:
                     self._conn.wait(
@@ -703,8 +706,6 @@ class Cursor(BaseCursor["Connection[Any]", Row]):
                     )
         except e.Error as ex:
             raise ex.with_traceback(None)
-        finally:
-            self._execmany_returning = None
 
     def stream(
         self,
@@ -820,7 +821,11 @@ class Cursor(BaseCursor["Connection[Any]", Row]):
             yield copy
 
     def _fetch_pipeline(self) -> None:
-        if not self.pgresult and self._conn._pipeline:
+        if (
+            self._execmany_returning is not False
+            and not self.pgresult
+            and self._conn._pipeline
+        ):
             with self._conn.lock:
                 self._conn.wait(self._conn._pipeline._fetch_gen(flush=True))
             assert self.pgresult
index 27f1bb889bb52e4647032b5de36e310b1cc547c6..7bd8b1dbb1941643185ab6c6967a79f05289faa3 100644 (file)
@@ -89,10 +89,8 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]):
         try:
             if Pipeline.is_supported():
                 async with self._conn.pipeline(), self._conn.lock:
-                    assert self._execmany_returning is None
-                    self._execmany_returning = returning
                     await self._conn.wait(
-                        self._executemany_gen_pipeline(query, params_seq)
+                        self._executemany_gen_pipeline(query, params_seq, returning)
                     )
             else:
                 await self._conn.wait(
@@ -100,8 +98,6 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]):
                 )
         except e.Error as ex:
             raise ex.with_traceback(None)
-        finally:
-            self._execmany_returning = None
 
     async def stream(
         self,
@@ -182,7 +178,11 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]):
             yield copy
 
     async def _fetch_pipeline(self) -> None:
-        if not self.pgresult and self._conn._pipeline:
+        if (
+            self._execmany_returning is not False
+            and not self.pgresult
+            and self._conn._pipeline
+        ):
             async with self._conn.lock:
                 await self._conn.wait(self._conn._pipeline._fetch_gen(flush=True))
             assert self.pgresult
index 97d58dbd4f883440f3bd8ca6918c0fe06b1ef550..0e64b1f2d3f25ff87c07547688ac19e37d98345f 100644 (file)
@@ -176,6 +176,28 @@ def test_executemany(conn):
         assert cur.nextset() is None
 
 
+def test_executemany_no_returning(conn):
+    conn.autocommit = True
+    conn.execute("drop table if exists execmanypipelinenoreturning")
+    conn.execute(
+        "create unlogged table execmanypipelinenoreturning ("
+        " id serial primary key, num integer)"
+    )
+    with conn.pipeline(), conn.cursor() as cur:
+        cur.executemany(
+            "insert into execmanypipelinenoreturning(num) values (%s)",
+            [(10,), (20,)],
+            returning=False,
+        )
+        assert cur.rowcount == 2
+        with pytest.raises(e.ProgrammingError, match="no result available"):
+            cur.fetchone()
+        assert cur.nextset() is None
+        with pytest.raises(e.ProgrammingError, match="no result available"):
+            cur.fetchone()
+        assert cur.nextset() is None
+
+
 def test_prepared(conn):
     conn.autocommit = True
     with conn.pipeline():
index d7baef894f81c955bde2cf7bcc11f12d17676d8d..783814f3b0ff89f9bb0ca927db6be4dfd5c88f41 100644 (file)
@@ -179,6 +179,28 @@ async def test_executemany(aconn):
         assert cur.nextset() is None
 
 
+async def test_executemany_no_returning(aconn):
+    await aconn.set_autocommit(True)
+    await aconn.execute("drop table if exists execmanypipelinenoreturning")
+    await aconn.execute(
+        "create unlogged table execmanypipelinenoreturning ("
+        " id serial primary key, num integer)"
+    )
+    async with aconn.pipeline(), aconn.cursor() as cur:
+        await cur.executemany(
+            "insert into execmanypipelinenoreturning(num) values (%s)",
+            [(10,), (20,)],
+            returning=False,
+        )
+        assert cur.rowcount == 2
+        with pytest.raises(e.ProgrammingError, match="no result available"):
+            await cur.fetchone()
+        assert cur.nextset() is None
+        with pytest.raises(e.ProgrammingError, match="no result available"):
+            await cur.fetchone()
+        assert cur.nextset() is None
+
+
 async def test_prepared(aconn):
     await aconn.set_autocommit(True)
     async with aconn.pipeline():