],
)
def test_typing_example(mypy, filename):
- cp = mypy.run(filename)
+ cp = mypy.run_on_file(filename)
errors = cp.stdout.decode("utf8", "replace").splitlines()
assert not errors
assert cp.returncode == 0
),
],
)
-def test_connection_type(conn, type, mypy, tmpdir):
+def test_connection_type(conn, type, mypy):
stmts = f"obj = {conn}"
- _test_reveal(stmts, type, mypy, tmpdir)
+ _test_reveal(stmts, type, mypy)
@pytest.mark.slow
),
],
)
-def test_cursor_type(conn, curs, type, mypy, tmpdir):
+def test_cursor_type(conn, curs, type, mypy):
stmts = f"""\
conn = {conn}
obj = {curs}
"""
- _test_reveal(stmts, type, mypy, tmpdir)
+ _test_reveal(stmts, type, mypy)
@pytest.mark.slow
)
@pytest.mark.parametrize("server_side", [False, True])
@pytest.mark.parametrize("conn_class", ["Connection", "AsyncConnection"])
-def test_fetchone_type(conn_class, server_side, curs, type, mypy, tmpdir):
+def test_fetchone_type(conn_class, server_side, curs, type, mypy):
await_ = "await" if "Async" in conn_class else ""
if server_side:
curs = curs.replace("(", "(name='foo',", 1)
curs = {curs}
obj = {await_} curs.fetchone()
"""
- _test_reveal(stmts, type, mypy, tmpdir)
+ _test_reveal(stmts, type, mypy)
@pytest.mark.slow
)
@pytest.mark.parametrize("server_side", [False, True])
@pytest.mark.parametrize("conn_class", ["Connection", "AsyncConnection"])
-def test_iter_type(conn_class, server_side, curs, type, mypy, tmpdir):
+def test_iter_type(conn_class, server_side, curs, type, mypy):
if "Async" in conn_class:
async_ = "async "
await_ = "await "
{async_}for obj in curs:
pass
"""
- _test_reveal(stmts, type, mypy, tmpdir)
+ _test_reveal(stmts, type, mypy)
@pytest.mark.slow
)
@pytest.mark.parametrize("server_side", [False, True])
@pytest.mark.parametrize("conn_class", ["Connection", "AsyncConnection"])
-def test_fetchsome_type(
- conn_class, server_side, curs, type, method, mypy, tmpdir
-):
+def test_fetchsome_type(conn_class, server_side, curs, type, method, mypy):
await_ = "await" if "Async" in conn_class else ""
if server_side:
curs = curs.replace("(", "(name='foo',", 1)
curs = {curs}
obj = {await_} curs.{method}()
"""
- _test_reveal(stmts, type, mypy, tmpdir)
+ _test_reveal(stmts, type, mypy)
+
+
+@pytest.mark.parametrize("server_side", [False, True])
+@pytest.mark.parametrize("conn_class", ["Connection", "AsyncConnection"])
+def test_cur_subclass_execute(mypy, conn_class, server_side):
+ async_ = "async " if "Async" in conn_class else ""
+ await_ = "await" if "Async" in conn_class else ""
+ cur_base_class = "".join(
+ [
+ "Async" if "Async" in conn_class else "",
+ "Server" if server_side else "",
+ "Cursor",
+ ]
+ )
+ cur_name = "'foo'" if server_side else ""
+
+ src = f"""\
+from typing import Any, cast
+import psycopg
+from psycopg.rows import Row, TupleRow
+
+class MyCursor(psycopg.{cur_base_class}[Row]):
+ pass
+
+{async_}def test() -> None:
+ conn = {await_} psycopg.{conn_class}.connect()
+
+ cur: MyCursor[TupleRow]
+ reveal_type(cur)
+
+ cur = cast(MyCursor[TupleRow], conn.cursor({cur_name}))
+ {async_}with cur as cur2:
+ reveal_type(cur2)
+ cur3 = {await_} cur2.execute("")
+ reveal_type(cur3)
+"""
+ cp = mypy.run_on_source(src)
+ out = cp.stdout.decode("utf8", "replace").splitlines()
+ assert len(out) == 3
+ types = [mypy.get_revealed(line) for line in out]
+ assert types[0] == types[1]
+ assert types[0] == types[2]
@pytest.fixture(scope="session")
def mypy(tmp_path_factory):
cache_dir = tmp_path_factory.mktemp(basename="mypy_cache")
+ src_dir = tmp_path_factory.mktemp("source")
class MypyRunner:
- def run(self, filename):
+ def run_on_file(self, filename):
cmdline = f"""
mypy
--strict
cmdline.append(filename)
return sp.run(cmdline, stdout=sp.PIPE, stderr=sp.STDOUT)
+ def run_on_source(self, source):
+ fn = src_dir / "tmp.py"
+ with fn.open("w") as f:
+ f.write(source)
+
+ return self.run_on_file(str(fn))
+
+ def get_revealed(self, line):
+ """return the type from an output of reveal_type"""
+ return re.sub(
+ r".*Revealed type is (['\"])([^']+)\1.*", r"\2", line
+ ).replace("*", "")
+
return MypyRunner()
-def _test_reveal(stmts, type, mypy, tmpdir):
+def _test_reveal(stmts, type, mypy):
ignore = (
"" if type.startswith("Optional") else "# type: ignore[assignment]"
)
ref: {type} = None {ignore}
reveal_type(ref)
"""
- fn = tmpdir / "tmp.py"
- with fn.open("w") as f:
- f.write(src)
-
- cp = mypy.run(str(fn))
+ cp = mypy.run_on_source(src)
out = cp.stdout.decode("utf8", "replace").splitlines()
assert len(out) == 2, "\n".join(out)
- got, want = [
- re.sub(r".*Revealed type is (['\"])([^']+)\1.*", r"\2", line).replace(
- "*", ""
- )
- for line in out
- ]
+ got, want = [mypy.get_revealed(line) for line in out]
assert got == want