]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Move dict/namedtuple row factory cursor inspection outside row maker
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 4 Aug 2021 01:47:26 +0000 (02:47 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 5 Aug 2021 20:32:04 +0000 (21:32 +0100)
The list of fields names is supposed to be cached at every query result,
not at every row.

psycopg/psycopg/rows.py
tests/test_cursor.py
tests/test_rows.py

index 21a7912f06d7f1e9c5017c885405d42875216df4..80b4c52cd8bb16f086b1dd63cedc002f68fca404 100644 (file)
@@ -4,11 +4,11 @@ psycopg row factories
 
 # Copyright (C) 2021 The Psycopg Team
 
-import functools
 import re
+import functools
+from typing import Any, Dict, NamedTuple, NoReturn, Sequence, Tuple
+from typing import TYPE_CHECKING, Type, TypeVar
 from collections import namedtuple
-from typing import Any, Dict, NamedTuple, Sequence, Tuple, Type, TypeVar
-from typing import TYPE_CHECKING
 
 from . import errors as e
 from .compat import Protocol
@@ -16,6 +16,8 @@ from .compat import Protocol
 if TYPE_CHECKING:
     from .cursor import BaseCursor, Cursor, AsyncCursor
 
+T = TypeVar("T")
+
 # Row factories
 
 Row = TypeVar("Row")
@@ -71,20 +73,6 @@ An alias for the type returned by `tuple_row()` (i.e. a tuple of any content).
 """
 
 
-def tuple_row(
-    cursor: "BaseCursor[Any, TupleRow]",
-) -> RowMaker[TupleRow]:
-    r"""Row factory to represent rows as simple tuples.
-
-    This is the default factory.
-
-    :param cursor: The cursor where to read from.
-    """
-    # Implementation detail: make sure this is the tuple type itself, not an
-    # equivalent function, because the C code fast-paths on it.
-    return tuple
-
-
 DictRow = Dict[str, Any]
 """
 An alias for the type returned by `dict_row()`
@@ -94,45 +82,46 @@ database.
 """
 
 
-def dict_row(
-    cursor: "BaseCursor[Any, DictRow]",
-) -> RowMaker[DictRow]:
-    r"""Row factory to represent rows as dicts.
+def tuple_row(cursor: "BaseCursor[Any, TupleRow]") -> RowMaker[TupleRow]:
+    r"""Row factory to represent rows as simple tuples.
+
+    This is the default factory.
+    """
+    # Implementation detail: make sure this is the tuple type itself, not an
+    # equivalent function, because the C code fast-paths on it.
+    return tuple
+
+
+def dict_row(cursor: "BaseCursor[Any, DictRow]") -> RowMaker[DictRow]:
+    """Row factory to represent rows as dicts.
 
     Note that this is not compatible with the DBAPI, which expects the records
     to be sequences.
-
-    :param cursor: The cursor where to read from.
     """
+    desc = cursor.description
+    if desc is not None:
+        titles = [c.name for c in desc]
+
+        def dict_row_(values: Sequence[Any]) -> Dict[str, Any]:
+            return dict(zip(titles, values))
 
-    def make_row(values: Sequence[Any]) -> Dict[str, Any]:
-        desc = cursor.description
-        if desc is None:
-            raise e.InterfaceError("The cursor doesn't have a result")
-        titles = (c.name for c in desc)
-        return dict(zip(titles, values))
+        return dict_row_
 
-    return make_row
+    else:
+        return no_result
 
 
 def namedtuple_row(
     cursor: "BaseCursor[Any, NamedTuple]",
 ) -> RowMaker[NamedTuple]:
-    r"""Row factory to represent rows as `~collections.namedtuple`.
+    """Row factory to represent rows as `~collections.namedtuple`."""
+    desc = cursor.description
+    if desc is not None:
+        nt = _make_nt(*(c.name for c in desc))
+        return nt._make
 
-    :param cursor: The cursor where to read from.
-    """
-
-    def make_row(values: Sequence[Any]) -> NamedTuple:
-        desc = cursor.description
-        if desc is None:
-            raise e.InterfaceError("The cursor doesn't have a result")
-        key = tuple(c.name for c in desc)
-        nt = _make_nt(key)
-        rv = nt._make(values)
-        return rv
-
-    return make_row
+    else:
+        return no_result
 
 
 # ascii except alnum and underscore
@@ -142,7 +131,7 @@ _re_clean = re.compile(
 
 
 @functools.lru_cache(512)
-def _make_nt(key: Sequence[str]) -> Type[NamedTuple]:
+def _make_nt(*key: str) -> Type[NamedTuple]:
     fields = []
     for s in key:
         s = _re_clean.sub("_", s)
@@ -152,3 +141,13 @@ def _make_nt(key: Sequence[str]) -> Type[NamedTuple]:
             s = "f" + s
         fields.append(s)
     return namedtuple("Row", fields)  # type: ignore[return-value]
+
+
+def no_result(values: Sequence[Any]) -> NoReturn:
+    """A `RowMaker` that always fail.
+
+    It can be used as return value for a `RowFactory` called with no result.
+    Note that the `!RowFactory` *will* be called with no result, but the
+    resulting `!RowMaker` never should.
+    """
+    raise e.InterfaceError("the cursor doesn't have a result")
index 597e56a42af3920422ba3ccde4f9bde25c98af9d..2962afa7d0c6a8a2127a5915750710257c3cd915 100644 (file)
@@ -633,12 +633,15 @@ def test_leak(dsn, faker, fmt, fmt_out, fetch, row_factory, retries):
 
 
 def my_row_factory(cursor):
-    assert cursor.description is not None
-    titles = [c.name for c in cursor.description]
-
-    def mkrow(values):
-        return [
-            f"{value.upper()}{title}" for title, value in zip(titles, values)
-        ]
-
-    return mkrow
+    if cursor.description is not None:
+        titles = [c.name for c in cursor.description]
+
+        def mkrow(values):
+            return [
+                f"{value.upper()}{title}"
+                for title, value in zip(titles, values)
+            ]
+
+        return mkrow
+    else:
+        return rows.no_result
index a1546d3b0a67278646d4f93787e3c2aee6365ca3..894a362e3927497b596b8131698e1c3774bd0a8a 100644 (file)
@@ -1,3 +1,5 @@
+import pytest
+
 from psycopg import rows
 
 
@@ -53,3 +55,27 @@ def test_namedtuple_row(conn):
     assert r2.number == 1
     assert not cur.nextset()
     assert type(r1) is not type(r2)
+
+
+@pytest.mark.parametrize(
+    "factory", "tuple_row dict_row namedtuple_row".split()
+)
+def test_no_result(factory, conn):
+    cur = conn.cursor(row_factory=factory_from_name(factory))
+    cur.execute("reset search_path")
+
+
+@pytest.mark.parametrize(
+    "factory", "tuple_row dict_row namedtuple_row".split()
+)
+def test_no_column(factory, conn):
+    cur = conn.cursor(row_factory=factory_from_name(factory))
+    cur.execute("select")
+    recs = cur.fetchall()
+    assert len(recs) == 1
+    assert not recs[0]
+
+
+def factory_from_name(name):
+    factory = getattr(rows, name)
+    return factory