I.e. all those suggested by pyupgrade --py38-plus.
return self._adapters
@property
- def connection(self) -> "BaseConnection[Row]":
+ def connection(self) -> BaseConnection[Row]:
# implement the AdaptContext protocol
return self
@staticmethod
def _notice_handler(
- wself: "ReferenceType[BaseConnection[Row]]", res: PGresult
+ wself: ReferenceType[BaseConnection[Row]], res: PGresult
) -> None:
self = wself()
if not (self and self._notice_handlers):
@staticmethod
def _notify_handler(
- wself: "ReferenceType[BaseConnection[Row]]", pgn: pq.PGnotify
+ wself: ReferenceType[BaseConnection[Row]], pgn: pq.PGnotify
) -> None:
self = wself()
if not (self and self._notify_handlers):
return self._get_solved_entries(hp, ans)
def _get_solved_entries(
- self, hp: HostPort, entries: "Sequence[SRV]"
+ self, hp: HostPort, entries: Sequence[SRV]
) -> list[HostPort]:
if not entries:
# No SRV entry found. Delegate the libpq a QNAME=target lookup
out["port"] = ",".join(str(hp.port) for hp in hps)
return out
- def sort_rfc2782(self, ans: "Sequence[SRV]") -> "list[SRV]":
+ def sort_rfc2782(self, ans: Sequence[SRV]) -> list[SRV]:
"""
Implement the priority/weight ordering defined in RFC 2782.
"""
# Divide the entries by priority:
- priorities: DefaultDict[int, "list[SRV]"] = defaultdict(list)
- out: "list[SRV]" = []
+ priorities: DefaultDict[int, list[SRV]] = defaultdict(list)
+ out: list[SRV] = []
for entry in ans:
priorities[entry.priority].append(entry)
self._encoding = ""
@classmethod
- def from_context(cls, context: AdaptContext | None) -> "Transformer":
+ def from_context(cls, context: AdaptContext | None) -> Transformer:
"""
Return a Transformer from an AdaptContext.
database: str | None = None
@classmethod
- def from_string(cls, s: str) -> "Xid":
+ def from_string(cls, s: str) -> Xid:
"""Try to parse an XA triple from the string.
This may fail for several reasons. In such case return an unparsed Xid.
return (self.format_id, self.gtrid, self.bqual)[index]
@classmethod
- def _parse_string(cls, s: str) -> "Xid":
+ def _parse_string(cls, s: str) -> Xid:
m = _re_xid.match(s)
if not m:
raise ValueError("bad Xid format")
return cls.from_parts(format_id, gtrid, bqual)
@classmethod
- def from_parts(cls, format_id: int | None, gtrid: str, bqual: str | None) -> "Xid":
+ def from_parts(cls, format_id: int | None, gtrid: str, bqual: str | None) -> Xid:
if format_id is not None:
if bqual is None:
raise TypeError("if format_id is specified, bqual must be too")
@classmethod
def _from_record(
cls, gid: str, prepared: dt.datetime, owner: str, database: str
- ) -> "Xid":
+ ) -> Xid:
xid = Xid.from_string(gid)
return replace(xid, prepared=prepared, owner=owner, database=database)
"""
return self.cls
- def upgrade(self, obj: Any, format: PyFormat) -> "Dumper":
+ def upgrade(self, obj: Any, format: PyFormat) -> Dumper:
"""
Implementation of the `~psycopg.abc.Dumper.upgrade()` member of the
`~psycopg.abc.Dumper` protocol. Look at its definition for details.
return self._adapters
@property
- def info(self) -> "CrdbConnectionInfo":
+ def info(self) -> CrdbConnectionInfo:
return CrdbConnectionInfo(self.pgconn)
def _check_tpc(self) -> None:
return self._info if _is_pgresult(self._info) else None
@property
- def diag(self) -> "Diagnostic":
+ def diag(self) -> Diagnostic:
"""
A `Diagnostic` object to inspect details of the errors from the database.
"""
# buffer object
return codecs.lookup(enc).decode(b)[0]
- def __add__(self, other: "Composable") -> "Composed":
+ def __add__(self, other: Composable) -> Composed:
if isinstance(other, Composed):
return Composed([self]) + other
if isinstance(other, Composable):
else:
return NotImplemented
- def __mul__(self, n: int) -> "Composed":
+ def __mul__(self, n: int) -> Composed:
return Composed([self] * n)
def __eq__(self, other: Any) -> bool:
def __iter__(self) -> Iterator[Composable]:
return iter(self._obj)
- def __add__(self, other: Composable) -> "Composed":
+ def __add__(self, other: Composable) -> Composed:
if isinstance(other, Composed):
return Composed(self._obj + other._obj)
if isinstance(other, Composable):
else:
return NotImplemented
- def join(self, joiner: "SQL" | LiteralString) -> "Composed":
+ def join(self, joiner: SQL | LiteralString) -> Composed:
"""
Return a new `!Composed` interposing the `!joiner` with the `!Composed` items.
__module__ = "psycopg"
- def __init__(self, transaction: "Transaction" | "AsyncTransaction" | None = None):
+ def __init__(self, transaction: Transaction | AsyncTransaction | None = None):
self.transaction = transaction
def __repr__(self) -> str:
sd = self._tx.get_dumper(item, format)
return (self.cls, sd.get_key(item, format))
- def upgrade(self, obj: list[Any], format: PyFormat) -> "BaseListDumper":
+ def upgrade(self, obj: list[Any], format: PyFormat) -> BaseListDumper:
# If we have an oid we don't need to upgrade
if self.oid:
return self
sd = self._tx.get_dumper(item, format)
return (self.cls, sd.get_key(item, format))
- def upgrade(self, obj: list[Any], format: PyFormat) -> "BaseListDumper":
+ def upgrade(self, obj: list[Any], format: PyFormat) -> BaseListDumper:
# If we have an oid we don't need to upgrade
if self.oid:
return self
return self._dump_method(self, obj)
@staticmethod
- def _dump_any(self: "TimedeltaDumper", obj: timedelta) -> bytes:
+ def _dump_any(self: TimedeltaDumper, obj: timedelta) -> bytes:
# The comma is parsed ok by PostgreSQL but it's not documented
# and it seems brittle to rely on it. CRDB doesn't consume it well.
return str(obj).encode().replace(b",", b"")
@staticmethod
- def _dump_sql(self: "TimedeltaDumper", obj: timedelta) -> bytes:
+ def _dump_sql(self: TimedeltaDumper, obj: timedelta) -> bytes:
# sql_standard format needs explicit signs
# otherwise -1 day 1 sec will mean -1 sec
return b"%+d day %+d second %+d microsecond" % (
return self._load_method(self, data)
@staticmethod
- def _load_iso(self: "TimestamptzLoader", data: Buffer) -> datetime:
+ def _load_iso(self: TimestamptzLoader, data: Buffer) -> datetime:
m = self._re_format.match(data)
if not m:
raise _get_timestamp_load_error(self.connection, data) from None
raise _get_timestamp_load_error(self.connection, data, ex) from None
@staticmethod
- def _load_notimpl(self: "TimestamptzLoader", data: Buffer) -> datetime:
+ def _load_notimpl(self: TimestamptzLoader, data: Buffer) -> datetime:
s = bytes(data).decode("utf8", "replace")
ds = _get_datestyle(self.connection).decode("ascii")
raise NotImplementedError(
return self._load_method(self, data)
@staticmethod
- def _load_postgres(self: "IntervalLoader", data: Buffer) -> timedelta:
+ def _load_postgres(self: IntervalLoader, data: Buffer) -> timedelta:
m = self._re_interval.match(data)
if not m:
s = bytes(data).decode("utf8", "replace")
raise DataError(f"can't parse interval {s!r}: {e}") from None
@staticmethod
- def _load_notimpl(self: "IntervalLoader", data: Buffer) -> timedelta:
+ def _load_notimpl(self: IntervalLoader, data: Buffer) -> timedelta:
s = bytes(data).decode("utf8", "replace")
ints = _get_intervalstyle(self.connection).decode("utf8", "replace")
raise NotImplementedError(
"""
).format(regtype=cls._to_regtype(conn))
- def _added(self, registry: "TypesRegistry") -> None:
+ def _added(self, registry: TypesRegistry) -> None:
# Map multiranges ranges and subtypes to info
registry._registry[MultirangeInfo, self.range_oid] = self
registry._registry[MultirangeInfo, self.subtype_oid] = self
def __getitem__(self, index: int) -> Range[T]: ...
@overload
- def __getitem__(self, index: slice) -> "Multirange[T]": ...
+ def __getitem__(self, index: slice) -> Multirange[T]: ...
- def __getitem__(self, index: int | slice) -> "Range[T] | Multirange[T]":
+ def __getitem__(self, index: int | slice) -> Range[T] | Multirange[T]:
if isinstance(index, int):
return self._ranges[index]
else:
else:
return (self.cls,)
- def upgrade(self, obj: Multirange[Any], format: PyFormat) -> "BaseMultirangeDumper":
+ def upgrade(self, obj: Multirange[Any], format: PyFormat) -> BaseMultirangeDumper:
# If we are a subclass whose oid is specified we don't need upgrade
if self.cls is not Multirange:
return self
ip_address: Callable[[str], Address] = None # type: ignore[assignment]
ip_interface: Callable[[str], Interface] = None # type: ignore[assignment]
ip_network: Callable[[str], Network] = None # type: ignore[assignment]
-IPv4Address: "type[ipaddress.IPv4Address]" = None # type: ignore[assignment]
-IPv6Address: "type[ipaddress.IPv6Address]" = None # type: ignore[assignment]
-IPv4Interface: "type[ipaddress.IPv4Interface]" = None # type: ignore[assignment]
-IPv6Interface: "type[ipaddress.IPv6Interface]" = None # type: ignore[assignment]
-IPv4Network: "type[ipaddress.IPv4Network]" = None # type: ignore[assignment]
-IPv6Network: "type[ipaddress.IPv6Network]" = None # type: ignore[assignment]
+IPv4Address: type[ipaddress.IPv4Address] = None # type: ignore[assignment]
+IPv6Address: type[ipaddress.IPv6Address] = None # type: ignore[assignment]
+IPv4Interface: type[ipaddress.IPv4Interface] = None # type: ignore[assignment]
+IPv6Interface: type[ipaddress.IPv6Interface] = None # type: ignore[assignment]
+IPv4Network: type[ipaddress.IPv4Network] = None # type: ignore[assignment]
+IPv6Network: type[ipaddress.IPv6Network] = None # type: ignore[assignment]
PGSQL_AF_INET = 2
PGSQL_AF_INET6 = 3
_MixedNumericDumper.int_classes = int
@abstractmethod
- def dump(self, obj: Decimal | int | "numpy.integer[Any]") -> Buffer | None: ...
+ def dump(self, obj: Decimal | int | numpy.integer[Any]) -> Buffer | None: ...
class NumericDumper(_MixedNumericDumper):
- def dump(self, obj: Decimal | int | "numpy.integer[Any]") -> Buffer | None:
+ def dump(self, obj: Decimal | int | numpy.integer[Any]) -> Buffer | None:
if isinstance(obj, self.int_classes):
return str(obj).encode()
elif isinstance(obj, Decimal):
class NumericBinaryDumper(_MixedNumericDumper):
format = Format.BINARY
- def dump(self, obj: Decimal | int | "numpy.integer[Any]") -> Buffer | None:
+ def dump(self, obj: Decimal | int | numpy.integer[Any]) -> Buffer | None:
if type(obj) is int:
return dump_int_to_numeric_binary(obj)
elif isinstance(obj, Decimal):
else:
return (self.cls,)
- def upgrade(self, obj: Range[Any], format: PyFormat) -> "BaseRangeDumper":
+ def upgrade(self, obj: Range[Any], format: PyFormat) -> BaseRangeDumper:
# If we are a subclass whose oid is specified we don't need upgrade
if self.cls is not Range:
return self
class GeometryBinaryLoader(Loader):
format = Format.BINARY
- def load(self, data: Buffer) -> "BaseGeometry":
+ def load(self, data: Buffer) -> BaseGeometry:
if not isinstance(data, bytes):
data = bytes(data)
return loads(data)
class GeometryLoader(Loader):
- def load(self, data: Buffer) -> "BaseGeometry":
+ def load(self, data: Buffer) -> BaseGeometry:
# it's a hex string in binary
if isinstance(data, memoryview):
data = bytes(data)
class BaseGeometryBinaryDumper(Dumper):
format = Format.BINARY
- def dump(self, obj: "BaseGeometry") -> Buffer | None:
+ def dump(self, obj: BaseGeometry) -> Buffer | None:
return dumps(obj) # type: ignore
class BaseGeometryDumper(Dumper):
- def dump(self, obj: "BaseGeometry") -> Buffer | None:
+ def dump(self, obj: BaseGeometry) -> Buffer | None:
return dumps(obj, hex=True).encode() # type: ignore
_CONNECTIONS_ERRORS = "connections_errors"
_CONNECTIONS_LOST = "connections_lost"
- _pool: Deque["Any"]
+ _pool: Deque[Any]
def __init__(
self,
def get_key(self, obj: str, format: PyFormat) -> type:
return self._cls
- def upgrade(self, obj: str, format: PyFormat) -> "MyStrDumper":
+ def upgrade(self, obj: str, format: PyFormat) -> MyStrDumper:
return self
class Tracer:
def trace(self, conn):
- pgconn: "pq.abc.PGconn"
+ pgconn: pq.abc.PGconn
if hasattr(conn, "exec_"):
pgconn = conn
class TraceLog:
- def __init__(self, pgconn: "pq.abc.PGconn"):
+ def __init__(self, pgconn: pq.abc.PGconn):
self.pgconn = pgconn
self.tempfile = TemporaryFile(buffering=0)
pgconn.trace(self.tempfile.fileno())
self.pgconn.untrace()
self.tempfile.close()
- def __iter__(self) -> "Iterator[TraceEntry]":
+ def __iter__(self) -> Iterator[TraceEntry]:
self.tempfile.seek(0)
data = self.tempfile.read()
- for entry in self._parse_entries(data):
- yield entry
+ yield from self._parse_entries(data)
- def _parse_entries(self, data: bytes) -> "Iterator[TraceEntry]":
+ def _parse_entries(self, data: bytes) -> Iterator[TraceEntry]:
for line in data.splitlines():
direction, length, type, *content = line.split(b"\t")
yield TraceEntry(
self.postgres_rule = postgres_rule
@classmethod
- def parse(cls, spec: str, *, postgres_rule: bool = False) -> "VersionCheck":
+ def parse(cls, spec: str, *, postgres_rule: bool = False) -> VersionCheck:
# Parse a spec like "> 9.6", "skip < 21.2.0"
m = re.match(
r"""(?ix)