]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
refactor: use typing.Self
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 30 Dec 2023 00:30:39 +0000 (01:30 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 3 Jan 2024 02:50:30 +0000 (03:50 +0100)
The object seems available for all the supported Python version and
should avoid problems with PyRight (see #708).

It is not a solution for #308 because we cannot use `Self[Row]`.

13 files changed:
psycopg/psycopg/_compat.py
psycopg/psycopg/_copy.py
psycopg/psycopg/_copy_async.py
psycopg/psycopg/_copy_base.py
psycopg/psycopg/_pipeline.py
psycopg/psycopg/connection.py
psycopg/psycopg/connection_async.py
psycopg/psycopg/crdb/connection.py
psycopg/psycopg/cursor.py
psycopg/psycopg/cursor_async.py
psycopg/psycopg/pq/_debug.py
psycopg/psycopg/server_cursor.py
psycopg/psycopg/transaction.py

index 6ac16505de69b31d1a1ad1ea6ab142389650f333..1e1130486599e6f60968462beec925aa998600c2 100644 (file)
@@ -23,9 +23,9 @@ else:
     from typing_extensions import TypeGuard
 
 if sys.version_info >= (3, 11):
-    from typing import LiteralString
+    from typing import LiteralString, Self
 else:
-    from typing_extensions import LiteralString
+    from typing_extensions import LiteralString, Self
 
 if sys.version_info >= (3, 13):
     from typing import TypeVar
@@ -36,6 +36,7 @@ __all__ = [
     "Counter",
     "Deque",
     "LiteralString",
+    "Self",
     "TypeGuard",
     "TypeVar",
     "ZoneInfo",
index 28d85eb35d1f7f54a395aac4f48dc591ad027afa..d7db77e21c2f339d79cf0c9f60521f99bca6170c 100644 (file)
@@ -15,6 +15,7 @@ from typing import Any, Iterator, Type, Tuple, Sequence, TYPE_CHECKING
 
 from . import pq
 from . import errors as e
+from ._compat import Self
 from ._copy_base import BaseCopy, MAX_BUFFER_SIZE, QUEUE_SIZE
 from .generators import copy_to, copy_end
 from ._encodings import pgconn_encoding
@@ -62,7 +63,7 @@ class Copy(BaseCopy["Connection[Any]"]):
         self.writer = writer
         self._write = writer.write
 
-    def __enter__(self: BaseCopy._Self) -> BaseCopy._Self:
+    def __enter__(self) -> Self:
         self._enter()
         return self
 
index 5f66dbf2986141a9befd16aefa53afa8a7a771be..7008fbcad8e6b7a6b652784da5b72a7fa6525aff 100644 (file)
@@ -12,6 +12,7 @@ from typing import Any, AsyncIterator, Type, Tuple, Sequence, TYPE_CHECKING
 
 from . import pq
 from . import errors as e
+from ._compat import Self
 from ._copy_base import BaseCopy, MAX_BUFFER_SIZE, QUEUE_SIZE
 from .generators import copy_to, copy_end
 from ._encodings import pgconn_encoding
@@ -59,7 +60,7 @@ class AsyncCopy(BaseCopy["AsyncConnection[Any]"]):
         self.writer = writer
         self._write = writer.write
 
-    async def __aenter__(self: BaseCopy._Self) -> BaseCopy._Self:
+    async def __aenter__(self) -> Self:
         self._enter()
         return self
 
index 8f2a8f72b2b7f02670b4ff1e321eabd46e074db3..140744ff1c1e5e375d26c144dcec3f96576f914c 100644 (file)
@@ -10,7 +10,7 @@ import re
 import struct
 from abc import ABC, abstractmethod
 from typing import Any, Dict, Generic, List, Match
-from typing import Optional, Sequence, Tuple, TypeVar, Union, TYPE_CHECKING
+from typing import Optional, Sequence, Tuple, Union, TYPE_CHECKING
 
 from . import pq
 from . import adapt
@@ -71,8 +71,6 @@ class BaseCopy(Generic[ConnectionType]):
     a file for later use.
     """
 
-    _Self = TypeVar("_Self", bound="BaseCopy[Any]")
-
     formatter: Formatter
 
     def __init__(
index ff7228eeea3f814ee5f9a3f02f918aef5dd43941..72ac97ddd248d6ef0ee7f897223d6251628b630f 100644 (file)
@@ -6,13 +6,13 @@ commands pipeline management
 
 import logging
 from types import TracebackType
-from typing import Any, List, Optional, Union, Tuple, Type, TypeVar, TYPE_CHECKING
+from typing import Any, List, Optional, Union, Tuple, Type, TYPE_CHECKING
 from typing_extensions import TypeAlias
 
 from . import pq
 from . import errors as e
 from .abc import PipelineCommand, PQGen
-from ._compat import Deque
+from ._compat import Deque, Self
 from .pq.misc import connection_summary
 from ._encodings import pgconn_encoding
 from ._preparing import Key, Prepare
@@ -220,7 +220,6 @@ class Pipeline(BasePipeline):
 
     __module__ = "psycopg"
     _conn: "Connection[Any]"
-    _Self = TypeVar("_Self", bound="Pipeline")
 
     def __init__(self, conn: "Connection[Any]") -> None:
         super().__init__(conn)
@@ -235,7 +234,7 @@ class Pipeline(BasePipeline):
         except e._NO_TRACEBACK as ex:
             raise ex.with_traceback(None)
 
-    def __enter__(self: _Self) -> _Self:
+    def __enter__(self) -> Self:
         with self._conn.lock:
             self._conn.wait(self._enter_gen())
         return self
@@ -264,7 +263,6 @@ class AsyncPipeline(BasePipeline):
 
     __module__ = "psycopg"
     _conn: "AsyncConnection[Any]"
-    _Self = TypeVar("_Self", bound="AsyncPipeline")
 
     def __init__(self, conn: "AsyncConnection[Any]") -> None:
         super().__init__(conn)
@@ -276,7 +274,7 @@ class AsyncPipeline(BasePipeline):
         except e._NO_TRACEBACK as ex:
             raise ex.with_traceback(None)
 
-    async def __aenter__(self: _Self) -> _Self:
+    async def __aenter__(self) -> Self:
         async with self._conn.lock:
             await self._conn.wait(self._enter_gen())
         return self
index a5adf2a92b540816f7c3417ca5ceb2479a2146fe..4bafd5ea1ec29c6e951497ede6c3db9b1dd9ffef 100644 (file)
@@ -12,7 +12,7 @@ from __future__ import annotations
 import logging
 from types import TracebackType
 from typing import Any, Generator, Iterator, Dict, List, Optional
-from typing import Type, TypeVar, Union, cast, overload, TYPE_CHECKING
+from typing import Type, Union, cast, overload, TYPE_CHECKING
 from contextlib import contextmanager
 
 from . import pq
@@ -23,6 +23,7 @@ from ._tpc import Xid
 from .rows import Row, RowFactory, tuple_row, TupleRow, args_row
 from .adapt import AdaptersMap
 from ._enums import IsolationLevel
+from ._compat import Self
 from .conninfo import make_conninfo, conninfo_to_dict
 from ._pipeline import Pipeline
 from ._encodings import pgconn_encoding
@@ -59,7 +60,6 @@ class Connection(BaseConnection[Row]):
     server_cursor_factory: Type[ServerCursor[Row]]
     row_factory: RowFactory[Row]
     _pipeline: Optional[Pipeline]
-    _Self = TypeVar("_Self", bound="Connection[Any]")
 
     def __init__(
         self,
@@ -85,7 +85,9 @@ class Connection(BaseConnection[Row]):
         context: Optional[AdaptContext] = None,
         **kwargs: Union[None, int, str],
     ) -> Connection[Row]:
-        # TODO: returned type should be _Self. See #308.
+        # TODO: returned type should be Self. See #308.
+        # Unfortunately we cannot use Self[Row] as Self is not parametric.
+        # https://peps.python.org/pep-0673/#use-in-generic-classes
         ...
 
     @overload
@@ -113,7 +115,7 @@ class Connection(BaseConnection[Row]):
         row_factory: Optional[RowFactory[Row]] = None,
         cursor_factory: Optional[Type[Cursor[Row]]] = None,
         **kwargs: Any,
-    ) -> Connection[Any]:
+    ) -> Self:
         """
         Connect to a database server and return a new `Connection` instance.
         """
@@ -138,7 +140,7 @@ class Connection(BaseConnection[Row]):
         rv.prepare_threshold = prepare_threshold
         return rv
 
-    def __enter__(self: _Self) -> _Self:
+    def __enter__(self) -> Self:
         return self
 
     def __exit__(
index 862a1ebe04008ea095b3b0a45a78607d8a9e2383..9d8dad8734e0edfde7251b2ba73fdf4a82514a7f 100644 (file)
@@ -9,7 +9,7 @@ from __future__ import annotations
 import logging
 from types import TracebackType
 from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional
-from typing import Type, TypeVar, Union, cast, overload, TYPE_CHECKING
+from typing import Type, Union, cast, overload, TYPE_CHECKING
 from contextlib import asynccontextmanager
 
 from . import pq
@@ -20,6 +20,7 @@ from ._tpc import Xid
 from .rows import Row, AsyncRowFactory, tuple_row, TupleRow, args_row
 from .adapt import AdaptersMap
 from ._enums import IsolationLevel
+from ._compat import Self
 from .conninfo import make_conninfo, conninfo_to_dict
 from ._pipeline import AsyncPipeline
 from ._encodings import pgconn_encoding
@@ -65,7 +66,6 @@ class AsyncConnection(BaseConnection[Row]):
     server_cursor_factory: Type[AsyncServerCursor[Row]]
     row_factory: AsyncRowFactory[Row]
     _pipeline: Optional[AsyncPipeline]
-    _Self = TypeVar("_Self", bound="AsyncConnection[Any]")
 
     def __init__(
         self,
@@ -91,7 +91,9 @@ class AsyncConnection(BaseConnection[Row]):
         context: Optional[AdaptContext] = None,
         **kwargs: Union[None, int, str],
     ) -> AsyncConnection[Row]:
-        # TODO: returned type should be _Self. See #308.
+        # TODO: returned type should be Self. See #308.
+        # Unfortunately we cannot use Self[Row] as Self is not parametric.
+        # https://peps.python.org/pep-0673/#use-in-generic-classes
         ...
 
     @overload
@@ -119,7 +121,7 @@ class AsyncConnection(BaseConnection[Row]):
         row_factory: Optional[AsyncRowFactory[Row]] = None,
         cursor_factory: Optional[Type[AsyncCursor[Row]]] = None,
         **kwargs: Any,
-    ) -> AsyncConnection[Any]:
+    ) -> Self:
         """
         Connect to a database server and return a new `AsyncConnection` instance.
         """
@@ -154,7 +156,7 @@ class AsyncConnection(BaseConnection[Row]):
         rv.prepare_threshold = prepare_threshold
         return rv
 
-    async def __aenter__(self: _Self) -> _Self:
+    async def __aenter__(self) -> Self:
         return self
 
     async def __aexit__(
index 451474b77ecf9a07bf93174187a2db3789a55c0c..49b7d5ffa3a465feecac4f31f12efebcac9e009c 100644 (file)
@@ -10,6 +10,7 @@ from typing import Any, Optional, Type, Union, overload, TYPE_CHECKING
 from .. import errors as e
 from ..abc import AdaptContext
 from ..rows import Row, RowFactory, AsyncRowFactory, TupleRow
+from .._compat import Self
 from ..conninfo import ConnectionInfo
 from ..connection import Connection
 from .._adapters_map import AdaptersMap
@@ -95,7 +96,7 @@ class CrdbConnection(_CrdbConnectionMixin, Connection[Row]):
         ...
 
     @classmethod
-    def connect(cls, conninfo: str = "", **kwargs: Any) -> "CrdbConnection[Any]":
+    def connect(cls, conninfo: str = "", **kwargs: Any) -> Self:
         """
         Connect to a database server and return a new `CrdbConnection` instance.
         """
@@ -142,10 +143,8 @@ class AsyncCrdbConnection(_CrdbConnectionMixin, AsyncConnection[Row]):
         ...
 
     @classmethod
-    async def connect(
-        cls, conninfo: str = "", **kwargs: Any
-    ) -> "AsyncCrdbConnection[Any]":
-        return await super().connect(conninfo, **kwargs)  # type: ignore [no-any-return]
+    async def connect(cls, conninfo: str = "", **kwargs: Any) -> Self:
+        return await super().connect(conninfo, **kwargs)  # type: ignore[no-any-return]
 
 
 class CrdbConnectionInfo(ConnectionInfo):
index c2544e3c8a5cf6cade5dbe1b8f23b63843c68de4..98500a64ba1167567694d4e3b8524b134ef91c17 100644 (file)
@@ -10,7 +10,7 @@ Psycopg Cursor object.
 from __future__ import annotations
 
 from types import TracebackType
-from typing import Any, Iterator, Iterable, List, Optional, Type, TypeVar
+from typing import Any, Iterator, Iterable, List, Optional, Type
 from typing import TYPE_CHECKING, overload
 from contextlib import contextmanager
 
@@ -19,6 +19,7 @@ from . import errors as e
 from .abc import Query, Params
 from .copy import Copy, Writer
 from .rows import Row, RowMaker, RowFactory
+from ._compat import Self
 from ._pipeline import Pipeline
 from ._cursor_base import BaseCursor
 
@@ -31,7 +32,6 @@ ACTIVE = pq.TransactionStatus.ACTIVE
 class Cursor(BaseCursor["Connection[Any]", Row]):
     __module__ = "psycopg"
     __slots__ = ()
-    _Self = TypeVar("_Self", bound="Cursor[Any]")
 
     @overload
     def __init__(self: Cursor[Row], connection: Connection[Row]):
@@ -52,7 +52,7 @@ class Cursor(BaseCursor["Connection[Any]", Row]):
         super().__init__(connection)
         self._row_factory = row_factory or connection.row_factory
 
-    def __enter__(self: _Self) -> _Self:
+    def __enter__(self) -> Self:
         return self
 
     def __exit__(
@@ -84,13 +84,13 @@ class Cursor(BaseCursor["Connection[Any]", Row]):
         return self._row_factory(self)
 
     def execute(
-        self: _Self,
+        self,
         query: Query,
         params: Optional[Params] = None,
         *,
         prepare: Optional[bool] = None,
         binary: Optional[bool] = None,
-    ) -> _Self:
+    ) -> Self:
         """
         Execute a query or command to the database.
         """
index f02cfc54028e4d34bf33089c17b0fbd4c0f03e5c..6c6d3f814855a82bab22307238c1e47fa3cd9f57 100644 (file)
@@ -7,7 +7,7 @@ Psycopg AsyncCursor object.
 from __future__ import annotations
 
 from types import TracebackType
-from typing import Any, AsyncIterator, Iterable, List, Optional, Type, TypeVar
+from typing import Any, AsyncIterator, Iterable, List, Optional, Type
 from typing import TYPE_CHECKING, overload
 from contextlib import asynccontextmanager
 
@@ -16,6 +16,7 @@ from . import errors as e
 from .abc import Query, Params
 from .copy import AsyncCopy, AsyncWriter
 from .rows import Row, RowMaker, AsyncRowFactory
+from ._compat import Self
 from ._pipeline import Pipeline
 from ._cursor_base import BaseCursor
 
@@ -28,7 +29,6 @@ ACTIVE = pq.TransactionStatus.ACTIVE
 class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]):
     __module__ = "psycopg"
     __slots__ = ()
-    _Self = TypeVar("_Self", bound="AsyncCursor[Any]")
 
     @overload
     def __init__(self: AsyncCursor[Row], connection: AsyncConnection[Row]):
@@ -52,7 +52,7 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]):
         super().__init__(connection)
         self._row_factory = row_factory or connection.row_factory
 
-    async def __aenter__(self: _Self) -> _Self:
+    async def __aenter__(self) -> Self:
         return self
 
     async def __aexit__(
@@ -84,13 +84,13 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]):
         return self._row_factory(self)
 
     async def execute(
-        self: _Self,
+        self,
         query: Query,
         params: Optional[Params] = None,
         *,
         prepare: Optional[bool] = None,
         binary: Optional[bool] = None,
-    ) -> _Self:
+    ) -> Self:
         """
         Execute a query or command to the database.
         """
index 50fc819e4995e32567f02beaa898a459240cf6c3..bc888648acf160503566a0a3195af0983390da71 100644 (file)
@@ -30,9 +30,9 @@ Suggested usage::
 
 import inspect
 import logging
-from typing import Any, Callable, Type, TYPE_CHECKING
+from typing import Any, Callable, TYPE_CHECKING
 from functools import wraps
-from .._compat import TypeVar
+from .._compat import Self, TypeVar
 
 from . import PGconn
 from .misc import connection_summary
@@ -48,7 +48,6 @@ logger = logging.getLogger("psycopg.debug")
 class PGconnDebug:
     """Wrapper for a PQconn logging all its access."""
 
-    _Self = TypeVar("_Self", bound="PGconnDebug")
     _pgconn: "abc.PGconn"
 
     def __init__(self, pgconn: "abc.PGconn"):
@@ -72,11 +71,11 @@ class PGconnDebug:
         logger.info("PGconn.%s <- %s", attr, value)
 
     @classmethod
-    def connect(cls: Type[_Self], conninfo: bytes) -> _Self:
+    def connect(cls, conninfo: bytes) -> Self:
         return cls(debugging(PGconn.connect)(conninfo))
 
     @classmethod
-    def connect_start(cls: Type[_Self], conninfo: bytes) -> _Self:
+    def connect_start(cls, conninfo: bytes) -> Self:
         return cls(debugging(PGconn.connect_start)(conninfo))
 
     @classmethod
index eada346f9eefa08d36bdadc25bdf2eac52f2dc34..7039d2950105cb7d706ee22c2be55e2c79d07622 100644 (file)
@@ -5,7 +5,7 @@ psycopg server-side cursor objects.
 # Copyright (C) 2020 The Psycopg Team
 
 from typing import Any, AsyncIterator, List, Iterable, Iterator
-from typing import Optional, TypeVar, TYPE_CHECKING, overload
+from typing import Optional, TYPE_CHECKING, overload
 from warnings import warn
 
 from . import pq
@@ -14,6 +14,7 @@ from . import errors as e
 from .abc import ConnectionType, Query, Params, PQGen
 from .rows import Row, RowFactory, AsyncRowFactory
 from .cursor import Cursor
+from ._compat import Self
 from .generators import execute
 from ._cursor_base import BaseCursor
 from .cursor_async import AsyncCursor
@@ -212,7 +213,6 @@ class ServerCursorMixin(BaseCursor[ConnectionType, Row]):
 class ServerCursor(ServerCursorMixin["Connection[Any]", Row], Cursor[Row]):
     __module__ = "psycopg"
     __slots__ = ()
-    _Self = TypeVar("_Self", bound="ServerCursor[Any]")
 
     @overload
     def __init__(
@@ -271,13 +271,13 @@ class ServerCursor(ServerCursorMixin["Connection[Any]", Row], Cursor[Row]):
             super().close()
 
     def execute(
-        self: _Self,
+        self,
         query: Query,
         params: Optional[Params] = None,
         *,
         binary: Optional[bool] = None,
         **kwargs: Any,
-    ) -> _Self:
+    ) -> Self:
         """
         Open a cursor to execute a query to the database.
         """
@@ -354,7 +354,6 @@ class AsyncServerCursor(
 ):
     __module__ = "psycopg"
     __slots__ = ()
-    _Self = TypeVar("_Self", bound="AsyncServerCursor[Any]")
 
     @overload
     def __init__(
@@ -410,13 +409,13 @@ class AsyncServerCursor(
             await super().close()
 
     async def execute(
-        self: _Self,
+        self,
         query: Query,
         params: Optional[Params] = None,
         *,
         binary: Optional[bool] = None,
         **kwargs: Any,
-    ) -> _Self:
+    ) -> Self:
         if kwargs:
             raise TypeError(f"keyword not supported: {list(kwargs)[0]}")
         if self._pgconn.pipeline_status:
index fae3c2ab557b5d746be71432abd69a53b9b06a20..c6405aa438ea25e6aae9965240ae9713ea59be57 100644 (file)
@@ -7,12 +7,13 @@ Transaction context managers returned by Connection.transaction()
 import logging
 
 from types import TracebackType
-from typing import Generic, Iterator, Optional, Type, Union, TypeVar, TYPE_CHECKING
+from typing import Generic, Iterator, Optional, Type, Union, TYPE_CHECKING
 
 from . import pq
 from . import sql
 from . import errors as e
 from .abc import ConnectionType, PQGen
+from ._compat import Self
 from .pq.misc import connection_summary
 
 if TYPE_CHECKING:
@@ -235,14 +236,12 @@ class Transaction(BaseTransaction["Connection[Any]"]):
 
     __module__ = "psycopg"
 
-    _Self = TypeVar("_Self", bound="Transaction")
-
     @property
     def connection(self) -> "Connection[Any]":
         """The connection the object is managing."""
         return self._conn
 
-    def __enter__(self: _Self) -> _Self:
+    def __enter__(self) -> Self:
         with self._conn.lock:
             self._conn.wait(self._enter_gen())
         return self
@@ -267,13 +266,11 @@ class AsyncTransaction(BaseTransaction["AsyncConnection[Any]"]):
 
     __module__ = "psycopg"
 
-    _Self = TypeVar("_Self", bound="AsyncTransaction")
-
     @property
     def connection(self) -> "AsyncConnection[Any]":
         return self._conn
 
-    async def __aenter__(self: _Self) -> _Self:
+    async def __aenter__(self) -> Self:
         async with self._conn.lock:
             await self._conn.wait(self._enter_gen())
         return self