From: Denis Laxalde Date: Thu, 11 Feb 2021 11:22:32 +0000 (+0100) Subject: Add constraints to the Row type variable X-Git-Tag: 3.0.dev0~106^2~25 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=39287db88368b3ae0370a9049781d3e8ec5d9f41;p=thirdparty%2Fpsycopg.git Add constraints to the Row type variable We constrain Row to be either a Tuple[Any, ...] (for the case no row factory is in use) or Any (otherwise). This avoids the artificial cast() in cursor code and will make changes in followup commits much easier. --- diff --git a/psycopg3/psycopg3/cursor.py b/psycopg3/psycopg3/cursor.py index 171318936..e4cc116f3 100644 --- a/psycopg3/psycopg3/cursor.py +++ b/psycopg3/psycopg3/cursor.py @@ -7,7 +7,7 @@ psycopg3 cursor objects import sys from types import TracebackType from typing import Any, AsyncIterator, Callable, Generic, Iterator, List -from typing import Optional, NoReturn, Sequence, Type, TYPE_CHECKING, cast +from typing import Optional, NoReturn, Sequence, Type, TYPE_CHECKING from contextlib import contextmanager from . import pq @@ -495,7 +495,7 @@ class Cursor(BaseCursor["Connection"]): while self._conn.wait(self._stream_fetchone_gen()): rec = self._tx.load_row(0) assert rec is not None - yield self._make_row(rec) if self._make_row else cast(Row, rec) + yield self._make_row(rec) if self._make_row else rec def fetchone(self) -> Optional[Row]: """ @@ -507,9 +507,7 @@ class Cursor(BaseCursor["Connection"]): record = self._tx.load_row(self._pos) if record is not None: self._pos += 1 - return ( - self._make_row(record) if self._make_row else cast(Row, record) - ) + return self._make_row(record) if self._make_row else record return record def fetchmany(self, size: int = 0) -> Sequence[Row]: @@ -529,7 +527,7 @@ class Cursor(BaseCursor["Connection"]): self._pos += len(records) if self._make_row: return list(map(self._make_row, records)) - return cast(Sequence[Row], records) + return records def fetchall(self) -> Sequence[Row]: """ @@ -541,7 +539,7 @@ class Cursor(BaseCursor["Connection"]): self._pos += self.pgresult.ntuples if self._make_row: return list(map(self._make_row, records)) - return cast(Sequence[Row], records) + return records def __iter__(self) -> Iterator[Row]: self._check_result() @@ -553,7 +551,7 @@ class Cursor(BaseCursor["Connection"]): if row is None: break self._pos += 1 - yield self._make_row(row) if self._make_row else cast(Row, row) + yield self._make_row(row) if self._make_row else row @contextmanager def copy(self, statement: Query) -> Iterator[Copy]: @@ -612,14 +610,14 @@ class AsyncCursor(BaseCursor["AsyncConnection"]): while await self._conn.wait(self._stream_fetchone_gen()): rec = self._tx.load_row(0) assert rec is not None - yield self._make_row(rec) if self._make_row else cast(Row, rec) + yield self._make_row(rec) if self._make_row else rec async def fetchone(self) -> Optional[Row]: self._check_result() rv = self._tx.load_row(self._pos) if rv is not None: self._pos += 1 - return self._make_row(rv) if self._make_row else cast(Row, rv) + return self._make_row(rv) if self._make_row else rv return rv async def fetchmany(self, size: int = 0) -> List[Row]: @@ -634,7 +632,7 @@ class AsyncCursor(BaseCursor["AsyncConnection"]): self._pos += len(records) if self._make_row: return list(map(self._make_row, records)) - return cast(List[Row], records) + return records async def fetchall(self) -> List[Row]: self._check_result() @@ -643,7 +641,7 @@ class AsyncCursor(BaseCursor["AsyncConnection"]): self._pos += self.pgresult.ntuples if self._make_row: return list(map(self._make_row, records)) - return cast(List[Row], records) + return records async def __aiter__(self) -> AsyncIterator[Row]: self._check_result() @@ -655,7 +653,7 @@ class AsyncCursor(BaseCursor["AsyncConnection"]): if row is None: break self._pos += 1 - yield self._make_row(row) if self._make_row else cast(Row, row) + yield self._make_row(row) if self._make_row else row @asynccontextmanager async def copy(self, statement: Query) -> AsyncIterator[AsyncCopy]: diff --git a/psycopg3/psycopg3/proto.py b/psycopg3/psycopg3/proto.py index 7edaf9e5c..517c782ad 100644 --- a/psycopg3/psycopg3/proto.py +++ b/psycopg3/psycopg3/proto.py @@ -120,7 +120,7 @@ class Transformer(Protocol): # Row factories -Row = TypeVar("Row") +Row = TypeVar("Row", Tuple[Any, ...], Any) class RowMaker(Protocol):