]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Make cursor.row_factory writable
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 24 Feb 2021 15:14:13 +0000 (16:14 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 24 Feb 2021 15:14:13 +0000 (16:14 +0100)
psycopg3/psycopg3/cursor.py
psycopg3/psycopg3/server_cursor.py
tests/test_cursor.py
tests/test_cursor_async.py
tests/test_server_cursor.py
tests/test_server_cursor_async.py

index dd6e2661b5201f71f0d9ad3bcba896fede64149e..9d2b75b2932600781225ebdfedf7d80d51afd2fd 100644 (file)
@@ -51,7 +51,7 @@ class BaseCursor(Generic[ConnectionType]):
     # https://bugs.python.org/issue41451
     if sys.version_info >= (3, 7):
         __slots__ = """
-            _conn format _adapters arraysize _closed _results _pgresult _pos
+            _conn format _adapters arraysize _closed _results pgresult _pos
             _iresult _rowcount _pgq _tx _last_query _row_factory
             __weakref__
             """.split()
@@ -78,7 +78,8 @@ class BaseCursor(Generic[ConnectionType]):
 
     def _reset(self) -> None:
         self._results: List["PGresult"] = []
-        self._pgresult: Optional["PGresult"] = None
+        self.pgresult: Optional["PGresult"] = None
+        """The `~psycopg3.pq.PGresult` exposed by the cursor."""
         self._pos = 0
         self._iresult = 0
         self._rowcount = -1
@@ -89,10 +90,10 @@ class BaseCursor(Generic[ConnectionType]):
         info = pq.misc.connection_summary(self._conn.pgconn)
         if self._closed:
             status = "closed"
-        elif not self._pgresult:
-            status = "no result"
+        elif self.pgresult:
+            status = pq.ExecStatus(self.pgresult.status).name
         else:
-            status = pq.ExecStatus(self._pgresult.status).name
+            status = "no result"
         return f"<{cls} [{status}] {info} at 0x{id(self):x}>"
 
     @property
@@ -125,15 +126,6 @@ class BaseCursor(Generic[ConnectionType]):
         """The last set of parameters sent to the server, if available."""
         return self._pgq.params if self._pgq else None
 
-    @property
-    def pgresult(self) -> Optional["PGresult"]:
-        """The `~psycopg3.pq.PGresult` exposed by the cursor."""
-        return self._pgresult
-
-    @pgresult.setter
-    def pgresult(self, result: Optional["PGresult"]) -> None:
-        self._pgresult = result
-
     @property
     def description(self) -> Optional[List[Column]]:
         """
@@ -157,7 +149,7 @@ class BaseCursor(Generic[ConnectionType]):
 
         `!None` if there is no result to fetch.
         """
-        return self._pos if self._pgresult else None
+        return self._pos if self.pgresult else None
 
     def setinputsizes(self, sizes: Sequence[Any]) -> None:
         # no-op
@@ -186,6 +178,16 @@ class BaseCursor(Generic[ConnectionType]):
         else:
             return None
 
+    @property
+    def row_factory(self) -> RowFactory:
+        return self._row_factory
+
+    @row_factory.setter
+    def row_factory(self, row_factory: RowFactory) -> None:
+        self._row_factory = row_factory
+        if self.pgresult:
+            self._tx.make_row = row_factory(self)
+
     #
     # Generators for the high level operations on the cursor
     #
index 6ea9b8819750004f5604efc1b7984015dc5431d9..1ce78a5cf1f86979b7663ac7693f15954ead1119 100644 (file)
@@ -40,10 +40,10 @@ class ServerCursorHelper(Generic[ConnectionType]):
         info = pq.misc.connection_summary(cur._conn.pgconn)
         if cur._closed:
             status = "closed"
-        elif not cur._pgresult:
+        elif not cur.pgresult:
             status = "no result"
         else:
-            status = pq.ExecStatus(cur._pgresult.status).name
+            status = pq.ExecStatus(cur.pgresult.status).name
         return f"<{cls} {self.name!r} [{status}] {info} at 0x{id(cur):x}>"
 
     def _declare_gen(
index e6e39ccd539668b7dc4fe5a6a8286d75c2d5ee55..60924c32c6274f951dbb2d5817da461664e543e9 100644 (file)
@@ -8,6 +8,7 @@ import pytest
 import psycopg3
 from psycopg3 import sql
 from psycopg3.oids import postgres_types as builtins
+from psycopg3.rows import dict_row
 from psycopg3.adapt import Format
 
 
@@ -296,7 +297,10 @@ def test_row_factory(conn):
     assert cur.fetchall() == [["Xx"]]
     assert cur.nextset()
     assert cur.fetchall() == [["Yy", "Zz"]]
-    assert cur.nextset() is None
+
+    cur.scroll(-1)
+    cur.row_factory = dict_row
+    assert cur.fetchone() == {"y": "y", "z": "z"}
 
 
 def test_scroll(conn):
index 88835b1c734adbec462f351ec03f3dfe4b8c2922..7cf689a40da20aad6c0516cca68a6db795dd432f 100644 (file)
@@ -5,6 +5,7 @@ import datetime as dt
 
 import psycopg3
 from psycopg3 import sql
+from psycopg3.rows import dict_row
 from psycopg3.adapt import Format
 from .test_cursor import my_row_factory
 
@@ -302,7 +303,10 @@ async def test_row_factory(aconn):
     assert await cur.fetchall() == [["Xx"]]
     assert cur.nextset()
     assert await cur.fetchall() == [["Yy", "Zz"]]
-    assert cur.nextset() is None
+
+    await cur.scroll(-1)
+    cur.row_factory = dict_row
+    assert await cur.fetchone() == {"y": "y", "z": "z"}
 
 
 async def test_scroll(aconn):
index a4e6c012431c6389053aa9dcdf54194dc576a86d..d1d73eab56d1dc7ab3cc4d540b298ac5e7df47db 100644 (file)
@@ -2,6 +2,7 @@ import pytest
 
 from psycopg3 import errors as e
 from psycopg3.pq import Format
+from psycopg3.rows import dict_row
 
 
 def test_funny_name(conn):
@@ -159,7 +160,7 @@ def test_row_factory(conn):
         return lambda values: [n] + [-v for v in values]
 
     cur = conn.cursor("foo", row_factory=my_row_factory)
-    cur.execute("select generate_series(1, 3)", scrollable=True)
+    cur.execute("select generate_series(1, 3) as x", scrollable=True)
     rows = cur.fetchall()
     cur.scroll(0, "absolute")
     while 1:
@@ -169,6 +170,10 @@ def test_row_factory(conn):
         rows.append(row)
     assert rows == [[1, -1], [1, -2], [1, -3]] * 2
 
+    cur.scroll(0, "absolute")
+    cur.row_factory = dict_row
+    assert cur.fetchone() == {"x": 1}
+
 
 def test_rownumber(conn):
     cur = conn.cursor("foo")
index 191d3dd7f4de4fb3551f67df0a1886e44772130e..eed0c791fb2e68e34c4881c898f41216eff7537c 100644 (file)
@@ -1,6 +1,7 @@
 import pytest
 
 from psycopg3 import errors as e
+from psycopg3.rows import dict_row
 from psycopg3.pq import Format
 
 pytestmark = pytest.mark.asyncio
@@ -161,7 +162,7 @@ async def test_row_factory(aconn):
         return lambda values: [n] + [-v for v in values]
 
     cur = aconn.cursor("foo", row_factory=my_row_factory)
-    await cur.execute("select generate_series(1, 3)", scrollable=True)
+    await cur.execute("select generate_series(1, 3) as x", scrollable=True)
     rows = await cur.fetchall()
     await cur.scroll(0, "absolute")
     while 1:
@@ -171,6 +172,10 @@ async def test_row_factory(aconn):
         rows.append(row)
     assert rows == [[1, -1], [1, -2], [1, -3]] * 2
 
+    await cur.scroll(0, "absolute")
+    cur.row_factory = dict_row
+    assert await cur.fetchone() == {"x": 1}
+
 
 async def test_rownumber(aconn):
     cur = aconn.cursor("foo")