From: Denis Laxalde Date: Thu, 15 Apr 2021 15:49:04 +0000 (+0200) Subject: Make Cursor generic on Row X-Git-Tag: 3.0.dev0~63^2~17 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=749f3ff784320edd59af2098c59397cd2958bccd;p=thirdparty%2Fpsycopg.git Make Cursor generic on Row We make RowMaker, RowFactory and Cursor types generic on a Row type variable thus making type inference work on cursor's fetch*() methods. For example: R = TypeVar("R") def my_row_factory(cursor: BaseCursor[Any, R]) -> Callable[[Sequence[Any]], R]: ... with conn.cursor(row_factory=my_row_factory) as cur: cur.execute("select 1") reveal_type(cur) # Revealed type is 'psycopg3.cursor.Cursor[R`-1]' r = cur.fetchone() reveal_type(r) # Revealed type is 'Union[R`-1, None]' The definition of RowMaker and RowFactory protocols needs two distinct type variable because the former is covariant on Row (using 'Row_co' type variable) and the latter is invariant on Row. In Cursor.__init__(), row_factory argument is now required as we remove its default value 'tuple_row'; this is helpful in order to keep Cursor definition generic on Row, which would be more difficult when specifying a concrete RowFactory by default binding Row to Tuple. The Connection is not (yet) generic on Row, so we use RowFactory[Any]. Still, in cursor() methods, we get a fully typed Cursor value when a row_factory argument is passed. We add two overloaded variants of these cursor() methods depending on whether row_factory is passed or not (in the former case, we return a Cursor[Row], in the latter case, a Cursor[Any]). A noticeable improvement is that we no longer need to explicitly declare or ignore types in Transformer's load_row() and load_rows() as this is not correctly inferred. Similarly, type annotations are not needed anymore in callers of these methods (Cursor's fetch*() methods). In TypeInfo's fetch*() method, we can drop superfluous type annotations. --- diff --git a/psycopg3/psycopg3/_column.py b/psycopg3/psycopg3/_column.py index 260c39913..caeddc563 100644 --- a/psycopg3/psycopg3/_column.py +++ b/psycopg3/psycopg3/_column.py @@ -23,7 +23,7 @@ class Column(Sequence[Any]): __module__ = "psycopg3" - def __init__(self, cursor: "BaseCursor[Any]", index: int): + def __init__(self, cursor: "BaseCursor[Any, Any]", index: int): res = cursor.pgresult assert res diff --git a/psycopg3/psycopg3/_transform.py b/psycopg3/psycopg3/_transform.py index 6f15f6d55..ec0119790 100644 --- a/psycopg3/psycopg3/_transform.py +++ b/psycopg3/psycopg3/_transform.py @@ -161,7 +161,9 @@ class Transformer(AdaptContext): dumper = cache[key1] = dumper.upgrade(obj, format) return dumper - def load_rows(self, row0: int, row1: int, make_row: RowMaker) -> List[Row]: + def load_rows( + self, row0: int, row1: int, make_row: RowMaker[Row] + ) -> List[Row]: res = self._pgresult if not res: raise e.InterfaceError("result not set") @@ -171,7 +173,7 @@ class Transformer(AdaptContext): f"rows must be included between 0 and {self._ntuples}" ) - records: List[Row] = [] + records = [] for row in range(row0, row1): record: List[Any] = [None] * self._nfields for col in range(self._nfields): @@ -182,7 +184,7 @@ class Transformer(AdaptContext): return records - def load_row(self, row: int, make_row: RowMaker) -> Optional[Row]: + def load_row(self, row: int, make_row: RowMaker[Row]) -> Optional[Row]: res = self._pgresult if not res: return None @@ -196,7 +198,7 @@ class Transformer(AdaptContext): if val is not None: record[col] = self._row_loaders[col](val) - return make_row(record) # type: ignore[no-any-return] + return make_row(record) def load_sequence( self, record: Sequence[Optional[bytes]] diff --git a/psycopg3/psycopg3/_typeinfo.py b/psycopg3/psycopg3/_typeinfo.py index d9948e5b2..338b9d57d 100644 --- a/psycopg3/psycopg3/_typeinfo.py +++ b/psycopg3/psycopg3/_typeinfo.py @@ -72,7 +72,7 @@ class TypeInfo: name = name.as_string(conn) cur = conn.cursor(binary=True, row_factory=dict_row) cur.execute(cls._info_query, {"name": name}) - recs: Sequence[Dict[str, Any]] = cur.fetchall() + recs = cur.fetchall() return cls._fetch(name, recs) @classmethod @@ -91,7 +91,7 @@ class TypeInfo: cur = conn.cursor(binary=True, row_factory=dict_row) await cur.execute(cls._info_query, {"name": name}) - recs: Sequence[Dict[str, Any]] = await cur.fetchall() + recs = await cur.fetchall() return cls._fetch(name, recs) @classmethod diff --git a/psycopg3/psycopg3/connection.py b/psycopg3/psycopg3/connection.py index 7f2b5b17a..204c460b9 100644 --- a/psycopg3/psycopg3/connection.py +++ b/psycopg3/psycopg3/connection.py @@ -23,7 +23,7 @@ from . import encodings from .pq import ConnStatus, ExecStatus, TransactionStatus, Format from .sql import Composable from .rows import tuple_row -from .proto import PQGen, PQGenConn, RV, RowFactory, Query, Params +from .proto import PQGen, PQGenConn, RV, Row, RowFactory, Query, Params from .proto import AdaptContext, ConnectionType from .cursor import Cursor, AsyncCursor from .conninfo import make_conninfo, ConnectionInfo @@ -98,7 +98,7 @@ class BaseConnection(AdaptContext): ConnStatus = pq.ConnStatus TransactionStatus = pq.TransactionStatus - row_factory: RowFactory = tuple_row + row_factory: RowFactory[Any] = tuple_row def __init__(self, pgconn: "PGconn"): self.pgconn = pgconn # TODO: document this @@ -344,7 +344,7 @@ class BaseConnection(AdaptContext): conninfo: str = "", *, autocommit: bool = False, - row_factory: RowFactory, + row_factory: RowFactory[Any], **kwargs: Any, ) -> PQGenConn[ConnectionType]: """Generator to connect to the database and create a new instance.""" @@ -443,7 +443,7 @@ class Connection(BaseConnection): conninfo: str = "", *, autocommit: bool = False, - row_factory: RowFactory = tuple_row, + row_factory: RowFactory[Any] = tuple_row, **kwargs: Any, ) -> "Connection": """ @@ -496,20 +496,24 @@ class Connection(BaseConnection): self._closed = True self.pgconn.finish() + @overload + def cursor(self, *, binary: bool = False) -> Cursor[Any]: + ... + @overload def cursor( - self, *, binary: bool = False, row_factory: Optional[RowFactory] = None - ) -> Cursor: + self, *, binary: bool = False, row_factory: RowFactory[Row] + ) -> Cursor[Row]: + ... + + @overload + def cursor(self, name: str, *, binary: bool = False) -> ServerCursor[Any]: ... @overload def cursor( - self, - name: str, - *, - binary: bool = False, - row_factory: Optional[RowFactory] = None, - ) -> ServerCursor: + self, name: str, *, binary: bool = False, row_factory: RowFactory[Row] + ) -> ServerCursor[Row]: ... def cursor( @@ -517,8 +521,8 @@ class Connection(BaseConnection): name: str = "", *, binary: bool = False, - row_factory: Optional[RowFactory] = None, - ) -> Union[Cursor, ServerCursor]: + row_factory: Optional[RowFactory[Any]] = None, + ) -> Union[Cursor[Any], ServerCursor[Any]]: """ Return a new cursor to send commands and queries to the connection. """ @@ -538,9 +542,9 @@ class Connection(BaseConnection): params: Optional[Params] = None, *, prepare: Optional[bool] = None, - ) -> Cursor: + ) -> Cursor[Any]: """Execute a query and return a cursor to read its results.""" - cur = self.cursor() + cur: Cursor[Any] = self.cursor() try: return cur.execute(query, params, prepare=prepare) except e.Error as ex: @@ -630,7 +634,7 @@ class AsyncConnection(BaseConnection): conninfo: str = "", *, autocommit: bool = False, - row_factory: RowFactory = tuple_row, + row_factory: RowFactory[Any] = tuple_row, **kwargs: Any, ) -> "AsyncConnection": return await cls._wait_conn( @@ -677,20 +681,26 @@ class AsyncConnection(BaseConnection): self._closed = True self.pgconn.finish() + @overload + def cursor(self, *, binary: bool = False) -> AsyncCursor[Any]: + ... + @overload def cursor( - self, *, binary: bool = False, row_factory: Optional[RowFactory] = None - ) -> AsyncCursor: + self, *, binary: bool = False, row_factory: RowFactory[Row] + ) -> AsyncCursor[Row]: ... @overload def cursor( - self, - name: str, - *, - binary: bool = False, - row_factory: Optional[RowFactory] = None, - ) -> AsyncServerCursor: + self, name: str, *, binary: bool = False + ) -> AsyncServerCursor[Any]: + ... + + @overload + def cursor( + self, name: str, *, binary: bool = False, row_factory: RowFactory[Row] + ) -> AsyncServerCursor[Row]: ... def cursor( @@ -698,8 +708,8 @@ class AsyncConnection(BaseConnection): name: str = "", *, binary: bool = False, - row_factory: Optional[RowFactory] = None, - ) -> Union[AsyncCursor, AsyncServerCursor]: + row_factory: Optional[RowFactory[Any]] = None, + ) -> Union[AsyncCursor[Any], AsyncServerCursor[Any]]: """ Return a new `AsyncCursor` to send commands and queries to the connection. """ @@ -719,8 +729,8 @@ class AsyncConnection(BaseConnection): params: Optional[Params] = None, *, prepare: Optional[bool] = None, - ) -> AsyncCursor: - cur = self.cursor() + ) -> AsyncCursor[Any]: + cur: AsyncCursor[Any] = self.cursor() try: return await cur.execute(query, params, prepare=prepare) except e.Error as ex: diff --git a/psycopg3/psycopg3/copy.py b/psycopg3/psycopg3/copy.py index 28e4be484..ed9fede9e 100644 --- a/psycopg3/psycopg3/copy.py +++ b/psycopg3/psycopg3/copy.py @@ -52,7 +52,7 @@ class BaseCopy(Generic[ConnectionType]): formatter: "Formatter" - def __init__(self, cursor: "BaseCursor[ConnectionType]"): + def __init__(self, cursor: "BaseCursor[ConnectionType, Any]"): self.cursor = cursor self.connection = cursor.connection self._pgconn = self.connection.pgconn @@ -153,7 +153,7 @@ class Copy(BaseCopy["Connection"]): __module__ = "psycopg3" - def __init__(self, cursor: "Cursor"): + def __init__(self, cursor: "Cursor[Any]"): super().__init__(cursor) self._queue: queue.Queue[Optional[bytes]] = queue.Queue( maxsize=self.QUEUE_SIZE @@ -285,7 +285,7 @@ class AsyncCopy(BaseCopy["AsyncConnection"]): __module__ = "psycopg3" - def __init__(self, cursor: "AsyncCursor"): + def __init__(self, cursor: "AsyncCursor[Any]"): super().__init__(cursor) self._queue: asyncio.Queue[Optional[bytes]] = asyncio.Queue( maxsize=self.QUEUE_SIZE diff --git a/psycopg3/psycopg3/cursor.py b/psycopg3/psycopg3/cursor.py index e6e981b06..1e200ce68 100644 --- a/psycopg3/psycopg3/cursor.py +++ b/psycopg3/psycopg3/cursor.py @@ -17,7 +17,6 @@ from . import generators from .pq import ExecStatus, Format from .copy import Copy, AsyncCopy -from .rows import tuple_row from .proto import ConnectionType, Query, Params, PQGen from .proto import Row, RowFactory from ._column import Column @@ -42,7 +41,7 @@ else: execute = generators.execute -class BaseCursor(Generic[ConnectionType]): +class BaseCursor(Generic[ConnectionType, Row]): # Slots with __weakref__ and generic bases don't work on Py 3.6 # https://bugs.python.org/issue41451 if sys.version_info >= (3, 7): @@ -61,7 +60,7 @@ class BaseCursor(Generic[ConnectionType]): connection: ConnectionType, *, format: Format = Format.TEXT, - row_factory: RowFactory = tuple_row, + row_factory: RowFactory[Row], ): self._conn = connection self.format = format @@ -174,12 +173,12 @@ class BaseCursor(Generic[ConnectionType]): return None @property - def row_factory(self) -> RowFactory: + def row_factory(self) -> RowFactory[Row]: """Writable attribute to control how result rows are formed.""" return self._row_factory @row_factory.setter - def row_factory(self, row_factory: RowFactory) -> None: + def row_factory(self, row_factory: RowFactory[Row]) -> None: self._row_factory = row_factory if self.pgresult: self._make_row = row_factory(self) @@ -472,11 +471,11 @@ class BaseCursor(Generic[ConnectionType]): self._pgq = pgq -class Cursor(BaseCursor["Connection"]): +class Cursor(BaseCursor["Connection", Row]): __module__ = "psycopg3" __slots__ = () - def __enter__(self) -> "Cursor": + def __enter__(self) -> "Cursor[Row]": return self def __exit__( @@ -499,7 +498,7 @@ class Cursor(BaseCursor["Connection"]): params: Optional[Params] = None, *, prepare: Optional[bool] = None, - ) -> "Cursor": + ) -> "Cursor[Row]": """ Execute a query or command to the database. """ @@ -561,7 +560,7 @@ class Cursor(BaseCursor["Connection"]): if not size: size = self.arraysize - records: List[Row] = self._tx.load_rows( + records = self._tx.load_rows( self._pos, min(self._pos + size, self.pgresult.ntuples), self._make_row, @@ -577,7 +576,7 @@ class Cursor(BaseCursor["Connection"]): """ self._check_result() assert self.pgresult - records: List[Row] = self._tx.load_rows( + records = self._tx.load_rows( self._pos, self.pgresult.ntuples, self._make_row ) self._pos = self.pgresult.ntuples @@ -623,11 +622,11 @@ class Cursor(BaseCursor["Connection"]): yield copy -class AsyncCursor(BaseCursor["AsyncConnection"]): +class AsyncCursor(BaseCursor["AsyncConnection", Row]): __module__ = "psycopg3" __slots__ = () - async def __aenter__(self) -> "AsyncCursor": + async def __aenter__(self) -> "AsyncCursor[Row]": return self async def __aexit__( @@ -647,7 +646,7 @@ class AsyncCursor(BaseCursor["AsyncConnection"]): params: Optional[Params] = None, *, prepare: Optional[bool] = None, - ) -> "AsyncCursor": + ) -> "AsyncCursor[Row]": try: async with self._conn.lock: await self._conn.wait( @@ -688,7 +687,7 @@ class AsyncCursor(BaseCursor["AsyncConnection"]): if not size: size = self.arraysize - records: List[Row] = self._tx.load_rows( + records = self._tx.load_rows( self._pos, min(self._pos + size, self.pgresult.ntuples), self._make_row, @@ -699,7 +698,7 @@ class AsyncCursor(BaseCursor["AsyncConnection"]): async def fetchall(self) -> List[Row]: self._check_result() assert self.pgresult - records: List[Row] = self._tx.load_rows( + records = self._tx.load_rows( self._pos, self.pgresult.ntuples, self._make_row ) self._pos = self.pgresult.ntuples diff --git a/psycopg3/psycopg3/proto.py b/psycopg3/psycopg3/proto.py index f06f9cd76..a00f20d7e 100644 --- a/psycopg3/psycopg3/proto.py +++ b/psycopg3/psycopg3/proto.py @@ -48,15 +48,16 @@ Wait states. # Row factories Row = TypeVar("Row") +Row_co = TypeVar("Row_co", covariant=True) -class RowMaker(Protocol): - def __call__(self, __values: Sequence[Any]) -> Any: +class RowMaker(Protocol[Row_co]): + def __call__(self, __values: Sequence[Any]) -> Row_co: ... -class RowFactory(Protocol): - def __call__(self, __cursor: "BaseCursor[Any]") -> RowMaker: +class RowFactory(Protocol[Row]): + def __call__(self, __cursor: "BaseCursor[Any, Row]") -> RowMaker[Row]: ... @@ -119,10 +120,12 @@ class Transformer(Protocol): def get_dumper(self, obj: Any, format: Format) -> "Dumper": ... - def load_rows(self, row0: int, row1: int, make_row: RowMaker) -> List[Row]: + def load_rows( + self, row0: int, row1: int, make_row: RowMaker[Row] + ) -> List[Row]: ... - def load_row(self, row: int, make_row: RowMaker) -> Optional[Row]: + def load_row(self, row: int, make_row: RowMaker[Row]) -> Optional[Row]: ... def load_sequence( diff --git a/psycopg3/psycopg3/rows.py b/psycopg3/psycopg3/rows.py index a730c5961..7e4babf1f 100644 --- a/psycopg3/psycopg3/rows.py +++ b/psycopg3/psycopg3/rows.py @@ -16,9 +16,12 @@ if TYPE_CHECKING: from .cursor import BaseCursor +TupleRow = Tuple[Any, ...] + + def tuple_row( - cursor: "BaseCursor[Any]", -) -> Callable[[Sequence[Any]], Tuple[Any, ...]]: + cursor: "BaseCursor[Any, TupleRow]", +) -> Callable[[Sequence[Any]], TupleRow]: """Row factory to represent rows as simple tuples. This is the default factory. @@ -28,9 +31,12 @@ def tuple_row( return tuple +DictRow = Dict[str, Any] + + def dict_row( - cursor: "BaseCursor[Any]", -) -> Callable[[Sequence[Any]], Dict[str, Any]]: + cursor: "BaseCursor[Any, DictRow]", +) -> Callable[[Sequence[Any]], DictRow]: """Row factory to represent rows as dicts. Note that this is not compatible with the DBAPI, which expects the records @@ -48,7 +54,7 @@ def dict_row( def namedtuple_row( - cursor: "BaseCursor[Any]", + cursor: "BaseCursor[Any, NamedTuple]", ) -> Callable[[Sequence[Any]], NamedTuple]: """Row factory to represent rows as `~collections.namedtuple`.""" diff --git a/psycopg3/psycopg3/server_cursor.py b/psycopg3/psycopg3/server_cursor.py index 4aa30f77d..9b4665431 100644 --- a/psycopg3/psycopg3/server_cursor.py +++ b/psycopg3/psycopg3/server_cursor.py @@ -12,7 +12,6 @@ from typing import Sequence, Type, TYPE_CHECKING from . import pq from . import sql from . import errors as e -from .rows import tuple_row from .cursor import BaseCursor, execute from .proto import ConnectionType, Query, Params, PQGen, Row, RowFactory @@ -23,7 +22,7 @@ if TYPE_CHECKING: DEFAULT_ITERSIZE = 100 -class ServerCursorHelper(Generic[ConnectionType]): +class ServerCursorHelper(Generic[ConnectionType, Row]): __slots__ = ("name", "described") """Helper object for common ServerCursor code. @@ -35,7 +34,7 @@ class ServerCursorHelper(Generic[ConnectionType]): self.name = name self.described = False - def _repr(self, cur: BaseCursor[ConnectionType]) -> str: + def _repr(self, cur: BaseCursor[ConnectionType, Row]) -> str: cls = f"{cur.__class__.__module__}.{cur.__class__.__qualname__}" info = pq.misc.connection_summary(cur._conn.pgconn) if cur._closed: @@ -48,7 +47,7 @@ class ServerCursorHelper(Generic[ConnectionType]): def _declare_gen( self, - cur: BaseCursor[ConnectionType], + cur: BaseCursor[ConnectionType, Row], query: Query, params: Optional[Params] = None, ) -> PQGen[None]: @@ -70,7 +69,9 @@ class ServerCursorHelper(Generic[ConnectionType]): # The above result only returned COMMAND_OK. Get the cursor shape yield from self._describe_gen(cur) - def _describe_gen(self, cur: BaseCursor[ConnectionType]) -> PQGen[None]: + def _describe_gen( + self, cur: BaseCursor[ConnectionType, Row] + ) -> PQGen[None]: conn = cur._conn conn.pgconn.send_describe_portal( self.name.encode(conn.client_encoding) @@ -79,7 +80,7 @@ class ServerCursorHelper(Generic[ConnectionType]): cur._execute_results(results) self.described = True - def _close_gen(self, cur: BaseCursor[ConnectionType]) -> PQGen[None]: + def _close_gen(self, cur: BaseCursor[ConnectionType, Row]) -> PQGen[None]: # if the connection is not in a sane state, don't even try if cur._conn.pgconn.transaction_status not in ( pq.TransactionStatus.IDLE, @@ -101,7 +102,7 @@ class ServerCursorHelper(Generic[ConnectionType]): yield from cur._conn._exec_command(query) def _fetch_gen( - self, cur: BaseCursor[ConnectionType], num: Optional[int] + self, cur: BaseCursor[ConnectionType, Row], num: Optional[int] ) -> PQGen[List[Row]]: # If we are stealing the cursor, make sure we know its shape if not self.described: @@ -123,7 +124,7 @@ class ServerCursorHelper(Generic[ConnectionType]): return cur._tx.load_rows(0, res.ntuples, cur._make_row) def _scroll_gen( - self, cur: BaseCursor[ConnectionType], value: int, mode: str + self, cur: BaseCursor[ConnectionType, Row], value: int, mode: str ) -> PQGen[None]: if mode not in ("relative", "absolute"): raise ValueError( @@ -138,7 +139,7 @@ class ServerCursorHelper(Generic[ConnectionType]): def _make_declare_statement( self, - cur: BaseCursor[ConnectionType], + cur: BaseCursor[ConnectionType, Row], query: Query, scrollable: Optional[bool], hold: bool, @@ -164,7 +165,7 @@ class ServerCursorHelper(Generic[ConnectionType]): return sql.SQL(" ").join(parts) -class ServerCursor(BaseCursor["Connection"]): +class ServerCursor(BaseCursor["Connection", Row]): __module__ = "psycopg3" __slots__ = ("_helper", "itersize") @@ -174,10 +175,10 @@ class ServerCursor(BaseCursor["Connection"]): name: str, *, format: pq.Format = pq.Format.TEXT, - row_factory: RowFactory = tuple_row, + row_factory: RowFactory[Row], ): super().__init__(connection, format=format, row_factory=row_factory) - self._helper: ServerCursorHelper["Connection"] + self._helper: ServerCursorHelper["Connection", Row] self._helper = ServerCursorHelper(name) self.itersize: int = DEFAULT_ITERSIZE @@ -192,7 +193,7 @@ class ServerCursor(BaseCursor["Connection"]): def __repr__(self) -> str: return self._helper._repr(self) - def __enter__(self) -> "ServerCursor": + def __enter__(self) -> "ServerCursor[Row]": return self def __exit__( @@ -223,7 +224,7 @@ class ServerCursor(BaseCursor["Connection"]): *, scrollable: Optional[bool] = None, hold: bool = False, - ) -> "ServerCursor": + ) -> "ServerCursor[Row]": """ Open a cursor to execute a query to the database. """ @@ -242,7 +243,7 @@ class ServerCursor(BaseCursor["Connection"]): def fetchone(self) -> Optional[Row]: with self._conn.lock: - recs: List[Row] = self._conn.wait(self._helper._fetch_gen(self, 1)) + recs = self._conn.wait(self._helper._fetch_gen(self, 1)) if recs: self._pos += 1 return recs[0] @@ -253,24 +254,20 @@ class ServerCursor(BaseCursor["Connection"]): if not size: size = self.arraysize with self._conn.lock: - recs: List[Row] = self._conn.wait( - self._helper._fetch_gen(self, size) - ) + recs = self._conn.wait(self._helper._fetch_gen(self, size)) self._pos += len(recs) return recs def fetchall(self) -> Sequence[Row]: with self._conn.lock: - recs: List[Row] = self._conn.wait( - self._helper._fetch_gen(self, None) - ) + recs = self._conn.wait(self._helper._fetch_gen(self, None)) self._pos += len(recs) return recs def __iter__(self) -> Iterator[Row]: while True: with self._conn.lock: - recs: List[Row] = self._conn.wait( + recs = self._conn.wait( self._helper._fetch_gen(self, self.itersize) ) for rec in recs: @@ -289,7 +286,7 @@ class ServerCursor(BaseCursor["Connection"]): self._pos = value -class AsyncServerCursor(BaseCursor["AsyncConnection"]): +class AsyncServerCursor(BaseCursor["AsyncConnection", Row]): __module__ = "psycopg3" __slots__ = ("_helper", "itersize") @@ -299,10 +296,10 @@ class AsyncServerCursor(BaseCursor["AsyncConnection"]): name: str, *, format: pq.Format = pq.Format.TEXT, - row_factory: RowFactory = tuple_row, + row_factory: RowFactory[Row], ): super().__init__(connection, format=format, row_factory=row_factory) - self._helper: ServerCursorHelper["AsyncConnection"] + self._helper: ServerCursorHelper["AsyncConnection", Row] self._helper = ServerCursorHelper(name) self.itersize: int = DEFAULT_ITERSIZE @@ -317,7 +314,7 @@ class AsyncServerCursor(BaseCursor["AsyncConnection"]): def __repr__(self) -> str: return self._helper._repr(self) - async def __aenter__(self) -> "AsyncServerCursor": + async def __aenter__(self) -> "AsyncServerCursor[Row]": return self async def __aexit__( @@ -344,7 +341,7 @@ class AsyncServerCursor(BaseCursor["AsyncConnection"]): *, scrollable: Optional[bool] = None, hold: bool = False, - ) -> "AsyncServerCursor": + ) -> "AsyncServerCursor[Row]": query = self._helper._make_declare_statement( self, query, scrollable=scrollable, hold=hold ) @@ -363,9 +360,7 @@ class AsyncServerCursor(BaseCursor["AsyncConnection"]): async def fetchone(self) -> Optional[Row]: async with self._conn.lock: - recs: List[Row] = await self._conn.wait( - self._helper._fetch_gen(self, 1) - ) + recs = await self._conn.wait(self._helper._fetch_gen(self, 1)) if recs: self._pos += 1 return recs[0] @@ -376,24 +371,20 @@ class AsyncServerCursor(BaseCursor["AsyncConnection"]): if not size: size = self.arraysize async with self._conn.lock: - recs: List[Row] = await self._conn.wait( - self._helper._fetch_gen(self, size) - ) + recs = await self._conn.wait(self._helper._fetch_gen(self, size)) self._pos += len(recs) return recs async def fetchall(self) -> Sequence[Row]: async with self._conn.lock: - recs: List[Row] = await self._conn.wait( - self._helper._fetch_gen(self, None) - ) + recs = await self._conn.wait(self._helper._fetch_gen(self, None)) self._pos += len(recs) return recs async def __aiter__(self) -> AsyncIterator[Row]: while True: async with self._conn.lock: - recs: List[Row] = await self._conn.wait( + recs = await self._conn.wait( self._helper._fetch_gen(self, self.itersize) ) for rec in recs: diff --git a/psycopg3_c/psycopg3_c/_psycopg3.pyi b/psycopg3_c/psycopg3_c/_psycopg3.pyi index caf380fe9..66e3b4a4c 100644 --- a/psycopg3_c/psycopg3_c/_psycopg3.pyi +++ b/psycopg3_c/psycopg3_c/_psycopg3.pyi @@ -34,10 +34,10 @@ class Transformer(proto.AdaptContext): ) -> Tuple[List[Any], Tuple[int, ...], Sequence[pq.Format]]: ... def get_dumper(self, obj: Any, format: Format) -> Dumper: ... def load_rows( - self, row0: int, row1: int, make_row: proto.RowMaker + self, row0: int, row1: int, make_row: proto.RowMaker[proto.Row] ) -> List[proto.Row]: ... def load_row( - self, row: int, make_row: proto.RowMaker + self, row: int, make_row: proto.RowMaker[proto.Row] ) -> Optional[proto.Row]: ... def load_sequence( self, record: Sequence[Optional[bytes]]