]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fix: use a generic return type to return self
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 24 May 2022 10:35:03 +0000 (12:35 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 7 Jun 2022 08:48:47 +0000 (10:48 +0200)
Standardize the var name to _Self, waiting for PEP 0673.

psycopg/psycopg/_pipeline.py
psycopg/psycopg/connection.py
psycopg/psycopg/connection_async.py
psycopg/psycopg/copy.py
psycopg/psycopg/cursor.py
psycopg/psycopg/cursor_async.py
psycopg/psycopg/server_cursor.py
psycopg/psycopg/transaction.py

index 787184860c05f14b5c1965b57026f5fe412b0166..fbbe97c23a654e9a79cfda31cd689a11bf386e6a 100644 (file)
@@ -6,7 +6,7 @@ commands pipeline management
 
 import logging
 from types import TracebackType
-from typing import Any, List, Optional, Union, Tuple, Type, TYPE_CHECKING
+from typing import Any, List, Optional, Union, Tuple, Type, TypeVar, TYPE_CHECKING
 
 from . import pq
 from . import errors as e
@@ -178,6 +178,7 @@ class Pipeline(BasePipeline):
 
     __module__ = "psycopg"
     _conn: "Connection[Any]"
+    _Self = TypeVar("_Self", bound="Pipeline")
 
     def __init__(self, conn: "Connection[Any]") -> None:
         super().__init__(conn)
@@ -192,7 +193,7 @@ class Pipeline(BasePipeline):
         except e.Error as ex:
             raise ex.with_traceback(None)
 
-    def __enter__(self) -> "Pipeline":
+    def __enter__(self: _Self) -> _Self:
         with self._conn.lock:
             self._conn.wait(self._enter_gen())
         return self
@@ -230,6 +231,7 @@ class AsyncPipeline(BasePipeline):
 
     __module__ = "psycopg"
     _conn: "AsyncConnection[Any]"
+    _Self = TypeVar("_Self", bound="AsyncPipeline")
 
     def __init__(self, conn: "AsyncConnection[Any]") -> None:
         super().__init__(conn)
@@ -241,7 +243,7 @@ class AsyncPipeline(BasePipeline):
         except e.Error as ex:
             raise ex.with_traceback(None)
 
-    async def __aenter__(self) -> "AsyncPipeline":
+    async def __aenter__(self: _Self) -> _Self:
         async with self._conn.lock:
             await self._conn.wait(self._enter_gen())
         return self
index 58cb63c6a7958a70a9518dd566e849ce05a8d6ec..daeae3e0f642eb28dc8e06f5623da76c2687b7ef 100644 (file)
@@ -656,8 +656,8 @@ class Connection(BaseConnection[Row]):
     cursor_factory: Type[Cursor[Row]]
     server_cursor_factory: Type[ServerCursor[Row]]
     row_factory: RowFactory[Row]
-
     _pipeline: Optional[Pipeline]
+    _Self = TypeVar("_Self", bound="Connection[Row]")
 
     def __init__(
         self,
@@ -683,6 +683,7 @@ class Connection(BaseConnection[Row]):
         context: Optional[AdaptContext] = None,
         **kwargs: Union[None, int, str],
     ) -> "Connection[Row]":
+        # TODO: returned type should be _Self. See #308.
         ...
 
     @overload
@@ -734,7 +735,7 @@ class Connection(BaseConnection[Row]):
         rv.prepare_threshold = prepare_threshold
         return rv
 
-    def __enter__(self) -> "Connection[Row]":
+    def __enter__(self: _Self) -> _Self:
         return self
 
     def __exit__(
index 3a2bc91afac624c1c383c364c18361ab9b3843ae..301ffac9ca76911d1a4000845a9cde9ca719d529 100644 (file)
@@ -9,7 +9,7 @@ import asyncio
 import logging
 from types import TracebackType
 from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional
-from typing import Type, Union, cast, overload, TYPE_CHECKING
+from typing import Type, TypeVar, Union, cast, overload, TYPE_CHECKING
 from contextlib import asynccontextmanager
 
 from . import pq
@@ -52,6 +52,7 @@ class AsyncConnection(BaseConnection[Row]):
     server_cursor_factory: Type[AsyncServerCursor[Row]]
     row_factory: AsyncRowFactory[Row]
     _pipeline: Optional[AsyncPipeline]
+    _Self = TypeVar("_Self", bound="AsyncConnection[Row]")
 
     def __init__(
         self,
@@ -77,6 +78,7 @@ class AsyncConnection(BaseConnection[Row]):
         context: Optional[AdaptContext] = None,
         **kwargs: Union[None, int, str],
     ) -> "AsyncConnection[Row]":
+        # TODO: returned type should be _Self. See #308.
         ...
 
     @overload
@@ -136,7 +138,7 @@ class AsyncConnection(BaseConnection[Row]):
         rv.prepare_threshold = prepare_threshold
         return rv
 
-    async def __aenter__(self) -> "AsyncConnection[Row]":
+    async def __aenter__(self: _Self) -> _Self:
         return self
 
     async def __aexit__(
index b1fe6dc7d603446671055e4112166b7ffc5315d4..266df21f99002a8478670312d4eb05733d181c1c 100644 (file)
@@ -11,8 +11,8 @@ import asyncio
 import threading
 from abc import ABC, abstractmethod
 from types import TracebackType
-from typing import TYPE_CHECKING, AsyncIterator, Iterator, Generic, Union
-from typing import Any, Dict, List, Match, Optional, Sequence, Type, Tuple
+from typing import Any, AsyncIterator, Dict, Generic, Iterator, List, Match
+from typing import Optional, Sequence, Tuple, Type, TypeVar, Union, TYPE_CHECKING
 
 from . import pq
 from . import errors as e
@@ -56,6 +56,8 @@ class BaseCopy(Generic[ConnectionType]):
     formatting the data in copy format and adding it to the queue.
     """
 
+    _Self = TypeVar("_Self", bound="BaseCopy[ConnectionType]")
+
     # Max size of the write queue of buffers. More than that copy will block
     # Each buffer should be around BUFFER_SIZE size.
     QUEUE_SIZE = 1024
@@ -204,7 +206,7 @@ class Copy(BaseCopy["Connection[Any]"]):
         self._worker: Optional[threading.Thread] = None
         self._worker_error: Optional[BaseException] = None
 
-    def __enter__(self) -> "Copy":
+    def __enter__(self: BaseCopy._Self) -> BaseCopy._Self:
         self._enter()
         return self
 
@@ -354,7 +356,7 @@ class AsyncCopy(BaseCopy["AsyncConnection[Any]"]):
         self._queue: asyncio.Queue[bytes] = asyncio.Queue(maxsize=self.QUEUE_SIZE)
         self._worker: Optional[asyncio.Future[None]] = None
 
-    async def __aenter__(self) -> "AsyncCopy":
+    async def __aenter__(self: BaseCopy._Self) -> BaseCopy._Self:
         self._enter()
         return self
 
index 30c41d8ea31668e79a0389ca3b6ac60b8b29a42a..501cd8c60262abfd58e05974792fff902784d35c 100644 (file)
@@ -41,8 +41,6 @@ else:
     fetch = generators.fetch
     send = generators.send
 
-_C = TypeVar("_C", bound="Cursor[Any]")
-
 TEXT = pq.Format.TEXT
 BINARY = pq.Format.BINARY
 
@@ -657,6 +655,7 @@ class BaseCursor(Generic[ConnectionType, Row]):
 class Cursor(BaseCursor["Connection[Any]", Row]):
     __module__ = "psycopg"
     __slots__ = ()
+    _Self = TypeVar("_Self", bound="Cursor[Row]")
 
     @overload
     def __init__(self: "Cursor[Row]", connection: "Connection[Row]"):
@@ -680,7 +679,7 @@ class Cursor(BaseCursor["Connection[Any]", Row]):
         super().__init__(connection)
         self._row_factory = row_factory or connection.row_factory
 
-    def __enter__(self: _C) -> _C:
+    def __enter__(self: _Self) -> _Self:
         return self
 
     def __exit__(
@@ -712,13 +711,13 @@ class Cursor(BaseCursor["Connection[Any]", Row]):
         return self._row_factory(self)
 
     def execute(
-        self: _C,
+        self: _Self,
         query: Query,
         params: Optional[Params] = None,
         *,
         prepare: Optional[bool] = None,
         binary: Optional[bool] = None,
-    ) -> _C:
+    ) -> _Self:
         """
         Execute a query or command to the database.
         """
index 5044732f8219cb3fabfc25c3c41ab71e2d114a9f..0be29a04feaffc389dcd54cb9284777aa1c4af97 100644 (file)
@@ -20,12 +20,11 @@ from ._pipeline import Pipeline
 if TYPE_CHECKING:
     from .connection_async import AsyncConnection
 
-_C = TypeVar("_C", bound="AsyncCursor[Any]")
-
 
 class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]):
     __module__ = "psycopg"
     __slots__ = ()
+    _Self = TypeVar("_Self", bound="AsyncCursor[Row]")
 
     @overload
     def __init__(self: "AsyncCursor[Row]", connection: "AsyncConnection[Row]"):
@@ -49,7 +48,7 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]):
         super().__init__(connection)
         self._row_factory = row_factory or connection.row_factory
 
-    async def __aenter__(self: _C) -> _C:
+    async def __aenter__(self: _Self) -> _Self:
         return self
 
     async def __aexit__(
@@ -77,13 +76,13 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]):
         return self._row_factory(self)
 
     async def execute(
-        self: _C,
+        self: _Self,
         query: Query,
         params: Optional[Params] = None,
         *,
         prepare: Optional[bool] = None,
         binary: Optional[bool] = None,
-    ) -> _C:
+    ) -> _Self:
         try:
             async with self._conn.lock:
                 await self._conn.wait(
index 65db17212b9fe40b14b7dbaca5d70d85566ab39b..53585b49b409fe6289cd592253b3bf338579e751 100644 (file)
@@ -194,13 +194,10 @@ class ServerCursorMixin(BaseCursor[ConnectionType, Row]):
         return sql.SQL(" ").join(parts)
 
 
-_C = TypeVar("_C", bound="ServerCursor[Any]")
-_AC = TypeVar("_AC", bound="AsyncServerCursor[Any]")
-
-
 class ServerCursor(ServerCursorMixin["Connection[Any]", Row], Cursor[Row]):
     __module__ = "psycopg"
     __slots__ = ()
+    _Self = TypeVar("_Self", bound="ServerCursor[Row]")
 
     @overload
     def __init__(
@@ -259,13 +256,13 @@ class ServerCursor(ServerCursorMixin["Connection[Any]", Row], Cursor[Row]):
             super().close()
 
     def execute(
-        self: _C,
+        self: _Self,
         query: Query,
         params: Optional[Params] = None,
         *,
         binary: Optional[bool] = None,
         **kwargs: Any,
-    ) -> _C:
+    ) -> _Self:
         """
         Open a cursor to execute a query to the database.
         """
@@ -342,6 +339,7 @@ class AsyncServerCursor(
 ):
     __module__ = "psycopg"
     __slots__ = ()
+    _Self = TypeVar("_Self", bound="AsyncServerCursor[Row]")
 
     @overload
     def __init__(
@@ -397,13 +395,13 @@ class AsyncServerCursor(
             await super().close()
 
     async def execute(
-        self: _AC,
+        self: _Self,
         query: Query,
         params: Optional[Params] = None,
         *,
         binary: Optional[bool] = None,
         **kwargs: Any,
-    ) -> _AC:
+    ) -> _Self:
         if kwargs:
             raise TypeError(f"keyword not supported: {list(kwargs)[0]}")
         if self._pgconn.pipeline_status:
index e5c1514b56175780505629ae6f1ba69b4a21b323..e13486e984c91beccaf8bc73fcdb6e74e31d3d51 100644 (file)
@@ -7,7 +7,7 @@ Transaction context managers returned by Connection.transaction()
 import logging
 
 from types import TracebackType
-from typing import Generic, Iterator, Optional, Type, Union, TYPE_CHECKING
+from typing import Generic, Iterator, Optional, Type, Union, TypeVar, TYPE_CHECKING
 
 from . import pq
 from . import sql
@@ -234,12 +234,14 @@ class Transaction(BaseTransaction["Connection[Any]"]):
 
     __module__ = "psycopg"
 
+    _Self = TypeVar("_Self", bound="Transaction")
+
     @property
     def connection(self) -> "Connection[Any]":
         """The connection the object is managing."""
         return self._conn
 
-    def __enter__(self) -> "Transaction":
+    def __enter__(self: _Self) -> _Self:
         with self._conn.lock:
             self._conn.wait(self._enter_gen())
         return self
@@ -264,11 +266,13 @@ class AsyncTransaction(BaseTransaction["AsyncConnection[Any]"]):
 
     __module__ = "psycopg"
 
+    _Self = TypeVar("_Self", bound="AsyncTransaction")
+
     @property
     def connection(self) -> "AsyncConnection[Any]":
         return self._conn
 
-    async def __aenter__(self) -> "AsyncTransaction":
+    async def __aenter__(self: _Self) -> _Self:
         async with self._conn.lock:
             await self._conn.wait(self._enter_gen())
         return self