]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fix: define the Row TypeVar as defaulting to TupleRow
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 3 Jan 2024 00:44:22 +0000 (01:44 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 3 Jan 2024 02:51:11 +0000 (03:51 +0100)
This allows to return `Self` uniformly from the `Connection.connect()` class
method, which in turns allows to subclass the connection without the
need of redefining the complex signature.

Close #308

psycopg/psycopg/connection.py
psycopg/psycopg/connection_async.py
psycopg/psycopg/crdb/connection.py
psycopg/psycopg/rows.py
psycopg_pool/psycopg_pool/pool.py
psycopg_pool/psycopg_pool/pool_async.py
tests/test_typing.py

index 4bafd5ea1ec29c6e951497ede6c3db9b1dd9ffef..d0e6391a3f6b885382b92d8e1d8450d20d966673 100644 (file)
@@ -20,7 +20,7 @@ from . import errors as e
 from . import waiting
 from .abc import AdaptContext, Params, PQGen, PQGenConn, Query, RV
 from ._tpc import Xid
-from .rows import Row, RowFactory, tuple_row, TupleRow, args_row
+from .rows import Row, RowFactory, tuple_row, args_row
 from .adapt import AdaptersMap
 from ._enums import IsolationLevel
 from ._compat import Self
@@ -84,10 +84,7 @@ class Connection(BaseConnection[Row]):
         cursor_factory: Optional[Type[Cursor[Row]]] = None,
         context: Optional[AdaptContext] = None,
         **kwargs: Union[None, int, str],
-    ) -> Connection[Row]:
-        # TODO: returned type should be Self. See #308.
-        # Unfortunately we cannot use Self[Row] as Self is not parametric.
-        # https://peps.python.org/pep-0673/#use-in-generic-classes
+    ) -> Self:
         ...
 
     @overload
@@ -101,7 +98,7 @@ class Connection(BaseConnection[Row]):
         cursor_factory: Optional[Type[Cursor[Any]]] = None,
         context: Optional[AdaptContext] = None,
         **kwargs: Union[None, int, str],
-    ) -> Connection[TupleRow]:
+    ) -> Self:
         ...
 
     @classmethod  # type: ignore[misc] # https://github.com/python/mypy/issues/11004
index 9d8dad8734e0edfde7251b2ba73fdf4a82514a7f..ec487e483dc0bff13123a14ca5cd1e1cb464b896 100644 (file)
@@ -17,7 +17,7 @@ from . import errors as e
 from . import waiting
 from .abc import AdaptContext, Params, PQGen, PQGenConn, Query, RV
 from ._tpc import Xid
-from .rows import Row, AsyncRowFactory, tuple_row, TupleRow, args_row
+from .rows import Row, AsyncRowFactory, tuple_row, args_row
 from .adapt import AdaptersMap
 from ._enums import IsolationLevel
 from ._compat import Self
@@ -90,10 +90,7 @@ class AsyncConnection(BaseConnection[Row]):
         cursor_factory: Optional[Type[AsyncCursor[Row]]] = None,
         context: Optional[AdaptContext] = None,
         **kwargs: Union[None, int, str],
-    ) -> AsyncConnection[Row]:
-        # TODO: returned type should be Self. See #308.
-        # Unfortunately we cannot use Self[Row] as Self is not parametric.
-        # https://peps.python.org/pep-0673/#use-in-generic-classes
+    ) -> Self:
         ...
 
     @overload
@@ -107,7 +104,7 @@ class AsyncConnection(BaseConnection[Row]):
         cursor_factory: Optional[Type[AsyncCursor[Any]]] = None,
         context: Optional[AdaptContext] = None,
         **kwargs: Union[None, int, str],
-    ) -> AsyncConnection[TupleRow]:
+    ) -> Self:
         ...
 
     @classmethod  # type: ignore[misc] # https://github.com/python/mypy/issues/11004
index 49b7d5ffa3a465feecac4f31f12efebcac9e009c..d88b63e46d7ff71b26542bc45e96d10f3b53bdc1 100644 (file)
@@ -5,12 +5,10 @@ CockroachDB-specific connections.
 # Copyright (C) 2022 The Psycopg Team
 
 import re
-from typing import Any, Optional, Type, Union, overload, TYPE_CHECKING
+from typing import Any, Optional, Union, TYPE_CHECKING
 
 from .. import errors as e
-from ..abc import AdaptContext
-from ..rows import Row, RowFactory, AsyncRowFactory, TupleRow
-from .._compat import Self
+from ..rows import Row
 from ..conninfo import ConnectionInfo
 from ..connection import Connection
 from .._adapters_map import AdaptersMap
@@ -19,8 +17,6 @@ from ._types import adapters
 
 if TYPE_CHECKING:
     from ..pq.abc import PGconn
-    from ..cursor import Cursor
-    from ..cursor_async import AsyncCursor
 
 
 class _CrdbConnectionMixin:
@@ -63,45 +59,6 @@ class CrdbConnection(_CrdbConnectionMixin, Connection[Row]):
 
     __module__ = "psycopg.crdb"
 
-    # TODO: this method shouldn't require re-definition if the base class
-    # implements a generic self.
-    # https://github.com/psycopg/psycopg/issues/308
-    @overload
-    @classmethod
-    def connect(
-        cls,
-        conninfo: str = "",
-        *,
-        autocommit: bool = False,
-        row_factory: RowFactory[Row],
-        prepare_threshold: Optional[int] = 5,
-        cursor_factory: "Optional[Type[Cursor[Row]]]" = None,
-        context: Optional[AdaptContext] = None,
-        **kwargs: Union[None, int, str],
-    ) -> "CrdbConnection[Row]":
-        ...
-
-    @overload
-    @classmethod
-    def connect(
-        cls,
-        conninfo: str = "",
-        *,
-        autocommit: bool = False,
-        prepare_threshold: Optional[int] = 5,
-        cursor_factory: "Optional[Type[Cursor[Any]]]" = None,
-        context: Optional[AdaptContext] = None,
-        **kwargs: Union[None, int, str],
-    ) -> "CrdbConnection[TupleRow]":
-        ...
-
-    @classmethod
-    def connect(cls, conninfo: str = "", **kwargs: Any) -> Self:
-        """
-        Connect to a database server and return a new `CrdbConnection` instance.
-        """
-        return super().connect(conninfo, **kwargs)  # type: ignore[return-value]
-
 
 class AsyncCrdbConnection(_CrdbConnectionMixin, AsyncConnection[Row]):
     """
@@ -110,42 +67,6 @@ class AsyncCrdbConnection(_CrdbConnectionMixin, AsyncConnection[Row]):
 
     __module__ = "psycopg.crdb"
 
-    # TODO: this method shouldn't require re-definition if the base class
-    # implements a generic self.
-    # https://github.com/psycopg/psycopg/issues/308
-    @overload
-    @classmethod
-    async def connect(
-        cls,
-        conninfo: str = "",
-        *,
-        autocommit: bool = False,
-        prepare_threshold: Optional[int] = 5,
-        row_factory: AsyncRowFactory[Row],
-        cursor_factory: "Optional[Type[AsyncCursor[Row]]]" = None,
-        context: Optional[AdaptContext] = None,
-        **kwargs: Union[None, int, str],
-    ) -> "AsyncCrdbConnection[Row]":
-        ...
-
-    @overload
-    @classmethod
-    async def connect(
-        cls,
-        conninfo: str = "",
-        *,
-        autocommit: bool = False,
-        prepare_threshold: Optional[int] = 5,
-        cursor_factory: "Optional[Type[AsyncCursor[Any]]]" = None,
-        context: Optional[AdaptContext] = None,
-        **kwargs: Union[None, int, str],
-    ) -> "AsyncCrdbConnection[TupleRow]":
-        ...
-
-    @classmethod
-    async def connect(cls, conninfo: str = "", **kwargs: Any) -> Self:
-        return await super().connect(conninfo, **kwargs)  # type: ignore[no-any-return]
-
 
 class CrdbConnectionInfo(ConnectionInfo):
     """
index d0d834864face9bf06dde87bd297d8dc6143810c..4c2f7781b5eba777b4df4c852c1c196a948a3e07 100644 (file)
@@ -29,7 +29,7 @@ T = TypeVar("T", covariant=True)
 
 # Row factories
 
-Row = TypeVar("Row", covariant=True)
+Row = TypeVar("Row", covariant=True, default="TupleRow")
 
 
 class RowMaker(Protocol[Row]):
index 1e298518bbd4c873c719895dfaa5d0741cf9306e..45e2184474fc089f85f9a74c3b8fc5d69956b83c 100644 (file)
@@ -625,7 +625,7 @@ class ConnectionPool(Generic[CT], BasePool):
             kwargs["connect_timeout"] = max(round(timeout), 1)
         t0 = monotonic()
         try:
-            conn: CT = cast(CT, self.connection_class.connect(self.conninfo, **kwargs))
+            conn = self.connection_class.connect(self.conninfo, **kwargs)
         except Exception:
             self._stats[self._CONNECTIONS_ERRORS] += 1
             raise
index d53243ff91187080e0cf0c61fea78262ca6d15a4..e3f7e113cfb98155459428e555d3589ae206e4a7 100644 (file)
@@ -669,9 +669,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
             kwargs["connect_timeout"] = max(round(timeout), 1)
         t0 = monotonic()
         try:
-            conn: ACT = cast(
-                ACT, await self.connection_class.connect(self.conninfo, **kwargs)
-            )
+            conn = await self.connection_class.connect(self.conninfo, **kwargs)
         except Exception:
             self._stats[self._CONNECTIONS_ERRORS] += 1
             raise
index fff9cec25b577bae8499b7d10c7bc79222d2dcfa..efafd0bb4ca6bfac241bcef21c991a2eb2ab62b6 100644 (file)
@@ -409,7 +409,6 @@ reveal_type(ref)
     assert got == want
 
 
-@pytest.mark.xfail(reason="https://github.com/psycopg/psycopg/issues/308")
 @pytest.mark.parametrize(
     "conn, type",
     [