@pytest.mark.parametrize(
- "curs, type",
+ "factory, type",
[
- (
- "conn.cursor()",
- "Tuple[Any, ...] | None",
- ),
- (
- "conn.cursor(row_factory=rows.dict_row)",
- "Dict[str, Any] | None",
- ),
- (
- "conn.cursor(row_factory=thing_row)",
- "Thing | None",
- ),
+ ("", "Tuple[Any, ...]"),
+ ("row_factory=rows.tuple_row", "Tuple[Any, ...]"),
+ ("row_factory=rows.dict_row", "Dict[str, Any]"),
+ ("row_factory=thing_row", "Thing"),
],
)
@pytest.mark.parametrize("server_side", [False, True])
@pytest.mark.parametrize("conn_class", ["Connection", "AsyncConnection"])
-def test_fetchone_type(conn_class, server_side, curs, type, mypy):
- await_ = "await" if "Async" in conn_class else ""
- if server_side:
- curs = curs.replace("(", "(name='foo',", 1)
- stmts = f"""\
-conn = {await_} psycopg.{conn_class}.connect()
-curs = {curs}
-obj = {await_} curs.fetchone()
-"""
- _test_reveal(stmts, type, mypy)
-
-
@pytest.mark.parametrize(
- "curs, type",
+ "fetch, typemod",
[
- (
- "conn.cursor()",
- "Tuple[Any, ...]",
- ),
- (
- "conn.cursor(row_factory=rows.dict_row)",
- "Dict[str, Any]",
- ),
- (
- "conn.cursor(row_factory=thing_row)",
- "Thing",
- ),
+ ("one", "{type} | None"),
+ ("many", "list[{type}]"),
+ ("all", "list[{type}]"),
+ ("iter", "{type}"),
],
)
-@pytest.mark.parametrize("server_side", [False, True])
-@pytest.mark.parametrize("conn_class", ["Connection", "AsyncConnection"])
-def test_iter_type(conn_class, server_side, curs, type, mypy):
+def test_fetch_type(conn_class, server_side, factory, type, fetch, typemod, mypy):
if "Async" in conn_class:
async_ = "async "
await_ = "await "
else:
async_ = await_ = ""
+ curs = f"conn.cursor({factory})"
if server_side:
curs = curs.replace("(", "(name='foo',", 1)
- stmts = f"""\
-conn = {await_}psycopg.{conn_class}.connect()
-curs = {curs}
-{async_}for obj in curs:
- pass
-"""
- _test_reveal(stmts, type, mypy)
-
-@pytest.mark.parametrize("method", ["fetchmany", "fetchall"])
-@pytest.mark.parametrize(
- "curs, type",
- [
- (
- "conn.cursor()",
- "list[Tuple[Any, ...]]",
- ),
- (
- "conn.cursor(row_factory=rows.dict_row)",
- "list[Dict[str, Any]]",
- ),
- (
- "conn.cursor(row_factory=thing_row)",
- "list[Thing]",
- ),
- ],
-)
-@pytest.mark.parametrize("server_side", [False, True])
-@pytest.mark.parametrize("conn_class", ["Connection", "AsyncConnection"])
-def test_fetchsome_type(conn_class, server_side, curs, type, method, mypy):
- await_ = "await" if "Async" in conn_class else ""
- if server_side:
- curs = curs.replace("(", "(name='foo',", 1)
stmts = f"""\
conn = {await_} psycopg.{conn_class}.connect()
curs = {curs}
-obj = {await_} curs.{method}()
"""
- _test_reveal(stmts, type, mypy)
+ if fetch == "one":
+ stmts += f"obj = {await_} curs.fetchone()"
+ elif fetch == "many":
+ stmts += f"obj = {await_} curs.fetchmany(5)"
+ elif fetch == "all":
+ stmts += f"obj = {await_} curs.fetchall()"
+ elif fetch == "iter":
+ stmts += f"{async_}for obj in curs: pass"
+ else:
+ pytest.fail(f"unexpected fetch: {fetch}")
+
+ _test_reveal(stmts, typemod.format(type=type), mypy)
@pytest.mark.parametrize("server_side", [False, True])