From: Daniele Varrazzo Date: Tue, 24 May 2022 10:35:03 +0000 (+0200) Subject: fix: use a generic return type to return self X-Git-Tag: 3.1~65 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=c3dbd462bd0d024e978041eaf7ae04f17b1d21ae;p=thirdparty%2Fpsycopg.git fix: use a generic return type to return self Standardize the var name to _Self, waiting for PEP 0673. --- diff --git a/psycopg/psycopg/_pipeline.py b/psycopg/psycopg/_pipeline.py index 787184860..fbbe97c23 100644 --- a/psycopg/psycopg/_pipeline.py +++ b/psycopg/psycopg/_pipeline.py @@ -6,7 +6,7 @@ commands pipeline management import logging from types import TracebackType -from typing import Any, List, Optional, Union, Tuple, Type, TYPE_CHECKING +from typing import Any, List, Optional, Union, Tuple, Type, TypeVar, TYPE_CHECKING from . import pq from . import errors as e @@ -178,6 +178,7 @@ class Pipeline(BasePipeline): __module__ = "psycopg" _conn: "Connection[Any]" + _Self = TypeVar("_Self", bound="Pipeline") def __init__(self, conn: "Connection[Any]") -> None: super().__init__(conn) @@ -192,7 +193,7 @@ class Pipeline(BasePipeline): except e.Error as ex: raise ex.with_traceback(None) - def __enter__(self) -> "Pipeline": + def __enter__(self: _Self) -> _Self: with self._conn.lock: self._conn.wait(self._enter_gen()) return self @@ -230,6 +231,7 @@ class AsyncPipeline(BasePipeline): __module__ = "psycopg" _conn: "AsyncConnection[Any]" + _Self = TypeVar("_Self", bound="AsyncPipeline") def __init__(self, conn: "AsyncConnection[Any]") -> None: super().__init__(conn) @@ -241,7 +243,7 @@ class AsyncPipeline(BasePipeline): except e.Error as ex: raise ex.with_traceback(None) - async def __aenter__(self) -> "AsyncPipeline": + async def __aenter__(self: _Self) -> _Self: async with self._conn.lock: await self._conn.wait(self._enter_gen()) return self diff --git a/psycopg/psycopg/connection.py b/psycopg/psycopg/connection.py index 58cb63c6a..daeae3e0f 100644 --- a/psycopg/psycopg/connection.py +++ b/psycopg/psycopg/connection.py @@ -656,8 +656,8 @@ class Connection(BaseConnection[Row]): cursor_factory: Type[Cursor[Row]] server_cursor_factory: Type[ServerCursor[Row]] row_factory: RowFactory[Row] - _pipeline: Optional[Pipeline] + _Self = TypeVar("_Self", bound="Connection[Row]") def __init__( self, @@ -683,6 +683,7 @@ class Connection(BaseConnection[Row]): context: Optional[AdaptContext] = None, **kwargs: Union[None, int, str], ) -> "Connection[Row]": + # TODO: returned type should be _Self. See #308. ... @overload @@ -734,7 +735,7 @@ class Connection(BaseConnection[Row]): rv.prepare_threshold = prepare_threshold return rv - def __enter__(self) -> "Connection[Row]": + def __enter__(self: _Self) -> _Self: return self def __exit__( diff --git a/psycopg/psycopg/connection_async.py b/psycopg/psycopg/connection_async.py index 3a2bc91af..301ffac9c 100644 --- a/psycopg/psycopg/connection_async.py +++ b/psycopg/psycopg/connection_async.py @@ -9,7 +9,7 @@ import asyncio import logging from types import TracebackType from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional -from typing import Type, Union, cast, overload, TYPE_CHECKING +from typing import Type, TypeVar, Union, cast, overload, TYPE_CHECKING from contextlib import asynccontextmanager from . import pq @@ -52,6 +52,7 @@ class AsyncConnection(BaseConnection[Row]): server_cursor_factory: Type[AsyncServerCursor[Row]] row_factory: AsyncRowFactory[Row] _pipeline: Optional[AsyncPipeline] + _Self = TypeVar("_Self", bound="AsyncConnection[Row]") def __init__( self, @@ -77,6 +78,7 @@ class AsyncConnection(BaseConnection[Row]): context: Optional[AdaptContext] = None, **kwargs: Union[None, int, str], ) -> "AsyncConnection[Row]": + # TODO: returned type should be _Self. See #308. ... @overload @@ -136,7 +138,7 @@ class AsyncConnection(BaseConnection[Row]): rv.prepare_threshold = prepare_threshold return rv - async def __aenter__(self) -> "AsyncConnection[Row]": + async def __aenter__(self: _Self) -> _Self: return self async def __aexit__( diff --git a/psycopg/psycopg/copy.py b/psycopg/psycopg/copy.py index b1fe6dc7d..266df21f9 100644 --- a/psycopg/psycopg/copy.py +++ b/psycopg/psycopg/copy.py @@ -11,8 +11,8 @@ import asyncio import threading from abc import ABC, abstractmethod from types import TracebackType -from typing import TYPE_CHECKING, AsyncIterator, Iterator, Generic, Union -from typing import Any, Dict, List, Match, Optional, Sequence, Type, Tuple +from typing import Any, AsyncIterator, Dict, Generic, Iterator, List, Match +from typing import Optional, Sequence, Tuple, Type, TypeVar, Union, TYPE_CHECKING from . import pq from . import errors as e @@ -56,6 +56,8 @@ class BaseCopy(Generic[ConnectionType]): formatting the data in copy format and adding it to the queue. """ + _Self = TypeVar("_Self", bound="BaseCopy[ConnectionType]") + # Max size of the write queue of buffers. More than that copy will block # Each buffer should be around BUFFER_SIZE size. QUEUE_SIZE = 1024 @@ -204,7 +206,7 @@ class Copy(BaseCopy["Connection[Any]"]): self._worker: Optional[threading.Thread] = None self._worker_error: Optional[BaseException] = None - def __enter__(self) -> "Copy": + def __enter__(self: BaseCopy._Self) -> BaseCopy._Self: self._enter() return self @@ -354,7 +356,7 @@ class AsyncCopy(BaseCopy["AsyncConnection[Any]"]): self._queue: asyncio.Queue[bytes] = asyncio.Queue(maxsize=self.QUEUE_SIZE) self._worker: Optional[asyncio.Future[None]] = None - async def __aenter__(self) -> "AsyncCopy": + async def __aenter__(self: BaseCopy._Self) -> BaseCopy._Self: self._enter() return self diff --git a/psycopg/psycopg/cursor.py b/psycopg/psycopg/cursor.py index 30c41d8ea..501cd8c60 100644 --- a/psycopg/psycopg/cursor.py +++ b/psycopg/psycopg/cursor.py @@ -41,8 +41,6 @@ else: fetch = generators.fetch send = generators.send -_C = TypeVar("_C", bound="Cursor[Any]") - TEXT = pq.Format.TEXT BINARY = pq.Format.BINARY @@ -657,6 +655,7 @@ class BaseCursor(Generic[ConnectionType, Row]): class Cursor(BaseCursor["Connection[Any]", Row]): __module__ = "psycopg" __slots__ = () + _Self = TypeVar("_Self", bound="Cursor[Row]") @overload def __init__(self: "Cursor[Row]", connection: "Connection[Row]"): @@ -680,7 +679,7 @@ class Cursor(BaseCursor["Connection[Any]", Row]): super().__init__(connection) self._row_factory = row_factory or connection.row_factory - def __enter__(self: _C) -> _C: + def __enter__(self: _Self) -> _Self: return self def __exit__( @@ -712,13 +711,13 @@ class Cursor(BaseCursor["Connection[Any]", Row]): return self._row_factory(self) def execute( - self: _C, + self: _Self, query: Query, params: Optional[Params] = None, *, prepare: Optional[bool] = None, binary: Optional[bool] = None, - ) -> _C: + ) -> _Self: """ Execute a query or command to the database. """ diff --git a/psycopg/psycopg/cursor_async.py b/psycopg/psycopg/cursor_async.py index 5044732f8..0be29a04f 100644 --- a/psycopg/psycopg/cursor_async.py +++ b/psycopg/psycopg/cursor_async.py @@ -20,12 +20,11 @@ from ._pipeline import Pipeline if TYPE_CHECKING: from .connection_async import AsyncConnection -_C = TypeVar("_C", bound="AsyncCursor[Any]") - class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]): __module__ = "psycopg" __slots__ = () + _Self = TypeVar("_Self", bound="AsyncCursor[Row]") @overload def __init__(self: "AsyncCursor[Row]", connection: "AsyncConnection[Row]"): @@ -49,7 +48,7 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]): super().__init__(connection) self._row_factory = row_factory or connection.row_factory - async def __aenter__(self: _C) -> _C: + async def __aenter__(self: _Self) -> _Self: return self async def __aexit__( @@ -77,13 +76,13 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]): return self._row_factory(self) async def execute( - self: _C, + self: _Self, query: Query, params: Optional[Params] = None, *, prepare: Optional[bool] = None, binary: Optional[bool] = None, - ) -> _C: + ) -> _Self: try: async with self._conn.lock: await self._conn.wait( diff --git a/psycopg/psycopg/server_cursor.py b/psycopg/psycopg/server_cursor.py index 65db17212..53585b49b 100644 --- a/psycopg/psycopg/server_cursor.py +++ b/psycopg/psycopg/server_cursor.py @@ -194,13 +194,10 @@ class ServerCursorMixin(BaseCursor[ConnectionType, Row]): return sql.SQL(" ").join(parts) -_C = TypeVar("_C", bound="ServerCursor[Any]") -_AC = TypeVar("_AC", bound="AsyncServerCursor[Any]") - - class ServerCursor(ServerCursorMixin["Connection[Any]", Row], Cursor[Row]): __module__ = "psycopg" __slots__ = () + _Self = TypeVar("_Self", bound="ServerCursor[Row]") @overload def __init__( @@ -259,13 +256,13 @@ class ServerCursor(ServerCursorMixin["Connection[Any]", Row], Cursor[Row]): super().close() def execute( - self: _C, + self: _Self, query: Query, params: Optional[Params] = None, *, binary: Optional[bool] = None, **kwargs: Any, - ) -> _C: + ) -> _Self: """ Open a cursor to execute a query to the database. """ @@ -342,6 +339,7 @@ class AsyncServerCursor( ): __module__ = "psycopg" __slots__ = () + _Self = TypeVar("_Self", bound="AsyncServerCursor[Row]") @overload def __init__( @@ -397,13 +395,13 @@ class AsyncServerCursor( await super().close() async def execute( - self: _AC, + self: _Self, query: Query, params: Optional[Params] = None, *, binary: Optional[bool] = None, **kwargs: Any, - ) -> _AC: + ) -> _Self: if kwargs: raise TypeError(f"keyword not supported: {list(kwargs)[0]}") if self._pgconn.pipeline_status: diff --git a/psycopg/psycopg/transaction.py b/psycopg/psycopg/transaction.py index e5c1514b5..e13486e98 100644 --- a/psycopg/psycopg/transaction.py +++ b/psycopg/psycopg/transaction.py @@ -7,7 +7,7 @@ Transaction context managers returned by Connection.transaction() import logging from types import TracebackType -from typing import Generic, Iterator, Optional, Type, Union, TYPE_CHECKING +from typing import Generic, Iterator, Optional, Type, Union, TypeVar, TYPE_CHECKING from . import pq from . import sql @@ -234,12 +234,14 @@ class Transaction(BaseTransaction["Connection[Any]"]): __module__ = "psycopg" + _Self = TypeVar("_Self", bound="Transaction") + @property def connection(self) -> "Connection[Any]": """The connection the object is managing.""" return self._conn - def __enter__(self) -> "Transaction": + def __enter__(self: _Self) -> _Self: with self._conn.lock: self._conn.wait(self._enter_gen()) return self @@ -264,11 +266,13 @@ class AsyncTransaction(BaseTransaction["AsyncConnection[Any]"]): __module__ = "psycopg" + _Self = TypeVar("_Self", bound="AsyncTransaction") + @property def connection(self) -> "AsyncConnection[Any]": return self._conn - async def __aenter__(self) -> "AsyncTransaction": + async def __aenter__(self: _Self) -> _Self: async with self._conn.lock: await self._conn.wait(self._enter_gen()) return self