]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Reduce the work done on the cursor by executemany
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 29 Nov 2021 01:13:23 +0000 (02:13 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 14 Dec 2021 22:06:29 +0000 (23:06 +0100)
Certain functions were executed once per query (e.g. _make_row_maker(),
_tx.set_pgresult()) but we can run them only once.

psycopg/psycopg/cursor.py
psycopg/psycopg/server_cursor.py

index 298638649b337321cd096f1d80c388c0af502710..6a69f411e77da92e55d75a7432dc089479e8cde8 100644 (file)
@@ -6,7 +6,7 @@ psycopg cursor objects
 
 import sys
 from types import TracebackType
-from typing import Any, Callable, Generic, Iterable, Iterator, List
+from typing import Any, Generic, Iterable, Iterator, List
 from typing import Optional, NoReturn, Sequence, Type, TypeVar, TYPE_CHECKING
 from contextlib import contextmanager
 
@@ -29,8 +29,6 @@ if TYPE_CHECKING:
     from .pq.abc import PGconn, PGresult
     from .connection import Connection
 
-execute: Callable[["PGconn"], PQGen[List["PGresult"]]]
-
 if _psycopg:
     execute = _psycopg.execute
     fetch = _psycopg.fetch
@@ -158,14 +156,8 @@ class BaseCursor(Generic[ConnectionType, Row]):
         Return `!True` if a new result is available, which will be the one
         methods `!fetch*()` will operate on.
         """
-        self._iresult += 1
-        if self._iresult < len(self._results):
-            self.pgresult = self._results[self._iresult]
-            self._tx.set_pgresult(self._results[self._iresult])
-            self._make_row = self._make_row_maker()
-            self._pos = 0
-            nrows = self.pgresult.command_tuples
-            self._rowcount = nrows if nrows is not None else -1
+        if self._iresult < len(self._results) - 1:
+            self._set_result(self._iresult + 1)
             return True
         else:
             return None
@@ -205,7 +197,8 @@ class BaseCursor(Generic[ConnectionType, Row]):
         results = yield from self._maybe_prepare_gen(
             pgq, prepare=prepare, binary=binary
         )
-        self._execute_results(results)
+        self._set_results(results)
+        self._set_result(0)
         self._last_query = query
 
         for cmd in self._conn._prepared.get_maintenance_commands():
@@ -217,6 +210,7 @@ class BaseCursor(Generic[ConnectionType, Row]):
         """Generator implementing `Cursor.executemany()`."""
         yield from self._start_query(query)
         first = True
+        nrows = 0
         for params in params_seq:
             if first:
                 pgq = self._convert_query(query, params)
@@ -226,8 +220,16 @@ class BaseCursor(Generic[ConnectionType, Row]):
                 pgq.dump(params)
 
             results = yield from self._maybe_prepare_gen(pgq, prepare=True)
-            self._execute_results(results)
+            self._set_results(results)
+
+            for res in results:
+                nrows += res.command_tuples or 0
 
+        self._set_result(0)
+        # Override rowcout for the first result. Calls to nextset() will change
+        # it to the value of that result only, but we hope nobody will notice.
+        # You haven't read this comment.
+        self._rowcount = nrows or -1
         self._last_query = query
 
         for cmd in self._conn._prepared.get_maintenance_commands():
@@ -239,7 +241,7 @@ class BaseCursor(Generic[ConnectionType, Row]):
         *,
         prepare: Optional[bool] = None,
         binary: Optional[bool] = None,
-    ) -> PQGen[Sequence["PGresult"]]:
+    ) -> PQGen[List["PGresult"]]:
         # Check if the query is prepared or needs preparing
         prep, name = self._conn._prepared.get(pgq, prepare)
         if prep is Prepare.YES:
@@ -394,13 +396,9 @@ class BaseCursor(Generic[ConnectionType, Row]):
         ExecStatus.COPY_BOTH,
     )
 
-    def _execute_results(
-        self, results: Sequence["PGresult"], format: Optional[Format] = None
-    ) -> None:
+    def _set_results(self, results: List["PGresult"]) -> None:
         """
-        Implement part of execute() after waiting common to sync and async
-
-        This is not a generator, but a normal non-blocking function.
+        Set the results from a query into the cursor state.
         """
         if not results:
             raise e.InternalError("got no result from the query")
@@ -409,23 +407,26 @@ class BaseCursor(Generic[ConnectionType, Row]):
             if res.status not in self._status_ok:
                 self._raise_from_results(results)
 
-        self._results = list(results)
-        self.pgresult = results[0]
+        self._results = results
+
+    def _set_result(self, i: int, format: Optional[Format] = None) -> None:
+        """
+        Select one of the results in the cursor as the active one.
+        """
+        self._iresult = i
+        res = self.pgresult = self._results[i]
 
         # Note: the only reason to override format is to correclty set
         # binary loaders on server-side cursors, because send_describe_portal
         # only returns a text result.
-        self._tx.set_pgresult(results[0], format=format)
+        self._tx.set_pgresult(res, format=format)
 
         self._make_row = self._make_row_maker()
+        self._pos = 0
         nrows = self.pgresult.command_tuples
-        if nrows is not None:
-            if self._rowcount < 0:
-                self._rowcount = nrows
-            else:
-                self._rowcount += nrows
+        self._rowcount = nrows if nrows is not None else -1
 
-    def _raise_from_results(self, results: Sequence["PGresult"]) -> NoReturn:
+    def _raise_from_results(self, results: List["PGresult"]) -> NoReturn:
         statuses = {res.status for res in results}
         badstats = statuses.difference(self._status_ok)
         if results[-1].status == ExecStatus.FATAL_ERROR:
index b4289a67373375f5348acfcd1338764d4a2a968c..ba6a62d4cbd84f4c68451b56ef1cf27de05f8e39 100644 (file)
@@ -87,7 +87,8 @@ class ServerCursorHelper(Generic[ConnectionType, Row]):
         conn = cur._conn
         conn.pgconn.send_describe_portal(self.name.encode(cur._encoding))
         results = yield from execute(conn.pgconn)
-        cur._execute_results(results, format=self._format)
+        cur._set_results(results)
+        cur._set_result(0, format=self._format)
         self.described = True
 
     def _close_gen(self, cur: BaseCursor[ConnectionType, Row]) -> PQGen[None]: