]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Make Cursor generic on Row
authorDenis Laxalde <denis.laxalde@dalibo.com>
Thu, 15 Apr 2021 15:49:04 +0000 (17:49 +0200)
committerDenis Laxalde <denis.laxalde@dalibo.com>
Mon, 26 Apr 2021 13:56:10 +0000 (15:56 +0200)
We make RowMaker, RowFactory and Cursor types generic on a Row type
variable thus making type inference work on cursor's fetch*() methods.

For example:

    R = TypeVar("R")
    def my_row_factory(cursor: BaseCursor[Any, R]) -> Callable[[Sequence[Any]], R]:
        ...

    with conn.cursor(row_factory=my_row_factory) as cur:
        cur.execute("select 1")
        reveal_type(cur)
        # Revealed type is 'psycopg3.cursor.Cursor[R`-1]'
        r = cur.fetchone()
        reveal_type(r)
        # Revealed type is 'Union[R`-1, None]'

The definition of RowMaker and RowFactory protocols needs two distinct
type variable because the former is covariant on Row (using 'Row_co'
type variable) and the latter is invariant on Row.

In Cursor.__init__(), row_factory argument is now required as we remove
its default value 'tuple_row'; this is helpful in order to keep Cursor
definition generic on Row, which would be more difficult when specifying
a concrete RowFactory by default binding Row to Tuple.

The Connection is not (yet) generic on Row, so we use RowFactory[Any].
Still, in cursor() methods, we get a fully typed Cursor value when a
row_factory argument is passed. We add two overloaded variants of these
cursor() methods depending on whether row_factory is passed or not (in
the former case, we return a Cursor[Row], in the latter case, a
Cursor[Any]).

A noticeable improvement is that we no longer need to explicitly declare
or ignore types in Transformer's load_row() and load_rows() as this is
not correctly inferred. Similarly, type annotations are not needed
anymore in callers of these methods (Cursor's fetch*() methods).
In TypeInfo's fetch*() method, we can drop superfluous type annotations.

psycopg3/psycopg3/_column.py
psycopg3/psycopg3/_transform.py
psycopg3/psycopg3/_typeinfo.py
psycopg3/psycopg3/connection.py
psycopg3/psycopg3/copy.py
psycopg3/psycopg3/cursor.py
psycopg3/psycopg3/proto.py
psycopg3/psycopg3/rows.py
psycopg3/psycopg3/server_cursor.py
psycopg3_c/psycopg3_c/_psycopg3.pyi

index 260c399136465c5a5f3c5d2134b64ef4ea868dea..caeddc5636ac5adb1fa39161d1698f8f0ea6f054 100644 (file)
@@ -23,7 +23,7 @@ class Column(Sequence[Any]):
 
     __module__ = "psycopg3"
 
-    def __init__(self, cursor: "BaseCursor[Any]", index: int):
+    def __init__(self, cursor: "BaseCursor[Any, Any]", index: int):
         res = cursor.pgresult
         assert res
 
index 6f15f6d557d6f485c84580ddd723efa25a635076..ec0119790f673b616acad4d154f2ee3bf58857c8 100644 (file)
@@ -161,7 +161,9 @@ class Transformer(AdaptContext):
             dumper = cache[key1] = dumper.upgrade(obj, format)
             return dumper
 
-    def load_rows(self, row0: int, row1: int, make_row: RowMaker) -> List[Row]:
+    def load_rows(
+        self, row0: int, row1: int, make_row: RowMaker[Row]
+    ) -> List[Row]:
         res = self._pgresult
         if not res:
             raise e.InterfaceError("result not set")
@@ -171,7 +173,7 @@ class Transformer(AdaptContext):
                 f"rows must be included between 0 and {self._ntuples}"
             )
 
-        records: List[Row] = []
+        records = []
         for row in range(row0, row1):
             record: List[Any] = [None] * self._nfields
             for col in range(self._nfields):
@@ -182,7 +184,7 @@ class Transformer(AdaptContext):
 
         return records
 
-    def load_row(self, row: int, make_row: RowMaker) -> Optional[Row]:
+    def load_row(self, row: int, make_row: RowMaker[Row]) -> Optional[Row]:
         res = self._pgresult
         if not res:
             return None
@@ -196,7 +198,7 @@ class Transformer(AdaptContext):
             if val is not None:
                 record[col] = self._row_loaders[col](val)
 
-        return make_row(record)  # type: ignore[no-any-return]
+        return make_row(record)
 
     def load_sequence(
         self, record: Sequence[Optional[bytes]]
index d9948e5b2fcf186ac7f9c40d23df7423020b31b3..338b9d57df94aa84317e7cc3f5938e3e13859a6d 100644 (file)
@@ -72,7 +72,7 @@ class TypeInfo:
             name = name.as_string(conn)
         cur = conn.cursor(binary=True, row_factory=dict_row)
         cur.execute(cls._info_query, {"name": name})
-        recs: Sequence[Dict[str, Any]] = cur.fetchall()
+        recs = cur.fetchall()
         return cls._fetch(name, recs)
 
     @classmethod
@@ -91,7 +91,7 @@ class TypeInfo:
 
         cur = conn.cursor(binary=True, row_factory=dict_row)
         await cur.execute(cls._info_query, {"name": name})
-        recs: Sequence[Dict[str, Any]] = await cur.fetchall()
+        recs = await cur.fetchall()
         return cls._fetch(name, recs)
 
     @classmethod
index 7f2b5b17acf27d165851948c11b1261b98632789..204c460b927ff5f4929e8be7190a20d97b6b7548 100644 (file)
@@ -23,7 +23,7 @@ from . import encodings
 from .pq import ConnStatus, ExecStatus, TransactionStatus, Format
 from .sql import Composable
 from .rows import tuple_row
-from .proto import PQGen, PQGenConn, RV, RowFactory, Query, Params
+from .proto import PQGen, PQGenConn, RV, Row, RowFactory, Query, Params
 from .proto import AdaptContext, ConnectionType
 from .cursor import Cursor, AsyncCursor
 from .conninfo import make_conninfo, ConnectionInfo
@@ -98,7 +98,7 @@ class BaseConnection(AdaptContext):
     ConnStatus = pq.ConnStatus
     TransactionStatus = pq.TransactionStatus
 
-    row_factory: RowFactory = tuple_row
+    row_factory: RowFactory[Any] = tuple_row
 
     def __init__(self, pgconn: "PGconn"):
         self.pgconn = pgconn  # TODO: document this
@@ -344,7 +344,7 @@ class BaseConnection(AdaptContext):
         conninfo: str = "",
         *,
         autocommit: bool = False,
-        row_factory: RowFactory,
+        row_factory: RowFactory[Any],
         **kwargs: Any,
     ) -> PQGenConn[ConnectionType]:
         """Generator to connect to the database and create a new instance."""
@@ -443,7 +443,7 @@ class Connection(BaseConnection):
         conninfo: str = "",
         *,
         autocommit: bool = False,
-        row_factory: RowFactory = tuple_row,
+        row_factory: RowFactory[Any] = tuple_row,
         **kwargs: Any,
     ) -> "Connection":
         """
@@ -496,20 +496,24 @@ class Connection(BaseConnection):
         self._closed = True
         self.pgconn.finish()
 
+    @overload
+    def cursor(self, *, binary: bool = False) -> Cursor[Any]:
+        ...
+
     @overload
     def cursor(
-        self, *, binary: bool = False, row_factory: Optional[RowFactory] = None
-    ) -> Cursor:
+        self, *, binary: bool = False, row_factory: RowFactory[Row]
+    ) -> Cursor[Row]:
+        ...
+
+    @overload
+    def cursor(self, name: str, *, binary: bool = False) -> ServerCursor[Any]:
         ...
 
     @overload
     def cursor(
-        self,
-        name: str,
-        *,
-        binary: bool = False,
-        row_factory: Optional[RowFactory] = None,
-    ) -> ServerCursor:
+        self, name: str, *, binary: bool = False, row_factory: RowFactory[Row]
+    ) -> ServerCursor[Row]:
         ...
 
     def cursor(
@@ -517,8 +521,8 @@ class Connection(BaseConnection):
         name: str = "",
         *,
         binary: bool = False,
-        row_factory: Optional[RowFactory] = None,
-    ) -> Union[Cursor, ServerCursor]:
+        row_factory: Optional[RowFactory[Any]] = None,
+    ) -> Union[Cursor[Any], ServerCursor[Any]]:
         """
         Return a new cursor to send commands and queries to the connection.
         """
@@ -538,9 +542,9 @@ class Connection(BaseConnection):
         params: Optional[Params] = None,
         *,
         prepare: Optional[bool] = None,
-    ) -> Cursor:
+    ) -> Cursor[Any]:
         """Execute a query and return a cursor to read its results."""
-        cur = self.cursor()
+        cur: Cursor[Any] = self.cursor()
         try:
             return cur.execute(query, params, prepare=prepare)
         except e.Error as ex:
@@ -630,7 +634,7 @@ class AsyncConnection(BaseConnection):
         conninfo: str = "",
         *,
         autocommit: bool = False,
-        row_factory: RowFactory = tuple_row,
+        row_factory: RowFactory[Any] = tuple_row,
         **kwargs: Any,
     ) -> "AsyncConnection":
         return await cls._wait_conn(
@@ -677,20 +681,26 @@ class AsyncConnection(BaseConnection):
         self._closed = True
         self.pgconn.finish()
 
+    @overload
+    def cursor(self, *, binary: bool = False) -> AsyncCursor[Any]:
+        ...
+
     @overload
     def cursor(
-        self, *, binary: bool = False, row_factory: Optional[RowFactory] = None
-    ) -> AsyncCursor:
+        self, *, binary: bool = False, row_factory: RowFactory[Row]
+    ) -> AsyncCursor[Row]:
         ...
 
     @overload
     def cursor(
-        self,
-        name: str,
-        *,
-        binary: bool = False,
-        row_factory: Optional[RowFactory] = None,
-    ) -> AsyncServerCursor:
+        self, name: str, *, binary: bool = False
+    ) -> AsyncServerCursor[Any]:
+        ...
+
+    @overload
+    def cursor(
+        self, name: str, *, binary: bool = False, row_factory: RowFactory[Row]
+    ) -> AsyncServerCursor[Row]:
         ...
 
     def cursor(
@@ -698,8 +708,8 @@ class AsyncConnection(BaseConnection):
         name: str = "",
         *,
         binary: bool = False,
-        row_factory: Optional[RowFactory] = None,
-    ) -> Union[AsyncCursor, AsyncServerCursor]:
+        row_factory: Optional[RowFactory[Any]] = None,
+    ) -> Union[AsyncCursor[Any], AsyncServerCursor[Any]]:
         """
         Return a new `AsyncCursor` to send commands and queries to the connection.
         """
@@ -719,8 +729,8 @@ class AsyncConnection(BaseConnection):
         params: Optional[Params] = None,
         *,
         prepare: Optional[bool] = None,
-    ) -> AsyncCursor:
-        cur = self.cursor()
+    ) -> AsyncCursor[Any]:
+        cur: AsyncCursor[Any] = self.cursor()
         try:
             return await cur.execute(query, params, prepare=prepare)
         except e.Error as ex:
index 28e4be484af1f0a964bc51901f8c7d63eb98328c..ed9fede9eab8dbe66670a8cb53d2b58593ee1e05 100644 (file)
@@ -52,7 +52,7 @@ class BaseCopy(Generic[ConnectionType]):
 
     formatter: "Formatter"
 
-    def __init__(self, cursor: "BaseCursor[ConnectionType]"):
+    def __init__(self, cursor: "BaseCursor[ConnectionType, Any]"):
         self.cursor = cursor
         self.connection = cursor.connection
         self._pgconn = self.connection.pgconn
@@ -153,7 +153,7 @@ class Copy(BaseCopy["Connection"]):
 
     __module__ = "psycopg3"
 
-    def __init__(self, cursor: "Cursor"):
+    def __init__(self, cursor: "Cursor[Any]"):
         super().__init__(cursor)
         self._queue: queue.Queue[Optional[bytes]] = queue.Queue(
             maxsize=self.QUEUE_SIZE
@@ -285,7 +285,7 @@ class AsyncCopy(BaseCopy["AsyncConnection"]):
 
     __module__ = "psycopg3"
 
-    def __init__(self, cursor: "AsyncCursor"):
+    def __init__(self, cursor: "AsyncCursor[Any]"):
         super().__init__(cursor)
         self._queue: asyncio.Queue[Optional[bytes]] = asyncio.Queue(
             maxsize=self.QUEUE_SIZE
index e6e981b06d7b4aa018b6bb6559ac6a4cc8a9b8c6..1e200ce68ef742cf5a6707969340f152258b44fa 100644 (file)
@@ -17,7 +17,6 @@ from . import generators
 
 from .pq import ExecStatus, Format
 from .copy import Copy, AsyncCopy
-from .rows import tuple_row
 from .proto import ConnectionType, Query, Params, PQGen
 from .proto import Row, RowFactory
 from ._column import Column
@@ -42,7 +41,7 @@ else:
     execute = generators.execute
 
 
-class BaseCursor(Generic[ConnectionType]):
+class BaseCursor(Generic[ConnectionType, Row]):
     # Slots with __weakref__ and generic bases don't work on Py 3.6
     # https://bugs.python.org/issue41451
     if sys.version_info >= (3, 7):
@@ -61,7 +60,7 @@ class BaseCursor(Generic[ConnectionType]):
         connection: ConnectionType,
         *,
         format: Format = Format.TEXT,
-        row_factory: RowFactory = tuple_row,
+        row_factory: RowFactory[Row],
     ):
         self._conn = connection
         self.format = format
@@ -174,12 +173,12 @@ class BaseCursor(Generic[ConnectionType]):
             return None
 
     @property
-    def row_factory(self) -> RowFactory:
+    def row_factory(self) -> RowFactory[Row]:
         """Writable attribute to control how result rows are formed."""
         return self._row_factory
 
     @row_factory.setter
-    def row_factory(self, row_factory: RowFactory) -> None:
+    def row_factory(self, row_factory: RowFactory[Row]) -> None:
         self._row_factory = row_factory
         if self.pgresult:
             self._make_row = row_factory(self)
@@ -472,11 +471,11 @@ class BaseCursor(Generic[ConnectionType]):
         self._pgq = pgq
 
 
-class Cursor(BaseCursor["Connection"]):
+class Cursor(BaseCursor["Connection", Row]):
     __module__ = "psycopg3"
     __slots__ = ()
 
-    def __enter__(self) -> "Cursor":
+    def __enter__(self) -> "Cursor[Row]":
         return self
 
     def __exit__(
@@ -499,7 +498,7 @@ class Cursor(BaseCursor["Connection"]):
         params: Optional[Params] = None,
         *,
         prepare: Optional[bool] = None,
-    ) -> "Cursor":
+    ) -> "Cursor[Row]":
         """
         Execute a query or command to the database.
         """
@@ -561,7 +560,7 @@ class Cursor(BaseCursor["Connection"]):
 
         if not size:
             size = self.arraysize
-        records: List[Row] = self._tx.load_rows(
+        records = self._tx.load_rows(
             self._pos,
             min(self._pos + size, self.pgresult.ntuples),
             self._make_row,
@@ -577,7 +576,7 @@ class Cursor(BaseCursor["Connection"]):
         """
         self._check_result()
         assert self.pgresult
-        records: List[Row] = self._tx.load_rows(
+        records = self._tx.load_rows(
             self._pos, self.pgresult.ntuples, self._make_row
         )
         self._pos = self.pgresult.ntuples
@@ -623,11 +622,11 @@ class Cursor(BaseCursor["Connection"]):
             yield copy
 
 
-class AsyncCursor(BaseCursor["AsyncConnection"]):
+class AsyncCursor(BaseCursor["AsyncConnection", Row]):
     __module__ = "psycopg3"
     __slots__ = ()
 
-    async def __aenter__(self) -> "AsyncCursor":
+    async def __aenter__(self) -> "AsyncCursor[Row]":
         return self
 
     async def __aexit__(
@@ -647,7 +646,7 @@ class AsyncCursor(BaseCursor["AsyncConnection"]):
         params: Optional[Params] = None,
         *,
         prepare: Optional[bool] = None,
-    ) -> "AsyncCursor":
+    ) -> "AsyncCursor[Row]":
         try:
             async with self._conn.lock:
                 await self._conn.wait(
@@ -688,7 +687,7 @@ class AsyncCursor(BaseCursor["AsyncConnection"]):
 
         if not size:
             size = self.arraysize
-        records: List[Row] = self._tx.load_rows(
+        records = self._tx.load_rows(
             self._pos,
             min(self._pos + size, self.pgresult.ntuples),
             self._make_row,
@@ -699,7 +698,7 @@ class AsyncCursor(BaseCursor["AsyncConnection"]):
     async def fetchall(self) -> List[Row]:
         self._check_result()
         assert self.pgresult
-        records: List[Row] = self._tx.load_rows(
+        records = self._tx.load_rows(
             self._pos, self.pgresult.ntuples, self._make_row
         )
         self._pos = self.pgresult.ntuples
index f06f9cd7690f3ea5d37c3e319fa31e10db890e59..a00f20d7ea58cf68159a73e0551ceedd7ae2a2df 100644 (file)
@@ -48,15 +48,16 @@ Wait states.
 # Row factories
 
 Row = TypeVar("Row")
+Row_co = TypeVar("Row_co", covariant=True)
 
 
-class RowMaker(Protocol):
-    def __call__(self, __values: Sequence[Any]) -> Any:
+class RowMaker(Protocol[Row_co]):
+    def __call__(self, __values: Sequence[Any]) -> Row_co:
         ...
 
 
-class RowFactory(Protocol):
-    def __call__(self, __cursor: "BaseCursor[Any]") -> RowMaker:
+class RowFactory(Protocol[Row]):
+    def __call__(self, __cursor: "BaseCursor[Any, Row]") -> RowMaker[Row]:
         ...
 
 
@@ -119,10 +120,12 @@ class Transformer(Protocol):
     def get_dumper(self, obj: Any, format: Format) -> "Dumper":
         ...
 
-    def load_rows(self, row0: int, row1: int, make_row: RowMaker) -> List[Row]:
+    def load_rows(
+        self, row0: int, row1: int, make_row: RowMaker[Row]
+    ) -> List[Row]:
         ...
 
-    def load_row(self, row: int, make_row: RowMaker) -> Optional[Row]:
+    def load_row(self, row: int, make_row: RowMaker[Row]) -> Optional[Row]:
         ...
 
     def load_sequence(
index a730c5961aa847c0b9bec81c9f40a50735f3ecc7..7e4babf1fc6b79d2d5f9d5a37aa18dd7194ed5ea 100644 (file)
@@ -16,9 +16,12 @@ if TYPE_CHECKING:
     from .cursor import BaseCursor
 
 
+TupleRow = Tuple[Any, ...]
+
+
 def tuple_row(
-    cursor: "BaseCursor[Any]",
-) -> Callable[[Sequence[Any]], Tuple[Any, ...]]:
+    cursor: "BaseCursor[Any, TupleRow]",
+) -> Callable[[Sequence[Any]], TupleRow]:
     """Row factory to represent rows as simple tuples.
 
     This is the default factory.
@@ -28,9 +31,12 @@ def tuple_row(
     return tuple
 
 
+DictRow = Dict[str, Any]
+
+
 def dict_row(
-    cursor: "BaseCursor[Any]",
-) -> Callable[[Sequence[Any]], Dict[str, Any]]:
+    cursor: "BaseCursor[Any, DictRow]",
+) -> Callable[[Sequence[Any]], DictRow]:
     """Row factory to represent rows as dicts.
 
     Note that this is not compatible with the DBAPI, which expects the records
@@ -48,7 +54,7 @@ def dict_row(
 
 
 def namedtuple_row(
-    cursor: "BaseCursor[Any]",
+    cursor: "BaseCursor[Any, NamedTuple]",
 ) -> Callable[[Sequence[Any]], NamedTuple]:
     """Row factory to represent rows as `~collections.namedtuple`."""
 
index 4aa30f77d485aa4516928daaa40e06f2a908b3b4..9b46654312c67173270e19e75c41af48bcc4505b 100644 (file)
@@ -12,7 +12,6 @@ from typing import Sequence, Type, TYPE_CHECKING
 from . import pq
 from . import sql
 from . import errors as e
-from .rows import tuple_row
 from .cursor import BaseCursor, execute
 from .proto import ConnectionType, Query, Params, PQGen, Row, RowFactory
 
@@ -23,7 +22,7 @@ if TYPE_CHECKING:
 DEFAULT_ITERSIZE = 100
 
 
-class ServerCursorHelper(Generic[ConnectionType]):
+class ServerCursorHelper(Generic[ConnectionType, Row]):
     __slots__ = ("name", "described")
     """Helper object for common ServerCursor code.
 
@@ -35,7 +34,7 @@ class ServerCursorHelper(Generic[ConnectionType]):
         self.name = name
         self.described = False
 
-    def _repr(self, cur: BaseCursor[ConnectionType]) -> str:
+    def _repr(self, cur: BaseCursor[ConnectionType, Row]) -> str:
         cls = f"{cur.__class__.__module__}.{cur.__class__.__qualname__}"
         info = pq.misc.connection_summary(cur._conn.pgconn)
         if cur._closed:
@@ -48,7 +47,7 @@ class ServerCursorHelper(Generic[ConnectionType]):
 
     def _declare_gen(
         self,
-        cur: BaseCursor[ConnectionType],
+        cur: BaseCursor[ConnectionType, Row],
         query: Query,
         params: Optional[Params] = None,
     ) -> PQGen[None]:
@@ -70,7 +69,9 @@ class ServerCursorHelper(Generic[ConnectionType]):
         # The above result only returned COMMAND_OK. Get the cursor shape
         yield from self._describe_gen(cur)
 
-    def _describe_gen(self, cur: BaseCursor[ConnectionType]) -> PQGen[None]:
+    def _describe_gen(
+        self, cur: BaseCursor[ConnectionType, Row]
+    ) -> PQGen[None]:
         conn = cur._conn
         conn.pgconn.send_describe_portal(
             self.name.encode(conn.client_encoding)
@@ -79,7 +80,7 @@ class ServerCursorHelper(Generic[ConnectionType]):
         cur._execute_results(results)
         self.described = True
 
-    def _close_gen(self, cur: BaseCursor[ConnectionType]) -> PQGen[None]:
+    def _close_gen(self, cur: BaseCursor[ConnectionType, Row]) -> PQGen[None]:
         # if the connection is not in a sane state, don't even try
         if cur._conn.pgconn.transaction_status not in (
             pq.TransactionStatus.IDLE,
@@ -101,7 +102,7 @@ class ServerCursorHelper(Generic[ConnectionType]):
         yield from cur._conn._exec_command(query)
 
     def _fetch_gen(
-        self, cur: BaseCursor[ConnectionType], num: Optional[int]
+        self, cur: BaseCursor[ConnectionType, Row], num: Optional[int]
     ) -> PQGen[List[Row]]:
         # If we are stealing the cursor, make sure we know its shape
         if not self.described:
@@ -123,7 +124,7 @@ class ServerCursorHelper(Generic[ConnectionType]):
         return cur._tx.load_rows(0, res.ntuples, cur._make_row)
 
     def _scroll_gen(
-        self, cur: BaseCursor[ConnectionType], value: int, mode: str
+        self, cur: BaseCursor[ConnectionType, Row], value: int, mode: str
     ) -> PQGen[None]:
         if mode not in ("relative", "absolute"):
             raise ValueError(
@@ -138,7 +139,7 @@ class ServerCursorHelper(Generic[ConnectionType]):
 
     def _make_declare_statement(
         self,
-        cur: BaseCursor[ConnectionType],
+        cur: BaseCursor[ConnectionType, Row],
         query: Query,
         scrollable: Optional[bool],
         hold: bool,
@@ -164,7 +165,7 @@ class ServerCursorHelper(Generic[ConnectionType]):
         return sql.SQL(" ").join(parts)
 
 
-class ServerCursor(BaseCursor["Connection"]):
+class ServerCursor(BaseCursor["Connection", Row]):
     __module__ = "psycopg3"
     __slots__ = ("_helper", "itersize")
 
@@ -174,10 +175,10 @@ class ServerCursor(BaseCursor["Connection"]):
         name: str,
         *,
         format: pq.Format = pq.Format.TEXT,
-        row_factory: RowFactory = tuple_row,
+        row_factory: RowFactory[Row],
     ):
         super().__init__(connection, format=format, row_factory=row_factory)
-        self._helper: ServerCursorHelper["Connection"]
+        self._helper: ServerCursorHelper["Connection", Row]
         self._helper = ServerCursorHelper(name)
         self.itersize: int = DEFAULT_ITERSIZE
 
@@ -192,7 +193,7 @@ class ServerCursor(BaseCursor["Connection"]):
     def __repr__(self) -> str:
         return self._helper._repr(self)
 
-    def __enter__(self) -> "ServerCursor":
+    def __enter__(self) -> "ServerCursor[Row]":
         return self
 
     def __exit__(
@@ -223,7 +224,7 @@ class ServerCursor(BaseCursor["Connection"]):
         *,
         scrollable: Optional[bool] = None,
         hold: bool = False,
-    ) -> "ServerCursor":
+    ) -> "ServerCursor[Row]":
         """
         Open a cursor to execute a query to the database.
         """
@@ -242,7 +243,7 @@ class ServerCursor(BaseCursor["Connection"]):
 
     def fetchone(self) -> Optional[Row]:
         with self._conn.lock:
-            recs: List[Row] = self._conn.wait(self._helper._fetch_gen(self, 1))
+            recs = self._conn.wait(self._helper._fetch_gen(self, 1))
         if recs:
             self._pos += 1
             return recs[0]
@@ -253,24 +254,20 @@ class ServerCursor(BaseCursor["Connection"]):
         if not size:
             size = self.arraysize
         with self._conn.lock:
-            recs: List[Row] = self._conn.wait(
-                self._helper._fetch_gen(self, size)
-            )
+            recs = self._conn.wait(self._helper._fetch_gen(self, size))
         self._pos += len(recs)
         return recs
 
     def fetchall(self) -> Sequence[Row]:
         with self._conn.lock:
-            recs: List[Row] = self._conn.wait(
-                self._helper._fetch_gen(self, None)
-            )
+            recs = self._conn.wait(self._helper._fetch_gen(self, None))
         self._pos += len(recs)
         return recs
 
     def __iter__(self) -> Iterator[Row]:
         while True:
             with self._conn.lock:
-                recs: List[Row] = self._conn.wait(
+                recs = self._conn.wait(
                     self._helper._fetch_gen(self, self.itersize)
                 )
             for rec in recs:
@@ -289,7 +286,7 @@ class ServerCursor(BaseCursor["Connection"]):
             self._pos = value
 
 
-class AsyncServerCursor(BaseCursor["AsyncConnection"]):
+class AsyncServerCursor(BaseCursor["AsyncConnection", Row]):
     __module__ = "psycopg3"
     __slots__ = ("_helper", "itersize")
 
@@ -299,10 +296,10 @@ class AsyncServerCursor(BaseCursor["AsyncConnection"]):
         name: str,
         *,
         format: pq.Format = pq.Format.TEXT,
-        row_factory: RowFactory = tuple_row,
+        row_factory: RowFactory[Row],
     ):
         super().__init__(connection, format=format, row_factory=row_factory)
-        self._helper: ServerCursorHelper["AsyncConnection"]
+        self._helper: ServerCursorHelper["AsyncConnection", Row]
         self._helper = ServerCursorHelper(name)
         self.itersize: int = DEFAULT_ITERSIZE
 
@@ -317,7 +314,7 @@ class AsyncServerCursor(BaseCursor["AsyncConnection"]):
     def __repr__(self) -> str:
         return self._helper._repr(self)
 
-    async def __aenter__(self) -> "AsyncServerCursor":
+    async def __aenter__(self) -> "AsyncServerCursor[Row]":
         return self
 
     async def __aexit__(
@@ -344,7 +341,7 @@ class AsyncServerCursor(BaseCursor["AsyncConnection"]):
         *,
         scrollable: Optional[bool] = None,
         hold: bool = False,
-    ) -> "AsyncServerCursor":
+    ) -> "AsyncServerCursor[Row]":
         query = self._helper._make_declare_statement(
             self, query, scrollable=scrollable, hold=hold
         )
@@ -363,9 +360,7 @@ class AsyncServerCursor(BaseCursor["AsyncConnection"]):
 
     async def fetchone(self) -> Optional[Row]:
         async with self._conn.lock:
-            recs: List[Row] = await self._conn.wait(
-                self._helper._fetch_gen(self, 1)
-            )
+            recs = await self._conn.wait(self._helper._fetch_gen(self, 1))
         if recs:
             self._pos += 1
             return recs[0]
@@ -376,24 +371,20 @@ class AsyncServerCursor(BaseCursor["AsyncConnection"]):
         if not size:
             size = self.arraysize
         async with self._conn.lock:
-            recs: List[Row] = await self._conn.wait(
-                self._helper._fetch_gen(self, size)
-            )
+            recs = await self._conn.wait(self._helper._fetch_gen(self, size))
         self._pos += len(recs)
         return recs
 
     async def fetchall(self) -> Sequence[Row]:
         async with self._conn.lock:
-            recs: List[Row] = await self._conn.wait(
-                self._helper._fetch_gen(self, None)
-            )
+            recs = await self._conn.wait(self._helper._fetch_gen(self, None))
         self._pos += len(recs)
         return recs
 
     async def __aiter__(self) -> AsyncIterator[Row]:
         while True:
             async with self._conn.lock:
-                recs: List[Row] = await self._conn.wait(
+                recs = await self._conn.wait(
                     self._helper._fetch_gen(self, self.itersize)
                 )
             for rec in recs:
index caf380fe9d5d3cbd59446e1f76a64d19d50c281a..66e3b4a4c70354ca7e3a25f95080052ec8b2cd33 100644 (file)
@@ -34,10 +34,10 @@ class Transformer(proto.AdaptContext):
     ) -> Tuple[List[Any], Tuple[int, ...], Sequence[pq.Format]]: ...
     def get_dumper(self, obj: Any, format: Format) -> Dumper: ...
     def load_rows(
-        self, row0: int, row1: int, make_row: proto.RowMaker
+        self, row0: int, row1: int, make_row: proto.RowMaker[proto.Row]
     ) -> List[proto.Row]: ...
     def load_row(
-        self, row: int, make_row: proto.RowMaker
+        self, row: int, make_row: proto.RowMaker[proto.Row]
     ) -> Optional[proto.Row]: ...
     def load_sequence(
         self, record: Sequence[Optional[bytes]]