From: Daniele Varrazzo Date: Wed, 24 Feb 2021 15:14:13 +0000 (+0100) Subject: Make cursor.row_factory writable X-Git-Tag: 3.0.dev0~106^2~4 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=721a0d1010e8e352805dc3c562520ba96005d548;p=thirdparty%2Fpsycopg.git Make cursor.row_factory writable --- diff --git a/psycopg3/psycopg3/cursor.py b/psycopg3/psycopg3/cursor.py index dd6e2661b..9d2b75b29 100644 --- a/psycopg3/psycopg3/cursor.py +++ b/psycopg3/psycopg3/cursor.py @@ -51,7 +51,7 @@ class BaseCursor(Generic[ConnectionType]): # https://bugs.python.org/issue41451 if sys.version_info >= (3, 7): __slots__ = """ - _conn format _adapters arraysize _closed _results _pgresult _pos + _conn format _adapters arraysize _closed _results pgresult _pos _iresult _rowcount _pgq _tx _last_query _row_factory __weakref__ """.split() @@ -78,7 +78,8 @@ class BaseCursor(Generic[ConnectionType]): def _reset(self) -> None: self._results: List["PGresult"] = [] - self._pgresult: Optional["PGresult"] = None + self.pgresult: Optional["PGresult"] = None + """The `~psycopg3.pq.PGresult` exposed by the cursor.""" self._pos = 0 self._iresult = 0 self._rowcount = -1 @@ -89,10 +90,10 @@ class BaseCursor(Generic[ConnectionType]): info = pq.misc.connection_summary(self._conn.pgconn) if self._closed: status = "closed" - elif not self._pgresult: - status = "no result" + elif self.pgresult: + status = pq.ExecStatus(self.pgresult.status).name else: - status = pq.ExecStatus(self._pgresult.status).name + status = "no result" return f"<{cls} [{status}] {info} at 0x{id(self):x}>" @property @@ -125,15 +126,6 @@ class BaseCursor(Generic[ConnectionType]): """The last set of parameters sent to the server, if available.""" return self._pgq.params if self._pgq else None - @property - def pgresult(self) -> Optional["PGresult"]: - """The `~psycopg3.pq.PGresult` exposed by the cursor.""" - return self._pgresult - - @pgresult.setter - def pgresult(self, result: Optional["PGresult"]) -> None: - self._pgresult = result - @property def description(self) -> Optional[List[Column]]: """ @@ -157,7 +149,7 @@ class BaseCursor(Generic[ConnectionType]): `!None` if there is no result to fetch. """ - return self._pos if self._pgresult else None + return self._pos if self.pgresult else None def setinputsizes(self, sizes: Sequence[Any]) -> None: # no-op @@ -186,6 +178,16 @@ class BaseCursor(Generic[ConnectionType]): else: return None + @property + def row_factory(self) -> RowFactory: + return self._row_factory + + @row_factory.setter + def row_factory(self, row_factory: RowFactory) -> None: + self._row_factory = row_factory + if self.pgresult: + self._tx.make_row = row_factory(self) + # # Generators for the high level operations on the cursor # diff --git a/psycopg3/psycopg3/server_cursor.py b/psycopg3/psycopg3/server_cursor.py index 6ea9b8819..1ce78a5cf 100644 --- a/psycopg3/psycopg3/server_cursor.py +++ b/psycopg3/psycopg3/server_cursor.py @@ -40,10 +40,10 @@ class ServerCursorHelper(Generic[ConnectionType]): info = pq.misc.connection_summary(cur._conn.pgconn) if cur._closed: status = "closed" - elif not cur._pgresult: + elif not cur.pgresult: status = "no result" else: - status = pq.ExecStatus(cur._pgresult.status).name + status = pq.ExecStatus(cur.pgresult.status).name return f"<{cls} {self.name!r} [{status}] {info} at 0x{id(cur):x}>" def _declare_gen( diff --git a/tests/test_cursor.py b/tests/test_cursor.py index e6e39ccd5..60924c32c 100644 --- a/tests/test_cursor.py +++ b/tests/test_cursor.py @@ -8,6 +8,7 @@ import pytest import psycopg3 from psycopg3 import sql from psycopg3.oids import postgres_types as builtins +from psycopg3.rows import dict_row from psycopg3.adapt import Format @@ -296,7 +297,10 @@ def test_row_factory(conn): assert cur.fetchall() == [["Xx"]] assert cur.nextset() assert cur.fetchall() == [["Yy", "Zz"]] - assert cur.nextset() is None + + cur.scroll(-1) + cur.row_factory = dict_row + assert cur.fetchone() == {"y": "y", "z": "z"} def test_scroll(conn): diff --git a/tests/test_cursor_async.py b/tests/test_cursor_async.py index 88835b1c7..7cf689a40 100644 --- a/tests/test_cursor_async.py +++ b/tests/test_cursor_async.py @@ -5,6 +5,7 @@ import datetime as dt import psycopg3 from psycopg3 import sql +from psycopg3.rows import dict_row from psycopg3.adapt import Format from .test_cursor import my_row_factory @@ -302,7 +303,10 @@ async def test_row_factory(aconn): assert await cur.fetchall() == [["Xx"]] assert cur.nextset() assert await cur.fetchall() == [["Yy", "Zz"]] - assert cur.nextset() is None + + await cur.scroll(-1) + cur.row_factory = dict_row + assert await cur.fetchone() == {"y": "y", "z": "z"} async def test_scroll(aconn): diff --git a/tests/test_server_cursor.py b/tests/test_server_cursor.py index a4e6c0124..d1d73eab5 100644 --- a/tests/test_server_cursor.py +++ b/tests/test_server_cursor.py @@ -2,6 +2,7 @@ import pytest from psycopg3 import errors as e from psycopg3.pq import Format +from psycopg3.rows import dict_row def test_funny_name(conn): @@ -159,7 +160,7 @@ def test_row_factory(conn): return lambda values: [n] + [-v for v in values] cur = conn.cursor("foo", row_factory=my_row_factory) - cur.execute("select generate_series(1, 3)", scrollable=True) + cur.execute("select generate_series(1, 3) as x", scrollable=True) rows = cur.fetchall() cur.scroll(0, "absolute") while 1: @@ -169,6 +170,10 @@ def test_row_factory(conn): rows.append(row) assert rows == [[1, -1], [1, -2], [1, -3]] * 2 + cur.scroll(0, "absolute") + cur.row_factory = dict_row + assert cur.fetchone() == {"x": 1} + def test_rownumber(conn): cur = conn.cursor("foo") diff --git a/tests/test_server_cursor_async.py b/tests/test_server_cursor_async.py index 191d3dd7f..eed0c791f 100644 --- a/tests/test_server_cursor_async.py +++ b/tests/test_server_cursor_async.py @@ -1,6 +1,7 @@ import pytest from psycopg3 import errors as e +from psycopg3.rows import dict_row from psycopg3.pq import Format pytestmark = pytest.mark.asyncio @@ -161,7 +162,7 @@ async def test_row_factory(aconn): return lambda values: [n] + [-v for v in values] cur = aconn.cursor("foo", row_factory=my_row_factory) - await cur.execute("select generate_series(1, 3)", scrollable=True) + await cur.execute("select generate_series(1, 3) as x", scrollable=True) rows = await cur.fetchall() await cur.scroll(0, "absolute") while 1: @@ -171,6 +172,10 @@ async def test_row_factory(aconn): rows.append(row) assert rows == [[1, -1], [1, -2], [1, -3]] * 2 + await cur.scroll(0, "absolute") + cur.row_factory = dict_row + assert await cur.fetchone() == {"x": 1} + async def test_rownumber(aconn): cur = aconn.cursor("foo")