From: Denis Laxalde Date: Sat, 7 Jan 2023 07:49:16 +0000 (+0100) Subject: refactor: handle returning-executemany() in _set_results() X-Git-Tag: pool-3.2.0~125^2~2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=1871fbfd09dcde82a4e57389195ca59ce72ed290;p=thirdparty%2Fpsycopg.git refactor: handle returning-executemany() in _set_results() --- diff --git a/psycopg/psycopg/cursor.py b/psycopg/psycopg/cursor.py index 30dcb9774..75c4ada93 100644 --- a/psycopg/psycopg/cursor.py +++ b/psycopg/psycopg/cursor.py @@ -253,6 +253,10 @@ class BaseCursor(Generic[ConnectionType, Row]): yield from self._start_query(query) if not returning: self._rowcount = 0 + + assert self._execmany_returning is None + self._execmany_returning = returning + first = True for params in params_seq: if first: @@ -265,16 +269,7 @@ class BaseCursor(Generic[ConnectionType, Row]): results = yield from self._maybe_prepare_gen(pgq, prepare=True) assert results is not None self._check_results(results) - if returning: - self._results.extend(results) - else: - # In non-returning case, set rowcount to the cumulated number - # of rows of executed queries. - for res in results: - self._rowcount += res.command_tuples or 0 - - if self._results: - self._select_current_result(0) + self._set_results(results) self._last_query = query @@ -528,8 +523,23 @@ class BaseCursor(Generic[ConnectionType, Row]): self._make_row = self._make_row_maker() def _set_results(self, results: List["PGresult"]) -> None: - self._results = results - self._select_current_result(0) + if self._execmany_returning is None: + # Received from execute() + self._results = results + self._select_current_result(0) + + else: + # Received from executemany() + if self._execmany_returning: + first_batch = not self._results + self._results.extend(results) + if first_batch: + self._select_current_result(0) + else: + # In non-returning case, set rowcount to the cumulated number + # of rows of executed queries. + for res in results: + self._rowcount += res.command_tuples or 0 def _set_results_from_pipeline(self, results: List["PGresult"]) -> None: self._check_results(results)