From: Daniele Varrazzo Date: Thu, 29 Apr 2021 16:11:14 +0000 (+0200) Subject: Fix `Connection.connect()` return type X-Git-Tag: 3.0.dev0~63^2~9 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=d142943902a82f158969f5500e3f353adf2fd1f0;p=thirdparty%2Fpsycopg.git Fix `Connection.connect()` return type Now `connect()` returns a `Connection[Tuple]`, whereas `connect(row_factory=something)` return the type of what row factory produces. The implementation of this is somewhat brittle, but that's mypy for you: @dlax (thank you!) noticed that defining `**kwargs: Union[str, int]` helped to disambiguate the row_factory param. I guess we will make a best-effort to maintain this "interface". Everything is to be documented. Strangely, mypy cannot figure out the type of conn = await self.connection_class.connect( self.conninfo, **self.kwargs ) in the async pool, but it can for the sync one (without the `await`). Added explicit type to disambiguate, on both the classes, for symmetry. Added regression tests to verify that refactoring doesn't break type inference. --- diff --git a/psycopg3/psycopg3/connection.py b/psycopg3/psycopg3/connection.py index 7c3135fc0..5d77e69d0 100644 --- a/psycopg3/psycopg3/connection.py +++ b/psycopg3/psycopg3/connection.py @@ -22,7 +22,7 @@ from . import waiting from . import encodings from .pq import ConnStatus, ExecStatus, TransactionStatus, Format from .sql import Composable -from .rows import tuple_row +from .rows import tuple_row, TupleRow from .proto import AdaptContext, ConnectionType, Params, PQGen, PQGenConn from .proto import Query, Row, RowConn, RowFactory, RV from .cursor import Cursor, AsyncCursor @@ -446,7 +446,30 @@ class Connection(BaseConnection[RowConn]): super().__init__(pgconn, row_factory) self.lock = threading.Lock() + @overload @classmethod + def connect( + cls, + conninfo: str = "", + *, + autocommit: bool = False, + row_factory: RowFactory[RowConn], + **kwargs: Union[None, int, str], + ) -> "Connection[RowConn]": + ... + + @overload + @classmethod + def connect( + cls, + conninfo: str = "", + *, + autocommit: bool = False, + **kwargs: Union[None, int, str], + ) -> "Connection[TupleRow]": + ... + + @classmethod # type: ignore[misc] def connect( cls, conninfo: str = "", @@ -454,7 +477,7 @@ class Connection(BaseConnection[RowConn]): autocommit: bool = False, row_factory: Optional[RowFactory[RowConn]] = None, **kwargs: Any, - ) -> "Connection[RowConn]": + ) -> "Connection[Any]": """ Connect to a database server and return a new `Connection` instance. @@ -639,7 +662,30 @@ class AsyncConnection(BaseConnection[RowConn]): super().__init__(pgconn, row_factory) self.lock = asyncio.Lock() + @overload @classmethod + async def connect( + cls, + conninfo: str = "", + *, + autocommit: bool = False, + row_factory: RowFactory[RowConn], + **kwargs: Union[None, int, str], + ) -> "AsyncConnection[RowConn]": + ... + + @overload + @classmethod + async def connect( + cls, + conninfo: str = "", + *, + autocommit: bool = False, + **kwargs: Union[None, int, str], + ) -> "AsyncConnection[TupleRow]": + ... + + @classmethod # type: ignore[misc] async def connect( cls, conninfo: str = "", @@ -647,7 +693,7 @@ class AsyncConnection(BaseConnection[RowConn]): autocommit: bool = False, row_factory: Optional[RowFactory[RowConn]] = None, **kwargs: Any, - ) -> "AsyncConnection[RowConn]": + ) -> "AsyncConnection[Any]": return await cls._wait_conn( cls._connect_gen( conninfo, diff --git a/psycopg3/psycopg3/pool/async_pool.py b/psycopg3/psycopg3/pool/async_pool.py index 6c877b46d..3d07bb279 100644 --- a/psycopg3/psycopg3/pool/async_pool.py +++ b/psycopg3/psycopg3/pool/async_pool.py @@ -350,6 +350,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]): self._stats[self._CONNECTIONS_NUM] += 1 t0 = monotonic() try: + conn: AsyncConnection[Any] conn = await self.connection_class.connect( self.conninfo, **self.kwargs ) diff --git a/psycopg3/psycopg3/pool/pool.py b/psycopg3/psycopg3/pool/pool.py index 2f311a6de..269c3e73e 100644 --- a/psycopg3/psycopg3/pool/pool.py +++ b/psycopg3/psycopg3/pool/pool.py @@ -422,6 +422,7 @@ class ConnectionPool(BasePool[Connection[Any]]): self._stats[self._CONNECTIONS_NUM] += 1 t0 = monotonic() try: + conn: Connection[Any] conn = self.connection_class.connect(self.conninfo, **self.kwargs) except Exception: self._stats[self._CONNECTIONS_ERRORS] += 1 diff --git a/tests/test_typing.py b/tests/test_typing.py index 95505b426..70b700684 100644 --- a/tests/test_typing.py +++ b/tests/test_typing.py @@ -1,4 +1,4 @@ -import os +import re import sys import subprocess as sp @@ -7,15 +7,234 @@ import pytest @pytest.mark.slow @pytest.mark.skipif(sys.version_info < (3, 7), reason="no future annotations") -def test_typing_example(): - cmdline = f""" - mypy - --strict - --show-error-codes --no-color-output --no-error-summary - --config-file= --no-incremental --cache-dir={os.devnull} - tests/typing_example.py - """.split() - cp = sp.run(cmdline, stdout=sp.PIPE, stderr=sp.STDOUT) +def test_typing_example(mypy): + cp = mypy.run("tests/typing_example.py") errors = cp.stdout.decode("utf8", "replace").splitlines() assert not errors assert cp.returncode == 0 + + +@pytest.mark.slow +@pytest.mark.parametrize( + "conn, type", + [ + ( + "psycopg3.connect()", + "psycopg3.Connection[Tuple[Any, ...]]", + ), + ( + "psycopg3.connect(row_factory=rows.tuple_row)", + "psycopg3.Connection[Tuple[Any, ...]]", + ), + ( + "psycopg3.connect(row_factory=rows.dict_row)", + "psycopg3.Connection[Dict[str, Any]]", + ), + ( + "psycopg3.connect(row_factory=rows.namedtuple_row)", + "psycopg3.Connection[NamedTuple]", + ), + ( + "psycopg3.connect(row_factory=thing_row)", + "psycopg3.Connection[Thing]", + ), + ( + "psycopg3.Connection.connect()", + "psycopg3.Connection[Tuple[Any, ...]]", + ), + ( + "psycopg3.Connection.connect(row_factory=rows.dict_row)", + "psycopg3.Connection[Dict[str, Any]]", + ), + ( + "await psycopg3.AsyncConnection.connect()", + "psycopg3.AsyncConnection[Tuple[Any, ...]]", + ), + ( + "await psycopg3.AsyncConnection.connect(row_factory=rows.dict_row)", + "psycopg3.AsyncConnection[Dict[str, Any]]", + ), + ], +) +def test_connection_type(conn, type, mypy, tmpdir): + stmts = f"obj = {conn}" + _test_reveal(stmts, type, mypy, tmpdir) + + +@pytest.mark.slow +@pytest.mark.parametrize( + "conn, curs, type", + [ + ( + "psycopg3.connect()", + "conn.cursor()", + "psycopg3.Cursor[Tuple[Any, ...]]", + ), + ( + "psycopg3.connect(row_factory=rows.dict_row)", + "conn.cursor()", + "psycopg3.Cursor[Dict[str, Any]]", + ), + ( + "psycopg3.connect(row_factory=rows.dict_row)", + "conn.cursor(row_factory=rows.namedtuple_row)", + "psycopg3.Cursor[NamedTuple]", + ), + ( + "psycopg3.connect(row_factory=thing_row)", + "conn.cursor()", + "psycopg3.Cursor[Thing]", + ), + ( + "psycopg3.connect()", + "conn.cursor(row_factory=thing_row)", + "psycopg3.Cursor[Thing]", + ), + # Async cursors + ( + "await psycopg3.AsyncConnection.connect()", + "conn.cursor()", + "psycopg3.AsyncCursor[Tuple[Any, ...]]", + ), + ( + "await psycopg3.AsyncConnection.connect()", + "conn.cursor(row_factory=thing_row)", + "psycopg3.AsyncCursor[Thing]", + ), + # Server-side cursors + ( + "psycopg3.connect()", + "conn.cursor(name='foo')", + "psycopg3.ServerCursor[Tuple[Any, ...]]", + ), + ( + "psycopg3.connect(row_factory=rows.dict_row)", + "conn.cursor(name='foo')", + "psycopg3.ServerCursor[Dict[str, Any]]", + ), + ( + "psycopg3.connect()", + "conn.cursor(name='foo', row_factory=rows.dict_row)", + "psycopg3.ServerCursor[Dict[str, Any]]", + ), + # Async server-side cursors + ( + "await psycopg3.AsyncConnection.connect()", + "conn.cursor(name='foo')", + "psycopg3.AsyncServerCursor[Tuple[Any, ...]]", + ), + ( + "await psycopg3.AsyncConnection.connect(row_factory=rows.dict_row)", + "conn.cursor(name='foo')", + "psycopg3.AsyncServerCursor[Dict[str, Any]]", + ), + ( + "psycopg3.connect()", + "conn.cursor(name='foo', row_factory=rows.dict_row)", + "psycopg3.ServerCursor[Dict[str, Any]]", + ), + ], +) +def test_cursor_type(conn, curs, type, mypy, tmpdir): + stmts = f"""\ +conn = {conn} +obj = {curs} +""" + _test_reveal(stmts, type, mypy, tmpdir) + + +@pytest.mark.slow +@pytest.mark.parametrize( + "curs, type", + [ + ( + "conn.cursor()", + "Optional[Tuple[Any, ...]]", + ), + ( + "conn.cursor(row_factory=rows.dict_row)", + "Optional[Dict[str, Any]]", + ), + ( + "conn.cursor(row_factory=thing_row)", + "Optional[Thing]", + ), + ], +) +@pytest.mark.parametrize("server_side", [False, True]) +@pytest.mark.parametrize("conn_class", ["Connection", "AsyncConnection"]) +def test_fetchone_type(conn_class, server_side, curs, type, mypy, tmpdir): + await_ = "await" if "Async" in conn_class else "" + if server_side: + curs = curs.replace("(", "(name='foo',", 1) + stmts = f"""\ +conn = {await_} psycopg3.{conn_class}.connect() +curs = {curs} +obj = {await_} curs.fetchone() +""" + _test_reveal(stmts, type, mypy, tmpdir) + + +@pytest.fixture(scope="session") +def mypy(tmp_path_factory): + cache_dir = tmp_path_factory.mktemp(basename="mypy_cache") + + class MypyRunner: + def run(self, filename): + cmdline = f""" + mypy + --strict + --show-error-codes --no-color-output --no-error-summary + --config-file= --cache-dir={cache_dir} + """.split() + cmdline.append(filename) + return sp.run(cmdline, stdout=sp.PIPE, stderr=sp.STDOUT) + + return MypyRunner() + + +def _test_reveal(stmts, type, mypy, tmpdir): + ignore = ( + "" if type.startswith("Optional") else "# type: ignore[assignment]" + ) + stmts = "\n".join(f" {line}" for line in stmts.splitlines()) + + src = f"""\ +from typing import Any, Callable, Dict, NamedTuple, Optional, Sequence, Tuple +import psycopg3 +from psycopg3 import rows + +class Thing: + def __init__(self, **kwargs: Any) -> None: + self.kwargs = kwargs + +def thing_row( + cur: psycopg3.BaseCursor[Any, Thing], +) -> Callable[[Sequence[Any]], Thing]: + assert cur.description + names = [d.name for d in cur.description] + + def make_row(t: Sequence[Any]) -> Thing: + return Thing(**dict(zip(names, t))) + + return make_row + +async def tmp() -> None: +{stmts} + reveal_type(obj) + +ref: {type} = None {ignore} +reveal_type(ref) +""" + fn = tmpdir / "tmp.py" + with fn.open("w") as f: + f.write(src) + + cp = mypy.run(str(fn)) + out = cp.stdout.decode("utf8", "replace").splitlines() + assert len(out) == 2, "\n".join(out) + got, want = [ + re.sub(r".*Revealed type is '([^']+)'.*", r"\1", line).replace("*", "") + for line in out + ] + assert got == want diff --git a/tests/typing_example.py b/tests/typing_example.py index fa59a2152..f03116a95 100644 --- a/tests/typing_example.py +++ b/tests/typing_example.py @@ -32,7 +32,7 @@ class Person: def check_row_factory_cursor() -> None: """Type-check connection.cursor(..., row_factory=) case.""" - conn = connect() # type: ignore[var-annotated] # Connection[Any] + conn = connect() cur1: Cursor[Any] cur1 = conn.cursor() @@ -81,7 +81,7 @@ def check_row_factory_connection() -> None: cur3: Cursor[Tuple[Any, ...]] r3: Optional[Tuple[Any, ...]] - conn3 = connect() # type: ignore[var-annotated] + conn3 = connect() cur3 = conn3.execute("select 3") with conn3.cursor() as cur3: cur3.execute("select 42")