]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
refactor: unify code paths to set the rowcount
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 24 Jul 2022 02:35:28 +0000 (03:35 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 26 Jul 2022 12:23:46 +0000 (13:23 +0100)
psycopg/psycopg/copy.py
psycopg/psycopg/cursor.py
psycopg/psycopg/cursor_async.py

index 08f1e2b213e2633cadeef237850867811d4ae369..3c54aafec1b5b0e4921a5ac90d569a843d9276db 100644 (file)
@@ -158,6 +158,11 @@ class BaseCopy(Generic[ConnectionType]):
 
         # res is the final PGresult
         self._finished = True
+
+        # This result is a COMMAND_OK which has info about the number of rows
+        # returned, but not about the columns, which is instead an information
+        # that was received on the COPY_OUT result at the beginning of COPY.
+        # So, don't replace the results in the cursor, just update the rowcount.
         nrows = res.command_tuples
         self.cursor._rowcount = nrows if nrows is not None else -1
         return memoryview(b"")
@@ -357,8 +362,7 @@ class LibpqWriter(Writer):
             bmsg = None
 
         res = self.connection.wait(copy_end(self._pgconn, bmsg))
-        nrows = res.command_tuples
-        self.cursor._rowcount = nrows if nrows is not None else -1
+        self.cursor._results = [res]
 
 
 class QueuedLibpqDriver(LibpqWriter):
@@ -570,8 +574,7 @@ class AsyncLibpqWriter(AsyncWriter):
             bmsg = None
 
         res = await self.connection.wait(copy_end(self._pgconn, bmsg))
-        nrows = res.command_tuples
-        self.cursor._rowcount = nrows if nrows is not None else -1
+        self.cursor._results = [res]
 
 
 class AsyncQueuedLibpqWriter(AsyncLibpqWriter):
index 2444706c067699987194177a5b55b608acc8a651..5653f07c9863a972ba5ffcaa95cbd7986e42175d 100644 (file)
@@ -425,10 +425,9 @@ class BaseCursor(Generic[ConnectionType, Row]):
         if len(results) != 1:
             raise e.ProgrammingError("COPY cannot be mixed with other operations")
 
-        result = results[0]
-        self._check_copy_result(result)
-        self.pgresult = result
-        self._tx.set_pgresult(result)
+        self._check_copy_result(results[0])
+        self._results = results
+        self._select_current_result(0)
 
     def _execute_send(
         self,
@@ -529,10 +528,16 @@ class BaseCursor(Generic[ConnectionType, Row]):
         # only returns a text result.
         self._tx.set_pgresult(res, format=format)
 
-        self._make_row = self._make_row_maker()
         self._pos = 0
-        nrows = self.pgresult.command_tuples
-        self._rowcount = nrows if nrows is not None else -1
+
+        # COPY_OUT has never info about nrows. We need such result for the
+        # columns in order to return a `description`, but not overwrite the
+        # cursor rowcount (which was set by the Copy object).
+        if res.status != COPY_OUT:
+            nrows = self.pgresult.command_tuples
+            self._rowcount = nrows if nrows is not None else -1
+
+        self._make_row = self._make_row_maker()
 
     def _set_results_from_pipeline(self, results: List["PGresult"]) -> None:
         self._check_results(results)
@@ -896,6 +901,10 @@ class Cursor(BaseCursor["Connection[Any]", Row]):
         except e.Error as ex:
             raise ex.with_traceback(None)
 
+        # If a fresher result has been set on the cursor by the Copy object,
+        # read its properties (especially rowcount).
+        self._select_current_result(0)
+
     def _fetch_pipeline(self) -> None:
         if (
             self._execmany_returning is not False
index 4b4844d40f4e0a64f6dbf9060d249a7b98e18154..0632940d6ca41cf9caeeb4bc276777975086f4a7 100644 (file)
@@ -222,6 +222,8 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]):
         except e.Error as ex:
             raise ex.with_traceback(None)
 
+        self._select_current_result(0)
+
     async def _fetch_pipeline(self) -> None:
         if (
             self._execmany_returning is not False