]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fix: consider cursor description results in row factories
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 19 Dec 2022 22:46:42 +0000 (22:46 +0000)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 19 Dec 2022 22:56:02 +0000 (22:56 +0000)
Fix #464

docs/news.rst
psycopg/psycopg/rows.py
tests/test_server_cursor.py
tests/test_server_cursor_async.py

index 5d7837193562a9a674aaafefb1db5f2b8c97ee5a..e4fd81ec21f1fead40bef14b77650cccc02dabf7 100644 (file)
@@ -7,6 +7,15 @@
 ``psycopg`` release notes
 =========================
 
+Future releases
+---------------
+
+Psycopg 3.1.7 (unreleased)
+^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+- Fix server-side cursors using row factories (:ticket:`#464`).
+
+
 Current release
 ---------------
 
index 4f96a1af00ad37e6cf56db135731e1880ecb4144..cb28b57ac9d9e131940488136a15d8c0d5bff5fa 100644 (file)
@@ -20,6 +20,7 @@ if TYPE_CHECKING:
     from .cursor_async import AsyncCursor
     from psycopg.pq.abc import PGresult
 
+COMMAND_OK = pq.ExecStatus.COMMAND_OK
 TUPLES_OK = pq.ExecStatus.TUPLES_OK
 SINGLE_TUPLE = pq.ExecStatus.SINGLE_TUPLE
 
@@ -244,7 +245,12 @@ def _get_nfields(res: "PGresult") -> Optional[int]:
     """
     nfields = res.nfields
 
-    if res.status == TUPLES_OK or res.status == SINGLE_TUPLE:
+    if (
+        res.status == TUPLES_OK
+        or res.status == SINGLE_TUPLE
+        # "describe" in named cursors
+        or (res.status == COMMAND_OK and nfields)
+    ):
         return nfields
     else:
         return None
index e1f8ce392077bc27de46722a6a2ca043f13c92a5..f7b6c8ed63cdb093ee3cdc17cdc33cb82da1a2bb 100644 (file)
@@ -331,6 +331,25 @@ def test_no_result(conn):
         assert cur.fetchall() == []
 
 
+@pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"])
+def test_standard_row_factory(conn, row_factory):
+    if row_factory == "tuple_row":
+        getter = lambda r: r[0]  # noqa: E731
+    elif row_factory == "dict_row":
+        getter = lambda r: r["bar"]  # noqa: E731
+    elif row_factory == "namedtuple_row":
+        getter = lambda r: r.bar  # noqa: E731
+    else:
+        assert False, row_factory
+
+    row_factory = getattr(rows, row_factory)
+    with conn.cursor("foo", row_factory=row_factory) as cur:
+        cur.execute("select generate_series(1, 5) as bar")
+        assert getter(cur.fetchone()) == 1
+        assert list(map(getter, cur.fetchmany(2))) == [2, 3]
+        assert list(map(getter, cur.fetchall())) == [4, 5]
+
+
 @pytest.mark.crdb_skip("scroll cursor")
 def test_row_factory(conn):
     n = 0
@@ -479,13 +498,17 @@ def test_hold(conn):
         assert curs.fetchone() == (1,)
 
 
-def test_steal_cursor(conn):
+@pytest.mark.parametrize("row_factory", ["tuple_row", "namedtuple_row"])
+def test_steal_cursor(conn, row_factory):
     cur1 = conn.cursor()
-    cur1.execute("declare test cursor for select generate_series(1, 6)")
+    cur1.execute("declare test cursor for select generate_series(1, 6) as s")
 
-    cur2 = conn.cursor("test")
+    cur2 = conn.cursor("test", row_factory=getattr(rows, row_factory))
     # can call fetch without execute
-    assert cur2.fetchone() == (1,)
+    rec = cur2.fetchone()
+    assert rec == (1,)
+    if row_factory == "namedtuple_row":
+        assert rec.s == 1
     assert cur2.fetchmany(3) == [(2,), (3,), (4,)]
     assert cur2.fetchall() == [(5,), (6,)]
     cur2.close()
index 0a795bb8253ea076b2a3066d8c1234ba27c2ef91..21b434503d9ecdda550c7dae521ef7ad9f1004f1 100644 (file)
@@ -342,6 +342,25 @@ async def test_no_result(aconn):
         assert (await cur.fetchall()) == []
 
 
+@pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"])
+async def test_standard_row_factory(aconn, row_factory):
+    if row_factory == "tuple_row":
+        getter = lambda r: r[0]  # noqa: E731
+    elif row_factory == "dict_row":
+        getter = lambda r: r["bar"]  # noqa: E731
+    elif row_factory == "namedtuple_row":
+        getter = lambda r: r.bar  # noqa: E731
+    else:
+        assert False, row_factory
+
+    row_factory = getattr(rows, row_factory)
+    async with aconn.cursor("foo", row_factory=row_factory) as cur:
+        await cur.execute("select generate_series(1, 5) as bar")
+        assert getter(await cur.fetchone()) == 1
+        assert list(map(getter, await cur.fetchmany(2))) == [2, 3]
+        assert list(map(getter, await cur.fetchall())) == [4, 5]
+
+
 @pytest.mark.crdb_skip("scroll cursor")
 async def test_row_factory(aconn):
     n = 0
@@ -495,15 +514,19 @@ async def test_hold(aconn):
         assert await curs.fetchone() == (1,)
 
 
-async def test_steal_cursor(aconn):
+@pytest.mark.parametrize("row_factory", ["tuple_row", "namedtuple_row"])
+async def test_steal_cursor(aconn, row_factory):
     cur1 = aconn.cursor()
     await cur1.execute(
-        "declare test cursor without hold for select generate_series(1, 6)"
+        "declare test cursor without hold for select generate_series(1, 6) as s"
     )
 
-    cur2 = aconn.cursor("test")
+    cur2 = aconn.cursor("test", row_factory=getattr(rows, row_factory))
     # can call fetch without execute
-    assert await cur2.fetchone() == (1,)
+    rec = await cur2.fetchone()
+    assert rec == (1,)
+    if row_factory == "namedtuple_row":
+        assert rec.s == 1
     assert await cur2.fetchmany(3) == [(2,), (3,), (4,)]
     assert await cur2.fetchall() == [(5,), (6,)]
     await cur2.close()