From: Daniele Varrazzo Date: Thu, 5 Aug 2021 19:39:57 +0000 (+0100) Subject: Add psycopg.rows.args_row and kwargs_row factories X-Git-Tag: 3.0.dev2~10^2~1 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=b5e42ad73335c736aee1ec7048019ab559c343e0;p=thirdparty%2Fpsycopg.git Add psycopg.rows.args_row and kwargs_row factories --- diff --git a/docs/advanced/rows.rst b/docs/advanced/rows.rst index 11faebdfd..3933e5bcb 100644 --- a/docs/advanced/rows.rst +++ b/docs/advanced/rows.rst @@ -31,9 +31,8 @@ callable (formally the `~psycopg.rows.RowMaker` protocol) accepting a .. autoclass:: psycopg.rows.AsyncRowFactory() - .. method:: __call__(cursor: AsyncCursor[Row]) -> RowMaker[Row] +.. autoclass:: psycopg.rows.BaseRowFactory() - Inspect the result on a cursor and return a `RowMaker` to convert rows. Note that it's easy to implement an object implementing both `!RowFactory` and `!AsyncRowFactory`: usually, everything you need to implement a row factory is @@ -89,14 +88,11 @@ The module `psycopg.rows` provides the implementation for a few row factories: .. currentmodule:: psycopg.rows .. autofunction:: tuple_row -.. autodata:: TupleRow - .. autofunction:: dict_row -.. autodata:: DictRow - .. autofunction:: namedtuple_row - .. autofunction:: class_row +.. autofunction:: args_row +.. autofunction:: kwargs_row This is not a row factory, but rather a factory of row factories. Specifying ``row_factory=class_row(MyClass)`` will create connections and @@ -162,7 +158,7 @@ Example: returning records as Pydantic models Using Pydantic_ it is possible to enforce static typing at runtime. Using a Pydantic model factory the code can be checked statically using mypy and -querying the database will raise an exception if the resultset is not +querying the database will raise an exception if the rows returned is not compatible with the model. .. _Pydantic: https://pydantic-docs.helpmanual.io/ diff --git a/psycopg/psycopg/rows.py b/psycopg/psycopg/rows.py index 3bf8b6bbb..f181e8787 100644 --- a/psycopg/psycopg/rows.py +++ b/psycopg/psycopg/rows.py @@ -60,13 +60,22 @@ class RowFactory(Protocol[Row]): class AsyncRowFactory(Protocol[Row]): """ - Callable protocol taking an `~psycopg.AsyncCursor` and returning a `RowMaker`. + Like `RowFactory`, taking an async cursor as argument. """ def __call__(self, __cursor: "AsyncCursor[Row]") -> RowMaker[Row]: ... +class BaseRowFactory(Protocol[Row]): + """ + Like `RowFactory`, taking either type of cursor as argument. + """ + + def __call__(self, __cursor: "BaseCursor[Any, Row]") -> RowMaker[Row]: + ... + + TupleRow = Tuple[Any, ...] """ An alias for the type returned by `tuple_row()` (i.e. a tuple of any content). @@ -93,11 +102,7 @@ def tuple_row(cursor: "BaseCursor[Any, TupleRow]") -> RowMaker[TupleRow]: def dict_row(cursor: "BaseCursor[Any, DictRow]") -> RowMaker[DictRow]: - """Row factory to represent rows as dicts. - - Note that this is not compatible with the DBAPI, which expects the records - to be sequences. - """ + """Row factory to represent rows as dictionaries.""" desc = cursor.description if desc is None: return no_result @@ -141,8 +146,8 @@ def _make_nt(*key: str) -> Type[NamedTuple]: return namedtuple("Row", fields) # type: ignore[return-value] -def class_row(cls: Type[T]) -> Callable[["BaseCursor[Any, T]"], RowMaker[T]]: - r"""Function to generate row factory functions returning a specific class. +def class_row(cls: Type[T]) -> BaseRowFactory[T]: + r"""Generate a row factory to represent rows as instances of the class *cls*. The class must support every output column name as a keyword parameter. @@ -166,6 +171,44 @@ def class_row(cls: Type[T]) -> Callable[["BaseCursor[Any, T]"], RowMaker[T]]: return class_row_ +def args_row(func: Callable[..., T]) -> BaseRowFactory[T]: + """Generate a row factory calling *func* with positional parameters for every row. + + :param func: The function to call for each row. It must support the fields + returned by the query as positional arguments. + """ + + def args_row_(cur: "BaseCursor[Any, T]") -> RowMaker[T]: + def args_row__(values: Sequence[Any]) -> T: + return func(*values) + + return args_row__ + + return args_row_ + + +def kwargs_row(func: Callable[..., T]) -> BaseRowFactory[T]: + """Generate a row factory calling *func* with keyword parameters for every row. + + :param func: The function to call for each row. It must support the fields + returned by the query as keyword arguments. + """ + + def kwargs_row_(cur: "BaseCursor[Any, T]") -> RowMaker[T]: + desc = cur.description + if desc is None: + return no_result + + names = [d.name for d in desc] + + def kwargs_row__(values: Sequence[Any]) -> T: + return func(**dict(zip(names, values))) + + return kwargs_row__ + + return kwargs_row_ + + def no_result(values: Sequence[Any]) -> NoReturn: """A `RowMaker` that always fail. diff --git a/tests/test_rows.py b/tests/test_rows.py index 83cc173f4..a345a5b5b 100644 --- a/tests/test_rows.py +++ b/tests/test_rows.py @@ -1,5 +1,6 @@ import pytest +import psycopg from psycopg import rows @@ -75,16 +76,35 @@ def test_class_row(conn): cur.fetchone() +def test_args_row(conn): + cur = conn.cursor(row_factory=rows.args_row(argf)) + cur.execute("select 'John' as first, 'Doe' as last") + assert cur.fetchone() == "JohnDoe" + + +def test_kwargs_row(conn): + cur = conn.cursor(row_factory=rows.kwargs_row(kwargf)) + cur.execute("select 'John' as first, 'Doe' as last") + (p,) = cur.fetchall() + assert isinstance(p, Person) + assert p.first == "John" + assert p.last == "Doe" + assert p.age == 42 + + @pytest.mark.parametrize( - "factory", "tuple_row dict_row namedtuple_row class_row".split() + "factory", + "tuple_row dict_row namedtuple_row class_row args_row kwargs_row".split(), ) def test_no_result(factory, conn): cur = conn.cursor(row_factory=factory_from_name(factory)) cur.execute("reset search_path") + with pytest.raises(psycopg.ProgrammingError): + cur.fetchone() @pytest.mark.parametrize( - "factory", "tuple_row dict_row namedtuple_row".split() + "factory", "tuple_row dict_row namedtuple_row args_row".split() ) def test_no_column(factory, conn): cur = conn.cursor(row_factory=factory_from_name(factory)) @@ -112,6 +132,10 @@ def factory_from_name(name): factory = getattr(rows, name) if factory is rows.class_row: factory = factory(Person) + if factory is rows.args_row: + factory = factory(argf) + if factory is rows.kwargs_row: + factory = factory(argf) return factory @@ -121,3 +145,11 @@ class Person: self.first = first self.last = last self.age = age + + +def argf(*args): + return "".join(map(str, args)) + + +def kwargf(**kwargs): + return Person(**kwargs, age=42) diff --git a/tests/test_typing.py b/tests/test_typing.py index aa66f6fc0..d7df5d5c1 100644 --- a/tests/test_typing.py +++ b/tests/test_typing.py @@ -45,6 +45,10 @@ def test_typing_example(mypy, filename): "psycopg.connect(row_factory=rows.namedtuple_row)", "psycopg.Connection[NamedTuple]", ), + ( + "psycopg.connect(row_factory=rows.class_row(Thing))", + "psycopg.Connection[Thing]", + ), ( "psycopg.connect(row_factory=thing_row)", "psycopg.Connection[Thing]", @@ -91,6 +95,11 @@ def test_connection_type(conn, type, mypy, tmpdir): "conn.cursor(row_factory=rows.namedtuple_row)", "psycopg.Cursor[NamedTuple]", ), + ( + "psycopg.connect(row_factory=rows.class_row(Thing))", + "conn.cursor()", + "psycopg.Cursor[Thing]", + ), ( "psycopg.connect(row_factory=thing_row)", "conn.cursor()", diff --git a/tests/typing_example.py b/tests/typing_example.py index aeb7a7a6b..aa3882cc7 100644 --- a/tests/typing_example.py +++ b/tests/typing_example.py @@ -3,9 +3,9 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Any, Callable, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union -from psycopg import Connection, Cursor, ServerCursor, connect +from psycopg import Connection, Cursor, ServerCursor, connect, rows from psycopg import AsyncConnection, AsyncCursor, AsyncServerCursor @@ -31,6 +31,14 @@ class Person: return mkrow +def kwargsf(*, foo: int, bar: int, baz: int) -> int: + return 42 + + +def argsf(foo: int, bar: int, baz: int) -> float: + return 42.0 + + def check_row_factory_cursor() -> None: """Type-check connection.cursor(..., row_factory=) case.""" conn = connect() @@ -147,3 +155,22 @@ async def async_check_row_factory_connection() -> None: await cur3.execute("select 42") r3 = await cur3.fetchone() r3 and len(r3) + + +def check_row_factories() -> None: + conn1 = connect(row_factory=rows.tuple_row) + v1: Tuple[Any, ...] = conn1.execute("").fetchall()[0] + + conn2 = connect(row_factory=rows.dict_row) + v2: Dict[str, Any] = conn2.execute("").fetchall()[0] + + conn3 = connect(row_factory=rows.class_row(Person)) + v3: Person = conn3.execute("").fetchall()[0] + + conn4 = connect(row_factory=rows.args_row(argsf)) + v4: float = conn4.execute("").fetchall()[0] + + conn5 = connect(row_factory=rows.kwargs_row(kwargsf)) + v5: int = conn5.execute("").fetchall()[0] + + v1, v2, v3, v4, v5