]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Set up row maker and loaders only once in a server-side cursor lifetime
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 12 Feb 2021 02:21:35 +0000 (03:21 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 12 Feb 2021 02:22:23 +0000 (03:22 +0100)
It wasn't happening once per movement, as I was fearing, but it was happening
exactly twice: once on DECLARE, once on describe_portal(). We actually
don't care about the DECLARE result: it was being set on the cursor only
to detect errors, so now that's done manually.

psycopg3/psycopg3/server_cursor.py
tests/test_server_cursor.py
tests/test_server_cursor_async.py

index f1386047f1157a9de8fb47a4bc8911c7e5f72ab3..851839ec79a1b4e8111fb948bdc702b7c22ae1e1 100644 (file)
@@ -61,11 +61,12 @@ class ServerCursorHelper(Generic[ConnectionType]):
 
         yield from cur._start_query(query)
         pgq = cur._convert_query(query, params)
-        cur._execute_send(pgq)
+        cur._execute_send(pgq, no_pqexec=True)
         results = yield from execute(conn.pgconn)
-        cur._execute_results(results)
+        if results[-1].status != pq.ExecStatus.COMMAND_OK:
+            cur._raise_from_results(results)
 
-        # The above result is an COMMAND_OK. Get the cursor result shape
+        # The above result only returned COMMAND_OK. Get the cursor shape
         yield from self._describe_gen(cur)
 
     def _describe_gen(self, cur: BaseCursor[ConnectionType]) -> PQGen[None]:
index 6e908a4078872757a7124a3d16199d9b7c074ab4..a4e6c012431c6389053aa9dcdf54194dc576a86d 100644 (file)
@@ -151,13 +151,23 @@ def test_nextset(conn):
 
 
 def test_row_factory(conn):
+    n = 0
+
     def my_row_factory(cur):
-        return lambda values: [-v for v in values]
+        nonlocal n
+        n += 1
+        return lambda values: [n] + [-v for v in values]
 
     cur = conn.cursor("foo", row_factory=my_row_factory)
-    cur.execute("select generate_series(1, 3)")
-    r = cur.fetchall()
-    assert r == [[-1], [-2], [-3]]
+    cur.execute("select generate_series(1, 3)", scrollable=True)
+    rows = cur.fetchall()
+    cur.scroll(0, "absolute")
+    while 1:
+        row = cur.fetchone()
+        if not row:
+            break
+        rows.append(row)
+    assert rows == [[1, -1], [1, -2], [1, -3]] * 2
 
 
 def test_rownumber(conn):
index 5625cd885ea7538fad43139732524f2a620d964a..191d3dd7f4de4fb3551f67df0a1886e44772130e 100644 (file)
@@ -153,13 +153,23 @@ async def test_nextset(aconn):
 
 
 async def test_row_factory(aconn):
+    n = 0
+
     def my_row_factory(cur):
-        return lambda values: [-v for v in values]
+        nonlocal n
+        n += 1
+        return lambda values: [n] + [-v for v in values]
 
     cur = aconn.cursor("foo", row_factory=my_row_factory)
-    await cur.execute("select generate_series(1, 3)")
-    r = await cur.fetchall()
-    assert r == [[-1], [-2], [-3]]
+    await cur.execute("select generate_series(1, 3)", scrollable=True)
+    rows = await cur.fetchall()
+    await cur.scroll(0, "absolute")
+    while 1:
+        row = await cur.fetchone()
+        if not row:
+            break
+        rows.append(row)
+    assert rows == [[1, -1], [1, -2], [1, -3]] * 2
 
 
 async def test_rownumber(aconn):