]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
refactor: handle returning-executemany() in _set_results()
authorDenis Laxalde <denis@laxalde.org>
Sat, 7 Jan 2023 07:49:16 +0000 (08:49 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 4 Feb 2023 11:16:29 +0000 (12:16 +0100)
psycopg/psycopg/cursor.py

index 30dcb977484a017eeaebcbc74c1c0ba4ca2e4dcd..75c4ada9332c70298a7f1722cfe2a6d423cceb9c 100644 (file)
@@ -253,6 +253,10 @@ class BaseCursor(Generic[ConnectionType, Row]):
         yield from self._start_query(query)
         if not returning:
             self._rowcount = 0
+
+        assert self._execmany_returning is None
+        self._execmany_returning = returning
+
         first = True
         for params in params_seq:
             if first:
@@ -265,16 +269,7 @@ class BaseCursor(Generic[ConnectionType, Row]):
             results = yield from self._maybe_prepare_gen(pgq, prepare=True)
             assert results is not None
             self._check_results(results)
-            if returning:
-                self._results.extend(results)
-            else:
-                # In non-returning case, set rowcount to the cumulated number
-                # of rows of executed queries.
-                for res in results:
-                    self._rowcount += res.command_tuples or 0
-
-        if self._results:
-            self._select_current_result(0)
+            self._set_results(results)
 
         self._last_query = query
 
@@ -528,8 +523,23 @@ class BaseCursor(Generic[ConnectionType, Row]):
         self._make_row = self._make_row_maker()
 
     def _set_results(self, results: List["PGresult"]) -> None:
-        self._results = results
-        self._select_current_result(0)
+        if self._execmany_returning is None:
+            # Received from execute()
+            self._results = results
+            self._select_current_result(0)
+
+        else:
+            # Received from executemany()
+            if self._execmany_returning:
+                first_batch = not self._results
+                self._results.extend(results)
+                if first_batch:
+                    self._select_current_result(0)
+            else:
+                # In non-returning case, set rowcount to the cumulated number
+                # of rows of executed queries.
+                for res in results:
+                    self._rowcount += res.command_tuples or 0
 
     def _set_results_from_pipeline(self, results: List["PGresult"]) -> None:
         self._check_results(results)