From: Daniele Varrazzo Date: Wed, 24 Feb 2021 03:02:57 +0000 (+0100) Subject: Fix use of TypeInfo with connections using dict_row X-Git-Tag: 3.0.dev0~106^2~6 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=b423af48ce7568f2f4d54dcbfe3d389c46d36372;p=thirdparty%2Fpsycopg.git Fix use of TypeInfo with connections using dict_row Actually always use dict_row as it makes things easier (we do create a dict out of the query to pass kwargs...) Relax the RowFactory signature to make it accept `dict_row()` and `namedtuple_row()` too. --- diff --git a/psycopg3/psycopg3/_typeinfo.py b/psycopg3/psycopg3/_typeinfo.py index 286d6c753..d9948e5b2 100644 --- a/psycopg3/psycopg3/_typeinfo.py +++ b/psycopg3/psycopg3/_typeinfo.py @@ -11,6 +11,7 @@ from typing import Any, Callable, Dict, Iterator, Optional from typing import Sequence, Type, TypeVar, Union, TYPE_CHECKING from . import errors as e +from .rows import dict_row from .proto import AdaptContext if TYPE_CHECKING: @@ -69,11 +70,10 @@ class TypeInfo: if isinstance(name, Composable): name = name.as_string(conn) - cur = conn.cursor(binary=True, row_factory=None) + cur = conn.cursor(binary=True, row_factory=dict_row) cur.execute(cls._info_query, {"name": name}) - recs: Sequence[Sequence[Any]] = cur.fetchall() - fields = [d[0] for d in cur.description or ()] - return cls._fetch(name, fields, recs) + recs: Sequence[Dict[str, Any]] = cur.fetchall() + return cls._fetch(name, recs) @classmethod async def fetch_async( @@ -89,21 +89,19 @@ class TypeInfo: if isinstance(name, Composable): name = name.as_string(conn) - cur = conn.cursor(binary=True, row_factory=None) + cur = conn.cursor(binary=True, row_factory=dict_row) await cur.execute(cls._info_query, {"name": name}) - recs: Sequence[Sequence[Any]] = await cur.fetchall() - fields = [d[0] for d in cur.description or ()] - return cls._fetch(name, fields, recs) + recs: Sequence[Dict[str, Any]] = await cur.fetchall() + return cls._fetch(name, recs) @classmethod def _fetch( cls: Type[T], name: str, - fields: Sequence[str], - recs: Sequence[Sequence[Any]], + recs: Sequence[Dict[str, Any]], ) -> Optional[T]: if len(recs) == 1: - return cls(**dict(zip(fields, recs[0]))) + return cls(**recs[0]) elif not recs: return None else: diff --git a/psycopg3/psycopg3/proto.py b/psycopg3/psycopg3/proto.py index 8786de050..acf04bdf6 100644 --- a/psycopg3/psycopg3/proto.py +++ b/psycopg3/psycopg3/proto.py @@ -51,14 +51,12 @@ Row = TypeVar("Row", Tuple[Any, ...], Any) class RowMaker(Protocol): - def __call__(self, __values: Sequence[Any]) -> Row: + def __call__(self, __values: Sequence[Any]) -> Any: ... class RowFactory(Protocol): - def __call__( - self, __cursor: "BaseCursor[ConnectionType]" - ) -> Optional[RowMaker]: + def __call__(self, __cursor: "BaseCursor[Any]") -> Optional[RowMaker]: ...