]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fix(crdb): fix type info of CrdbConnection.connect()
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 24 May 2022 08:50:57 +0000 (10:50 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 12 Jul 2022 11:58:34 +0000 (12:58 +0100)
It requires specifying entirely the connect() signature because of by #308.

psycopg/psycopg/crdb.py
tests/crdb/test_connection.py
tests/crdb/test_connection_async.py [new file with mode: 0644]
tests/crdb/test_typing.py [new file with mode: 0644]

index 5fa3492be60348581da3b1ff638962ab6718ffa3..22c190e01ede077a2c492e0ca469df3c0040b46e 100644 (file)
@@ -6,12 +6,12 @@ Types configuration specific for CockroachDB.
 
 import re
 from enum import Enum
-from typing import Any, Optional, Union, TYPE_CHECKING
+from typing import Any, Optional, Type, Union, overload, TYPE_CHECKING
 from ._typeinfo import TypeInfo, TypesRegistry
 
 from . import errors as e
 from .abc import AdaptContext
-from .rows import Row
+from .rows import Row, RowFactory, AsyncRowFactory, TupleRow
 from .postgres import TEXT_OID
 from .conninfo import ConnectionInfo
 from .connection import Connection
@@ -21,6 +21,8 @@ from .types.enum import EnumDumper, EnumBinaryDumper
 
 if TYPE_CHECKING:
     from .pq.abc import PGconn
+    from .cursor import Cursor
+    from .cursor_async import AsyncCursor
 
 types = TypesRegistry()
 
@@ -59,11 +61,81 @@ class _CrdbConnectionMixin:
 
 
 class CrdbConnection(_CrdbConnectionMixin, Connection[Row]):
-    pass
+    # TODO: this method shouldn't require re-definition if the base class
+    # implements a generic self.
+    # https://github.com/psycopg/psycopg/issues/308
+    @overload
+    @classmethod
+    def connect(
+        cls,
+        conninfo: str = "",
+        *,
+        autocommit: bool = False,
+        row_factory: RowFactory[Row],
+        prepare_threshold: Optional[int] = 5,
+        cursor_factory: "Optional[Type[Cursor[Row]]]" = None,
+        context: Optional[AdaptContext] = None,
+        **kwargs: Union[None, int, str],
+    ) -> "CrdbConnection[Row]":
+        ...
+
+    @overload
+    @classmethod
+    def connect(
+        cls,
+        conninfo: str = "",
+        *,
+        autocommit: bool = False,
+        prepare_threshold: Optional[int] = 5,
+        cursor_factory: "Optional[Type[Cursor[Any]]]" = None,
+        context: Optional[AdaptContext] = None,
+        **kwargs: Union[None, int, str],
+    ) -> "CrdbConnection[TupleRow]":
+        ...
+
+    @classmethod
+    def connect(cls, conninfo: str = "", **kwargs: Any) -> "CrdbConnection[Any]":
+        return super().connect(conninfo, **kwargs)  # type: ignore[return-value]
 
 
 class AsyncCrdbConnection(_CrdbConnectionMixin, AsyncConnection[Row]):
-    pass
+    # TODO: this method shouldn't require re-definition if the base class
+    # implements a generic self.
+    # https://github.com/psycopg/psycopg/issues/308
+    @overload
+    @classmethod
+    async def connect(
+        cls,
+        conninfo: str = "",
+        *,
+        autocommit: bool = False,
+        prepare_threshold: Optional[int] = 5,
+        row_factory: AsyncRowFactory[Row],
+        cursor_factory: "Optional[Type[AsyncCursor[Row]]]" = None,
+        context: Optional[AdaptContext] = None,
+        **kwargs: Union[None, int, str],
+    ) -> "AsyncCrdbConnection[Row]":
+        ...
+
+    @overload
+    @classmethod
+    async def connect(
+        cls,
+        conninfo: str = "",
+        *,
+        autocommit: bool = False,
+        prepare_threshold: Optional[int] = 5,
+        cursor_factory: "Optional[Type[AsyncCursor[Any]]]" = None,
+        context: Optional[AdaptContext] = None,
+        **kwargs: Union[None, int, str],
+    ) -> "AsyncCrdbConnection[TupleRow]":
+        ...
+
+    @classmethod
+    async def connect(
+        cls, conninfo: str = "", **kwargs: Any
+    ) -> "AsyncCrdbConnection[Any]":
+        return await super().connect(conninfo, **kwargs)  # type: ignore [no-any-return]
 
 
 connect = CrdbConnection.connect
index 1f1f64d287a08a84a4b0d4c6268e7b333152d2cf..fb57c6e6a09142bab25cc82dfa687356703d54f4 100644 (file)
@@ -12,5 +12,8 @@ def test_is_crdb(conn):
 
 
 def test_connect(dsn):
+    with psycopg.crdb.CrdbConnection.connect(dsn) as conn:
+        assert isinstance(conn, CrdbConnection)
+
     with psycopg.crdb.connect(dsn) as conn:
         assert isinstance(conn, CrdbConnection)
diff --git a/tests/crdb/test_connection_async.py b/tests/crdb/test_connection_async.py
new file mode 100644 (file)
index 0000000..3d6da1b
--- /dev/null
@@ -0,0 +1,16 @@
+import psycopg.crdb
+from psycopg.crdb import AsyncCrdbConnection
+
+import pytest
+
+pytestmark = [pytest.mark.crdb, pytest.mark.asyncio]
+
+
+async def test_is_crdb(aconn):
+    assert AsyncCrdbConnection.is_crdb(aconn)
+    assert AsyncCrdbConnection.is_crdb(aconn.pgconn)
+
+
+async def test_connect(dsn):
+    async with await psycopg.crdb.AsyncCrdbConnection.connect(dsn) as conn:
+        assert isinstance(conn, psycopg.crdb.AsyncCrdbConnection)
diff --git a/tests/crdb/test_typing.py b/tests/crdb/test_typing.py
new file mode 100644 (file)
index 0000000..2cff0a7
--- /dev/null
@@ -0,0 +1,49 @@
+import pytest
+
+from ..test_typing import _test_reveal
+
+
+@pytest.mark.parametrize(
+    "conn, type",
+    [
+        (
+            "psycopg.crdb.connect()",
+            "psycopg.crdb.CrdbConnection[Tuple[Any, ...]]",
+        ),
+        (
+            "psycopg.crdb.connect(row_factory=rows.dict_row)",
+            "psycopg.crdb.CrdbConnection[Dict[str, Any]]",
+        ),
+        (
+            "psycopg.crdb.CrdbConnection.connect()",
+            "psycopg.crdb.CrdbConnection[Tuple[Any, ...]]",
+        ),
+        (
+            "psycopg.crdb.CrdbConnection.connect(row_factory=rows.tuple_row)",
+            "psycopg.crdb.CrdbConnection[Tuple[Any, ...]]",
+        ),
+        (
+            "psycopg.crdb.CrdbConnection.connect(row_factory=rows.dict_row)",
+            "psycopg.crdb.CrdbConnection[Dict[str, Any]]",
+        ),
+        (
+            "await psycopg.crdb.AsyncCrdbConnection.connect()",
+            "psycopg.crdb.AsyncCrdbConnection[Tuple[Any, ...]]",
+        ),
+        (
+            "await psycopg.crdb.AsyncCrdbConnection.connect(row_factory=rows.dict_row)",
+            "psycopg.crdb.AsyncCrdbConnection[Dict[str, Any]]",
+        ),
+    ],
+)
+def test_connection_type(conn, type, mypy):
+    stmts = f"obj = {conn}"
+    _test_reveal_crdb(stmts, type, mypy)
+
+
+def _test_reveal_crdb(stmts, type, mypy):
+    stmts = f"""\
+import psycopg.crdb
+{stmts}
+"""
+    _test_reveal(stmts, type, mypy)