]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Fix use of TypeInfo with connections using dict_row
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 24 Feb 2021 03:02:57 +0000 (04:02 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 24 Feb 2021 03:11:51 +0000 (04:11 +0100)
Actually always use dict_row as it makes things easier (we do create a
dict out of the query to pass kwargs...)

Relax the RowFactory signature to make it accept `dict_row()` and
`namedtuple_row()` too.

psycopg3/psycopg3/_typeinfo.py
psycopg3/psycopg3/proto.py

index 286d6c753ceb22bbf0f3ad9c66a8419d20ebe9fe..d9948e5b2fcf186ac7f9c40d23df7423020b31b3 100644 (file)
@@ -11,6 +11,7 @@ from typing import Any, Callable, Dict, Iterator, Optional
 from typing import Sequence, Type, TypeVar, Union, TYPE_CHECKING
 
 from . import errors as e
+from .rows import dict_row
 from .proto import AdaptContext
 
 if TYPE_CHECKING:
@@ -69,11 +70,10 @@ class TypeInfo:
 
         if isinstance(name, Composable):
             name = name.as_string(conn)
-        cur = conn.cursor(binary=True, row_factory=None)
+        cur = conn.cursor(binary=True, row_factory=dict_row)
         cur.execute(cls._info_query, {"name": name})
-        recs: Sequence[Sequence[Any]] = cur.fetchall()
-        fields = [d[0] for d in cur.description or ()]
-        return cls._fetch(name, fields, recs)
+        recs: Sequence[Dict[str, Any]] = cur.fetchall()
+        return cls._fetch(name, recs)
 
     @classmethod
     async def fetch_async(
@@ -89,21 +89,19 @@ class TypeInfo:
         if isinstance(name, Composable):
             name = name.as_string(conn)
 
-        cur = conn.cursor(binary=True, row_factory=None)
+        cur = conn.cursor(binary=True, row_factory=dict_row)
         await cur.execute(cls._info_query, {"name": name})
-        recs: Sequence[Sequence[Any]] = await cur.fetchall()
-        fields = [d[0] for d in cur.description or ()]
-        return cls._fetch(name, fields, recs)
+        recs: Sequence[Dict[str, Any]] = await cur.fetchall()
+        return cls._fetch(name, recs)
 
     @classmethod
     def _fetch(
         cls: Type[T],
         name: str,
-        fields: Sequence[str],
-        recs: Sequence[Sequence[Any]],
+        recs: Sequence[Dict[str, Any]],
     ) -> Optional[T]:
         if len(recs) == 1:
-            return cls(**dict(zip(fields, recs[0])))
+            return cls(**recs[0])
         elif not recs:
             return None
         else:
index 8786de0506c7201ef34cbc2b55295b7c3e9bc266..acf04bdf68aed877ca3f724f6acc419b4e44ff2f 100644 (file)
@@ -51,14 +51,12 @@ Row = TypeVar("Row", Tuple[Any, ...], Any)
 
 
 class RowMaker(Protocol):
-    def __call__(self, __values: Sequence[Any]) -> Row:
+    def __call__(self, __values: Sequence[Any]) -> Any:
         ...
 
 
 class RowFactory(Protocol):
-    def __call__(
-        self, __cursor: "BaseCursor[ConnectionType]"
-    ) -> Optional[RowMaker]:
+    def __call__(self, __cursor: "BaseCursor[Any]") -> Optional[RowMaker]:
         ...