]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fix: avoid to create reference loops in datetime adapters
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 7 Feb 2024 01:16:06 +0000 (01:16 +0000)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 8 Feb 2024 00:33:18 +0000 (00:33 +0000)
Setting the reference to a bound method in the state creates a reference
loop.

The issue is minimal because the gc will be able to break these loops
anyway and because it mostly happens with exotic or unsupported
date/interval styles.

psycopg/psycopg/types/datetime.py

index fa1da8effbe43717122f21f994e29fe445ba9b80..57f702b0afc1afde08199d5082c5cde8e37804fb 100644 (file)
@@ -193,19 +193,22 @@ class TimedeltaDumper(Dumper):
 
     def __init__(self, cls: type, context: Optional[AdaptContext] = None):
         super().__init__(cls, context)
-        if self.connection:
-            if (
-                self.connection.pgconn.parameter_status(b"IntervalStyle")
-                == b"sql_standard"
-            ):
-                setattr(self, "dump", self._dump_sql)
+        if _get_intervalstyle(self.connection) == b"sql_standard":
+            self._dump_method = self._dump_sql
+        else:
+            self._dump_method = self._dump_any
 
     def dump(self, obj: timedelta) -> bytes:
+        return self._dump_method(self, obj)
+
+    @staticmethod
+    def _dump_any(self: "TimedeltaDumper", obj: timedelta) -> bytes:
         # The comma is parsed ok by PostgreSQL but it's not documented
         # and it seems brittle to rely on it. CRDB doesn't consume it well.
         return str(obj).encode().replace(b",", b"")
 
-    def _dump_sql(self, obj: timedelta) -> bytes:
+    @staticmethod
+    def _dump_sql(self: "TimedeltaDumper", obj: timedelta) -> bytes:
         # sql_standard format needs explicit signs
         # otherwise -1 day 1 sec will mean -1 sec
         return b"%+d day %+d second %+d microsecond" % (
@@ -495,10 +498,16 @@ class TimestamptzLoader(Loader):
         self._timezone = get_tzinfo(self.connection.pgconn if self.connection else None)
 
         ds = _get_datestyle(self.connection)
-        if not ds.startswith(b"I"):  # not ISO
-            setattr(self, "load", self._load_notimpl)
+        if ds.startswith(b"I"):  # ISO
+            self._load_method = self._load_iso
+        else:
+            self._load_method = self._load_notimpl
 
     def load(self, data: Buffer) -> datetime:
+        return self._load_method(self, data)
+
+    @staticmethod
+    def _load_iso(self: "TimestamptzLoader", data: Buffer) -> datetime:
         m = self._re_format.match(data)
         if not m:
             raise _get_timestamp_load_error(self.connection, data) from None
@@ -544,7 +553,8 @@ class TimestamptzLoader(Loader):
 
         raise _get_timestamp_load_error(self.connection, data, ex) from None
 
-    def _load_notimpl(self, data: Buffer) -> datetime:
+    @staticmethod
+    def _load_notimpl(self: "TimestamptzLoader", data: Buffer) -> datetime:
         s = bytes(data).decode("utf8", "replace")
         ds = _get_datestyle(self.connection).decode("ascii")
         raise NotImplementedError(
@@ -602,12 +612,16 @@ class IntervalLoader(Loader):
 
     def __init__(self, oid: int, context: Optional[AdaptContext] = None):
         super().__init__(oid, context)
-        if self.connection:
-            ints = self.connection.pgconn.parameter_status(b"IntervalStyle")
-            if ints != b"postgres":
-                setattr(self, "load", self._load_notimpl)
+        if _get_intervalstyle(self.connection) == b"postgres":
+            self._load_method = self._load_postgres
+        else:
+            self._load_method = self._load_notimpl
 
     def load(self, data: Buffer) -> timedelta:
+        return self._load_method(self, data)
+
+    @staticmethod
+    def _load_postgres(self: "IntervalLoader", data: Buffer) -> timedelta:
         m = self._re_interval.match(data)
         if not m:
             s = bytes(data).decode("utf8", "replace")
@@ -635,13 +649,10 @@ class IntervalLoader(Loader):
             s = bytes(data).decode("utf8", "replace")
             raise DataError(f"can't parse interval {s!r}: {e}") from None
 
-    def _load_notimpl(self, data: Buffer) -> timedelta:
+    @staticmethod
+    def _load_notimpl(self: "IntervalLoader", data: Buffer) -> timedelta:
         s = bytes(data).decode("utf8", "replace")
-        ints = (
-            self.connection
-            and self.connection.pgconn.parameter_status(b"IntervalStyle")
-            or b"unknown"
-        ).decode("utf8", "replace")
+        ints = _get_intervalstyle(self.connection).decode("utf8", "replace")
         raise NotImplementedError(
             f"can't parse interval with IntervalStyle {ints}: {s!r}"
         )
@@ -674,6 +685,15 @@ def _get_datestyle(conn: Optional["BaseConnection[Any]"]) -> bytes:
     return b"ISO, DMY"
 
 
+def _get_intervalstyle(conn: Optional["BaseConnection[Any]"]) -> bytes:
+    if conn:
+        ints = conn.pgconn.parameter_status(b"IntervalStyle")
+        if ints:
+            return ints
+
+    return b"unknown"
+
+
 def _get_timestamp_load_error(
     conn: Optional["BaseConnection[Any]"], data: Buffer, ex: Optional[Exception] = None
 ) -> Exception: