From: Daniele Varrazzo Date: Tue, 24 May 2022 08:50:57 +0000 (+0200) Subject: fix(crdb): fix type info of CrdbConnection.connect() X-Git-Tag: 3.1~49^2~51 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=3a32ed72678c5819e08d1b2e28a6d4dac436ee7d;p=thirdparty%2Fpsycopg.git fix(crdb): fix type info of CrdbConnection.connect() It requires specifying entirely the connect() signature because of by #308. --- diff --git a/psycopg/psycopg/crdb.py b/psycopg/psycopg/crdb.py index 5fa3492be..22c190e01 100644 --- a/psycopg/psycopg/crdb.py +++ b/psycopg/psycopg/crdb.py @@ -6,12 +6,12 @@ Types configuration specific for CockroachDB. import re from enum import Enum -from typing import Any, Optional, Union, TYPE_CHECKING +from typing import Any, Optional, Type, Union, overload, TYPE_CHECKING from ._typeinfo import TypeInfo, TypesRegistry from . import errors as e from .abc import AdaptContext -from .rows import Row +from .rows import Row, RowFactory, AsyncRowFactory, TupleRow from .postgres import TEXT_OID from .conninfo import ConnectionInfo from .connection import Connection @@ -21,6 +21,8 @@ from .types.enum import EnumDumper, EnumBinaryDumper if TYPE_CHECKING: from .pq.abc import PGconn + from .cursor import Cursor + from .cursor_async import AsyncCursor types = TypesRegistry() @@ -59,11 +61,81 @@ class _CrdbConnectionMixin: class CrdbConnection(_CrdbConnectionMixin, Connection[Row]): - pass + # TODO: this method shouldn't require re-definition if the base class + # implements a generic self. + # https://github.com/psycopg/psycopg/issues/308 + @overload + @classmethod + def connect( + cls, + conninfo: str = "", + *, + autocommit: bool = False, + row_factory: RowFactory[Row], + prepare_threshold: Optional[int] = 5, + cursor_factory: "Optional[Type[Cursor[Row]]]" = None, + context: Optional[AdaptContext] = None, + **kwargs: Union[None, int, str], + ) -> "CrdbConnection[Row]": + ... + + @overload + @classmethod + def connect( + cls, + conninfo: str = "", + *, + autocommit: bool = False, + prepare_threshold: Optional[int] = 5, + cursor_factory: "Optional[Type[Cursor[Any]]]" = None, + context: Optional[AdaptContext] = None, + **kwargs: Union[None, int, str], + ) -> "CrdbConnection[TupleRow]": + ... + + @classmethod + def connect(cls, conninfo: str = "", **kwargs: Any) -> "CrdbConnection[Any]": + return super().connect(conninfo, **kwargs) # type: ignore[return-value] class AsyncCrdbConnection(_CrdbConnectionMixin, AsyncConnection[Row]): - pass + # TODO: this method shouldn't require re-definition if the base class + # implements a generic self. + # https://github.com/psycopg/psycopg/issues/308 + @overload + @classmethod + async def connect( + cls, + conninfo: str = "", + *, + autocommit: bool = False, + prepare_threshold: Optional[int] = 5, + row_factory: AsyncRowFactory[Row], + cursor_factory: "Optional[Type[AsyncCursor[Row]]]" = None, + context: Optional[AdaptContext] = None, + **kwargs: Union[None, int, str], + ) -> "AsyncCrdbConnection[Row]": + ... + + @overload + @classmethod + async def connect( + cls, + conninfo: str = "", + *, + autocommit: bool = False, + prepare_threshold: Optional[int] = 5, + cursor_factory: "Optional[Type[AsyncCursor[Any]]]" = None, + context: Optional[AdaptContext] = None, + **kwargs: Union[None, int, str], + ) -> "AsyncCrdbConnection[TupleRow]": + ... + + @classmethod + async def connect( + cls, conninfo: str = "", **kwargs: Any + ) -> "AsyncCrdbConnection[Any]": + return await super().connect(conninfo, **kwargs) # type: ignore [no-any-return] connect = CrdbConnection.connect diff --git a/tests/crdb/test_connection.py b/tests/crdb/test_connection.py index 1f1f64d28..fb57c6e6a 100644 --- a/tests/crdb/test_connection.py +++ b/tests/crdb/test_connection.py @@ -12,5 +12,8 @@ def test_is_crdb(conn): def test_connect(dsn): + with psycopg.crdb.CrdbConnection.connect(dsn) as conn: + assert isinstance(conn, CrdbConnection) + with psycopg.crdb.connect(dsn) as conn: assert isinstance(conn, CrdbConnection) diff --git a/tests/crdb/test_connection_async.py b/tests/crdb/test_connection_async.py new file mode 100644 index 000000000..3d6da1b40 --- /dev/null +++ b/tests/crdb/test_connection_async.py @@ -0,0 +1,16 @@ +import psycopg.crdb +from psycopg.crdb import AsyncCrdbConnection + +import pytest + +pytestmark = [pytest.mark.crdb, pytest.mark.asyncio] + + +async def test_is_crdb(aconn): + assert AsyncCrdbConnection.is_crdb(aconn) + assert AsyncCrdbConnection.is_crdb(aconn.pgconn) + + +async def test_connect(dsn): + async with await psycopg.crdb.AsyncCrdbConnection.connect(dsn) as conn: + assert isinstance(conn, psycopg.crdb.AsyncCrdbConnection) diff --git a/tests/crdb/test_typing.py b/tests/crdb/test_typing.py new file mode 100644 index 000000000..2cff0a735 --- /dev/null +++ b/tests/crdb/test_typing.py @@ -0,0 +1,49 @@ +import pytest + +from ..test_typing import _test_reveal + + +@pytest.mark.parametrize( + "conn, type", + [ + ( + "psycopg.crdb.connect()", + "psycopg.crdb.CrdbConnection[Tuple[Any, ...]]", + ), + ( + "psycopg.crdb.connect(row_factory=rows.dict_row)", + "psycopg.crdb.CrdbConnection[Dict[str, Any]]", + ), + ( + "psycopg.crdb.CrdbConnection.connect()", + "psycopg.crdb.CrdbConnection[Tuple[Any, ...]]", + ), + ( + "psycopg.crdb.CrdbConnection.connect(row_factory=rows.tuple_row)", + "psycopg.crdb.CrdbConnection[Tuple[Any, ...]]", + ), + ( + "psycopg.crdb.CrdbConnection.connect(row_factory=rows.dict_row)", + "psycopg.crdb.CrdbConnection[Dict[str, Any]]", + ), + ( + "await psycopg.crdb.AsyncCrdbConnection.connect()", + "psycopg.crdb.AsyncCrdbConnection[Tuple[Any, ...]]", + ), + ( + "await psycopg.crdb.AsyncCrdbConnection.connect(row_factory=rows.dict_row)", + "psycopg.crdb.AsyncCrdbConnection[Dict[str, Any]]", + ), + ], +) +def test_connection_type(conn, type, mypy): + stmts = f"obj = {conn}" + _test_reveal_crdb(stmts, type, mypy) + + +def _test_reveal_crdb(stmts, type, mypy): + stmts = f"""\ +import psycopg.crdb +{stmts} +""" + _test_reveal(stmts, type, mypy)