From: Daniele Varrazzo Date: Wed, 7 Feb 2024 01:16:06 +0000 (+0000) Subject: fix: avoid to create reference loops in datetime adapters X-Git-Tag: 3.1.19~9 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=2695a506975767a16132d25fbe09cd778f3365b6;p=thirdparty%2Fpsycopg.git fix: avoid to create reference loops in datetime adapters Setting the reference to a bound method in the state creates a reference loop. The issue is minimal because the gc will be able to break these loops anyway and because it mostly happens with exotic or unsupported date/interval styles. --- diff --git a/psycopg/psycopg/types/datetime.py b/psycopg/psycopg/types/datetime.py index fa1da8eff..57f702b0a 100644 --- a/psycopg/psycopg/types/datetime.py +++ b/psycopg/psycopg/types/datetime.py @@ -193,19 +193,22 @@ class TimedeltaDumper(Dumper): def __init__(self, cls: type, context: Optional[AdaptContext] = None): super().__init__(cls, context) - if self.connection: - if ( - self.connection.pgconn.parameter_status(b"IntervalStyle") - == b"sql_standard" - ): - setattr(self, "dump", self._dump_sql) + if _get_intervalstyle(self.connection) == b"sql_standard": + self._dump_method = self._dump_sql + else: + self._dump_method = self._dump_any def dump(self, obj: timedelta) -> bytes: + return self._dump_method(self, obj) + + @staticmethod + def _dump_any(self: "TimedeltaDumper", obj: timedelta) -> bytes: # The comma is parsed ok by PostgreSQL but it's not documented # and it seems brittle to rely on it. CRDB doesn't consume it well. return str(obj).encode().replace(b",", b"") - def _dump_sql(self, obj: timedelta) -> bytes: + @staticmethod + def _dump_sql(self: "TimedeltaDumper", obj: timedelta) -> bytes: # sql_standard format needs explicit signs # otherwise -1 day 1 sec will mean -1 sec return b"%+d day %+d second %+d microsecond" % ( @@ -495,10 +498,16 @@ class TimestamptzLoader(Loader): self._timezone = get_tzinfo(self.connection.pgconn if self.connection else None) ds = _get_datestyle(self.connection) - if not ds.startswith(b"I"): # not ISO - setattr(self, "load", self._load_notimpl) + if ds.startswith(b"I"): # ISO + self._load_method = self._load_iso + else: + self._load_method = self._load_notimpl def load(self, data: Buffer) -> datetime: + return self._load_method(self, data) + + @staticmethod + def _load_iso(self: "TimestamptzLoader", data: Buffer) -> datetime: m = self._re_format.match(data) if not m: raise _get_timestamp_load_error(self.connection, data) from None @@ -544,7 +553,8 @@ class TimestamptzLoader(Loader): raise _get_timestamp_load_error(self.connection, data, ex) from None - def _load_notimpl(self, data: Buffer) -> datetime: + @staticmethod + def _load_notimpl(self: "TimestamptzLoader", data: Buffer) -> datetime: s = bytes(data).decode("utf8", "replace") ds = _get_datestyle(self.connection).decode("ascii") raise NotImplementedError( @@ -602,12 +612,16 @@ class IntervalLoader(Loader): def __init__(self, oid: int, context: Optional[AdaptContext] = None): super().__init__(oid, context) - if self.connection: - ints = self.connection.pgconn.parameter_status(b"IntervalStyle") - if ints != b"postgres": - setattr(self, "load", self._load_notimpl) + if _get_intervalstyle(self.connection) == b"postgres": + self._load_method = self._load_postgres + else: + self._load_method = self._load_notimpl def load(self, data: Buffer) -> timedelta: + return self._load_method(self, data) + + @staticmethod + def _load_postgres(self: "IntervalLoader", data: Buffer) -> timedelta: m = self._re_interval.match(data) if not m: s = bytes(data).decode("utf8", "replace") @@ -635,13 +649,10 @@ class IntervalLoader(Loader): s = bytes(data).decode("utf8", "replace") raise DataError(f"can't parse interval {s!r}: {e}") from None - def _load_notimpl(self, data: Buffer) -> timedelta: + @staticmethod + def _load_notimpl(self: "IntervalLoader", data: Buffer) -> timedelta: s = bytes(data).decode("utf8", "replace") - ints = ( - self.connection - and self.connection.pgconn.parameter_status(b"IntervalStyle") - or b"unknown" - ).decode("utf8", "replace") + ints = _get_intervalstyle(self.connection).decode("utf8", "replace") raise NotImplementedError( f"can't parse interval with IntervalStyle {ints}: {s!r}" ) @@ -674,6 +685,15 @@ def _get_datestyle(conn: Optional["BaseConnection[Any]"]) -> bytes: return b"ISO, DMY" +def _get_intervalstyle(conn: Optional["BaseConnection[Any]"]) -> bytes: + if conn: + ints = conn.pgconn.parameter_status(b"IntervalStyle") + if ints: + return ints + + return b"unknown" + + def _get_timestamp_load_error( conn: Optional["BaseConnection[Any]"], data: Buffer, ex: Optional[Exception] = None ) -> Exception: