]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Add psycopg.rows.args_row and kwargs_row factories
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 5 Aug 2021 19:39:57 +0000 (20:39 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 5 Aug 2021 20:32:04 +0000 (21:32 +0100)
docs/advanced/rows.rst
psycopg/psycopg/rows.py
tests/test_rows.py
tests/test_typing.py
tests/typing_example.py

index 11faebdfda74b7ff2f5a8e19b9fa930de52f9289..3933e5bcb279a70bd90235ed6d3c63c4f8cd1b4b 100644 (file)
@@ -31,9 +31,8 @@ callable (formally the `~psycopg.rows.RowMaker` protocol) accepting a
 
 .. autoclass:: psycopg.rows.AsyncRowFactory()
 
-   .. method:: __call__(cursor: AsyncCursor[Row]) -> RowMaker[Row]
+.. autoclass:: psycopg.rows.BaseRowFactory()
 
-        Inspect the result on a cursor and return a `RowMaker` to convert rows.
 
 Note that it's easy to implement an object implementing both `!RowFactory` and
 `!AsyncRowFactory`: usually, everything you need to implement a row factory is
@@ -89,14 +88,11 @@ The module `psycopg.rows` provides the implementation for a few row factories:
 .. currentmodule:: psycopg.rows
 
 .. autofunction:: tuple_row
-.. autodata:: TupleRow
-
 .. autofunction:: dict_row
-.. autodata:: DictRow
-
 .. autofunction:: namedtuple_row
-
 .. autofunction:: class_row
+.. autofunction:: args_row
+.. autofunction:: kwargs_row
 
     This is not a row factory, but rather a factory of row factories.
     Specifying ``row_factory=class_row(MyClass)`` will create connections and
@@ -162,7 +158,7 @@ Example: returning records as Pydantic models
 
 Using Pydantic_ it is possible to enforce static typing at runtime. Using a
 Pydantic model factory the code can be checked statically using mypy and
-querying the database will raise an exception if the resultset is not
+querying the database will raise an exception if the rows returned is not
 compatible with the model.
 
 .. _Pydantic: https://pydantic-docs.helpmanual.io/
index 3bf8b6bbb06599f086feefa7d99c63a76730f427..f181e87870d06d805810ffefd398a2e89d992a7a 100644 (file)
@@ -60,13 +60,22 @@ class RowFactory(Protocol[Row]):
 
 class AsyncRowFactory(Protocol[Row]):
     """
-    Callable protocol taking an `~psycopg.AsyncCursor` and returning a `RowMaker`.
+    Like `RowFactory`, taking an async cursor as argument.
     """
 
     def __call__(self, __cursor: "AsyncCursor[Row]") -> RowMaker[Row]:
         ...
 
 
+class BaseRowFactory(Protocol[Row]):
+    """
+    Like `RowFactory`, taking either type of cursor as argument.
+    """
+
+    def __call__(self, __cursor: "BaseCursor[Any, Row]") -> RowMaker[Row]:
+        ...
+
+
 TupleRow = Tuple[Any, ...]
 """
 An alias for the type returned by `tuple_row()` (i.e. a tuple of any content).
@@ -93,11 +102,7 @@ def tuple_row(cursor: "BaseCursor[Any, TupleRow]") -> RowMaker[TupleRow]:
 
 
 def dict_row(cursor: "BaseCursor[Any, DictRow]") -> RowMaker[DictRow]:
-    """Row factory to represent rows as dicts.
-
-    Note that this is not compatible with the DBAPI, which expects the records
-    to be sequences.
-    """
+    """Row factory to represent rows as dictionaries."""
     desc = cursor.description
     if desc is None:
         return no_result
@@ -141,8 +146,8 @@ def _make_nt(*key: str) -> Type[NamedTuple]:
     return namedtuple("Row", fields)  # type: ignore[return-value]
 
 
-def class_row(cls: Type[T]) -> Callable[["BaseCursor[Any, T]"], RowMaker[T]]:
-    r"""Function to generate row factory functions returning a specific class.
+def class_row(cls: Type[T]) -> BaseRowFactory[T]:
+    r"""Generate a row factory to represent rows as instances of the class *cls*.
 
     The class must support every output column name as a keyword parameter.
 
@@ -166,6 +171,44 @@ def class_row(cls: Type[T]) -> Callable[["BaseCursor[Any, T]"], RowMaker[T]]:
     return class_row_
 
 
+def args_row(func: Callable[..., T]) -> BaseRowFactory[T]:
+    """Generate a row factory calling *func* with positional parameters for every row.
+
+    :param func: The function to call for each row. It must support the fields
+        returned by the query as positional arguments.
+    """
+
+    def args_row_(cur: "BaseCursor[Any, T]") -> RowMaker[T]:
+        def args_row__(values: Sequence[Any]) -> T:
+            return func(*values)
+
+        return args_row__
+
+    return args_row_
+
+
+def kwargs_row(func: Callable[..., T]) -> BaseRowFactory[T]:
+    """Generate a row factory calling *func* with keyword parameters for every row.
+
+    :param func: The function to call for each row. It must support the fields
+        returned by the query as keyword arguments.
+    """
+
+    def kwargs_row_(cur: "BaseCursor[Any, T]") -> RowMaker[T]:
+        desc = cur.description
+        if desc is None:
+            return no_result
+
+        names = [d.name for d in desc]
+
+        def kwargs_row__(values: Sequence[Any]) -> T:
+            return func(**dict(zip(names, values)))
+
+        return kwargs_row__
+
+    return kwargs_row_
+
+
 def no_result(values: Sequence[Any]) -> NoReturn:
     """A `RowMaker` that always fail.
 
index 83cc173f4023a3b0b45385b01dcd6b518089dd08..a345a5b5b4afc801081176616ce324b943874f51 100644 (file)
@@ -1,5 +1,6 @@
 import pytest
 
+import psycopg
 from psycopg import rows
 
 
@@ -75,16 +76,35 @@ def test_class_row(conn):
             cur.fetchone()
 
 
+def test_args_row(conn):
+    cur = conn.cursor(row_factory=rows.args_row(argf))
+    cur.execute("select 'John' as first, 'Doe' as last")
+    assert cur.fetchone() == "JohnDoe"
+
+
+def test_kwargs_row(conn):
+    cur = conn.cursor(row_factory=rows.kwargs_row(kwargf))
+    cur.execute("select 'John' as first, 'Doe' as last")
+    (p,) = cur.fetchall()
+    assert isinstance(p, Person)
+    assert p.first == "John"
+    assert p.last == "Doe"
+    assert p.age == 42
+
+
 @pytest.mark.parametrize(
-    "factory", "tuple_row dict_row namedtuple_row class_row".split()
+    "factory",
+    "tuple_row dict_row namedtuple_row class_row args_row kwargs_row".split(),
 )
 def test_no_result(factory, conn):
     cur = conn.cursor(row_factory=factory_from_name(factory))
     cur.execute("reset search_path")
+    with pytest.raises(psycopg.ProgrammingError):
+        cur.fetchone()
 
 
 @pytest.mark.parametrize(
-    "factory", "tuple_row dict_row namedtuple_row".split()
+    "factory", "tuple_row dict_row namedtuple_row args_row".split()
 )
 def test_no_column(factory, conn):
     cur = conn.cursor(row_factory=factory_from_name(factory))
@@ -112,6 +132,10 @@ def factory_from_name(name):
     factory = getattr(rows, name)
     if factory is rows.class_row:
         factory = factory(Person)
+    if factory is rows.args_row:
+        factory = factory(argf)
+    if factory is rows.kwargs_row:
+        factory = factory(argf)
 
     return factory
 
@@ -121,3 +145,11 @@ class Person:
         self.first = first
         self.last = last
         self.age = age
+
+
+def argf(*args):
+    return "".join(map(str, args))
+
+
+def kwargf(**kwargs):
+    return Person(**kwargs, age=42)
index aa66f6fc057a11793c79df01b41111221fbdea01..d7df5d5c125d1d5d58e9d0f5dc5abc15903ba3b8 100644 (file)
@@ -45,6 +45,10 @@ def test_typing_example(mypy, filename):
             "psycopg.connect(row_factory=rows.namedtuple_row)",
             "psycopg.Connection[NamedTuple]",
         ),
+        (
+            "psycopg.connect(row_factory=rows.class_row(Thing))",
+            "psycopg.Connection[Thing]",
+        ),
         (
             "psycopg.connect(row_factory=thing_row)",
             "psycopg.Connection[Thing]",
@@ -91,6 +95,11 @@ def test_connection_type(conn, type, mypy, tmpdir):
             "conn.cursor(row_factory=rows.namedtuple_row)",
             "psycopg.Cursor[NamedTuple]",
         ),
+        (
+            "psycopg.connect(row_factory=rows.class_row(Thing))",
+            "conn.cursor()",
+            "psycopg.Cursor[Thing]",
+        ),
         (
             "psycopg.connect(row_factory=thing_row)",
             "conn.cursor()",
index aeb7a7a6bdc0873ac2b5e236e67b9fd7889f3960..aa3882cc76e493db4779b224fb9c6d08aab3f2f0 100644 (file)
@@ -3,9 +3,9 @@
 from __future__ import annotations
 
 from dataclasses import dataclass
-from typing import Any, Callable, Optional, Sequence, Tuple, Union
+from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
 
-from psycopg import Connection, Cursor, ServerCursor, connect
+from psycopg import Connection, Cursor, ServerCursor, connect, rows
 from psycopg import AsyncConnection, AsyncCursor, AsyncServerCursor
 
 
@@ -31,6 +31,14 @@ class Person:
         return mkrow
 
 
+def kwargsf(*, foo: int, bar: int, baz: int) -> int:
+    return 42
+
+
+def argsf(foo: int, bar: int, baz: int) -> float:
+    return 42.0
+
+
 def check_row_factory_cursor() -> None:
     """Type-check connection.cursor(..., row_factory=<MyRowFactory>) case."""
     conn = connect()
@@ -147,3 +155,22 @@ async def async_check_row_factory_connection() -> None:
         await cur3.execute("select 42")
         r3 = await cur3.fetchone()
         r3 and len(r3)
+
+
+def check_row_factories() -> None:
+    conn1 = connect(row_factory=rows.tuple_row)
+    v1: Tuple[Any, ...] = conn1.execute("").fetchall()[0]
+
+    conn2 = connect(row_factory=rows.dict_row)
+    v2: Dict[str, Any] = conn2.execute("").fetchall()[0]
+
+    conn3 = connect(row_factory=rows.class_row(Person))
+    v3: Person = conn3.execute("").fetchall()[0]
+
+    conn4 = connect(row_factory=rows.args_row(argsf))
+    v4: float = conn4.execute("").fetchall()[0]
+
+    conn5 = connect(row_factory=rows.kwargs_row(kwargsf))
+    v5: int = conn5.execute("").fetchall()[0]
+
+    v1, v2, v3, v4, v5