]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Add constraints to the Row type variable
authorDenis Laxalde <denis.laxalde@dalibo.com>
Thu, 11 Feb 2021 11:22:32 +0000 (12:22 +0100)
committerDenis Laxalde <denis.laxalde@dalibo.com>
Thu, 11 Feb 2021 15:55:55 +0000 (16:55 +0100)
We constrain Row to be either a Tuple[Any, ...] (for the case no row
factory is in use) or Any (otherwise).

This avoids the artificial cast() in cursor code and will make changes
in followup commits much easier.

psycopg3/psycopg3/cursor.py
psycopg3/psycopg3/proto.py

index 171318936b4a04c99e835992084c19263b490bab..e4cc116f322c71a0013395ed0e30519e86f90bdf 100644 (file)
@@ -7,7 +7,7 @@ psycopg3 cursor objects
 import sys
 from types import TracebackType
 from typing import Any, AsyncIterator, Callable, Generic, Iterator, List
-from typing import Optional, NoReturn, Sequence, Type, TYPE_CHECKING, cast
+from typing import Optional, NoReturn, Sequence, Type, TYPE_CHECKING
 from contextlib import contextmanager
 
 from . import pq
@@ -495,7 +495,7 @@ class Cursor(BaseCursor["Connection"]):
             while self._conn.wait(self._stream_fetchone_gen()):
                 rec = self._tx.load_row(0)
                 assert rec is not None
-                yield self._make_row(rec) if self._make_row else cast(Row, rec)
+                yield self._make_row(rec) if self._make_row else rec
 
     def fetchone(self) -> Optional[Row]:
         """
@@ -507,9 +507,7 @@ class Cursor(BaseCursor["Connection"]):
         record = self._tx.load_row(self._pos)
         if record is not None:
             self._pos += 1
-            return (
-                self._make_row(record) if self._make_row else cast(Row, record)
-            )
+            return self._make_row(record) if self._make_row else record
         return record
 
     def fetchmany(self, size: int = 0) -> Sequence[Row]:
@@ -529,7 +527,7 @@ class Cursor(BaseCursor["Connection"]):
         self._pos += len(records)
         if self._make_row:
             return list(map(self._make_row, records))
-        return cast(Sequence[Row], records)
+        return records
 
     def fetchall(self) -> Sequence[Row]:
         """
@@ -541,7 +539,7 @@ class Cursor(BaseCursor["Connection"]):
         self._pos += self.pgresult.ntuples
         if self._make_row:
             return list(map(self._make_row, records))
-        return cast(Sequence[Row], records)
+        return records
 
     def __iter__(self) -> Iterator[Row]:
         self._check_result()
@@ -553,7 +551,7 @@ class Cursor(BaseCursor["Connection"]):
             if row is None:
                 break
             self._pos += 1
-            yield self._make_row(row) if self._make_row else cast(Row, row)
+            yield self._make_row(row) if self._make_row else row
 
     @contextmanager
     def copy(self, statement: Query) -> Iterator[Copy]:
@@ -612,14 +610,14 @@ class AsyncCursor(BaseCursor["AsyncConnection"]):
             while await self._conn.wait(self._stream_fetchone_gen()):
                 rec = self._tx.load_row(0)
                 assert rec is not None
-                yield self._make_row(rec) if self._make_row else cast(Row, rec)
+                yield self._make_row(rec) if self._make_row else rec
 
     async def fetchone(self) -> Optional[Row]:
         self._check_result()
         rv = self._tx.load_row(self._pos)
         if rv is not None:
             self._pos += 1
-            return self._make_row(rv) if self._make_row else cast(Row, rv)
+            return self._make_row(rv) if self._make_row else rv
         return rv
 
     async def fetchmany(self, size: int = 0) -> List[Row]:
@@ -634,7 +632,7 @@ class AsyncCursor(BaseCursor["AsyncConnection"]):
         self._pos += len(records)
         if self._make_row:
             return list(map(self._make_row, records))
-        return cast(List[Row], records)
+        return records
 
     async def fetchall(self) -> List[Row]:
         self._check_result()
@@ -643,7 +641,7 @@ class AsyncCursor(BaseCursor["AsyncConnection"]):
         self._pos += self.pgresult.ntuples
         if self._make_row:
             return list(map(self._make_row, records))
-        return cast(List[Row], records)
+        return records
 
     async def __aiter__(self) -> AsyncIterator[Row]:
         self._check_result()
@@ -655,7 +653,7 @@ class AsyncCursor(BaseCursor["AsyncConnection"]):
             if row is None:
                 break
             self._pos += 1
-            yield self._make_row(row) if self._make_row else cast(Row, row)
+            yield self._make_row(row) if self._make_row else row
 
     @asynccontextmanager
     async def copy(self, statement: Query) -> AsyncIterator[AsyncCopy]:
index 7edaf9e5c7b4523d4eb4a4ca01b5f9938d73c96b..517c782ad2e84ef2549c441bc38789ad0360ac63 100644 (file)
@@ -120,7 +120,7 @@ class Transformer(Protocol):
 
 # Row factories
 
-Row = TypeVar("Row")
+Row = TypeVar("Row", Tuple[Any, ...], Any)
 
 
 class RowMaker(Protocol):