]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fix: fix 'executemany()' when pipeline mode is not available
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 27 Mar 2022 17:57:17 +0000 (19:57 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 2 Apr 2022 23:23:22 +0000 (01:23 +0200)
psycopg/psycopg/cursor.py
psycopg/psycopg/cursor_async.py

index 87dbcb847f00422782fb014a61c104aa30086732..40aacf15fe2e8e3943e6f9bac7fc92c2b4b8cfd7 100644 (file)
@@ -21,6 +21,7 @@ from .rows import Row, RowMaker, RowFactory
 from ._column import Column
 from ._cmodule import _psycopg
 from ._queries import PostgresQuery
+from ._pipeline import Pipeline
 from ._encodings import pgconn_encoding
 from ._preparing import Prepare
 
@@ -210,10 +211,12 @@ class BaseCursor(Generic[ConnectionType, Row]):
         for cmd in self._conn._prepared.get_maintenance_commands():
             yield from self._conn._exec_command(cmd)
 
-    def _executemany_gen(
-        self, query: Query, params_seq: Iterable[Params], returning: bool
+    def _executemany_gen_pipeline(
+        self, query: Query, params_seq: Iterable[Params]
     ) -> PQGen[None]:
-        """Generator implementing `Cursor.executemany()`."""
+        """
+        Generator implementing `Cursor.executemany()` with pipelines available.
+        """
         pipeline = self._conn._pipeline
         assert pipeline
 
@@ -239,6 +242,44 @@ class BaseCursor(Generic[ConnectionType, Row]):
 
         yield from pipeline._flush_gen()
 
+    def _executemany_gen_no_pipeline(
+        self, query: Query, params_seq: Iterable[Params], returning: bool
+    ) -> PQGen[None]:
+        """
+        Generator implementing `Cursor.executemany()` with pipelines not available.
+        """
+        yield from self._start_query(query)
+        first = True
+        nrows = 0
+        for params in params_seq:
+            if first:
+                pgq = self._convert_query(query, params)
+                self._query = pgq
+                first = False
+            else:
+                pgq.dump(params)
+
+            results = yield from self._maybe_prepare_gen(pgq, prepare=True)
+            assert results is not None
+            self._check_results(results)
+            if returning and results[0].status == ExecStatus.TUPLES_OK:
+                self._results.extend(results)
+
+            for res in results:
+                nrows += res.command_tuples or 0
+
+        if self._results:
+            self._set_current_result(0)
+
+        # Override rowcount for the first result. Calls to nextset() will change
+        # it to the value of that result only, but we hope nobody will notice.
+        # You haven't read this comment.
+        self._rowcount = nrows
+        self._last_query = query
+
+        for cmd in self._conn._prepared.get_maintenance_commands():
+            yield from self._conn._exec_command(cmd)
+
     def _maybe_prepare_gen(
         self,
         pgq: PostgresQuery,
@@ -478,7 +519,6 @@ class BaseCursor(Generic[ConnectionType, Row]):
             # TODO: bug we also end up here on executemany() if run from inside
             # a pipeline block. This causes a wrong rowcount. As it isn't so
             # serious, currently leaving it this way.
-            first_batch = not self._results
             self._results.extend(results)
             if first_batch:
                 self._set_current_result(0)
@@ -656,11 +696,16 @@ class Cursor(BaseCursor["Connection[Any]", Row]):
         Execute the same command with a sequence of input data.
         """
         try:
-            with self._conn.pipeline():
-                with self._conn.lock:
+            if Pipeline.is_supported():
+                with self._conn.pipeline(), self._conn.lock:
                     assert self._execmany_returning is None
                     self._execmany_returning = returning
-                    self._conn.wait(self._executemany_gen(query, params_seq, returning))
+                    self._conn.wait(self._executemany_gen_pipeline(query, params_seq))
+            else:
+                with self._conn.lock:
+                    self._conn.wait(
+                        self._executemany_gen_no_pipeline(query, params_seq, returning)
+                    )
         except e.Error as ex:
             raise ex.with_traceback(None)
         finally:
index 61bdc1f35eb4bbf794733bb9281df92ae1b02414..27f1bb889bb52e4647032b5de36e310b1cc547c6 100644 (file)
@@ -14,6 +14,7 @@ from .abc import Query, Params
 from .copy import AsyncCopy
 from .rows import Row, RowMaker, AsyncRowFactory
 from .cursor import BaseCursor
+from ._pipeline import Pipeline
 
 if TYPE_CHECKING:
     from .connection_async import AsyncConnection
@@ -86,13 +87,17 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]):
         returning: bool = False,
     ) -> None:
         try:
-            async with self._conn.pipeline():
-                async with self._conn.lock:
+            if Pipeline.is_supported():
+                async with self._conn.pipeline(), self._conn.lock:
                     assert self._execmany_returning is None
                     self._execmany_returning = returning
                     await self._conn.wait(
-                        self._executemany_gen(query, params_seq, returning)
+                        self._executemany_gen_pipeline(query, params_seq)
                     )
+            else:
+                await self._conn.wait(
+                    self._executemany_gen_no_pipeline(query, params_seq, returning)
+                )
         except e.Error as ex:
             raise ex.with_traceback(None)
         finally: