From: Denis Laxalde Date: Tue, 4 Jun 2024 07:39:14 +0000 (+0200) Subject: refactor: remove more types as strings X-Git-Tag: 3.2.0~18^2 X-Git-Url: http://git.ipfire.org/gitweb/gitweb.cgi?a=commitdiff_plain;h=633d6e4b456fac1af0e3526f9d878d36673882eb;p=thirdparty%2Fpsycopg.git refactor: remove more types as strings I.e. all those suggested by pyupgrade --py38-plus. --- diff --git a/psycopg/psycopg/_connection_base.py b/psycopg/psycopg/_connection_base.py index 8b0277dad..8a9a6e50d 100644 --- a/psycopg/psycopg/_connection_base.py +++ b/psycopg/psycopg/_connection_base.py @@ -278,7 +278,7 @@ class BaseConnection(Generic[Row]): return self._adapters @property - def connection(self) -> "BaseConnection[Row]": + def connection(self) -> BaseConnection[Row]: # implement the AdaptContext protocol return self @@ -337,7 +337,7 @@ class BaseConnection(Generic[Row]): @staticmethod def _notice_handler( - wself: "ReferenceType[BaseConnection[Row]]", res: PGresult + wself: ReferenceType[BaseConnection[Row]], res: PGresult ) -> None: self = wself() if not (self and self._notice_handlers): @@ -370,7 +370,7 @@ class BaseConnection(Generic[Row]): @staticmethod def _notify_handler( - wself: "ReferenceType[BaseConnection[Row]]", pgn: pq.PGnotify + wself: ReferenceType[BaseConnection[Row]], pgn: pq.PGnotify ) -> None: self = wself() if not (self and self._notify_handlers): diff --git a/psycopg/psycopg/_dns.py b/psycopg/psycopg/_dns.py index 1d74aa470..9dd67f0b5 100644 --- a/psycopg/psycopg/_dns.py +++ b/psycopg/psycopg/_dns.py @@ -185,7 +185,7 @@ class Rfc2782Resolver: return self._get_solved_entries(hp, ans) def _get_solved_entries( - self, hp: HostPort, entries: "Sequence[SRV]" + self, hp: HostPort, entries: Sequence[SRV] ) -> list[HostPort]: if not entries: # No SRV entry found. Delegate the libpq a QNAME=target lookup @@ -216,13 +216,13 @@ class Rfc2782Resolver: out["port"] = ",".join(str(hp.port) for hp in hps) return out - def sort_rfc2782(self, ans: "Sequence[SRV]") -> "list[SRV]": + def sort_rfc2782(self, ans: Sequence[SRV]) -> list[SRV]: """ Implement the priority/weight ordering defined in RFC 2782. """ # Divide the entries by priority: - priorities: DefaultDict[int, "list[SRV]"] = defaultdict(list) - out: "list[SRV]" = [] + priorities: DefaultDict[int, list[SRV]] = defaultdict(list) + out: list[SRV] = [] for entry in ans: priorities[entry.priority].append(entry) diff --git a/psycopg/psycopg/_py_transformer.py b/psycopg/psycopg/_py_transformer.py index 0620113ee..53008949d 100644 --- a/psycopg/psycopg/_py_transformer.py +++ b/psycopg/psycopg/_py_transformer.py @@ -100,7 +100,7 @@ class Transformer(AdaptContext): self._encoding = "" @classmethod - def from_context(cls, context: AdaptContext | None) -> "Transformer": + def from_context(cls, context: AdaptContext | None) -> Transformer: """ Return a Transformer from an AdaptContext. diff --git a/psycopg/psycopg/_tpc.py b/psycopg/psycopg/_tpc.py index c809d5e84..880feff3c 100644 --- a/psycopg/psycopg/_tpc.py +++ b/psycopg/psycopg/_tpc.py @@ -31,7 +31,7 @@ class Xid: database: str | None = None @classmethod - def from_string(cls, s: str) -> "Xid": + def from_string(cls, s: str) -> Xid: """Try to parse an XA triple from the string. This may fail for several reasons. In such case return an unparsed Xid. @@ -51,7 +51,7 @@ class Xid: return (self.format_id, self.gtrid, self.bqual)[index] @classmethod - def _parse_string(cls, s: str) -> "Xid": + def _parse_string(cls, s: str) -> Xid: m = _re_xid.match(s) if not m: raise ValueError("bad Xid format") @@ -62,7 +62,7 @@ class Xid: return cls.from_parts(format_id, gtrid, bqual) @classmethod - def from_parts(cls, format_id: int | None, gtrid: str, bqual: str | None) -> "Xid": + def from_parts(cls, format_id: int | None, gtrid: str, bqual: str | None) -> Xid: if format_id is not None: if bqual is None: raise TypeError("if format_id is specified, bqual must be too") @@ -107,7 +107,7 @@ class Xid: @classmethod def _from_record( cls, gid: str, prepared: dt.datetime, owner: str, database: str - ) -> "Xid": + ) -> Xid: xid = Xid.from_string(gid) return replace(xid, prepared=prepared, owner=owner, database=database) diff --git a/psycopg/psycopg/adapt.py b/psycopg/psycopg/adapt.py index 889801de9..428042fe1 100644 --- a/psycopg/psycopg/adapt.py +++ b/psycopg/psycopg/adapt.py @@ -107,7 +107,7 @@ class Dumper(abc.Dumper, ABC): """ return self.cls - def upgrade(self, obj: Any, format: PyFormat) -> "Dumper": + def upgrade(self, obj: Any, format: PyFormat) -> Dumper: """ Implementation of the `~psycopg.abc.Dumper.upgrade()` member of the `~psycopg.abc.Dumper` protocol. Look at its definition for details. diff --git a/psycopg/psycopg/crdb/connection.py b/psycopg/psycopg/crdb/connection.py index 74b34a7db..9ca1c640d 100644 --- a/psycopg/psycopg/crdb/connection.py +++ b/psycopg/psycopg/crdb/connection.py @@ -44,7 +44,7 @@ class _CrdbConnectionMixin: return self._adapters @property - def info(self) -> "CrdbConnectionInfo": + def info(self) -> CrdbConnectionInfo: return CrdbConnectionInfo(self.pgconn) def _check_tpc(self) -> None: diff --git a/psycopg/psycopg/errors.py b/psycopg/psycopg/errors.py index 0d162feaa..00cb24c9c 100644 --- a/psycopg/psycopg/errors.py +++ b/psycopg/psycopg/errors.py @@ -292,7 +292,7 @@ class Error(Exception): return self._info if _is_pgresult(self._info) else None @property - def diag(self) -> "Diagnostic": + def diag(self) -> Diagnostic: """ A `Diagnostic` object to inspect details of the errors from the database. """ diff --git a/psycopg/psycopg/sql.py b/psycopg/psycopg/sql.py index 96d518d37..282c138fc 100644 --- a/psycopg/psycopg/sql.py +++ b/psycopg/psycopg/sql.py @@ -88,7 +88,7 @@ class Composable(ABC): # buffer object return codecs.lookup(enc).decode(b)[0] - def __add__(self, other: "Composable") -> "Composed": + def __add__(self, other: Composable) -> Composed: if isinstance(other, Composed): return Composed([self]) + other if isinstance(other, Composable): @@ -96,7 +96,7 @@ class Composable(ABC): else: return NotImplemented - def __mul__(self, n: int) -> "Composed": + def __mul__(self, n: int) -> Composed: return Composed([self] * n) def __eq__(self, other: Any) -> bool: @@ -138,7 +138,7 @@ class Composed(Composable): def __iter__(self) -> Iterator[Composable]: return iter(self._obj) - def __add__(self, other: Composable) -> "Composed": + def __add__(self, other: Composable) -> Composed: if isinstance(other, Composed): return Composed(self._obj + other._obj) if isinstance(other, Composable): @@ -146,7 +146,7 @@ class Composed(Composable): else: return NotImplemented - def join(self, joiner: "SQL" | LiteralString) -> "Composed": + def join(self, joiner: SQL | LiteralString) -> Composed: """ Return a new `!Composed` interposing the `!joiner` with the `!Composed` items. diff --git a/psycopg/psycopg/transaction.py b/psycopg/psycopg/transaction.py index 9a60a5fbc..d0a3aa349 100644 --- a/psycopg/psycopg/transaction.py +++ b/psycopg/psycopg/transaction.py @@ -40,7 +40,7 @@ class Rollback(Exception): __module__ = "psycopg" - def __init__(self, transaction: "Transaction" | "AsyncTransaction" | None = None): + def __init__(self, transaction: Transaction | AsyncTransaction | None = None): self.transaction = transaction def __repr__(self) -> str: diff --git a/psycopg/psycopg/types/array.py b/psycopg/psycopg/types/array.py index 78f3eaeff..1da8f4316 100644 --- a/psycopg/psycopg/types/array.py +++ b/psycopg/psycopg/types/array.py @@ -128,7 +128,7 @@ class ListDumper(BaseListDumper): sd = self._tx.get_dumper(item, format) return (self.cls, sd.get_key(item, format)) - def upgrade(self, obj: list[Any], format: PyFormat) -> "BaseListDumper": + def upgrade(self, obj: list[Any], format: PyFormat) -> BaseListDumper: # If we have an oid we don't need to upgrade if self.oid: return self @@ -232,7 +232,7 @@ class ListBinaryDumper(BaseListDumper): sd = self._tx.get_dumper(item, format) return (self.cls, sd.get_key(item, format)) - def upgrade(self, obj: list[Any], format: PyFormat) -> "BaseListDumper": + def upgrade(self, obj: list[Any], format: PyFormat) -> BaseListDumper: # If we have an oid we don't need to upgrade if self.oid: return self diff --git a/psycopg/psycopg/types/datetime.py b/psycopg/psycopg/types/datetime.py index 866a2fa6c..3d8db27f8 100644 --- a/psycopg/psycopg/types/datetime.py +++ b/psycopg/psycopg/types/datetime.py @@ -204,13 +204,13 @@ class TimedeltaDumper(Dumper): return self._dump_method(self, obj) @staticmethod - def _dump_any(self: "TimedeltaDumper", obj: timedelta) -> bytes: + def _dump_any(self: TimedeltaDumper, obj: timedelta) -> bytes: # The comma is parsed ok by PostgreSQL but it's not documented # and it seems brittle to rely on it. CRDB doesn't consume it well. return str(obj).encode().replace(b",", b"") @staticmethod - def _dump_sql(self: "TimedeltaDumper", obj: timedelta) -> bytes: + def _dump_sql(self: TimedeltaDumper, obj: timedelta) -> bytes: # sql_standard format needs explicit signs # otherwise -1 day 1 sec will mean -1 sec return b"%+d day %+d second %+d microsecond" % ( @@ -509,7 +509,7 @@ class TimestamptzLoader(Loader): return self._load_method(self, data) @staticmethod - def _load_iso(self: "TimestamptzLoader", data: Buffer) -> datetime: + def _load_iso(self: TimestamptzLoader, data: Buffer) -> datetime: m = self._re_format.match(data) if not m: raise _get_timestamp_load_error(self.connection, data) from None @@ -556,7 +556,7 @@ class TimestamptzLoader(Loader): raise _get_timestamp_load_error(self.connection, data, ex) from None @staticmethod - def _load_notimpl(self: "TimestamptzLoader", data: Buffer) -> datetime: + def _load_notimpl(self: TimestamptzLoader, data: Buffer) -> datetime: s = bytes(data).decode("utf8", "replace") ds = _get_datestyle(self.connection).decode("ascii") raise NotImplementedError( @@ -623,7 +623,7 @@ class IntervalLoader(Loader): return self._load_method(self, data) @staticmethod - def _load_postgres(self: "IntervalLoader", data: Buffer) -> timedelta: + def _load_postgres(self: IntervalLoader, data: Buffer) -> timedelta: m = self._re_interval.match(data) if not m: s = bytes(data).decode("utf8", "replace") @@ -652,7 +652,7 @@ class IntervalLoader(Loader): raise DataError(f"can't parse interval {s!r}: {e}") from None @staticmethod - def _load_notimpl(self: "IntervalLoader", data: Buffer) -> timedelta: + def _load_notimpl(self: IntervalLoader, data: Buffer) -> timedelta: s = bytes(data).decode("utf8", "replace") ints = _get_intervalstyle(self.connection).decode("utf8", "replace") raise NotImplementedError( diff --git a/psycopg/psycopg/types/multirange.py b/psycopg/psycopg/types/multirange.py index d8070279f..735333049 100644 --- a/psycopg/psycopg/types/multirange.py +++ b/psycopg/psycopg/types/multirange.py @@ -63,7 +63,7 @@ WHERE t.oid = {regtype} """ ).format(regtype=cls._to_regtype(conn)) - def _added(self, registry: "TypesRegistry") -> None: + def _added(self, registry: TypesRegistry) -> None: # Map multiranges ranges and subtypes to info registry._registry[MultirangeInfo, self.range_oid] = self registry._registry[MultirangeInfo, self.subtype_oid] = self @@ -95,9 +95,9 @@ class Multirange(MutableSequence[Range[T]]): def __getitem__(self, index: int) -> Range[T]: ... @overload - def __getitem__(self, index: slice) -> "Multirange[T]": ... + def __getitem__(self, index: slice) -> Multirange[T]: ... - def __getitem__(self, index: int | slice) -> "Range[T] | Multirange[T]": + def __getitem__(self, index: int | slice) -> Range[T] | Multirange[T]: if isinstance(index, int): return self._ranges[index] else: @@ -199,7 +199,7 @@ class BaseMultirangeDumper(RecursiveDumper): else: return (self.cls,) - def upgrade(self, obj: Multirange[Any], format: PyFormat) -> "BaseMultirangeDumper": + def upgrade(self, obj: Multirange[Any], format: PyFormat) -> BaseMultirangeDumper: # If we are a subclass whose oid is specified we don't need upgrade if self.cls is not Multirange: return self diff --git a/psycopg/psycopg/types/net.py b/psycopg/psycopg/types/net.py index 67f6adb4b..a04aed08f 100644 --- a/psycopg/psycopg/types/net.py +++ b/psycopg/psycopg/types/net.py @@ -25,12 +25,12 @@ Network: TypeAlias = "ipaddress.IPv4Network | ipaddress.IPv6Network" ip_address: Callable[[str], Address] = None # type: ignore[assignment] ip_interface: Callable[[str], Interface] = None # type: ignore[assignment] ip_network: Callable[[str], Network] = None # type: ignore[assignment] -IPv4Address: "type[ipaddress.IPv4Address]" = None # type: ignore[assignment] -IPv6Address: "type[ipaddress.IPv6Address]" = None # type: ignore[assignment] -IPv4Interface: "type[ipaddress.IPv4Interface]" = None # type: ignore[assignment] -IPv6Interface: "type[ipaddress.IPv6Interface]" = None # type: ignore[assignment] -IPv4Network: "type[ipaddress.IPv4Network]" = None # type: ignore[assignment] -IPv6Network: "type[ipaddress.IPv6Network]" = None # type: ignore[assignment] +IPv4Address: type[ipaddress.IPv4Address] = None # type: ignore[assignment] +IPv6Address: type[ipaddress.IPv6Address] = None # type: ignore[assignment] +IPv4Interface: type[ipaddress.IPv4Interface] = None # type: ignore[assignment] +IPv6Interface: type[ipaddress.IPv6Interface] = None # type: ignore[assignment] +IPv4Network: type[ipaddress.IPv4Network] = None # type: ignore[assignment] +IPv6Network: type[ipaddress.IPv6Network] = None # type: ignore[assignment] PGSQL_AF_INET = 2 PGSQL_AF_INET6 = 3 diff --git a/psycopg/psycopg/types/numeric.py b/psycopg/psycopg/types/numeric.py index f3a108543..d13f16e8b 100644 --- a/psycopg/psycopg/types/numeric.py +++ b/psycopg/psycopg/types/numeric.py @@ -386,11 +386,11 @@ class _MixedNumericDumper(Dumper, ABC): _MixedNumericDumper.int_classes = int @abstractmethod - def dump(self, obj: Decimal | int | "numpy.integer[Any]") -> Buffer | None: ... + def dump(self, obj: Decimal | int | numpy.integer[Any]) -> Buffer | None: ... class NumericDumper(_MixedNumericDumper): - def dump(self, obj: Decimal | int | "numpy.integer[Any]") -> Buffer | None: + def dump(self, obj: Decimal | int | numpy.integer[Any]) -> Buffer | None: if isinstance(obj, self.int_classes): return str(obj).encode() elif isinstance(obj, Decimal): @@ -404,7 +404,7 @@ class NumericDumper(_MixedNumericDumper): class NumericBinaryDumper(_MixedNumericDumper): format = Format.BINARY - def dump(self, obj: Decimal | int | "numpy.integer[Any]") -> Buffer | None: + def dump(self, obj: Decimal | int | numpy.integer[Any]) -> Buffer | None: if type(obj) is int: return dump_int_to_numeric_binary(obj) elif isinstance(obj, Decimal): diff --git a/psycopg/psycopg/types/range.py b/psycopg/psycopg/types/range.py index bdce29401..665c9a6de 100644 --- a/psycopg/psycopg/types/range.py +++ b/psycopg/psycopg/types/range.py @@ -303,7 +303,7 @@ class BaseRangeDumper(RecursiveDumper): else: return (self.cls,) - def upgrade(self, obj: Range[Any], format: PyFormat) -> "BaseRangeDumper": + def upgrade(self, obj: Range[Any], format: PyFormat) -> BaseRangeDumper: # If we are a subclass whose oid is specified we don't need upgrade if self.cls is not Range: return self diff --git a/psycopg/psycopg/types/shapely.py b/psycopg/psycopg/types/shapely.py index 7c4da657a..6a7e15761 100644 --- a/psycopg/psycopg/types/shapely.py +++ b/psycopg/psycopg/types/shapely.py @@ -25,14 +25,14 @@ except ImportError: class GeometryBinaryLoader(Loader): format = Format.BINARY - def load(self, data: Buffer) -> "BaseGeometry": + def load(self, data: Buffer) -> BaseGeometry: if not isinstance(data, bytes): data = bytes(data) return loads(data) class GeometryLoader(Loader): - def load(self, data: Buffer) -> "BaseGeometry": + def load(self, data: Buffer) -> BaseGeometry: # it's a hex string in binary if isinstance(data, memoryview): data = bytes(data) @@ -42,12 +42,12 @@ class GeometryLoader(Loader): class BaseGeometryBinaryDumper(Dumper): format = Format.BINARY - def dump(self, obj: "BaseGeometry") -> Buffer | None: + def dump(self, obj: BaseGeometry) -> Buffer | None: return dumps(obj) # type: ignore class BaseGeometryDumper(Dumper): - def dump(self, obj: "BaseGeometry") -> Buffer | None: + def dump(self, obj: BaseGeometry) -> Buffer | None: return dumps(obj, hex=True).encode() # type: ignore diff --git a/psycopg_pool/psycopg_pool/base.py b/psycopg_pool/psycopg_pool/base.py index f03696815..427fd24dc 100644 --- a/psycopg_pool/psycopg_pool/base.py +++ b/psycopg_pool/psycopg_pool/base.py @@ -40,7 +40,7 @@ class BasePool: _CONNECTIONS_ERRORS = "connections_errors" _CONNECTIONS_LOST = "connections_lost" - _pool: Deque["Any"] + _pool: Deque[Any] def __init__( self, diff --git a/tests/adapters_example.py b/tests/adapters_example.py index 5c0444a50..d69141dd1 100644 --- a/tests/adapters_example.py +++ b/tests/adapters_example.py @@ -31,7 +31,7 @@ class MyStrDumper: def get_key(self, obj: str, format: PyFormat) -> type: return self._cls - def upgrade(self, obj: str, format: PyFormat) -> "MyStrDumper": + def upgrade(self, obj: str, format: PyFormat) -> MyStrDumper: return self diff --git a/tests/fix_pq.py b/tests/fix_pq.py index 0d68dc1fb..1cff7e18b 100644 --- a/tests/fix_pq.py +++ b/tests/fix_pq.py @@ -92,7 +92,7 @@ def trace(libpq): class Tracer: def trace(self, conn): - pgconn: "pq.abc.PGconn" + pgconn: pq.abc.PGconn if hasattr(conn, "exec_"): pgconn = conn @@ -105,7 +105,7 @@ class Tracer: class TraceLog: - def __init__(self, pgconn: "pq.abc.PGconn"): + def __init__(self, pgconn: pq.abc.PGconn): self.pgconn = pgconn self.tempfile = TemporaryFile(buffering=0) pgconn.trace(self.tempfile.fileno()) @@ -116,13 +116,12 @@ class TraceLog: self.pgconn.untrace() self.tempfile.close() - def __iter__(self) -> "Iterator[TraceEntry]": + def __iter__(self) -> Iterator[TraceEntry]: self.tempfile.seek(0) data = self.tempfile.read() - for entry in self._parse_entries(data): - yield entry + yield from self._parse_entries(data) - def _parse_entries(self, data: bytes) -> "Iterator[TraceEntry]": + def _parse_entries(self, data: bytes) -> Iterator[TraceEntry]: for line in data.splitlines(): direction, length, type, *content = line.split(b"\t") yield TraceEntry( diff --git a/tests/utils.py b/tests/utils.py index 443fba1c7..98f42915d 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -85,7 +85,7 @@ class VersionCheck: self.postgres_rule = postgres_rule @classmethod - def parse(cls, spec: str, *, postgres_rule: bool = False) -> "VersionCheck": + def parse(cls, spec: str, *, postgres_rule: bool = False) -> VersionCheck: # Parse a spec like "> 9.6", "skip < 21.2.0" m = re.match( r"""(?ix)