]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Add class_row row factory generator
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 4 Aug 2021 13:49:14 +0000 (14:49 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 5 Aug 2021 20:32:04 +0000 (21:32 +0100)
docs/advanced/rows.rst
psycopg/psycopg/rows.py
tests/test_rows.py

index 01e3d3672dc9616a0d74931ad471c8c1219726f3..11faebdfda74b7ff2f5a8e19b9fa930de52f9289 100644 (file)
@@ -96,6 +96,30 @@ The module `psycopg.rows` provides the implementation for a few row factories:
 
 .. autofunction:: namedtuple_row
 
+.. autofunction:: class_row
+
+    This is not a row factory, but rather a factory of row factories.
+    Specifying ``row_factory=class_row(MyClass)`` will create connections and
+    cursors returning `!MyClass` objects on fetch.
+
+    Example::
+
+        from dataclasses import dataclass
+        import psycopg
+        from psycopg.rows import class_row
+
+        @dataclass
+        class Person:
+            first_name: str
+            last_name: str
+            age: int = None
+
+        conn = psycopg.connect()
+        cur = conn.cursor(row_factory=class_row(Person))
+
+        cur.execute("select 'John' as first_name, 'Smith' as last_name").fetchone()
+        # Person(first_name='John', last_name='Smith', age=None)
+
 
 Use with a static analyzer
 --------------------------
index 80b4c52cd8bb16f086b1dd63cedc002f68fca404..a5137b52f9b994c8e948e727899f302cadaf9837 100644 (file)
@@ -6,7 +6,7 @@ psycopg row factories
 
 import re
 import functools
-from typing import Any, Dict, NamedTuple, NoReturn, Sequence, Tuple
+from typing import Any, Callable, Dict, NamedTuple, NoReturn, Sequence, Tuple
 from typing import TYPE_CHECKING, Type, TypeVar
 from collections import namedtuple
 
@@ -143,6 +143,32 @@ def _make_nt(*key: str) -> Type[NamedTuple]:
     return namedtuple("Row", fields)  # type: ignore[return-value]
 
 
+def class_row(cls: Type[T]) -> Callable[["BaseCursor[Any, T]"], RowMaker[T]]:
+    r"""Function to generate row factory functions returning a specific class.
+
+    The class must support every output column name as a keyword parameter.
+
+    :param cls: The class to return for each row. It must support the fields
+        returned by the query as keyword arguments.
+    :rtype: `!Callable[[Cursor],` `RowMaker`\[~T]]
+    """
+
+    def class_row_(cur: "BaseCursor[Any, T]") -> RowMaker[T]:
+        desc = cur.description
+        if desc is not None:
+            names = [d.name for d in desc]
+
+            def class_row__(values: Sequence[Any]) -> T:
+                return cls(**dict(zip(names, values)))  # type: ignore
+
+            return class_row__
+
+        else:
+            return no_result
+
+    return class_row_
+
+
 def no_result(values: Sequence[Any]) -> NoReturn:
     """A `RowMaker` that always fail.
 
index 894a362e3927497b596b8131698e1c3774bd0a8a..83cc173f4023a3b0b45385b01dcd6b518089dd08 100644 (file)
@@ -57,8 +57,26 @@ def test_namedtuple_row(conn):
     assert type(r1) is not type(r2)
 
 
+def test_class_row(conn):
+    cur = conn.cursor(row_factory=rows.class_row(Person))
+    cur.execute("select 'John' as first, 'Doe' as last")
+    (p,) = cur.fetchall()
+    assert isinstance(p, Person)
+    assert p.first == "John"
+    assert p.last == "Doe"
+    assert p.age is None
+
+    for query in (
+        "select 'John' as first",
+        "select 'John' as first, 'Doe' as last, 42 as wat",
+    ):
+        cur.execute(query)
+        with pytest.raises(TypeError):
+            cur.fetchone()
+
+
 @pytest.mark.parametrize(
-    "factory", "tuple_row dict_row namedtuple_row".split()
+    "factory", "tuple_row dict_row namedtuple_row class_row".split()
 )
 def test_no_result(factory, conn):
     cur = conn.cursor(row_factory=factory_from_name(factory))
@@ -76,6 +94,30 @@ def test_no_column(factory, conn):
     assert not recs[0]
 
 
+def test_no_column_class_row(conn):
+    class Empty:
+        def __init__(self, x=10, y=20):
+            self.x = x
+            self.y = y
+
+    cur = conn.cursor(row_factory=rows.class_row(Empty))
+    cur.execute("select")
+    x = cur.fetchone()
+    assert isinstance(x, Empty)
+    assert x.x == 10
+    assert x.y == 20
+
+
 def factory_from_name(name):
     factory = getattr(rows, name)
+    if factory is rows.class_row:
+        factory = factory(Person)
+
     return factory
+
+
+class Person:
+    def __init__(self, first, last, age=None):
+        self.first = first
+        self.last = last
+        self.age = age