]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Introduce row_factory option in connection.cursor()
authorDenis Laxalde <denis.laxalde@dalibo.com>
Tue, 9 Feb 2021 15:51:11 +0000 (16:51 +0100)
committerDenis Laxalde <denis.laxalde@dalibo.com>
Wed, 10 Feb 2021 09:36:37 +0000 (10:36 +0100)
We add a row_factory keyword argument in connection.cursor() and cursor
classes that will be used to produce individual rows of the result set.

A RowFactory can be implemented as a class with a __call__ method
accepting raw values and initialized with a cursor instance; the
RowFactory instance is created when results are available. Type
definitions for RowFactory (and its respective RowMaker) are defined as
callback protocols so as to allow user to define a row factory without
the need for writing a class.

The default row factory returns values unchanged.

psycopg3/psycopg3/connection.py
psycopg3/psycopg3/cursor.py
psycopg3/psycopg3/proto.py
tests/test_cursor.py
tests/test_cursor_async.py

index 0727d415c94f298dbcd603413c61622cdde8ab7c..79842ae4a59f5e3081a8d98a49c8b35824cad95e 100644 (file)
@@ -29,8 +29,8 @@ from . import waiting
 from . import encodings
 from .pq import ConnStatus, ExecStatus, TransactionStatus, Format
 from .sql import Composable
-from .proto import PQGen, PQGenConn, RV, Query, Params, AdaptContext
-from .proto import ConnectionType
+from .proto import PQGen, PQGenConn, RV, RowFactory, Query, Params
+from .proto import AdaptContext, ConnectionType
 from .conninfo import make_conninfo
 from .generators import notifies
 from .transaction import Transaction, AsyncTransaction
@@ -448,7 +448,12 @@ class Connection(BaseConnection):
         """Close the database connection."""
         self.pgconn.finish()
 
-    def cursor(self, name: str = "", binary: bool = False) -> "Cursor":
+    def cursor(
+        self,
+        name: str = "",
+        binary: bool = False,
+        row_factory: RowFactory = cursor.default_row_factory,
+    ) -> "Cursor":
         """
         Return a new `Cursor` to send commands and queries to the connection.
         """
@@ -456,7 +461,7 @@ class Connection(BaseConnection):
             raise NotImplementedError
 
         format = Format.BINARY if binary else Format.TEXT
-        return self.cursor_factory(self, format=format)
+        return self.cursor_factory(self, row_factory, format=format)
 
     def execute(
         self,
@@ -584,7 +589,10 @@ class AsyncConnection(BaseConnection):
         self.pgconn.finish()
 
     async def cursor(
-        self, name: str = "", binary: bool = False
+        self,
+        name: str = "",
+        binary: bool = False,
+        row_factory: RowFactory = cursor.default_row_factory,
     ) -> "AsyncCursor":
         """
         Return a new `AsyncCursor` to send commands and queries to the connection.
@@ -593,7 +601,7 @@ class AsyncConnection(BaseConnection):
             raise NotImplementedError
 
         format = Format.BINARY if binary else Format.TEXT
-        return self.cursor_factory(self, format=format)
+        return self.cursor_factory(self, row_factory, format=format)
 
     async def execute(
         self,
index abc5d2cd4d51040401fb8ef26131458c0aba1d5d..f7bdd7c20968a07e8bcc48b3f626d940eb99049c 100644 (file)
@@ -18,6 +18,7 @@ from . import generators
 from .pq import ExecStatus, Format
 from .copy import Copy, AsyncCopy
 from .proto import ConnectionType, Query, Params, PQGen
+from .proto import Row, RowFactory, RowMaker
 from ._column import Column
 from ._queries import PostgresQuery
 from ._preparing import Prepare
@@ -43,13 +44,17 @@ else:
     execute = generators.execute
 
 
+def default_row_factory(cursor: Any) -> RowMaker:
+    return lambda values: values
+
+
 class BaseCursor(Generic[ConnectionType]):
     # Slots with __weakref__ and generic bases don't work on Py 3.6
     # https://bugs.python.org/issue41451
     if sys.version_info >= (3, 7):
         __slots__ = """
             _conn format _adapters arraysize _closed _results _pgresult _pos
-            _iresult _rowcount _pgq _tx _last_query
+            _iresult _rowcount _pgq _tx _last_query _row_factory _make_row
             __weakref__
             """.split()
 
@@ -60,11 +65,13 @@ class BaseCursor(Generic[ConnectionType]):
     def __init__(
         self,
         connection: ConnectionType,
+        row_factory: RowFactory,
         format: Format = Format.TEXT,
     ):
         self._conn = connection
         self.format = format
         self._adapters = adapt.AdaptersMap(connection.adapters)
+        self._row_factory = row_factory
         self.arraysize = 1
         self._closed = False
         self._last_query: Optional[Query] = None
@@ -73,6 +80,7 @@ class BaseCursor(Generic[ConnectionType]):
     def _reset(self) -> None:
         self._results: List["PGresult"] = []
         self._pgresult: Optional["PGresult"] = None
+        self._make_row: Optional[RowMaker] = None
         self._pos = 0
         self._iresult = 0
         self._rowcount = -1
@@ -261,6 +269,7 @@ class BaseCursor(Generic[ConnectionType]):
             return None
 
         elif res.status == ExecStatus.SINGLE_TUPLE:
+            self._make_row = self._row_factory(self)
             self.pgresult = res  # will set it on the transformer too
             # TODO: the transformer may do excessive work here: create a
             # path that doesn't clear the loaders every time.
@@ -364,6 +373,7 @@ class BaseCursor(Generic[ConnectionType]):
 
         self._results = list(results)
         self.pgresult = results[0]
+        self._make_row = self._row_factory(self)
         nrows = self.pgresult.command_tuples
         if nrows is not None:
             if self._rowcount < 0:
@@ -478,7 +488,7 @@ class Cursor(BaseCursor["Connection"]):
 
     def stream(
         self, query: Query, params: Optional[Params] = None
-    ) -> Iterator[Sequence[Any]]:
+    ) -> Iterator[Row]:
         """
         Iterate row-by-row on a result from the database.
         """
@@ -487,9 +497,10 @@ class Cursor(BaseCursor["Connection"]):
             while self._conn.wait(self._stream_fetchone_gen()):
                 rec = self._tx.load_row(0)
                 assert rec is not None
-                yield rec
+                assert self._make_row is not None
+                yield self._make_row(rec)
 
-    def fetchone(self) -> Optional[Sequence[Any]]:
+    def fetchone(self) -> Optional[Row]:
         """
         Return the next record from the current recordset.
 
@@ -499,9 +510,11 @@ class Cursor(BaseCursor["Connection"]):
         record = self._tx.load_row(self._pos)
         if record is not None:
             self._pos += 1
+            assert self._make_row is not None
+            return self._make_row(record)
         return record
 
-    def fetchmany(self, size: int = 0) -> Sequence[Sequence[Any]]:
+    def fetchmany(self, size: int = 0) -> Sequence[Row]:
         """
         Return the next *size* records from the current recordset.
 
@@ -516,9 +529,10 @@ class Cursor(BaseCursor["Connection"]):
             self._pos, min(self._pos + size, self.pgresult.ntuples)
         )
         self._pos += len(records)
-        return records
+        assert self._make_row is not None
+        return [self._make_row(r) for r in records]
 
-    def fetchall(self) -> Sequence[Sequence[Any]]:
+    def fetchall(self) -> Sequence[Row]:
         """
         Return all the remaining records from the current recordset.
         """
@@ -526,9 +540,10 @@ class Cursor(BaseCursor["Connection"]):
         assert self.pgresult
         records = self._tx.load_rows(self._pos, self.pgresult.ntuples)
         self._pos += self.pgresult.ntuples
-        return records
+        assert self._make_row is not None
+        return [self._make_row(r) for r in records]
 
-    def __iter__(self) -> Iterator[Sequence[Any]]:
+    def __iter__(self) -> Iterator[Row]:
         self._check_result()
 
         load = self._tx.load_row
@@ -538,7 +553,8 @@ class Cursor(BaseCursor["Connection"]):
             if row is None:
                 break
             self._pos += 1
-            yield row
+            assert self._make_row is not None
+            yield self._make_row(row)
 
     @contextmanager
     def copy(self, statement: Query) -> Iterator[Copy]:
@@ -591,22 +607,25 @@ class AsyncCursor(BaseCursor["AsyncConnection"]):
 
     async def stream(
         self, query: Query, params: Optional[Params] = None
-    ) -> AsyncIterator[Sequence[Any]]:
+    ) -> AsyncIterator[Row]:
         async with self._conn.lock:
             await self._conn.wait(self._stream_send_gen(query, params))
             while await self._conn.wait(self._stream_fetchone_gen()):
                 rec = self._tx.load_row(0)
                 assert rec is not None
-                yield rec
+                assert self._make_row is not None
+                yield self._make_row(rec)
 
-    async def fetchone(self) -> Optional[Sequence[Any]]:
+    async def fetchone(self) -> Optional[Row]:
         self._check_result()
         rv = self._tx.load_row(self._pos)
         if rv is not None:
             self._pos += 1
+            assert self._make_row is not None
+            return self._make_row(rv)
         return rv
 
-    async def fetchmany(self, size: int = 0) -> Sequence[Sequence[Any]]:
+    async def fetchmany(self, size: int = 0) -> List[Row]:
         self._check_result()
         assert self.pgresult
 
@@ -616,16 +635,18 @@ class AsyncCursor(BaseCursor["AsyncConnection"]):
             self._pos, min(self._pos + size, self.pgresult.ntuples)
         )
         self._pos += len(records)
-        return records
+        assert self._make_row is not None
+        return [self._make_row(r) for r in records]
 
-    async def fetchall(self) -> Sequence[Sequence[Any]]:
+    async def fetchall(self) -> List[Row]:
         self._check_result()
         assert self.pgresult
         records = self._tx.load_rows(self._pos, self.pgresult.ntuples)
         self._pos += self.pgresult.ntuples
-        return records
+        assert self._make_row is not None
+        return [self._make_row(r) for r in records]
 
-    async def __aiter__(self) -> AsyncIterator[Sequence[Any]]:
+    async def __aiter__(self) -> AsyncIterator[Row]:
         self._check_result()
 
         load = self._tx.load_row
@@ -635,7 +656,8 @@ class AsyncCursor(BaseCursor["AsyncConnection"]):
             if row is None:
                 break
             self._pos += 1
-            yield row
+            assert self._make_row is not None
+            yield self._make_row(row)
 
     @asynccontextmanager
     async def copy(self, statement: Query) -> AsyncIterator[AsyncCopy]:
index b1e966cb5d3837d9d70cccfe96124db1fa32eca9..7edaf9e5c7b4523d4eb4a4ca01b5f9938d73c96b 100644 (file)
@@ -14,6 +14,7 @@ from ._enums import Format
 
 if TYPE_CHECKING:
     from .connection import BaseConnection
+    from .cursor import BaseCursor
     from .adapt import Dumper, Loader, AdaptersMap
     from .waiting import Wait, Ready
     from .sql import Composable
@@ -115,3 +116,18 @@ class Transformer(Protocol):
 
     def get_loader(self, oid: int, format: pq.Format) -> "Loader":
         ...
+
+
+# Row factories
+
+Row = TypeVar("Row")
+
+
+class RowMaker(Protocol):
+    def __call__(self, __values: Sequence[Any]) -> Row:
+        ...
+
+
+class RowFactory(Protocol):
+    def __call__(self, __cursor: "BaseCursor[ConnectionType]") -> RowMaker:
+        ...
index 481dbf5e496098c3208bc8d038714a862d1bc938..cdd9f6de1901d10d0b1d7249703a1788c66f8d88 100644 (file)
@@ -263,6 +263,16 @@ def test_iter_stop(conn):
     assert list(cur) == []
 
 
+def test_row_factory(conn):
+    def my_row_factory(cur):
+        return lambda values: [-v for v in values]
+
+    cur = conn.cursor(row_factory=my_row_factory)
+    cur.execute("select generate_series(1, 3)")
+    r = cur.fetchall()
+    assert r == [[-1], [-2], [-3]]
+
+
 def test_query_params_execute(conn):
     cur = conn.cursor()
     assert cur.query is None
index 6285aa5b57157ebc9375d0e0741f465a1e747ed7..1aa4f219837ea685f6a9d1fc7df4c7dda161e5f5 100644 (file)
@@ -268,6 +268,25 @@ async def test_iter_stop(aconn):
         assert False
 
 
+async def test_row_factory(aconn):
+    def my_row_factory(cursor):
+        assert cursor.description is not None
+        titles = [c.name for c in cursor.description]
+
+        def mkrow(values):
+            return [
+                f"{value.upper()}{title}"
+                for title, value in zip(titles, values)
+            ]
+
+        return mkrow
+
+    cur = await aconn.cursor(row_factory=my_row_factory)
+    await cur.execute("select 'foo' as bar")
+    (r,) = await cur.fetchone()
+    assert r == "FOObar"
+
+
 async def test_query_params_execute(aconn):
     cur = await aconn.cursor()
     assert cur.query is None