]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Add tests to verify that execute and enter respect cursor subclasses
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 7 Oct 2021 19:47:02 +0000 (21:47 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 7 Oct 2021 20:45:33 +0000 (22:45 +0200)
tests/test_typing.py

index d7df5d5c125d1d5d58e9d0f5dc5abc15903ba3b8..7eebabe0654bc53a3f61178dc6a0317444ca44fa 100644 (file)
@@ -19,7 +19,7 @@ import pytest
     ],
 )
 def test_typing_example(mypy, filename):
-    cp = mypy.run(filename)
+    cp = mypy.run_on_file(filename)
     errors = cp.stdout.decode("utf8", "replace").splitlines()
     assert not errors
     assert cp.returncode == 0
@@ -71,9 +71,9 @@ def test_typing_example(mypy, filename):
         ),
     ],
 )
-def test_connection_type(conn, type, mypy, tmpdir):
+def test_connection_type(conn, type, mypy):
     stmts = f"obj = {conn}"
-    _test_reveal(stmts, type, mypy, tmpdir)
+    _test_reveal(stmts, type, mypy)
 
 
 @pytest.mark.slow
@@ -155,12 +155,12 @@ def test_connection_type(conn, type, mypy, tmpdir):
         ),
     ],
 )
-def test_cursor_type(conn, curs, type, mypy, tmpdir):
+def test_cursor_type(conn, curs, type, mypy):
     stmts = f"""\
 conn = {conn}
 obj = {curs}
 """
-    _test_reveal(stmts, type, mypy, tmpdir)
+    _test_reveal(stmts, type, mypy)
 
 
 @pytest.mark.slow
@@ -183,7 +183,7 @@ obj = {curs}
 )
 @pytest.mark.parametrize("server_side", [False, True])
 @pytest.mark.parametrize("conn_class", ["Connection", "AsyncConnection"])
-def test_fetchone_type(conn_class, server_side, curs, type, mypy, tmpdir):
+def test_fetchone_type(conn_class, server_side, curs, type, mypy):
     await_ = "await" if "Async" in conn_class else ""
     if server_side:
         curs = curs.replace("(", "(name='foo',", 1)
@@ -192,7 +192,7 @@ conn = {await_} psycopg.{conn_class}.connect()
 curs = {curs}
 obj = {await_} curs.fetchone()
 """
-    _test_reveal(stmts, type, mypy, tmpdir)
+    _test_reveal(stmts, type, mypy)
 
 
 @pytest.mark.slow
@@ -215,7 +215,7 @@ obj = {await_} curs.fetchone()
 )
 @pytest.mark.parametrize("server_side", [False, True])
 @pytest.mark.parametrize("conn_class", ["Connection", "AsyncConnection"])
-def test_iter_type(conn_class, server_side, curs, type, mypy, tmpdir):
+def test_iter_type(conn_class, server_side, curs, type, mypy):
     if "Async" in conn_class:
         async_ = "async "
         await_ = "await "
@@ -230,7 +230,7 @@ curs = {curs}
 {async_}for obj in curs:
     pass
 """
-    _test_reveal(stmts, type, mypy, tmpdir)
+    _test_reveal(stmts, type, mypy)
 
 
 @pytest.mark.slow
@@ -254,9 +254,7 @@ curs = {curs}
 )
 @pytest.mark.parametrize("server_side", [False, True])
 @pytest.mark.parametrize("conn_class", ["Connection", "AsyncConnection"])
-def test_fetchsome_type(
-    conn_class, server_side, curs, type, method, mypy, tmpdir
-):
+def test_fetchsome_type(conn_class, server_side, curs, type, method, mypy):
     await_ = "await" if "Async" in conn_class else ""
     if server_side:
         curs = curs.replace("(", "(name='foo',", 1)
@@ -265,15 +263,58 @@ conn = {await_} psycopg.{conn_class}.connect()
 curs = {curs}
 obj = {await_} curs.{method}()
 """
-    _test_reveal(stmts, type, mypy, tmpdir)
+    _test_reveal(stmts, type, mypy)
+
+
+@pytest.mark.parametrize("server_side", [False, True])
+@pytest.mark.parametrize("conn_class", ["Connection", "AsyncConnection"])
+def test_cur_subclass_execute(mypy, conn_class, server_side):
+    async_ = "async " if "Async" in conn_class else ""
+    await_ = "await" if "Async" in conn_class else ""
+    cur_base_class = "".join(
+        [
+            "Async" if "Async" in conn_class else "",
+            "Server" if server_side else "",
+            "Cursor",
+        ]
+    )
+    cur_name = "'foo'" if server_side else ""
+
+    src = f"""\
+from typing import Any, cast
+import psycopg
+from psycopg.rows import Row, TupleRow
+
+class MyCursor(psycopg.{cur_base_class}[Row]):
+    pass
+
+{async_}def test() -> None:
+    conn = {await_} psycopg.{conn_class}.connect()
+
+    cur: MyCursor[TupleRow]
+    reveal_type(cur)
+
+    cur = cast(MyCursor[TupleRow], conn.cursor({cur_name}))
+    {async_}with cur as cur2:
+        reveal_type(cur2)
+        cur3 = {await_} cur2.execute("")
+        reveal_type(cur3)
+"""
+    cp = mypy.run_on_source(src)
+    out = cp.stdout.decode("utf8", "replace").splitlines()
+    assert len(out) == 3
+    types = [mypy.get_revealed(line) for line in out]
+    assert types[0] == types[1]
+    assert types[0] == types[2]
 
 
 @pytest.fixture(scope="session")
 def mypy(tmp_path_factory):
     cache_dir = tmp_path_factory.mktemp(basename="mypy_cache")
+    src_dir = tmp_path_factory.mktemp("source")
 
     class MypyRunner:
-        def run(self, filename):
+        def run_on_file(self, filename):
             cmdline = f"""
                 mypy
                 --strict
@@ -283,10 +324,23 @@ def mypy(tmp_path_factory):
             cmdline.append(filename)
             return sp.run(cmdline, stdout=sp.PIPE, stderr=sp.STDOUT)
 
+        def run_on_source(self, source):
+            fn = src_dir / "tmp.py"
+            with fn.open("w") as f:
+                f.write(source)
+
+            return self.run_on_file(str(fn))
+
+        def get_revealed(self, line):
+            """return the type from an output of reveal_type"""
+            return re.sub(
+                r".*Revealed type is (['\"])([^']+)\1.*", r"\2", line
+            ).replace("*", "")
+
     return MypyRunner()
 
 
-def _test_reveal(stmts, type, mypy, tmpdir):
+def _test_reveal(stmts, type, mypy):
     ignore = (
         "" if type.startswith("Optional") else "# type: ignore[assignment]"
     )
@@ -320,17 +374,8 @@ async def tmp() -> None:
 ref: {type} = None  {ignore}
 reveal_type(ref)
 """
-    fn = tmpdir / "tmp.py"
-    with fn.open("w") as f:
-        f.write(src)
-
-    cp = mypy.run(str(fn))
+    cp = mypy.run_on_source(src)
     out = cp.stdout.decode("utf8", "replace").splitlines()
     assert len(out) == 2, "\n".join(out)
-    got, want = [
-        re.sub(r".*Revealed type is (['\"])([^']+)\1.*", r"\2", line).replace(
-            "*", ""
-        )
-        for line in out
-    ]
+    got, want = [mypy.get_revealed(line) for line in out]
     assert got == want