]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Add typing test for cursor iteration
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 29 Apr 2021 17:22:27 +0000 (19:22 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 29 Apr 2021 17:22:27 +0000 (19:22 +0200)
tests/test_typing.py

index 03c0f366a9149a7d7d45f3c5ce7bc4c0d83247b7..c2df4c05075f53c9ca3123d9b9c559984611a8f5 100644 (file)
@@ -175,6 +175,44 @@ obj = {await_} curs.fetchone()
     _test_reveal(stmts, type, mypy, tmpdir)
 
 
+@pytest.mark.slow
+@pytest.mark.parametrize(
+    "curs, type",
+    [
+        (
+            "conn.cursor()",
+            "Tuple[Any, ...]",
+        ),
+        (
+            "conn.cursor(row_factory=rows.dict_row)",
+            "Dict[str, Any]",
+        ),
+        (
+            "conn.cursor(row_factory=thing_row)",
+            "Thing",
+        ),
+    ],
+)
+@pytest.mark.parametrize("server_side", [False, True])
+@pytest.mark.parametrize("conn_class", ["Connection", "AsyncConnection"])
+def test_iter_type(conn_class, server_side, curs, type, mypy, tmpdir):
+    if "Async" in conn_class:
+        async_ = "async "
+        await_ = "await "
+    else:
+        async_ = await_ = ""
+
+    if server_side:
+        curs = curs.replace("(", "(name='foo',", 1)
+    stmts = f"""\
+conn = {await_}psycopg3.{conn_class}.connect()
+curs = {curs}
+{async_}for obj in curs:
+    pass
+"""
+    _test_reveal(stmts, type, mypy, tmpdir)
+
+
 @pytest.mark.slow
 @pytest.mark.parametrize("method", ["fetchmany", "fetchall"])
 @pytest.mark.parametrize(