From: Daniele Varrazzo Date: Fri, 12 Feb 2021 01:44:06 +0000 (+0100) Subject: Fix setting row maker on nextset() X-Git-Tag: 3.0.dev0~106^2~17 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=7d341d0e363cf3e38f646468cf3c26520c22af97;p=thirdparty%2Fpsycopg.git Fix setting row maker on nextset() The test was broken and didn't test that the row_factory function was really called. --- diff --git a/psycopg3/psycopg3/cursor.py b/psycopg3/psycopg3/cursor.py index db87c83d2..e281cb669 100644 --- a/psycopg3/psycopg3/cursor.py +++ b/psycopg3/psycopg3/cursor.py @@ -177,6 +177,8 @@ class BaseCursor(Generic[ConnectionType]): if self._iresult < len(self._results): self.pgresult = self._results[self._iresult] self._tx.set_pgresult(self._results[self._iresult]) + if self._row_factory: + self._tx.make_row = self._row_factory(self) self._pos = 0 nrows = self.pgresult.command_tuples self._rowcount = nrows if nrows is not None else -1 diff --git a/tests/test_cursor.py b/tests/test_cursor.py index 7ff25f0a3..e6e39ccd5 100644 --- a/tests/test_cursor.py +++ b/tests/test_cursor.py @@ -287,18 +287,15 @@ def test_iter_stop(conn): def test_row_factory(conn): - def my_row_factory(cur): - return lambda values: [-v for v in values] - cur = conn.cursor(row_factory=my_row_factory) - cur.execute("select generate_series(1, 3)") - r = cur.fetchall() - assert r == [[-1], [-2], [-3]] + cur.execute("select 'foo' as bar") + (r,) = cur.fetchone() + assert r == "FOObar" - cur.execute("select 42; select generate_series(1,3)") - assert cur.fetchall() == [[-42]] + cur.execute("select 'x' as x; select 'y' as y, 'z' as z") + assert cur.fetchall() == [["Xx"]] assert cur.nextset() - assert cur.fetchall() == [[-1], [-2], [-3]] + assert cur.fetchall() == [["Yy", "Zz"]] assert cur.nextset() is None @@ -569,3 +566,15 @@ def test_leak(dsn, faker, fmt, fetch): assert ( n[0] == n[1] == n[2] ), f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}" + + +def my_row_factory(cursor): + assert cursor.description is not None + titles = [c.name for c in cursor.description] + + def mkrow(values): + return [ + f"{value.upper()}{title}" for title, value in zip(titles, values) + ] + + return mkrow diff --git a/tests/test_cursor_async.py b/tests/test_cursor_async.py index e56777323..88835b1c7 100644 --- a/tests/test_cursor_async.py +++ b/tests/test_cursor_async.py @@ -6,6 +6,7 @@ import datetime as dt import psycopg3 from psycopg3 import sql from psycopg3.adapt import Format +from .test_cursor import my_row_factory pytestmark = pytest.mark.asyncio @@ -292,17 +293,6 @@ async def test_iter_stop(aconn): async def test_row_factory(aconn): - def my_row_factory(cursor): - def mkrow(values): - assert cursor.description is not None - titles = [c.name for c in cursor.description] - return [ - f"{value.upper()}{title}" - for title, value in zip(titles, values) - ] - - return mkrow - cur = aconn.cursor(row_factory=my_row_factory) await cur.execute("select 'foo' as bar") (r,) = await cur.fetchone()