From db51b86bc34e217f244c0d4815e5873ef5c24d9a Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Thu, 7 Oct 2021 21:47:02 +0200 Subject: [PATCH] Add tests to verify that execute and enter respect cursor subclasses --- tests/test_typing.py | 97 ++++++++++++++++++++++++++++++++------------ 1 file changed, 71 insertions(+), 26 deletions(-) diff --git a/tests/test_typing.py b/tests/test_typing.py index d7df5d5c1..7eebabe06 100644 --- a/tests/test_typing.py +++ b/tests/test_typing.py @@ -19,7 +19,7 @@ import pytest ], ) def test_typing_example(mypy, filename): - cp = mypy.run(filename) + cp = mypy.run_on_file(filename) errors = cp.stdout.decode("utf8", "replace").splitlines() assert not errors assert cp.returncode == 0 @@ -71,9 +71,9 @@ def test_typing_example(mypy, filename): ), ], ) -def test_connection_type(conn, type, mypy, tmpdir): +def test_connection_type(conn, type, mypy): stmts = f"obj = {conn}" - _test_reveal(stmts, type, mypy, tmpdir) + _test_reveal(stmts, type, mypy) @pytest.mark.slow @@ -155,12 +155,12 @@ def test_connection_type(conn, type, mypy, tmpdir): ), ], ) -def test_cursor_type(conn, curs, type, mypy, tmpdir): +def test_cursor_type(conn, curs, type, mypy): stmts = f"""\ conn = {conn} obj = {curs} """ - _test_reveal(stmts, type, mypy, tmpdir) + _test_reveal(stmts, type, mypy) @pytest.mark.slow @@ -183,7 +183,7 @@ obj = {curs} ) @pytest.mark.parametrize("server_side", [False, True]) @pytest.mark.parametrize("conn_class", ["Connection", "AsyncConnection"]) -def test_fetchone_type(conn_class, server_side, curs, type, mypy, tmpdir): +def test_fetchone_type(conn_class, server_side, curs, type, mypy): await_ = "await" if "Async" in conn_class else "" if server_side: curs = curs.replace("(", "(name='foo',", 1) @@ -192,7 +192,7 @@ conn = {await_} psycopg.{conn_class}.connect() curs = {curs} obj = {await_} curs.fetchone() """ - _test_reveal(stmts, type, mypy, tmpdir) + _test_reveal(stmts, type, mypy) @pytest.mark.slow @@ -215,7 +215,7 @@ obj = {await_} curs.fetchone() ) @pytest.mark.parametrize("server_side", [False, True]) @pytest.mark.parametrize("conn_class", ["Connection", "AsyncConnection"]) -def test_iter_type(conn_class, server_side, curs, type, mypy, tmpdir): +def test_iter_type(conn_class, server_side, curs, type, mypy): if "Async" in conn_class: async_ = "async " await_ = "await " @@ -230,7 +230,7 @@ curs = {curs} {async_}for obj in curs: pass """ - _test_reveal(stmts, type, mypy, tmpdir) + _test_reveal(stmts, type, mypy) @pytest.mark.slow @@ -254,9 +254,7 @@ curs = {curs} ) @pytest.mark.parametrize("server_side", [False, True]) @pytest.mark.parametrize("conn_class", ["Connection", "AsyncConnection"]) -def test_fetchsome_type( - conn_class, server_side, curs, type, method, mypy, tmpdir -): +def test_fetchsome_type(conn_class, server_side, curs, type, method, mypy): await_ = "await" if "Async" in conn_class else "" if server_side: curs = curs.replace("(", "(name='foo',", 1) @@ -265,15 +263,58 @@ conn = {await_} psycopg.{conn_class}.connect() curs = {curs} obj = {await_} curs.{method}() """ - _test_reveal(stmts, type, mypy, tmpdir) + _test_reveal(stmts, type, mypy) + + +@pytest.mark.parametrize("server_side", [False, True]) +@pytest.mark.parametrize("conn_class", ["Connection", "AsyncConnection"]) +def test_cur_subclass_execute(mypy, conn_class, server_side): + async_ = "async " if "Async" in conn_class else "" + await_ = "await" if "Async" in conn_class else "" + cur_base_class = "".join( + [ + "Async" if "Async" in conn_class else "", + "Server" if server_side else "", + "Cursor", + ] + ) + cur_name = "'foo'" if server_side else "" + + src = f"""\ +from typing import Any, cast +import psycopg +from psycopg.rows import Row, TupleRow + +class MyCursor(psycopg.{cur_base_class}[Row]): + pass + +{async_}def test() -> None: + conn = {await_} psycopg.{conn_class}.connect() + + cur: MyCursor[TupleRow] + reveal_type(cur) + + cur = cast(MyCursor[TupleRow], conn.cursor({cur_name})) + {async_}with cur as cur2: + reveal_type(cur2) + cur3 = {await_} cur2.execute("") + reveal_type(cur3) +""" + cp = mypy.run_on_source(src) + out = cp.stdout.decode("utf8", "replace").splitlines() + assert len(out) == 3 + types = [mypy.get_revealed(line) for line in out] + assert types[0] == types[1] + assert types[0] == types[2] @pytest.fixture(scope="session") def mypy(tmp_path_factory): cache_dir = tmp_path_factory.mktemp(basename="mypy_cache") + src_dir = tmp_path_factory.mktemp("source") class MypyRunner: - def run(self, filename): + def run_on_file(self, filename): cmdline = f""" mypy --strict @@ -283,10 +324,23 @@ def mypy(tmp_path_factory): cmdline.append(filename) return sp.run(cmdline, stdout=sp.PIPE, stderr=sp.STDOUT) + def run_on_source(self, source): + fn = src_dir / "tmp.py" + with fn.open("w") as f: + f.write(source) + + return self.run_on_file(str(fn)) + + def get_revealed(self, line): + """return the type from an output of reveal_type""" + return re.sub( + r".*Revealed type is (['\"])([^']+)\1.*", r"\2", line + ).replace("*", "") + return MypyRunner() -def _test_reveal(stmts, type, mypy, tmpdir): +def _test_reveal(stmts, type, mypy): ignore = ( "" if type.startswith("Optional") else "# type: ignore[assignment]" ) @@ -320,17 +374,8 @@ async def tmp() -> None: ref: {type} = None {ignore} reveal_type(ref) """ - fn = tmpdir / "tmp.py" - with fn.open("w") as f: - f.write(src) - - cp = mypy.run(str(fn)) + cp = mypy.run_on_source(src) out = cp.stdout.decode("utf8", "replace").splitlines() assert len(out) == 2, "\n".join(out) - got, want = [ - re.sub(r".*Revealed type is (['\"])([^']+)\1.*", r"\2", line).replace( - "*", "" - ) - for line in out - ] + got, want = [mypy.get_revealed(line) for line in out] assert got == want -- 2.47.2