From: Denis Laxalde Date: Tue, 9 Feb 2021 15:51:11 +0000 (+0100) Subject: Introduce row_factory option in connection.cursor() X-Git-Tag: 3.0.dev0~106^2~28 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=771d79321126e426a3878957c486619c9e0a9405;p=thirdparty%2Fpsycopg.git Introduce row_factory option in connection.cursor() We add a row_factory keyword argument in connection.cursor() and cursor classes that will be used to produce individual rows of the result set. A RowFactory can be implemented as a class with a __call__ method accepting raw values and initialized with a cursor instance; the RowFactory instance is created when results are available. Type definitions for RowFactory (and its respective RowMaker) are defined as callback protocols so as to allow user to define a row factory without the need for writing a class. The default row factory returns values unchanged. --- diff --git a/psycopg3/psycopg3/connection.py b/psycopg3/psycopg3/connection.py index 0727d415c..79842ae4a 100644 --- a/psycopg3/psycopg3/connection.py +++ b/psycopg3/psycopg3/connection.py @@ -29,8 +29,8 @@ from . import waiting from . import encodings from .pq import ConnStatus, ExecStatus, TransactionStatus, Format from .sql import Composable -from .proto import PQGen, PQGenConn, RV, Query, Params, AdaptContext -from .proto import ConnectionType +from .proto import PQGen, PQGenConn, RV, RowFactory, Query, Params +from .proto import AdaptContext, ConnectionType from .conninfo import make_conninfo from .generators import notifies from .transaction import Transaction, AsyncTransaction @@ -448,7 +448,12 @@ class Connection(BaseConnection): """Close the database connection.""" self.pgconn.finish() - def cursor(self, name: str = "", binary: bool = False) -> "Cursor": + def cursor( + self, + name: str = "", + binary: bool = False, + row_factory: RowFactory = cursor.default_row_factory, + ) -> "Cursor": """ Return a new `Cursor` to send commands and queries to the connection. """ @@ -456,7 +461,7 @@ class Connection(BaseConnection): raise NotImplementedError format = Format.BINARY if binary else Format.TEXT - return self.cursor_factory(self, format=format) + return self.cursor_factory(self, row_factory, format=format) def execute( self, @@ -584,7 +589,10 @@ class AsyncConnection(BaseConnection): self.pgconn.finish() async def cursor( - self, name: str = "", binary: bool = False + self, + name: str = "", + binary: bool = False, + row_factory: RowFactory = cursor.default_row_factory, ) -> "AsyncCursor": """ Return a new `AsyncCursor` to send commands and queries to the connection. @@ -593,7 +601,7 @@ class AsyncConnection(BaseConnection): raise NotImplementedError format = Format.BINARY if binary else Format.TEXT - return self.cursor_factory(self, format=format) + return self.cursor_factory(self, row_factory, format=format) async def execute( self, diff --git a/psycopg3/psycopg3/cursor.py b/psycopg3/psycopg3/cursor.py index abc5d2cd4..f7bdd7c20 100644 --- a/psycopg3/psycopg3/cursor.py +++ b/psycopg3/psycopg3/cursor.py @@ -18,6 +18,7 @@ from . import generators from .pq import ExecStatus, Format from .copy import Copy, AsyncCopy from .proto import ConnectionType, Query, Params, PQGen +from .proto import Row, RowFactory, RowMaker from ._column import Column from ._queries import PostgresQuery from ._preparing import Prepare @@ -43,13 +44,17 @@ else: execute = generators.execute +def default_row_factory(cursor: Any) -> RowMaker: + return lambda values: values + + class BaseCursor(Generic[ConnectionType]): # Slots with __weakref__ and generic bases don't work on Py 3.6 # https://bugs.python.org/issue41451 if sys.version_info >= (3, 7): __slots__ = """ _conn format _adapters arraysize _closed _results _pgresult _pos - _iresult _rowcount _pgq _tx _last_query + _iresult _rowcount _pgq _tx _last_query _row_factory _make_row __weakref__ """.split() @@ -60,11 +65,13 @@ class BaseCursor(Generic[ConnectionType]): def __init__( self, connection: ConnectionType, + row_factory: RowFactory, format: Format = Format.TEXT, ): self._conn = connection self.format = format self._adapters = adapt.AdaptersMap(connection.adapters) + self._row_factory = row_factory self.arraysize = 1 self._closed = False self._last_query: Optional[Query] = None @@ -73,6 +80,7 @@ class BaseCursor(Generic[ConnectionType]): def _reset(self) -> None: self._results: List["PGresult"] = [] self._pgresult: Optional["PGresult"] = None + self._make_row: Optional[RowMaker] = None self._pos = 0 self._iresult = 0 self._rowcount = -1 @@ -261,6 +269,7 @@ class BaseCursor(Generic[ConnectionType]): return None elif res.status == ExecStatus.SINGLE_TUPLE: + self._make_row = self._row_factory(self) self.pgresult = res # will set it on the transformer too # TODO: the transformer may do excessive work here: create a # path that doesn't clear the loaders every time. @@ -364,6 +373,7 @@ class BaseCursor(Generic[ConnectionType]): self._results = list(results) self.pgresult = results[0] + self._make_row = self._row_factory(self) nrows = self.pgresult.command_tuples if nrows is not None: if self._rowcount < 0: @@ -478,7 +488,7 @@ class Cursor(BaseCursor["Connection"]): def stream( self, query: Query, params: Optional[Params] = None - ) -> Iterator[Sequence[Any]]: + ) -> Iterator[Row]: """ Iterate row-by-row on a result from the database. """ @@ -487,9 +497,10 @@ class Cursor(BaseCursor["Connection"]): while self._conn.wait(self._stream_fetchone_gen()): rec = self._tx.load_row(0) assert rec is not None - yield rec + assert self._make_row is not None + yield self._make_row(rec) - def fetchone(self) -> Optional[Sequence[Any]]: + def fetchone(self) -> Optional[Row]: """ Return the next record from the current recordset. @@ -499,9 +510,11 @@ class Cursor(BaseCursor["Connection"]): record = self._tx.load_row(self._pos) if record is not None: self._pos += 1 + assert self._make_row is not None + return self._make_row(record) return record - def fetchmany(self, size: int = 0) -> Sequence[Sequence[Any]]: + def fetchmany(self, size: int = 0) -> Sequence[Row]: """ Return the next *size* records from the current recordset. @@ -516,9 +529,10 @@ class Cursor(BaseCursor["Connection"]): self._pos, min(self._pos + size, self.pgresult.ntuples) ) self._pos += len(records) - return records + assert self._make_row is not None + return [self._make_row(r) for r in records] - def fetchall(self) -> Sequence[Sequence[Any]]: + def fetchall(self) -> Sequence[Row]: """ Return all the remaining records from the current recordset. """ @@ -526,9 +540,10 @@ class Cursor(BaseCursor["Connection"]): assert self.pgresult records = self._tx.load_rows(self._pos, self.pgresult.ntuples) self._pos += self.pgresult.ntuples - return records + assert self._make_row is not None + return [self._make_row(r) for r in records] - def __iter__(self) -> Iterator[Sequence[Any]]: + def __iter__(self) -> Iterator[Row]: self._check_result() load = self._tx.load_row @@ -538,7 +553,8 @@ class Cursor(BaseCursor["Connection"]): if row is None: break self._pos += 1 - yield row + assert self._make_row is not None + yield self._make_row(row) @contextmanager def copy(self, statement: Query) -> Iterator[Copy]: @@ -591,22 +607,25 @@ class AsyncCursor(BaseCursor["AsyncConnection"]): async def stream( self, query: Query, params: Optional[Params] = None - ) -> AsyncIterator[Sequence[Any]]: + ) -> AsyncIterator[Row]: async with self._conn.lock: await self._conn.wait(self._stream_send_gen(query, params)) while await self._conn.wait(self._stream_fetchone_gen()): rec = self._tx.load_row(0) assert rec is not None - yield rec + assert self._make_row is not None + yield self._make_row(rec) - async def fetchone(self) -> Optional[Sequence[Any]]: + async def fetchone(self) -> Optional[Row]: self._check_result() rv = self._tx.load_row(self._pos) if rv is not None: self._pos += 1 + assert self._make_row is not None + return self._make_row(rv) return rv - async def fetchmany(self, size: int = 0) -> Sequence[Sequence[Any]]: + async def fetchmany(self, size: int = 0) -> List[Row]: self._check_result() assert self.pgresult @@ -616,16 +635,18 @@ class AsyncCursor(BaseCursor["AsyncConnection"]): self._pos, min(self._pos + size, self.pgresult.ntuples) ) self._pos += len(records) - return records + assert self._make_row is not None + return [self._make_row(r) for r in records] - async def fetchall(self) -> Sequence[Sequence[Any]]: + async def fetchall(self) -> List[Row]: self._check_result() assert self.pgresult records = self._tx.load_rows(self._pos, self.pgresult.ntuples) self._pos += self.pgresult.ntuples - return records + assert self._make_row is not None + return [self._make_row(r) for r in records] - async def __aiter__(self) -> AsyncIterator[Sequence[Any]]: + async def __aiter__(self) -> AsyncIterator[Row]: self._check_result() load = self._tx.load_row @@ -635,7 +656,8 @@ class AsyncCursor(BaseCursor["AsyncConnection"]): if row is None: break self._pos += 1 - yield row + assert self._make_row is not None + yield self._make_row(row) @asynccontextmanager async def copy(self, statement: Query) -> AsyncIterator[AsyncCopy]: diff --git a/psycopg3/psycopg3/proto.py b/psycopg3/psycopg3/proto.py index b1e966cb5..7edaf9e5c 100644 --- a/psycopg3/psycopg3/proto.py +++ b/psycopg3/psycopg3/proto.py @@ -14,6 +14,7 @@ from ._enums import Format if TYPE_CHECKING: from .connection import BaseConnection + from .cursor import BaseCursor from .adapt import Dumper, Loader, AdaptersMap from .waiting import Wait, Ready from .sql import Composable @@ -115,3 +116,18 @@ class Transformer(Protocol): def get_loader(self, oid: int, format: pq.Format) -> "Loader": ... + + +# Row factories + +Row = TypeVar("Row") + + +class RowMaker(Protocol): + def __call__(self, __values: Sequence[Any]) -> Row: + ... + + +class RowFactory(Protocol): + def __call__(self, __cursor: "BaseCursor[ConnectionType]") -> RowMaker: + ... diff --git a/tests/test_cursor.py b/tests/test_cursor.py index 481dbf5e4..cdd9f6de1 100644 --- a/tests/test_cursor.py +++ b/tests/test_cursor.py @@ -263,6 +263,16 @@ def test_iter_stop(conn): assert list(cur) == [] +def test_row_factory(conn): + def my_row_factory(cur): + return lambda values: [-v for v in values] + + cur = conn.cursor(row_factory=my_row_factory) + cur.execute("select generate_series(1, 3)") + r = cur.fetchall() + assert r == [[-1], [-2], [-3]] + + def test_query_params_execute(conn): cur = conn.cursor() assert cur.query is None diff --git a/tests/test_cursor_async.py b/tests/test_cursor_async.py index 6285aa5b5..1aa4f2198 100644 --- a/tests/test_cursor_async.py +++ b/tests/test_cursor_async.py @@ -268,6 +268,25 @@ async def test_iter_stop(aconn): assert False +async def test_row_factory(aconn): + def my_row_factory(cursor): + assert cursor.description is not None + titles = [c.name for c in cursor.description] + + def mkrow(values): + return [ + f"{value.upper()}{title}" + for title, value in zip(titles, values) + ] + + return mkrow + + cur = await aconn.cursor(row_factory=my_row_factory) + await cur.execute("select 'foo' as bar") + (r,) = await cur.fetchone() + assert r == "FOObar" + + async def test_query_params_execute(aconn): cur = await aconn.cursor() assert cur.query is None