From: Daniele Varrazzo Date: Mon, 19 Dec 2022 22:46:42 +0000 (+0000) Subject: fix: consider cursor description results in row factories X-Git-Tag: pool-3.1.5~4^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=ea643991d041860a07be4472d6d473eb9a11b7fc;p=thirdparty%2Fpsycopg.git fix: consider cursor description results in row factories Fix #464 --- diff --git a/docs/news.rst b/docs/news.rst index 5d7837193..e4fd81ec2 100644 --- a/docs/news.rst +++ b/docs/news.rst @@ -7,6 +7,15 @@ ``psycopg`` release notes ========================= +Future releases +--------------- + +Psycopg 3.1.7 (unreleased) +^^^^^^^^^^^^^^^^^^^^^^^^^^ + +- Fix server-side cursors using row factories (:ticket:`#464`). + + Current release --------------- diff --git a/psycopg/psycopg/rows.py b/psycopg/psycopg/rows.py index 4f96a1af0..cb28b57ac 100644 --- a/psycopg/psycopg/rows.py +++ b/psycopg/psycopg/rows.py @@ -20,6 +20,7 @@ if TYPE_CHECKING: from .cursor_async import AsyncCursor from psycopg.pq.abc import PGresult +COMMAND_OK = pq.ExecStatus.COMMAND_OK TUPLES_OK = pq.ExecStatus.TUPLES_OK SINGLE_TUPLE = pq.ExecStatus.SINGLE_TUPLE @@ -244,7 +245,12 @@ def _get_nfields(res: "PGresult") -> Optional[int]: """ nfields = res.nfields - if res.status == TUPLES_OK or res.status == SINGLE_TUPLE: + if ( + res.status == TUPLES_OK + or res.status == SINGLE_TUPLE + # "describe" in named cursors + or (res.status == COMMAND_OK and nfields) + ): return nfields else: return None diff --git a/tests/test_server_cursor.py b/tests/test_server_cursor.py index e1f8ce392..f7b6c8ed6 100644 --- a/tests/test_server_cursor.py +++ b/tests/test_server_cursor.py @@ -331,6 +331,25 @@ def test_no_result(conn): assert cur.fetchall() == [] +@pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"]) +def test_standard_row_factory(conn, row_factory): + if row_factory == "tuple_row": + getter = lambda r: r[0] # noqa: E731 + elif row_factory == "dict_row": + getter = lambda r: r["bar"] # noqa: E731 + elif row_factory == "namedtuple_row": + getter = lambda r: r.bar # noqa: E731 + else: + assert False, row_factory + + row_factory = getattr(rows, row_factory) + with conn.cursor("foo", row_factory=row_factory) as cur: + cur.execute("select generate_series(1, 5) as bar") + assert getter(cur.fetchone()) == 1 + assert list(map(getter, cur.fetchmany(2))) == [2, 3] + assert list(map(getter, cur.fetchall())) == [4, 5] + + @pytest.mark.crdb_skip("scroll cursor") def test_row_factory(conn): n = 0 @@ -479,13 +498,17 @@ def test_hold(conn): assert curs.fetchone() == (1,) -def test_steal_cursor(conn): +@pytest.mark.parametrize("row_factory", ["tuple_row", "namedtuple_row"]) +def test_steal_cursor(conn, row_factory): cur1 = conn.cursor() - cur1.execute("declare test cursor for select generate_series(1, 6)") + cur1.execute("declare test cursor for select generate_series(1, 6) as s") - cur2 = conn.cursor("test") + cur2 = conn.cursor("test", row_factory=getattr(rows, row_factory)) # can call fetch without execute - assert cur2.fetchone() == (1,) + rec = cur2.fetchone() + assert rec == (1,) + if row_factory == "namedtuple_row": + assert rec.s == 1 assert cur2.fetchmany(3) == [(2,), (3,), (4,)] assert cur2.fetchall() == [(5,), (6,)] cur2.close() diff --git a/tests/test_server_cursor_async.py b/tests/test_server_cursor_async.py index 0a795bb82..21b434503 100644 --- a/tests/test_server_cursor_async.py +++ b/tests/test_server_cursor_async.py @@ -342,6 +342,25 @@ async def test_no_result(aconn): assert (await cur.fetchall()) == [] +@pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"]) +async def test_standard_row_factory(aconn, row_factory): + if row_factory == "tuple_row": + getter = lambda r: r[0] # noqa: E731 + elif row_factory == "dict_row": + getter = lambda r: r["bar"] # noqa: E731 + elif row_factory == "namedtuple_row": + getter = lambda r: r.bar # noqa: E731 + else: + assert False, row_factory + + row_factory = getattr(rows, row_factory) + async with aconn.cursor("foo", row_factory=row_factory) as cur: + await cur.execute("select generate_series(1, 5) as bar") + assert getter(await cur.fetchone()) == 1 + assert list(map(getter, await cur.fetchmany(2))) == [2, 3] + assert list(map(getter, await cur.fetchall())) == [4, 5] + + @pytest.mark.crdb_skip("scroll cursor") async def test_row_factory(aconn): n = 0 @@ -495,15 +514,19 @@ async def test_hold(aconn): assert await curs.fetchone() == (1,) -async def test_steal_cursor(aconn): +@pytest.mark.parametrize("row_factory", ["tuple_row", "namedtuple_row"]) +async def test_steal_cursor(aconn, row_factory): cur1 = aconn.cursor() await cur1.execute( - "declare test cursor without hold for select generate_series(1, 6)" + "declare test cursor without hold for select generate_series(1, 6) as s" ) - cur2 = aconn.cursor("test") + cur2 = aconn.cursor("test", row_factory=getattr(rows, row_factory)) # can call fetch without execute - assert await cur2.fetchone() == (1,) + rec = await cur2.fetchone() + assert rec == (1,) + if row_factory == "namedtuple_row": + assert rec.s == 1 assert await cur2.fetchmany(3) == [(2,), (3,), (4,)] assert await cur2.fetchall() == [(5,), (6,)] await cur2.close()