import re
from enum import Enum
-from typing import Any, Optional, Union, TYPE_CHECKING
+from typing import Any, Optional, Type, Union, overload, TYPE_CHECKING
from ._typeinfo import TypeInfo, TypesRegistry
from . import errors as e
from .abc import AdaptContext
-from .rows import Row
+from .rows import Row, RowFactory, AsyncRowFactory, TupleRow
from .postgres import TEXT_OID
from .conninfo import ConnectionInfo
from .connection import Connection
if TYPE_CHECKING:
from .pq.abc import PGconn
+ from .cursor import Cursor
+ from .cursor_async import AsyncCursor
types = TypesRegistry()
class CrdbConnection(_CrdbConnectionMixin, Connection[Row]):
- pass
+ # TODO: this method shouldn't require re-definition if the base class
+ # implements a generic self.
+ # https://github.com/psycopg/psycopg/issues/308
+ @overload
+ @classmethod
+ def connect(
+ cls,
+ conninfo: str = "",
+ *,
+ autocommit: bool = False,
+ row_factory: RowFactory[Row],
+ prepare_threshold: Optional[int] = 5,
+ cursor_factory: "Optional[Type[Cursor[Row]]]" = None,
+ context: Optional[AdaptContext] = None,
+ **kwargs: Union[None, int, str],
+ ) -> "CrdbConnection[Row]":
+ ...
+
+ @overload
+ @classmethod
+ def connect(
+ cls,
+ conninfo: str = "",
+ *,
+ autocommit: bool = False,
+ prepare_threshold: Optional[int] = 5,
+ cursor_factory: "Optional[Type[Cursor[Any]]]" = None,
+ context: Optional[AdaptContext] = None,
+ **kwargs: Union[None, int, str],
+ ) -> "CrdbConnection[TupleRow]":
+ ...
+
+ @classmethod
+ def connect(cls, conninfo: str = "", **kwargs: Any) -> "CrdbConnection[Any]":
+ return super().connect(conninfo, **kwargs) # type: ignore[return-value]
class AsyncCrdbConnection(_CrdbConnectionMixin, AsyncConnection[Row]):
- pass
+ # TODO: this method shouldn't require re-definition if the base class
+ # implements a generic self.
+ # https://github.com/psycopg/psycopg/issues/308
+ @overload
+ @classmethod
+ async def connect(
+ cls,
+ conninfo: str = "",
+ *,
+ autocommit: bool = False,
+ prepare_threshold: Optional[int] = 5,
+ row_factory: AsyncRowFactory[Row],
+ cursor_factory: "Optional[Type[AsyncCursor[Row]]]" = None,
+ context: Optional[AdaptContext] = None,
+ **kwargs: Union[None, int, str],
+ ) -> "AsyncCrdbConnection[Row]":
+ ...
+
+ @overload
+ @classmethod
+ async def connect(
+ cls,
+ conninfo: str = "",
+ *,
+ autocommit: bool = False,
+ prepare_threshold: Optional[int] = 5,
+ cursor_factory: "Optional[Type[AsyncCursor[Any]]]" = None,
+ context: Optional[AdaptContext] = None,
+ **kwargs: Union[None, int, str],
+ ) -> "AsyncCrdbConnection[TupleRow]":
+ ...
+
+ @classmethod
+ async def connect(
+ cls, conninfo: str = "", **kwargs: Any
+ ) -> "AsyncCrdbConnection[Any]":
+ return await super().connect(conninfo, **kwargs) # type: ignore [no-any-return]
connect = CrdbConnection.connect
--- /dev/null
+import pytest
+
+from ..test_typing import _test_reveal
+
+
+@pytest.mark.parametrize(
+ "conn, type",
+ [
+ (
+ "psycopg.crdb.connect()",
+ "psycopg.crdb.CrdbConnection[Tuple[Any, ...]]",
+ ),
+ (
+ "psycopg.crdb.connect(row_factory=rows.dict_row)",
+ "psycopg.crdb.CrdbConnection[Dict[str, Any]]",
+ ),
+ (
+ "psycopg.crdb.CrdbConnection.connect()",
+ "psycopg.crdb.CrdbConnection[Tuple[Any, ...]]",
+ ),
+ (
+ "psycopg.crdb.CrdbConnection.connect(row_factory=rows.tuple_row)",
+ "psycopg.crdb.CrdbConnection[Tuple[Any, ...]]",
+ ),
+ (
+ "psycopg.crdb.CrdbConnection.connect(row_factory=rows.dict_row)",
+ "psycopg.crdb.CrdbConnection[Dict[str, Any]]",
+ ),
+ (
+ "await psycopg.crdb.AsyncCrdbConnection.connect()",
+ "psycopg.crdb.AsyncCrdbConnection[Tuple[Any, ...]]",
+ ),
+ (
+ "await psycopg.crdb.AsyncCrdbConnection.connect(row_factory=rows.dict_row)",
+ "psycopg.crdb.AsyncCrdbConnection[Dict[str, Any]]",
+ ),
+ ],
+)
+def test_connection_type(conn, type, mypy):
+ stmts = f"obj = {conn}"
+ _test_reveal_crdb(stmts, type, mypy)
+
+
+def _test_reveal_crdb(stmts, type, mypy):
+ stmts = f"""\
+import psycopg.crdb
+{stmts}
+"""
+ _test_reveal(stmts, type, mypy)