]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
refactor: remove more types as strings
authorDenis Laxalde <denis.laxalde@dalibo.com>
Tue, 4 Jun 2024 07:39:14 +0000 (09:39 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 4 Jun 2024 19:03:55 +0000 (21:03 +0200)
I.e. all those suggested by pyupgrade --py38-plus.

20 files changed:
psycopg/psycopg/_connection_base.py
psycopg/psycopg/_dns.py
psycopg/psycopg/_py_transformer.py
psycopg/psycopg/_tpc.py
psycopg/psycopg/adapt.py
psycopg/psycopg/crdb/connection.py
psycopg/psycopg/errors.py
psycopg/psycopg/sql.py
psycopg/psycopg/transaction.py
psycopg/psycopg/types/array.py
psycopg/psycopg/types/datetime.py
psycopg/psycopg/types/multirange.py
psycopg/psycopg/types/net.py
psycopg/psycopg/types/numeric.py
psycopg/psycopg/types/range.py
psycopg/psycopg/types/shapely.py
psycopg_pool/psycopg_pool/base.py
tests/adapters_example.py
tests/fix_pq.py
tests/utils.py

index 8b0277dadf6f4e70de41deac6c59b0935b1712a3..8a9a6e50def49266f296ba4048334784ded3cb17 100644 (file)
@@ -278,7 +278,7 @@ class BaseConnection(Generic[Row]):
         return self._adapters
 
     @property
-    def connection(self) -> "BaseConnection[Row]":
+    def connection(self) -> BaseConnection[Row]:
         # implement the AdaptContext protocol
         return self
 
@@ -337,7 +337,7 @@ class BaseConnection(Generic[Row]):
 
     @staticmethod
     def _notice_handler(
-        wself: "ReferenceType[BaseConnection[Row]]", res: PGresult
+        wself: ReferenceType[BaseConnection[Row]], res: PGresult
     ) -> None:
         self = wself()
         if not (self and self._notice_handlers):
@@ -370,7 +370,7 @@ class BaseConnection(Generic[Row]):
 
     @staticmethod
     def _notify_handler(
-        wself: "ReferenceType[BaseConnection[Row]]", pgn: pq.PGnotify
+        wself: ReferenceType[BaseConnection[Row]], pgn: pq.PGnotify
     ) -> None:
         self = wself()
         if not (self and self._notify_handlers):
index 1d74aa4704f036bc8b1ccb7366633e0934bdd218..9dd67f0b54e0dcefb8cae02870cbcabf5aef17ec 100644 (file)
@@ -185,7 +185,7 @@ class Rfc2782Resolver:
         return self._get_solved_entries(hp, ans)
 
     def _get_solved_entries(
-        self, hp: HostPort, entries: "Sequence[SRV]"
+        self, hp: HostPort, entries: Sequence[SRV]
     ) -> list[HostPort]:
         if not entries:
             # No SRV entry found. Delegate the libpq a QNAME=target lookup
@@ -216,13 +216,13 @@ class Rfc2782Resolver:
         out["port"] = ",".join(str(hp.port) for hp in hps)
         return out
 
-    def sort_rfc2782(self, ans: "Sequence[SRV]") -> "list[SRV]":
+    def sort_rfc2782(self, ans: Sequence[SRV]) -> list[SRV]:
         """
         Implement the priority/weight ordering defined in RFC 2782.
         """
         # Divide the entries by priority:
-        priorities: DefaultDict[int, "list[SRV]"] = defaultdict(list)
-        out: "list[SRV]" = []
+        priorities: DefaultDict[int, list[SRV]] = defaultdict(list)
+        out: list[SRV] = []
         for entry in ans:
             priorities[entry.priority].append(entry)
 
index 0620113eecc5e4366fd74682bd65980d44cb660c..53008949d7432d03109868dc631dacbce9f90700 100644 (file)
@@ -100,7 +100,7 @@ class Transformer(AdaptContext):
         self._encoding = ""
 
     @classmethod
-    def from_context(cls, context: AdaptContext | None) -> "Transformer":
+    def from_context(cls, context: AdaptContext | None) -> Transformer:
         """
         Return a Transformer from an AdaptContext.
 
index c809d5e84086945b2958bcca41a6867503b86a76..880feff3c138c1df063c596fdc418d8c96a82303 100644 (file)
@@ -31,7 +31,7 @@ class Xid:
     database: str | None = None
 
     @classmethod
-    def from_string(cls, s: str) -> "Xid":
+    def from_string(cls, s: str) -> Xid:
         """Try to parse an XA triple from the string.
 
         This may fail for several reasons. In such case return an unparsed Xid.
@@ -51,7 +51,7 @@ class Xid:
         return (self.format_id, self.gtrid, self.bqual)[index]
 
     @classmethod
-    def _parse_string(cls, s: str) -> "Xid":
+    def _parse_string(cls, s: str) -> Xid:
         m = _re_xid.match(s)
         if not m:
             raise ValueError("bad Xid format")
@@ -62,7 +62,7 @@ class Xid:
         return cls.from_parts(format_id, gtrid, bqual)
 
     @classmethod
-    def from_parts(cls, format_id: int | None, gtrid: str, bqual: str | None) -> "Xid":
+    def from_parts(cls, format_id: int | None, gtrid: str, bqual: str | None) -> Xid:
         if format_id is not None:
             if bqual is None:
                 raise TypeError("if format_id is specified, bqual must be too")
@@ -107,7 +107,7 @@ class Xid:
     @classmethod
     def _from_record(
         cls, gid: str, prepared: dt.datetime, owner: str, database: str
-    ) -> "Xid":
+    ) -> Xid:
         xid = Xid.from_string(gid)
         return replace(xid, prepared=prepared, owner=owner, database=database)
 
index 889801de9b1fdcd7be8ef0d6890330ed592fa16e..428042fe1d346b105b22d10f8be71033c92faa2d 100644 (file)
@@ -107,7 +107,7 @@ class Dumper(abc.Dumper, ABC):
         """
         return self.cls
 
-    def upgrade(self, obj: Any, format: PyFormat) -> "Dumper":
+    def upgrade(self, obj: Any, format: PyFormat) -> Dumper:
         """
         Implementation of the `~psycopg.abc.Dumper.upgrade()` member of the
         `~psycopg.abc.Dumper` protocol. Look at its definition for details.
index 74b34a7db7073a227b2b7f922ed802e20ed53f34..9ca1c640dc818fe31e1def0cdb1fb44290e5a0ae 100644 (file)
@@ -44,7 +44,7 @@ class _CrdbConnectionMixin:
         return self._adapters
 
     @property
-    def info(self) -> "CrdbConnectionInfo":
+    def info(self) -> CrdbConnectionInfo:
         return CrdbConnectionInfo(self.pgconn)
 
     def _check_tpc(self) -> None:
index 0d162feaa9af08ba0f6382bc316e4bad4d1abc6d..00cb24c9c9455c28ff3b35c5ed3a425bfa838061 100644 (file)
@@ -292,7 +292,7 @@ class Error(Exception):
         return self._info if _is_pgresult(self._info) else None
 
     @property
-    def diag(self) -> "Diagnostic":
+    def diag(self) -> Diagnostic:
         """
         A `Diagnostic` object to inspect details of the errors from the database.
         """
index 96d518d3768bf972572918115bd29de75165c5b6..282c138fcabe1436c5559df205e9e1bd55432313 100644 (file)
@@ -88,7 +88,7 @@ class Composable(ABC):
             # buffer object
             return codecs.lookup(enc).decode(b)[0]
 
-    def __add__(self, other: "Composable") -> "Composed":
+    def __add__(self, other: Composable) -> Composed:
         if isinstance(other, Composed):
             return Composed([self]) + other
         if isinstance(other, Composable):
@@ -96,7 +96,7 @@ class Composable(ABC):
         else:
             return NotImplemented
 
-    def __mul__(self, n: int) -> "Composed":
+    def __mul__(self, n: int) -> Composed:
         return Composed([self] * n)
 
     def __eq__(self, other: Any) -> bool:
@@ -138,7 +138,7 @@ class Composed(Composable):
     def __iter__(self) -> Iterator[Composable]:
         return iter(self._obj)
 
-    def __add__(self, other: Composable) -> "Composed":
+    def __add__(self, other: Composable) -> Composed:
         if isinstance(other, Composed):
             return Composed(self._obj + other._obj)
         if isinstance(other, Composable):
@@ -146,7 +146,7 @@ class Composed(Composable):
         else:
             return NotImplemented
 
-    def join(self, joiner: "SQL" | LiteralString) -> "Composed":
+    def join(self, joiner: SQL | LiteralString) -> Composed:
         """
         Return a new `!Composed` interposing the `!joiner` with the `!Composed` items.
 
index 9a60a5fbc6bce3b9700a9a8c71b8439c0e90161b..d0a3aa349ad9aa0dfa48427363e1531463128c92 100644 (file)
@@ -40,7 +40,7 @@ class Rollback(Exception):
 
     __module__ = "psycopg"
 
-    def __init__(self, transaction: "Transaction" | "AsyncTransaction" | None = None):
+    def __init__(self, transaction: Transaction | AsyncTransaction | None = None):
         self.transaction = transaction
 
     def __repr__(self) -> str:
index 78f3eaeff6d0a98a1bb68851f2c7cd2c3af4b8bc..1da8f43169d9b621dc853bd43adccb64e6b7bdcc 100644 (file)
@@ -128,7 +128,7 @@ class ListDumper(BaseListDumper):
         sd = self._tx.get_dumper(item, format)
         return (self.cls, sd.get_key(item, format))
 
-    def upgrade(self, obj: list[Any], format: PyFormat) -> "BaseListDumper":
+    def upgrade(self, obj: list[Any], format: PyFormat) -> BaseListDumper:
         # If we have an oid we don't need to upgrade
         if self.oid:
             return self
@@ -232,7 +232,7 @@ class ListBinaryDumper(BaseListDumper):
         sd = self._tx.get_dumper(item, format)
         return (self.cls, sd.get_key(item, format))
 
-    def upgrade(self, obj: list[Any], format: PyFormat) -> "BaseListDumper":
+    def upgrade(self, obj: list[Any], format: PyFormat) -> BaseListDumper:
         # If we have an oid we don't need to upgrade
         if self.oid:
             return self
index 866a2fa6cbde72befae0cb62098097b64805190f..3d8db27f8e8aa644931eea2eef0542631c9cf1ca 100644 (file)
@@ -204,13 +204,13 @@ class TimedeltaDumper(Dumper):
         return self._dump_method(self, obj)
 
     @staticmethod
-    def _dump_any(self: "TimedeltaDumper", obj: timedelta) -> bytes:
+    def _dump_any(self: TimedeltaDumper, obj: timedelta) -> bytes:
         # The comma is parsed ok by PostgreSQL but it's not documented
         # and it seems brittle to rely on it. CRDB doesn't consume it well.
         return str(obj).encode().replace(b",", b"")
 
     @staticmethod
-    def _dump_sql(self: "TimedeltaDumper", obj: timedelta) -> bytes:
+    def _dump_sql(self: TimedeltaDumper, obj: timedelta) -> bytes:
         # sql_standard format needs explicit signs
         # otherwise -1 day 1 sec will mean -1 sec
         return b"%+d day %+d second %+d microsecond" % (
@@ -509,7 +509,7 @@ class TimestamptzLoader(Loader):
         return self._load_method(self, data)
 
     @staticmethod
-    def _load_iso(self: "TimestamptzLoader", data: Buffer) -> datetime:
+    def _load_iso(self: TimestamptzLoader, data: Buffer) -> datetime:
         m = self._re_format.match(data)
         if not m:
             raise _get_timestamp_load_error(self.connection, data) from None
@@ -556,7 +556,7 @@ class TimestamptzLoader(Loader):
         raise _get_timestamp_load_error(self.connection, data, ex) from None
 
     @staticmethod
-    def _load_notimpl(self: "TimestamptzLoader", data: Buffer) -> datetime:
+    def _load_notimpl(self: TimestamptzLoader, data: Buffer) -> datetime:
         s = bytes(data).decode("utf8", "replace")
         ds = _get_datestyle(self.connection).decode("ascii")
         raise NotImplementedError(
@@ -623,7 +623,7 @@ class IntervalLoader(Loader):
         return self._load_method(self, data)
 
     @staticmethod
-    def _load_postgres(self: "IntervalLoader", data: Buffer) -> timedelta:
+    def _load_postgres(self: IntervalLoader, data: Buffer) -> timedelta:
         m = self._re_interval.match(data)
         if not m:
             s = bytes(data).decode("utf8", "replace")
@@ -652,7 +652,7 @@ class IntervalLoader(Loader):
             raise DataError(f"can't parse interval {s!r}: {e}") from None
 
     @staticmethod
-    def _load_notimpl(self: "IntervalLoader", data: Buffer) -> timedelta:
+    def _load_notimpl(self: IntervalLoader, data: Buffer) -> timedelta:
         s = bytes(data).decode("utf8", "replace")
         ints = _get_intervalstyle(self.connection).decode("utf8", "replace")
         raise NotImplementedError(
index d8070279f1f1afd6c9ddeb0b3bca55efd3c3722f..735333049b3fb1e9fb7a774adf560a5768a24333 100644 (file)
@@ -63,7 +63,7 @@ WHERE t.oid = {regtype}
 """
         ).format(regtype=cls._to_regtype(conn))
 
-    def _added(self, registry: "TypesRegistry") -> None:
+    def _added(self, registry: TypesRegistry) -> None:
         # Map multiranges ranges and subtypes to info
         registry._registry[MultirangeInfo, self.range_oid] = self
         registry._registry[MultirangeInfo, self.subtype_oid] = self
@@ -95,9 +95,9 @@ class Multirange(MutableSequence[Range[T]]):
     def __getitem__(self, index: int) -> Range[T]: ...
 
     @overload
-    def __getitem__(self, index: slice) -> "Multirange[T]": ...
+    def __getitem__(self, index: slice) -> Multirange[T]: ...
 
-    def __getitem__(self, index: int | slice) -> "Range[T] | Multirange[T]":
+    def __getitem__(self, index: int | slice) -> Range[T] | Multirange[T]:
         if isinstance(index, int):
             return self._ranges[index]
         else:
@@ -199,7 +199,7 @@ class BaseMultirangeDumper(RecursiveDumper):
         else:
             return (self.cls,)
 
-    def upgrade(self, obj: Multirange[Any], format: PyFormat) -> "BaseMultirangeDumper":
+    def upgrade(self, obj: Multirange[Any], format: PyFormat) -> BaseMultirangeDumper:
         # If we are a subclass whose oid is specified we don't need upgrade
         if self.cls is not Multirange:
             return self
index 67f6adb4b27002e91b47209c0c0fd87608461d6b..a04aed08f046f8fd249c6b85a9c2c9836402cdea 100644 (file)
@@ -25,12 +25,12 @@ Network: TypeAlias = "ipaddress.IPv4Network | ipaddress.IPv6Network"
 ip_address: Callable[[str], Address] = None  # type: ignore[assignment]
 ip_interface: Callable[[str], Interface] = None  # type: ignore[assignment]
 ip_network: Callable[[str], Network] = None  # type: ignore[assignment]
-IPv4Address: "type[ipaddress.IPv4Address]" = None  # type: ignore[assignment]
-IPv6Address: "type[ipaddress.IPv6Address]" = None  # type: ignore[assignment]
-IPv4Interface: "type[ipaddress.IPv4Interface]" = None  # type: ignore[assignment]
-IPv6Interface: "type[ipaddress.IPv6Interface]" = None  # type: ignore[assignment]
-IPv4Network: "type[ipaddress.IPv4Network]" = None  # type: ignore[assignment]
-IPv6Network: "type[ipaddress.IPv6Network]" = None  # type: ignore[assignment]
+IPv4Address: type[ipaddress.IPv4Address] = None  # type: ignore[assignment]
+IPv6Address: type[ipaddress.IPv6Address] = None  # type: ignore[assignment]
+IPv4Interface: type[ipaddress.IPv4Interface] = None  # type: ignore[assignment]
+IPv6Interface: type[ipaddress.IPv6Interface] = None  # type: ignore[assignment]
+IPv4Network: type[ipaddress.IPv4Network] = None  # type: ignore[assignment]
+IPv6Network: type[ipaddress.IPv6Network] = None  # type: ignore[assignment]
 
 PGSQL_AF_INET = 2
 PGSQL_AF_INET6 = 3
index f3a1085436e1b46857e893f27ef56be80241a568..d13f16e8b06325debe25344ddb6df0a654b36c59 100644 (file)
@@ -386,11 +386,11 @@ class _MixedNumericDumper(Dumper, ABC):
                 _MixedNumericDumper.int_classes = int
 
     @abstractmethod
-    def dump(self, obj: Decimal | int | "numpy.integer[Any]") -> Buffer | None: ...
+    def dump(self, obj: Decimal | int | numpy.integer[Any]) -> Buffer | None: ...
 
 
 class NumericDumper(_MixedNumericDumper):
-    def dump(self, obj: Decimal | int | "numpy.integer[Any]") -> Buffer | None:
+    def dump(self, obj: Decimal | int | numpy.integer[Any]) -> Buffer | None:
         if isinstance(obj, self.int_classes):
             return str(obj).encode()
         elif isinstance(obj, Decimal):
@@ -404,7 +404,7 @@ class NumericDumper(_MixedNumericDumper):
 class NumericBinaryDumper(_MixedNumericDumper):
     format = Format.BINARY
 
-    def dump(self, obj: Decimal | int | "numpy.integer[Any]") -> Buffer | None:
+    def dump(self, obj: Decimal | int | numpy.integer[Any]) -> Buffer | None:
         if type(obj) is int:
             return dump_int_to_numeric_binary(obj)
         elif isinstance(obj, Decimal):
index bdce29401a2ea8f160c6d244c1d0eef775a34380..665c9a6de396596a7755d0f47548e1f0d84d7662 100644 (file)
@@ -303,7 +303,7 @@ class BaseRangeDumper(RecursiveDumper):
         else:
             return (self.cls,)
 
-    def upgrade(self, obj: Range[Any], format: PyFormat) -> "BaseRangeDumper":
+    def upgrade(self, obj: Range[Any], format: PyFormat) -> BaseRangeDumper:
         # If we are a subclass whose oid is specified we don't need upgrade
         if self.cls is not Range:
             return self
index 7c4da657ad82ba585506245e2ff3cec42dd3f3f6..6a7e15761bfc48a5b09f9c91e664a7b53fef44c8 100644 (file)
@@ -25,14 +25,14 @@ except ImportError:
 class GeometryBinaryLoader(Loader):
     format = Format.BINARY
 
-    def load(self, data: Buffer) -> "BaseGeometry":
+    def load(self, data: Buffer) -> BaseGeometry:
         if not isinstance(data, bytes):
             data = bytes(data)
         return loads(data)
 
 
 class GeometryLoader(Loader):
-    def load(self, data: Buffer) -> "BaseGeometry":
+    def load(self, data: Buffer) -> BaseGeometry:
         # it's a hex string in binary
         if isinstance(data, memoryview):
             data = bytes(data)
@@ -42,12 +42,12 @@ class GeometryLoader(Loader):
 class BaseGeometryBinaryDumper(Dumper):
     format = Format.BINARY
 
-    def dump(self, obj: "BaseGeometry") -> Buffer | None:
+    def dump(self, obj: BaseGeometry) -> Buffer | None:
         return dumps(obj)  # type: ignore
 
 
 class BaseGeometryDumper(Dumper):
-    def dump(self, obj: "BaseGeometry") -> Buffer | None:
+    def dump(self, obj: BaseGeometry) -> Buffer | None:
         return dumps(obj, hex=True).encode()  # type: ignore
 
 
index f0369681597c0b590bca3a69a98e35ce1f88ffd7..427fd24dc1c335370d8bf2a55bbf38610e7c7d69 100644 (file)
@@ -40,7 +40,7 @@ class BasePool:
     _CONNECTIONS_ERRORS = "connections_errors"
     _CONNECTIONS_LOST = "connections_lost"
 
-    _pool: Deque["Any"]
+    _pool: Deque[Any]
 
     def __init__(
         self,
index 5c0444a501878793e16ddf07f819c3b8ec05c1eb..d69141dd1b67002b6583822d834805326547688c 100644 (file)
@@ -31,7 +31,7 @@ class MyStrDumper:
     def get_key(self, obj: str, format: PyFormat) -> type:
         return self._cls
 
-    def upgrade(self, obj: str, format: PyFormat) -> "MyStrDumper":
+    def upgrade(self, obj: str, format: PyFormat) -> MyStrDumper:
         return self
 
 
index 0d68dc1fbf9ca5718d6d8fd23d92a15f4b8b1003..1cff7e18bd6f146a95b66f6f5708da94fab1e4fb 100644 (file)
@@ -92,7 +92,7 @@ def trace(libpq):
 
 class Tracer:
     def trace(self, conn):
-        pgconn: "pq.abc.PGconn"
+        pgconn: pq.abc.PGconn
 
         if hasattr(conn, "exec_"):
             pgconn = conn
@@ -105,7 +105,7 @@ class Tracer:
 
 
 class TraceLog:
-    def __init__(self, pgconn: "pq.abc.PGconn"):
+    def __init__(self, pgconn: pq.abc.PGconn):
         self.pgconn = pgconn
         self.tempfile = TemporaryFile(buffering=0)
         pgconn.trace(self.tempfile.fileno())
@@ -116,13 +116,12 @@ class TraceLog:
             self.pgconn.untrace()
         self.tempfile.close()
 
-    def __iter__(self) -> "Iterator[TraceEntry]":
+    def __iter__(self) -> Iterator[TraceEntry]:
         self.tempfile.seek(0)
         data = self.tempfile.read()
-        for entry in self._parse_entries(data):
-            yield entry
+        yield from self._parse_entries(data)
 
-    def _parse_entries(self, data: bytes) -> "Iterator[TraceEntry]":
+    def _parse_entries(self, data: bytes) -> Iterator[TraceEntry]:
         for line in data.splitlines():
             direction, length, type, *content = line.split(b"\t")
             yield TraceEntry(
index 443fba1c79b64f9240fecc0d4e5718a803cf2f9d..98f42915d9e6a8293e6ffa6a91b755c1128e9b80 100644 (file)
@@ -85,7 +85,7 @@ class VersionCheck:
         self.postgres_rule = postgres_rule
 
     @classmethod
-    def parse(cls, spec: str, *, postgres_rule: bool = False) -> "VersionCheck":
+    def parse(cls, spec: str, *, postgres_rule: bool = False) -> VersionCheck:
         # Parse a spec like "> 9.6", "skip < 21.2.0"
         m = re.match(
             r"""(?ix)