]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Improve typing definition for cursor execute/enter
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 7 Oct 2021 18:29:45 +0000 (20:29 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 7 Oct 2021 20:45:33 +0000 (22:45 +0200)
This is more redundant than using AnyCursor but it respects better the
interface returned.

psycopg/psycopg/cursor.py
psycopg/psycopg/cursor_async.py
psycopg/psycopg/server_cursor.py

index 86272d8fa677182e6502964d99bd25d69aa0e2f3..36d50fd0ff1243b04f5c90171c25e1f17e5c631f 100644 (file)
@@ -7,7 +7,7 @@ psycopg cursor objects
 import sys
 from types import TracebackType
 from typing import Any, Callable, Generic, Iterator, List
-from typing import Optional, NoReturn, Sequence, Type, TYPE_CHECKING, TypeVar
+from typing import Optional, NoReturn, Sequence, Type, TypeVar, TYPE_CHECKING
 from contextlib import contextmanager
 
 from . import pq
@@ -43,6 +43,8 @@ else:
     fetch = generators.fetch
     send = generators.send
 
+_C = TypeVar("_C", bound="Cursor[Any]")
+
 
 class BaseCursor(Generic[ConnectionType, Row]):
     # Slots with __weakref__ and generic bases don't work on Py 3.6
@@ -491,9 +493,6 @@ class BaseCursor(Generic[ConnectionType, Row]):
         self._closed = True
 
 
-AnyCursor = TypeVar("AnyCursor", bound="BaseCursor[Any, Any]")
-
-
 class Cursor(BaseCursor["Connection[Any]", Row]):
     __module__ = "psycopg"
     __slots__ = ()
@@ -504,7 +503,7 @@ class Cursor(BaseCursor["Connection[Any]", Row]):
         super().__init__(connection)
         self._row_factory = row_factory
 
-    def __enter__(self: AnyCursor) -> AnyCursor:
+    def __enter__(self: _C) -> _C:
         return self
 
     def __exit__(
@@ -536,13 +535,13 @@ class Cursor(BaseCursor["Connection[Any]", Row]):
         return self._row_factory(self)
 
     def execute(
-        self: AnyCursor,
+        self: _C,
         query: Query,
         params: Optional[Params] = None,
         *,
         prepare: Optional[bool] = None,
         binary: Optional[bool] = None,
-    ) -> AnyCursor:
+    ) -> _C:
         """
         Execute a query or command to the database.
         """
index df263542ee158f0d53a91df215932b490665b6ed..db573c0c7841447000f9869acc9c4bb82bee0a24 100644 (file)
@@ -6,19 +6,21 @@ psycopg async cursor objects
 
 from types import TracebackType
 from typing import Any, AsyncIterator, List
-from typing import Optional, Sequence, Type, TYPE_CHECKING
+from typing import Optional, Sequence, Type, TypeVar, TYPE_CHECKING
 
 from . import errors as e
 
 from .abc import Query, Params
 from .copy import AsyncCopy
 from .rows import Row, RowMaker, AsyncRowFactory
-from .cursor import BaseCursor, AnyCursor
+from .cursor import BaseCursor
 from ._compat import asynccontextmanager
 
 if TYPE_CHECKING:
     from .connection_async import AsyncConnection
 
+_C = TypeVar("_C", bound="AsyncCursor[Any]")
+
 
 class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]):
     __module__ = "psycopg"
@@ -33,7 +35,7 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]):
         super().__init__(connection)
         self._row_factory = row_factory
 
-    async def __aenter__(self: AnyCursor) -> AnyCursor:
+    async def __aenter__(self: _C) -> _C:
         return self
 
     async def __aexit__(
@@ -61,13 +63,13 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]):
         return self._row_factory(self)
 
     async def execute(
-        self: AnyCursor,
+        self: _C,
         query: Query,
         params: Optional[Params] = None,
         *,
         prepare: Optional[bool] = None,
         binary: Optional[bool] = None,
-    ) -> AnyCursor:
+    ) -> _C:
         try:
             async with self._conn.lock:
                 await self._conn.wait(
index 1485b43ffd18c828e55751ac2f713677089a54a6..d936b7b6f8a447e7fb63424b2376f2db146ba16d 100644 (file)
@@ -5,15 +5,15 @@ psycopg server-side cursor objects.
 # Copyright (C) 2020-2021 The Psycopg Team
 
 import warnings
-from typing import Any, AsyncIterator, cast, Generic, List, Iterator, Optional
-from typing import Sequence, TYPE_CHECKING
+from typing import Any, AsyncIterator, Generic, List, Iterator, Optional
+from typing import Sequence, TypeVar, TYPE_CHECKING
 
 from . import pq
 from . import sql
 from . import errors as e
 from .abc import ConnectionType, Query, Params, PQGen
 from .rows import Row, RowFactory, AsyncRowFactory
-from .cursor import AnyCursor, BaseCursor, Cursor, execute
+from .cursor import BaseCursor, Cursor, execute
 from .cursor_async import AsyncCursor
 from ._encodings import pgconn_encoding
 
@@ -178,6 +178,10 @@ class ServerCursorHelper(Generic[ConnectionType, Row]):
         return sql.SQL(" ").join(parts)
 
 
+_C = TypeVar("_C", bound="ServerCursor[Any]")
+_AC = TypeVar("_AC", bound="AsyncServerCursor[Any]")
+
+
 class ServerCursor(Cursor[Row]):
     __module__ = "psycopg"
     __slots__ = ("_helper", "itersize")
@@ -240,19 +244,19 @@ class ServerCursor(Cursor[Row]):
             super().close()
 
     def execute(
-        self: AnyCursor,
+        self: _C,
         query: Query,
         params: Optional[Params] = None,
         *,
         binary: Optional[bool] = None,
         **kwargs: Any,
-    ) -> AnyCursor:
+    ) -> _C:
         """
         Open a cursor to execute a query to the database.
         """
         if kwargs:
             raise TypeError(f"keyword not supported: {list(kwargs)[0]}")
-        helper = cast(ServerCursor[Row], self)._helper
+        helper = self._helper
         query = helper._make_declare_statement(self, query)
 
         if binary is None:
@@ -364,16 +368,16 @@ class AsyncServerCursor(AsyncCursor[Row]):
             await super().close()
 
     async def execute(
-        self: AnyCursor,
+        self: _AC,
         query: Query,
         params: Optional[Params] = None,
         *,
         binary: Optional[bool] = None,
         **kwargs: Any,
-    ) -> AnyCursor:
+    ) -> _AC:
         if kwargs:
             raise TypeError(f"keyword not supported: {list(kwargs)[0]}")
-        helper = cast(AsyncServerCursor[Row], self)._helper
+        helper = self._helper
         query = helper._make_declare_statement(self, query)
 
         if binary is None: