From d7ee27f77b41338c07618f9bae22aba609aa7a71 Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Thu, 29 Apr 2021 19:06:35 +0200 Subject: [PATCH] Fix typing of server-side cursors fetchmany/fetchall --- psycopg3/psycopg3/server_cursor.py | 8 +++---- tests/test_typing.py | 37 +++++++++++++++++++++++++++++- 2 files changed, 40 insertions(+), 5 deletions(-) diff --git a/psycopg3/psycopg3/server_cursor.py b/psycopg3/psycopg3/server_cursor.py index 0723b991a..099563bac 100644 --- a/psycopg3/psycopg3/server_cursor.py +++ b/psycopg3/psycopg3/server_cursor.py @@ -251,7 +251,7 @@ class ServerCursor(BaseCursor["Connection[Any]", Row]): else: return None - def fetchmany(self, size: int = 0) -> Sequence[Row]: + def fetchmany(self, size: int = 0) -> List[Row]: if not size: size = self.arraysize with self._conn.lock: @@ -259,7 +259,7 @@ class ServerCursor(BaseCursor["Connection[Any]", Row]): self._pos += len(recs) return recs - def fetchall(self) -> Sequence[Row]: + def fetchall(self) -> List[Row]: with self._conn.lock: recs = self._conn.wait(self._helper._fetch_gen(self, None)) self._pos += len(recs) @@ -368,7 +368,7 @@ class AsyncServerCursor(BaseCursor["AsyncConnection[Any]", Row]): else: return None - async def fetchmany(self, size: int = 0) -> Sequence[Row]: + async def fetchmany(self, size: int = 0) -> List[Row]: if not size: size = self.arraysize async with self._conn.lock: @@ -376,7 +376,7 @@ class AsyncServerCursor(BaseCursor["AsyncConnection[Any]", Row]): self._pos += len(recs) return recs - async def fetchall(self) -> Sequence[Row]: + async def fetchall(self) -> List[Row]: async with self._conn.lock: recs = await self._conn.wait(self._helper._fetch_gen(self, None)) self._pos += len(recs) diff --git a/tests/test_typing.py b/tests/test_typing.py index 70b700684..03c0f366a 100644 --- a/tests/test_typing.py +++ b/tests/test_typing.py @@ -175,6 +175,41 @@ obj = {await_} curs.fetchone() _test_reveal(stmts, type, mypy, tmpdir) +@pytest.mark.slow +@pytest.mark.parametrize("method", ["fetchmany", "fetchall"]) +@pytest.mark.parametrize( + "curs, type", + [ + ( + "conn.cursor()", + "List[Tuple[Any, ...]]", + ), + ( + "conn.cursor(row_factory=rows.dict_row)", + "List[Dict[str, Any]]", + ), + ( + "conn.cursor(row_factory=thing_row)", + "List[Thing]", + ), + ], +) +@pytest.mark.parametrize("server_side", [False, True]) +@pytest.mark.parametrize("conn_class", ["Connection", "AsyncConnection"]) +def test_fetchsome_type( + conn_class, server_side, curs, type, method, mypy, tmpdir +): + await_ = "await" if "Async" in conn_class else "" + if server_side: + curs = curs.replace("(", "(name='foo',", 1) + stmts = f"""\ +conn = {await_} psycopg3.{conn_class}.connect() +curs = {curs} +obj = {await_} curs.{method}() +""" + _test_reveal(stmts, type, mypy, tmpdir) + + @pytest.fixture(scope="session") def mypy(tmp_path_factory): cache_dir = tmp_path_factory.mktemp(basename="mypy_cache") @@ -200,7 +235,7 @@ def _test_reveal(stmts, type, mypy, tmpdir): stmts = "\n".join(f" {line}" for line in stmts.splitlines()) src = f"""\ -from typing import Any, Callable, Dict, NamedTuple, Optional, Sequence, Tuple +from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Tuple import psycopg3 from psycopg3 import rows -- 2.47.2