]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
refactor: cleaner inheritance in record/composite loaders
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 6 Oct 2025 18:17:23 +0000 (20:17 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 30 Oct 2025 13:23:34 +0000 (13:23 +0000)
- Avoid a common base class to reuse the code, use module-level
  functions.
- Make the inheritance lines of text and binary loaders similar.
- Add documentation and notes.

This refactoring is in preparation of the implementation of these
loaders in C, to clarify and simplify things. See #1175.

psycopg/psycopg/types/composite.py

index 1890fa6dc3853cfe7352fb81af780f27924f7463..572deefb798d4a62c1466e14e00804d4f497e845 100644 (file)
@@ -11,11 +11,12 @@ import struct
 from typing import TYPE_CHECKING, Any, NamedTuple, cast
 from functools import cache
 from collections import namedtuple
-from collections.abc import Callable, Iterator, Sequence
+from collections.abc import Callable, Sequence
 
 from .. import abc, postgres, pq, sql
 from .._oids import TEXT_OID
-from ..adapt import Buffer, Dumper, Loader, PyFormat, RecursiveDumper, Transformer
+from ..adapt import Buffer, Dumper, Loader, PyFormat, RecursiveDumper, RecursiveLoader
+from ..adapt import Transformer
 from .._struct import pack_len, unpack_len
 from .._typeinfo import TypeInfo
 from .._encodings import _as_python_identifier
@@ -155,55 +156,36 @@ class TupleBinaryDumper(Dumper):
         return out
 
 
-class BaseCompositeLoader(Loader):
-    def __init__(self, oid: int, context: abc.AdaptContext | None = None):
-        super().__init__(oid, context)
-        self._tx = Transformer(context)
-
-    def _parse_record(self, data: abc.Buffer) -> Iterator[bytes | None]:
-        """
-        Split a non-empty representation of a composite type into components.
-
-        Terminators shouldn't be used in `!data` (so that both record and range
-        representations can be parsed).
-        """
-        for m in self._re_tokenize.finditer(data):
-            if m.group(1):
-                yield None
-            elif m.group(2) is not None:
-                yield self._re_undouble.sub(rb"\1", m.group(2))
-            else:
-                yield m.group(3)
-
-        # If the final group ended in `,` there is a final NULL in the record
-        # that the regexp couldn't parse.
-        if m and m.group().endswith(b","):
-            yield None
-
-    _re_tokenize = re.compile(
-        rb"""(?x)
-          (,)                       # an empty token, representing NULL
-        | " ((?: [^"] | "")*) " ,?  # or a quoted string
-        | ([^",)]+) ,?              # or an unquoted string
-        """
-    )
-
-    _re_undouble = re.compile(rb'(["\\])\1')
+class RecordLoader(RecursiveLoader):
+    """
+    Load a `record` field from PostgreSQL.
 
+    In text mode we don't have type information of the composite's fields, so
+    convert every item as text. Note that in binary loading we have per-field
+    oids instead.
+    """
 
-class RecordLoader(BaseCompositeLoader):
     def load(self, data: abc.Buffer) -> tuple[Any, ...]:
         if data == b"()":
             return ()
 
         cast = self._tx.get_loader(TEXT_OID, self.format).load
-        return tuple(
-            cast(token) if token is not None else None
-            for token in self._parse_record(data[1:-1])
-        )
+        record = _parse_text_record(data[1:-1])
+        for i in range(len(record)):
+            if (f := record[i]) is not None:
+                record[i] = cast(f)
+
+        return tuple(record)
 
 
 class RecordBinaryLoader(Loader):
+    """
+    Load a `record` field from PostgreSQL.
+
+    Unlike in text mode, the composite data contains oids of the fields,
+    so we can actually parse the records in its original types.
+    """
+
     format = pq.Format.BINARY
 
     def __init__(self, oid: int, context: abc.AdaptContext | None = None):
@@ -216,20 +198,7 @@ class RecordBinaryLoader(Loader):
         self._txs: dict[tuple[int, ...], abc.Transformer] = {}
 
     def load(self, data: abc.Buffer) -> tuple[Any, ...]:
-        nfields = unpack_len(data, 0)[0]
-        offset = 4
-        oids = []
-        record: list[Buffer | None] = []
-        for _ in range(nfields):
-            oid, length = _unpack_oidlen(data, offset)
-            offset += 8
-            oids.append(oid)
-            if length >= 0:
-                record.append(data[offset : offset + length])
-                offset += length
-            else:
-                record.append(None)
-
+        record, oids = _parse_binary_record(data)
         key = tuple(oids)
         try:
             tx = self._txs[key]
@@ -240,34 +209,62 @@ class RecordBinaryLoader(Loader):
         return tx.load_sequence(record)
 
 
-class CompositeLoader(RecordLoader):
+class _CompositeLoader(Loader):
+    """
+    Base class to create text loaders of specific composite types.
+
+    The class is complete but lack information about the fields types and
+    object factory. These will be added by register_composite(), which will
+    create a subclass of this class.
+    """
+
     factory: Callable[..., Any]
     fields_types: list[int]
-    _types_set = False
 
-    def load(self, data: abc.Buffer) -> Any:
-        if not self._types_set:
-            self._config_types(data)
-            self._types_set = True
+    def __init__(self, oid: int, context: abc.AdaptContext | None = None):
+        super().__init__(oid, context)
+        # Note: we cannot use the RecursiveLoader base class here because we
+        # always want a different Transformer instance, otherwise the types
+        # loaded will conflict with the types loaded by the record.
+        self._tx = Transformer(context)
+        self._tx.set_loader_types(self.fields_types, self.format)
 
+    def load(self, data: abc.Buffer) -> Any:
+        # Use `type(self).factory` instead of `self.factory` because, if
+        # `factory` is a function, `self.factory` will become bound and the
+        # first argument passed will become `self`.
         if data == b"()":
             return type(self).factory()
 
         return type(self).factory(
-            *self._tx.load_sequence(tuple(self._parse_record(data[1:-1])))
+            *self._tx.load_sequence(_parse_text_record(data[1:-1]))
         )
 
-    def _config_types(self, data: abc.Buffer) -> None:
-        self._tx.set_loader_types(self.fields_types, self.format)
 
+class _CompositeBinaryLoader(Loader):
+    """
+    Base class to create text loaders of specific composite types.
+
+    The class is complete but lack information about the fields types and
+    object factory. These will be added by register_composite(), which will
+    create a subclass of this class.
+    """
 
-class CompositeBinaryLoader(RecordBinaryLoader):
     format = pq.Format.BINARY
     factory: Callable[..., Any]
+    fields_types: list[int]
+
+    def __init__(self, oid: int, context: abc.AdaptContext | None = None):
+        super().__init__(oid, context)
+        self._tx = Transformer(context)
+        self._tx.set_loader_types(self.fields_types, self.format)
 
     def load(self, data: abc.Buffer) -> Any:
-        r = super().load(data)
-        return type(self).factory(*r)
+        record, _ = _parse_binary_record(data)  # assume oids == self.fields_types
+        # Use `type(self).factory` instead of `self.factory` because, if
+        # `factory` is a function, `self.factory` will become bound and the
+        # first argument passed will become `self`.
+        return type(self).factory(*self._tx.load_sequence(record))
 
 
 def register_composite(
@@ -305,12 +302,12 @@ def register_composite(
     adapters = context.adapters if context else postgres.adapters
 
     # generate and register a customized text loader
-    loader: type[BaseCompositeLoader]
+    loader: type[Loader]
     loader = _make_loader(info.name, tuple(info.field_types), factory)
     adapters.register_loader(info.oid, loader)
 
     # generate and register a customized binary loader
-    loader = _make_binary_loader(info.name, factory)
+    loader = _make_binary_loader(info.name, tuple(info.field_types), factory)
     adapters.register_loader(info.oid, loader)
 
     # If the factory is a type, create and register dumpers for it
@@ -339,6 +336,65 @@ def _nt_from_info(info: CompositeInfo) -> type[NamedTuple]:
     return _make_nt(name, fields)
 
 
+def _parse_text_record(data: abc.Buffer) -> list[bytes | None]:
+    """
+    Split a non-empty representation of a composite type into components.
+
+    Terminators shouldn't be used in `!data` (so that both record and range
+    representations can be parsed).
+    """
+    record: list[bytes | None] = []
+    for m in _re_tokenize.finditer(data):
+        if m.group(1):
+            record.append(None)
+        elif m.group(2) is not None:
+            record.append(_re_undouble.sub(rb"\1", m.group(2)))
+        else:
+            record.append(m.group(3))
+
+    # If the final group ended in `,` there is a final NULL in the record
+    # that the regexp couldn't parse.
+    if m and m.group().endswith(b","):
+        record.append(None)
+
+    return record
+
+
+_re_tokenize = re.compile(
+    rb"""(?x)
+      (,)                       # an empty token, representing NULL
+    | " ((?: [^"] | "")*) " ,?  # or a quoted string
+    | ([^",)]+) ,?              # or an unquoted string
+    """
+)
+_re_undouble = re.compile(rb'(["\\])\1')
+
+
+def _parse_binary_record(data: abc.Buffer) -> tuple[list[Buffer | None], list[int]]:
+    """
+    Parse the binary representation of a composite type.
+
+    Return the sequence of fields and oids found in the type. The fields
+    are returned as buffer: they will need a Transformer to be converted
+    to Python types.
+    """
+    nfields = unpack_len(data, 0)[0]
+    offset = 4
+    oids = []
+    record: list[Buffer | None] = []
+    for _ in range(nfields):
+        oid, length = _unpack_oidlen(data, offset)
+        offset += 8
+        oids.append(oid)
+        if length >= 0:
+            record.append(data[offset : offset + length])
+            offset += length
+        else:
+            record.append(None)
+
+    return record, oids
+
+
 # Cache all dynamically-generated types to avoid leaks in case the types
 # cannot be GC'd.
 
@@ -351,20 +407,24 @@ def _make_nt(name: str, fields: tuple[str, ...]) -> type[NamedTuple]:
 @cache
 def _make_loader(
     name: str, types: tuple[int, ...], factory: Callable[..., Any]
-) -> type[BaseCompositeLoader]:
+) -> type[_CompositeLoader]:
+    doc = f"Text loader for the '{name}' composite."
     return type(
         f"{name.title()}Loader",
-        (CompositeLoader,),
-        {"factory": factory, "fields_types": list(types)},
+        (_CompositeLoader,),
+        {"__doc__": doc, "factory": factory, "fields_types": list(types)},
     )
 
 
 @cache
 def _make_binary_loader(
-    name: str, factory: Callable[..., Any]
-) -> type[BaseCompositeLoader]:
+    name: str, types: tuple[int, ...], factory: Callable[..., Any]
+) -> type[_CompositeBinaryLoader]:
+    doc = f"Binary loader for the '{name}' composite."
     return type(
-        f"{name.title()}BinaryLoader", (CompositeBinaryLoader,), {"factory": factory}
+        f"{name.title()}BinaryLoader",
+        (_CompositeBinaryLoader,),
+        {"__doc__": doc, "factory": factory, "fields_types": list(types)},
     )