]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Fix `Connection.connect()` return type
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 29 Apr 2021 16:11:14 +0000 (18:11 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 29 Apr 2021 17:13:48 +0000 (19:13 +0200)
Now `connect()` returns a `Connection[Tuple]`, whereas
`connect(row_factory=something)` return the type of what row factory
produces. The implementation of this is somewhat brittle, but that's
mypy for you: @dlax (thank you!) noticed that defining `**kwargs:
Union[str, int]` helped to disambiguate the row_factory param. I guess
we will make a best-effort to maintain this "interface". Everything is
to be documented.

Strangely, mypy cannot figure out the type of

    conn = await self.connection_class.connect(
        self.conninfo, **self.kwargs
    )

in the async pool, but it can for the sync one (without the `await`).
Added explicit type to disambiguate, on both the classes, for symmetry.

Added regression tests to verify that refactoring doesn't break type
inference.

psycopg3/psycopg3/connection.py
psycopg3/psycopg3/pool/async_pool.py
psycopg3/psycopg3/pool/pool.py
tests/test_typing.py
tests/typing_example.py

index 7c3135fc065da9c198ddb3833a676929fb82b870..5d77e69d090581c025b2ccfe98bdca41cd1a40fd 100644 (file)
@@ -22,7 +22,7 @@ from . import waiting
 from . import encodings
 from .pq import ConnStatus, ExecStatus, TransactionStatus, Format
 from .sql import Composable
-from .rows import tuple_row
+from .rows import tuple_row, TupleRow
 from .proto import AdaptContext, ConnectionType, Params, PQGen, PQGenConn
 from .proto import Query, Row, RowConn, RowFactory, RV
 from .cursor import Cursor, AsyncCursor
@@ -446,7 +446,30 @@ class Connection(BaseConnection[RowConn]):
         super().__init__(pgconn, row_factory)
         self.lock = threading.Lock()
 
+    @overload
     @classmethod
+    def connect(
+        cls,
+        conninfo: str = "",
+        *,
+        autocommit: bool = False,
+        row_factory: RowFactory[RowConn],
+        **kwargs: Union[None, int, str],
+    ) -> "Connection[RowConn]":
+        ...
+
+    @overload
+    @classmethod
+    def connect(
+        cls,
+        conninfo: str = "",
+        *,
+        autocommit: bool = False,
+        **kwargs: Union[None, int, str],
+    ) -> "Connection[TupleRow]":
+        ...
+
+    @classmethod  # type: ignore[misc]
     def connect(
         cls,
         conninfo: str = "",
@@ -454,7 +477,7 @@ class Connection(BaseConnection[RowConn]):
         autocommit: bool = False,
         row_factory: Optional[RowFactory[RowConn]] = None,
         **kwargs: Any,
-    ) -> "Connection[RowConn]":
+    ) -> "Connection[Any]":
         """
         Connect to a database server and return a new `Connection` instance.
 
@@ -639,7 +662,30 @@ class AsyncConnection(BaseConnection[RowConn]):
         super().__init__(pgconn, row_factory)
         self.lock = asyncio.Lock()
 
+    @overload
     @classmethod
+    async def connect(
+        cls,
+        conninfo: str = "",
+        *,
+        autocommit: bool = False,
+        row_factory: RowFactory[RowConn],
+        **kwargs: Union[None, int, str],
+    ) -> "AsyncConnection[RowConn]":
+        ...
+
+    @overload
+    @classmethod
+    async def connect(
+        cls,
+        conninfo: str = "",
+        *,
+        autocommit: bool = False,
+        **kwargs: Union[None, int, str],
+    ) -> "AsyncConnection[TupleRow]":
+        ...
+
+    @classmethod  # type: ignore[misc]
     async def connect(
         cls,
         conninfo: str = "",
@@ -647,7 +693,7 @@ class AsyncConnection(BaseConnection[RowConn]):
         autocommit: bool = False,
         row_factory: Optional[RowFactory[RowConn]] = None,
         **kwargs: Any,
-    ) -> "AsyncConnection[RowConn]":
+    ) -> "AsyncConnection[Any]":
         return await cls._wait_conn(
             cls._connect_gen(
                 conninfo,
index 6c877b46d381ff9574f808b19d7ffa4a79d3d3df..3d07bb27989267aae85e5f9a0efffde2b525dbbd 100644 (file)
@@ -350,6 +350,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]):
         self._stats[self._CONNECTIONS_NUM] += 1
         t0 = monotonic()
         try:
+            conn: AsyncConnection[Any]
             conn = await self.connection_class.connect(
                 self.conninfo, **self.kwargs
             )
index 2f311a6de865a2f450feddfd3d80e387713c53c4..269c3e73e9f65e129063a0f9afce40b9f56f45fa 100644 (file)
@@ -422,6 +422,7 @@ class ConnectionPool(BasePool[Connection[Any]]):
         self._stats[self._CONNECTIONS_NUM] += 1
         t0 = monotonic()
         try:
+            conn: Connection[Any]
             conn = self.connection_class.connect(self.conninfo, **self.kwargs)
         except Exception:
             self._stats[self._CONNECTIONS_ERRORS] += 1
index 95505b4267db36592b67f79023200cd7dfd3a871..70b7006845cbde774fdca8cdc3502519375c1e56 100644 (file)
@@ -1,4 +1,4 @@
-import os
+import re
 import sys
 import subprocess as sp
 
@@ -7,15 +7,234 @@ import pytest
 
 @pytest.mark.slow
 @pytest.mark.skipif(sys.version_info < (3, 7), reason="no future annotations")
-def test_typing_example():
-    cmdline = f"""
-        mypy
-        --strict
-        --show-error-codes --no-color-output --no-error-summary
-        --config-file= --no-incremental --cache-dir={os.devnull}
-        tests/typing_example.py
-        """.split()
-    cp = sp.run(cmdline, stdout=sp.PIPE, stderr=sp.STDOUT)
+def test_typing_example(mypy):
+    cp = mypy.run("tests/typing_example.py")
     errors = cp.stdout.decode("utf8", "replace").splitlines()
     assert not errors
     assert cp.returncode == 0
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize(
+    "conn, type",
+    [
+        (
+            "psycopg3.connect()",
+            "psycopg3.Connection[Tuple[Any, ...]]",
+        ),
+        (
+            "psycopg3.connect(row_factory=rows.tuple_row)",
+            "psycopg3.Connection[Tuple[Any, ...]]",
+        ),
+        (
+            "psycopg3.connect(row_factory=rows.dict_row)",
+            "psycopg3.Connection[Dict[str, Any]]",
+        ),
+        (
+            "psycopg3.connect(row_factory=rows.namedtuple_row)",
+            "psycopg3.Connection[NamedTuple]",
+        ),
+        (
+            "psycopg3.connect(row_factory=thing_row)",
+            "psycopg3.Connection[Thing]",
+        ),
+        (
+            "psycopg3.Connection.connect()",
+            "psycopg3.Connection[Tuple[Any, ...]]",
+        ),
+        (
+            "psycopg3.Connection.connect(row_factory=rows.dict_row)",
+            "psycopg3.Connection[Dict[str, Any]]",
+        ),
+        (
+            "await psycopg3.AsyncConnection.connect()",
+            "psycopg3.AsyncConnection[Tuple[Any, ...]]",
+        ),
+        (
+            "await psycopg3.AsyncConnection.connect(row_factory=rows.dict_row)",
+            "psycopg3.AsyncConnection[Dict[str, Any]]",
+        ),
+    ],
+)
+def test_connection_type(conn, type, mypy, tmpdir):
+    stmts = f"obj = {conn}"
+    _test_reveal(stmts, type, mypy, tmpdir)
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize(
+    "conn, curs, type",
+    [
+        (
+            "psycopg3.connect()",
+            "conn.cursor()",
+            "psycopg3.Cursor[Tuple[Any, ...]]",
+        ),
+        (
+            "psycopg3.connect(row_factory=rows.dict_row)",
+            "conn.cursor()",
+            "psycopg3.Cursor[Dict[str, Any]]",
+        ),
+        (
+            "psycopg3.connect(row_factory=rows.dict_row)",
+            "conn.cursor(row_factory=rows.namedtuple_row)",
+            "psycopg3.Cursor[NamedTuple]",
+        ),
+        (
+            "psycopg3.connect(row_factory=thing_row)",
+            "conn.cursor()",
+            "psycopg3.Cursor[Thing]",
+        ),
+        (
+            "psycopg3.connect()",
+            "conn.cursor(row_factory=thing_row)",
+            "psycopg3.Cursor[Thing]",
+        ),
+        # Async cursors
+        (
+            "await psycopg3.AsyncConnection.connect()",
+            "conn.cursor()",
+            "psycopg3.AsyncCursor[Tuple[Any, ...]]",
+        ),
+        (
+            "await psycopg3.AsyncConnection.connect()",
+            "conn.cursor(row_factory=thing_row)",
+            "psycopg3.AsyncCursor[Thing]",
+        ),
+        # Server-side cursors
+        (
+            "psycopg3.connect()",
+            "conn.cursor(name='foo')",
+            "psycopg3.ServerCursor[Tuple[Any, ...]]",
+        ),
+        (
+            "psycopg3.connect(row_factory=rows.dict_row)",
+            "conn.cursor(name='foo')",
+            "psycopg3.ServerCursor[Dict[str, Any]]",
+        ),
+        (
+            "psycopg3.connect()",
+            "conn.cursor(name='foo', row_factory=rows.dict_row)",
+            "psycopg3.ServerCursor[Dict[str, Any]]",
+        ),
+        # Async server-side cursors
+        (
+            "await psycopg3.AsyncConnection.connect()",
+            "conn.cursor(name='foo')",
+            "psycopg3.AsyncServerCursor[Tuple[Any, ...]]",
+        ),
+        (
+            "await psycopg3.AsyncConnection.connect(row_factory=rows.dict_row)",
+            "conn.cursor(name='foo')",
+            "psycopg3.AsyncServerCursor[Dict[str, Any]]",
+        ),
+        (
+            "psycopg3.connect()",
+            "conn.cursor(name='foo', row_factory=rows.dict_row)",
+            "psycopg3.ServerCursor[Dict[str, Any]]",
+        ),
+    ],
+)
+def test_cursor_type(conn, curs, type, mypy, tmpdir):
+    stmts = f"""\
+conn = {conn}
+obj = {curs}
+"""
+    _test_reveal(stmts, type, mypy, tmpdir)
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize(
+    "curs, type",
+    [
+        (
+            "conn.cursor()",
+            "Optional[Tuple[Any, ...]]",
+        ),
+        (
+            "conn.cursor(row_factory=rows.dict_row)",
+            "Optional[Dict[str, Any]]",
+        ),
+        (
+            "conn.cursor(row_factory=thing_row)",
+            "Optional[Thing]",
+        ),
+    ],
+)
+@pytest.mark.parametrize("server_side", [False, True])
+@pytest.mark.parametrize("conn_class", ["Connection", "AsyncConnection"])
+def test_fetchone_type(conn_class, server_side, curs, type, mypy, tmpdir):
+    await_ = "await" if "Async" in conn_class else ""
+    if server_side:
+        curs = curs.replace("(", "(name='foo',", 1)
+    stmts = f"""\
+conn = {await_} psycopg3.{conn_class}.connect()
+curs = {curs}
+obj = {await_} curs.fetchone()
+"""
+    _test_reveal(stmts, type, mypy, tmpdir)
+
+
+@pytest.fixture(scope="session")
+def mypy(tmp_path_factory):
+    cache_dir = tmp_path_factory.mktemp(basename="mypy_cache")
+
+    class MypyRunner:
+        def run(self, filename):
+            cmdline = f"""
+                mypy
+                --strict
+                --show-error-codes --no-color-output --no-error-summary
+                --config-file= --cache-dir={cache_dir}
+                """.split()
+            cmdline.append(filename)
+            return sp.run(cmdline, stdout=sp.PIPE, stderr=sp.STDOUT)
+
+    return MypyRunner()
+
+
+def _test_reveal(stmts, type, mypy, tmpdir):
+    ignore = (
+        "" if type.startswith("Optional") else "# type: ignore[assignment]"
+    )
+    stmts = "\n".join(f"    {line}" for line in stmts.splitlines())
+
+    src = f"""\
+from typing import Any, Callable, Dict, NamedTuple, Optional, Sequence, Tuple
+import psycopg3
+from psycopg3 import rows
+
+class Thing:
+    def __init__(self, **kwargs: Any) -> None:
+        self.kwargs = kwargs
+
+def thing_row(
+    cur: psycopg3.BaseCursor[Any, Thing],
+) -> Callable[[Sequence[Any]], Thing]:
+    assert cur.description
+    names = [d.name for d in cur.description]
+
+    def make_row(t: Sequence[Any]) -> Thing:
+        return Thing(**dict(zip(names, t)))
+
+    return make_row
+
+async def tmp() -> None:
+{stmts}
+    reveal_type(obj)
+
+ref: {type} = None  {ignore}
+reveal_type(ref)
+"""
+    fn = tmpdir / "tmp.py"
+    with fn.open("w") as f:
+        f.write(src)
+
+    cp = mypy.run(str(fn))
+    out = cp.stdout.decode("utf8", "replace").splitlines()
+    assert len(out) == 2, "\n".join(out)
+    got, want = [
+        re.sub(r".*Revealed type is '([^']+)'.*", r"\1", line).replace("*", "")
+        for line in out
+    ]
+    assert got == want
index fa59a2152577ffaabf0030937821ed2f5d3e6c3e..f03116a95cb5555837e30797c26f3b66ca58c129 100644 (file)
@@ -32,7 +32,7 @@ class Person:
 
 def check_row_factory_cursor() -> None:
     """Type-check connection.cursor(..., row_factory=<MyRowFactory>) case."""
-    conn = connect()  # type: ignore[var-annotated] # Connection[Any]
+    conn = connect()
 
     cur1: Cursor[Any]
     cur1 = conn.cursor()
@@ -81,7 +81,7 @@ def check_row_factory_connection() -> None:
 
     cur3: Cursor[Tuple[Any, ...]]
     r3: Optional[Tuple[Any, ...]]
-    conn3 = connect()  # type: ignore[var-annotated]
+    conn3 = connect()
     cur3 = conn3.execute("select 3")
     with conn3.cursor() as cur3:
         cur3.execute("select 42")