]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Fix typing of server-side cursors fetchmany/fetchall
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 29 Apr 2021 17:06:35 +0000 (19:06 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 29 Apr 2021 17:14:47 +0000 (19:14 +0200)
psycopg3/psycopg3/server_cursor.py
tests/test_typing.py

index 0723b991a8b62d05f1e9beb0e2d25c4dd2513d52..099563bac64612c89999ea601c0c3327993c84d0 100644 (file)
@@ -251,7 +251,7 @@ class ServerCursor(BaseCursor["Connection[Any]", Row]):
         else:
             return None
 
-    def fetchmany(self, size: int = 0) -> Sequence[Row]:
+    def fetchmany(self, size: int = 0) -> List[Row]:
         if not size:
             size = self.arraysize
         with self._conn.lock:
@@ -259,7 +259,7 @@ class ServerCursor(BaseCursor["Connection[Any]", Row]):
         self._pos += len(recs)
         return recs
 
-    def fetchall(self) -> Sequence[Row]:
+    def fetchall(self) -> List[Row]:
         with self._conn.lock:
             recs = self._conn.wait(self._helper._fetch_gen(self, None))
         self._pos += len(recs)
@@ -368,7 +368,7 @@ class AsyncServerCursor(BaseCursor["AsyncConnection[Any]", Row]):
         else:
             return None
 
-    async def fetchmany(self, size: int = 0) -> Sequence[Row]:
+    async def fetchmany(self, size: int = 0) -> List[Row]:
         if not size:
             size = self.arraysize
         async with self._conn.lock:
@@ -376,7 +376,7 @@ class AsyncServerCursor(BaseCursor["AsyncConnection[Any]", Row]):
         self._pos += len(recs)
         return recs
 
-    async def fetchall(self) -> Sequence[Row]:
+    async def fetchall(self) -> List[Row]:
         async with self._conn.lock:
             recs = await self._conn.wait(self._helper._fetch_gen(self, None))
         self._pos += len(recs)
index 70b7006845cbde774fdca8cdc3502519375c1e56..03c0f366a9149a7d7d45f3c5ce7bc4c0d83247b7 100644 (file)
@@ -175,6 +175,41 @@ obj = {await_} curs.fetchone()
     _test_reveal(stmts, type, mypy, tmpdir)
 
 
+@pytest.mark.slow
+@pytest.mark.parametrize("method", ["fetchmany", "fetchall"])
+@pytest.mark.parametrize(
+    "curs, type",
+    [
+        (
+            "conn.cursor()",
+            "List[Tuple[Any, ...]]",
+        ),
+        (
+            "conn.cursor(row_factory=rows.dict_row)",
+            "List[Dict[str, Any]]",
+        ),
+        (
+            "conn.cursor(row_factory=thing_row)",
+            "List[Thing]",
+        ),
+    ],
+)
+@pytest.mark.parametrize("server_side", [False, True])
+@pytest.mark.parametrize("conn_class", ["Connection", "AsyncConnection"])
+def test_fetchsome_type(
+    conn_class, server_side, curs, type, method, mypy, tmpdir
+):
+    await_ = "await" if "Async" in conn_class else ""
+    if server_side:
+        curs = curs.replace("(", "(name='foo',", 1)
+    stmts = f"""\
+conn = {await_} psycopg3.{conn_class}.connect()
+curs = {curs}
+obj = {await_} curs.{method}()
+"""
+    _test_reveal(stmts, type, mypy, tmpdir)
+
+
 @pytest.fixture(scope="session")
 def mypy(tmp_path_factory):
     cache_dir = tmp_path_factory.mktemp(basename="mypy_cache")
@@ -200,7 +235,7 @@ def _test_reveal(stmts, type, mypy, tmpdir):
     stmts = "\n".join(f"    {line}" for line in stmts.splitlines())
 
     src = f"""\
-from typing import Any, Callable, Dict, NamedTuple, Optional, Sequence, Tuple
+from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Tuple
 import psycopg3
 from psycopg3 import rows