]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Fix type specification for `namedtuple_row()`
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 24 Feb 2021 02:11:26 +0000 (03:11 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 24 Feb 2021 02:11:26 +0000 (03:11 +0100)
psycopg3/psycopg3/rows.py

index 18f7b93df28a174a7fb0a0e677cad26fba50a10d..452d8a3f3be9a634210863db57756308cd978d9f 100644 (file)
@@ -7,36 +7,57 @@ psycopg3 row factories
 import functools
 import re
 from collections import namedtuple
-from typing import Any, Callable, Dict, Sequence, Tuple, Type
+from typing import Any, Callable, Dict, Sequence, Type, NamedTuple
+from typing import TYPE_CHECKING
 
-from .cursor import BaseCursor
-from .proto import ConnectionType
+from . import errors as e
+
+if TYPE_CHECKING:
+    from .cursor import BaseCursor
+
+
+def tuple_row(cursor: "BaseCursor[Any]") -> None:
+    """Row factory to represent rows as simple tuples.
+
+    This is the default factory.
+    """
+    # Implementation detail: just return None instead of a callable because
+    # the Transformer knows how to use this value.
+    return None
 
 
 def dict_row(
-    cursor: BaseCursor[ConnectionType],
+    cursor: "BaseCursor[Any]",
 ) -> Callable[[Sequence[Any]], Dict[str, Any]]:
-    """Row factory to represent rows as dicts."""
+    """Row factory to represent rows as dicts.
+
+    Note that this is not compatible with the DBAPI, which expects the records
+    to be sequences.
+    """
 
     def make_row(values: Sequence[Any]) -> Dict[str, Any]:
-        assert cursor.description
-        titles = (c.name for c in cursor.description)
+        desc = cursor.description
+        if desc is None:
+            raise e.InterfaceError("The cursor doesn't have a result")
+        titles = (c.name for c in desc)
         return dict(zip(titles, values))
 
     return make_row
 
 
 def namedtuple_row(
-    cursor: BaseCursor[ConnectionType],
-) -> Callable[[Sequence[Any]], Tuple[Any, ...]]:
+    cursor: "BaseCursor[Any]",
+) -> Callable[[Sequence[Any]], NamedTuple]:
     """Row factory to represent rows as `~collections.namedtuple`."""
 
-    def make_row(values: Sequence[Any]) -> Tuple[Any, ...]:
-        assert cursor.description
-        key = tuple(c.name for c in cursor.description)
+    def make_row(values: Sequence[Any]) -> NamedTuple:
+        desc = cursor.description
+        if desc is None:
+            raise e.InterfaceError("The cursor doesn't have a result")
+        key = tuple(c.name for c in desc)
         nt = _make_nt(key)
-        rv = nt._make(values)  # type: ignore[attr-defined]
-        return rv  # type: ignore[no-any-return]
+        rv = nt._make(values)
+        return rv
 
     return make_row
 
@@ -48,7 +69,7 @@ _re_clean = re.compile(
 
 
 @functools.lru_cache(512)
-def _make_nt(key: Sequence[str]) -> Type[Tuple[Any, ...]]:
+def _make_nt(key: Sequence[str]) -> Type[NamedTuple]:
     fields = []
     for s in key:
         s = _re_clean.sub("_", s)
@@ -57,4 +78,4 @@ def _make_nt(key: Sequence[str]) -> Type[Tuple[Any, ...]]:
         if s[0] == "_" or "0" <= s[0] <= "9":
             s = "f" + s
         fields.append(s)
-    return namedtuple("Row", fields)
+    return namedtuple("Row", fields)  # type: ignore[return-value]