]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
refactor(test): better parametrization of fetch type tests
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 3 May 2025 01:31:27 +0000 (03:31 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 4 May 2025 17:34:42 +0000 (19:34 +0200)
tests/test_typing.py

index 2ad397189fdb988eb7884b9ed25df59bc5f9f744..0537eb82db5746e5d6bbeff0af9033fdf47f2cd5 100644 (file)
@@ -231,103 +231,52 @@ obj = {curs}
 
 
 @pytest.mark.parametrize(
-    "curs, type",
+    "factory, type",
     [
-        (
-            "conn.cursor()",
-            "Tuple[Any, ...] | None",
-        ),
-        (
-            "conn.cursor(row_factory=rows.dict_row)",
-            "Dict[str, Any] | None",
-        ),
-        (
-            "conn.cursor(row_factory=thing_row)",
-            "Thing | None",
-        ),
+        ("", "Tuple[Any, ...]"),
+        ("row_factory=rows.tuple_row", "Tuple[Any, ...]"),
+        ("row_factory=rows.dict_row", "Dict[str, Any]"),
+        ("row_factory=thing_row", "Thing"),
     ],
 )
 @pytest.mark.parametrize("server_side", [False, True])
 @pytest.mark.parametrize("conn_class", ["Connection", "AsyncConnection"])
-def test_fetchone_type(conn_class, server_side, curs, type, mypy):
-    await_ = "await" if "Async" in conn_class else ""
-    if server_side:
-        curs = curs.replace("(", "(name='foo',", 1)
-    stmts = f"""\
-conn = {await_} psycopg.{conn_class}.connect()
-curs = {curs}
-obj = {await_} curs.fetchone()
-"""
-    _test_reveal(stmts, type, mypy)
-
-
 @pytest.mark.parametrize(
-    "curs, type",
+    "fetch, typemod",
     [
-        (
-            "conn.cursor()",
-            "Tuple[Any, ...]",
-        ),
-        (
-            "conn.cursor(row_factory=rows.dict_row)",
-            "Dict[str, Any]",
-        ),
-        (
-            "conn.cursor(row_factory=thing_row)",
-            "Thing",
-        ),
+        ("one", "{type} | None"),
+        ("many", "list[{type}]"),
+        ("all", "list[{type}]"),
+        ("iter", "{type}"),
     ],
 )
-@pytest.mark.parametrize("server_side", [False, True])
-@pytest.mark.parametrize("conn_class", ["Connection", "AsyncConnection"])
-def test_iter_type(conn_class, server_side, curs, type, mypy):
+def test_fetch_type(conn_class, server_side, factory, type, fetch, typemod, mypy):
     if "Async" in conn_class:
         async_ = "async "
         await_ = "await "
     else:
         async_ = await_ = ""
 
+    curs = f"conn.cursor({factory})"
     if server_side:
         curs = curs.replace("(", "(name='foo',", 1)
-    stmts = f"""\
-conn = {await_}psycopg.{conn_class}.connect()
-curs = {curs}
-{async_}for obj in curs:
-    pass
-"""
-    _test_reveal(stmts, type, mypy)
-
 
-@pytest.mark.parametrize("method", ["fetchmany", "fetchall"])
-@pytest.mark.parametrize(
-    "curs, type",
-    [
-        (
-            "conn.cursor()",
-            "list[Tuple[Any, ...]]",
-        ),
-        (
-            "conn.cursor(row_factory=rows.dict_row)",
-            "list[Dict[str, Any]]",
-        ),
-        (
-            "conn.cursor(row_factory=thing_row)",
-            "list[Thing]",
-        ),
-    ],
-)
-@pytest.mark.parametrize("server_side", [False, True])
-@pytest.mark.parametrize("conn_class", ["Connection", "AsyncConnection"])
-def test_fetchsome_type(conn_class, server_side, curs, type, method, mypy):
-    await_ = "await" if "Async" in conn_class else ""
-    if server_side:
-        curs = curs.replace("(", "(name='foo',", 1)
     stmts = f"""\
 conn = {await_} psycopg.{conn_class}.connect()
 curs = {curs}
-obj = {await_} curs.{method}()
 """
-    _test_reveal(stmts, type, mypy)
+    if fetch == "one":
+        stmts += f"obj = {await_} curs.fetchone()"
+    elif fetch == "many":
+        stmts += f"obj = {await_} curs.fetchmany(5)"
+    elif fetch == "all":
+        stmts += f"obj = {await_} curs.fetchall()"
+    elif fetch == "iter":
+        stmts += f"{async_}for obj in curs: pass"
+    else:
+        pytest.fail(f"unexpected fetch: {fetch}")
+
+    _test_reveal(stmts, typemod.format(type=type), mypy)
 
 
 @pytest.mark.parametrize("server_side", [False, True])