From: Daniele Varrazzo Date: Thu, 7 Oct 2021 18:29:45 +0000 (+0200) Subject: Improve typing definition for cursor execute/enter X-Git-Tag: 3.0~16 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=f343897842188aa970d03c89d29790522f7ec013;p=thirdparty%2Fpsycopg.git Improve typing definition for cursor execute/enter This is more redundant than using AnyCursor but it respects better the interface returned. --- diff --git a/psycopg/psycopg/cursor.py b/psycopg/psycopg/cursor.py index 86272d8fa..36d50fd0f 100644 --- a/psycopg/psycopg/cursor.py +++ b/psycopg/psycopg/cursor.py @@ -7,7 +7,7 @@ psycopg cursor objects import sys from types import TracebackType from typing import Any, Callable, Generic, Iterator, List -from typing import Optional, NoReturn, Sequence, Type, TYPE_CHECKING, TypeVar +from typing import Optional, NoReturn, Sequence, Type, TypeVar, TYPE_CHECKING from contextlib import contextmanager from . import pq @@ -43,6 +43,8 @@ else: fetch = generators.fetch send = generators.send +_C = TypeVar("_C", bound="Cursor[Any]") + class BaseCursor(Generic[ConnectionType, Row]): # Slots with __weakref__ and generic bases don't work on Py 3.6 @@ -491,9 +493,6 @@ class BaseCursor(Generic[ConnectionType, Row]): self._closed = True -AnyCursor = TypeVar("AnyCursor", bound="BaseCursor[Any, Any]") - - class Cursor(BaseCursor["Connection[Any]", Row]): __module__ = "psycopg" __slots__ = () @@ -504,7 +503,7 @@ class Cursor(BaseCursor["Connection[Any]", Row]): super().__init__(connection) self._row_factory = row_factory - def __enter__(self: AnyCursor) -> AnyCursor: + def __enter__(self: _C) -> _C: return self def __exit__( @@ -536,13 +535,13 @@ class Cursor(BaseCursor["Connection[Any]", Row]): return self._row_factory(self) def execute( - self: AnyCursor, + self: _C, query: Query, params: Optional[Params] = None, *, prepare: Optional[bool] = None, binary: Optional[bool] = None, - ) -> AnyCursor: + ) -> _C: """ Execute a query or command to the database. """ diff --git a/psycopg/psycopg/cursor_async.py b/psycopg/psycopg/cursor_async.py index df263542e..db573c0c7 100644 --- a/psycopg/psycopg/cursor_async.py +++ b/psycopg/psycopg/cursor_async.py @@ -6,19 +6,21 @@ psycopg async cursor objects from types import TracebackType from typing import Any, AsyncIterator, List -from typing import Optional, Sequence, Type, TYPE_CHECKING +from typing import Optional, Sequence, Type, TypeVar, TYPE_CHECKING from . import errors as e from .abc import Query, Params from .copy import AsyncCopy from .rows import Row, RowMaker, AsyncRowFactory -from .cursor import BaseCursor, AnyCursor +from .cursor import BaseCursor from ._compat import asynccontextmanager if TYPE_CHECKING: from .connection_async import AsyncConnection +_C = TypeVar("_C", bound="AsyncCursor[Any]") + class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]): __module__ = "psycopg" @@ -33,7 +35,7 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]): super().__init__(connection) self._row_factory = row_factory - async def __aenter__(self: AnyCursor) -> AnyCursor: + async def __aenter__(self: _C) -> _C: return self async def __aexit__( @@ -61,13 +63,13 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]): return self._row_factory(self) async def execute( - self: AnyCursor, + self: _C, query: Query, params: Optional[Params] = None, *, prepare: Optional[bool] = None, binary: Optional[bool] = None, - ) -> AnyCursor: + ) -> _C: try: async with self._conn.lock: await self._conn.wait( diff --git a/psycopg/psycopg/server_cursor.py b/psycopg/psycopg/server_cursor.py index 1485b43ff..d936b7b6f 100644 --- a/psycopg/psycopg/server_cursor.py +++ b/psycopg/psycopg/server_cursor.py @@ -5,15 +5,15 @@ psycopg server-side cursor objects. # Copyright (C) 2020-2021 The Psycopg Team import warnings -from typing import Any, AsyncIterator, cast, Generic, List, Iterator, Optional -from typing import Sequence, TYPE_CHECKING +from typing import Any, AsyncIterator, Generic, List, Iterator, Optional +from typing import Sequence, TypeVar, TYPE_CHECKING from . import pq from . import sql from . import errors as e from .abc import ConnectionType, Query, Params, PQGen from .rows import Row, RowFactory, AsyncRowFactory -from .cursor import AnyCursor, BaseCursor, Cursor, execute +from .cursor import BaseCursor, Cursor, execute from .cursor_async import AsyncCursor from ._encodings import pgconn_encoding @@ -178,6 +178,10 @@ class ServerCursorHelper(Generic[ConnectionType, Row]): return sql.SQL(" ").join(parts) +_C = TypeVar("_C", bound="ServerCursor[Any]") +_AC = TypeVar("_AC", bound="AsyncServerCursor[Any]") + + class ServerCursor(Cursor[Row]): __module__ = "psycopg" __slots__ = ("_helper", "itersize") @@ -240,19 +244,19 @@ class ServerCursor(Cursor[Row]): super().close() def execute( - self: AnyCursor, + self: _C, query: Query, params: Optional[Params] = None, *, binary: Optional[bool] = None, **kwargs: Any, - ) -> AnyCursor: + ) -> _C: """ Open a cursor to execute a query to the database. """ if kwargs: raise TypeError(f"keyword not supported: {list(kwargs)[0]}") - helper = cast(ServerCursor[Row], self)._helper + helper = self._helper query = helper._make_declare_statement(self, query) if binary is None: @@ -364,16 +368,16 @@ class AsyncServerCursor(AsyncCursor[Row]): await super().close() async def execute( - self: AnyCursor, + self: _AC, query: Query, params: Optional[Params] = None, *, binary: Optional[bool] = None, **kwargs: Any, - ) -> AnyCursor: + ) -> _AC: if kwargs: raise TypeError(f"keyword not supported: {list(kwargs)[0]}") - helper = cast(AsyncServerCursor[Row], self)._helper + helper = self._helper query = helper._make_declare_statement(self, query) if binary is None: