]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
feat: support loading composite types with keyword arguments
authorSean Stewart <sean.stewart@mavenclinic.com>
Tue, 15 Oct 2024 21:50:51 +0000 (17:50 -0400)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 30 Oct 2025 14:41:50 +0000 (14:41 +0000)
Changes:

1. Add `KeywordComposite*` classes to support loading composite types with keyword arguments.
2. Some optimizations for the `*BinaryLoader`.
    - Use a generator to unpack (oid, value) pairs.
    - Use a membership check instead of `KeyError` to optimize hot loops when looking up transformers.
3. Generics for the `*InstanceLoader` so mypy can better-track output results.

---

Co-authored by Daniele Varrazzo, mostly rebasing the feature on the
refactoring happening in the composite loading in #1175 and applying
some suggested PR changes.

psycopg/psycopg/types/composite.py
tests/types/test_composite.py

index 572deefb798d4a62c1466e14e00804d4f497e845..06943623a7c6ca916cd6c1401a8053c460bc19a3 100644 (file)
@@ -8,7 +8,7 @@ from __future__ import annotations
 
 import re
 import struct
-from typing import TYPE_CHECKING, Any, NamedTuple, cast
+from typing import TYPE_CHECKING, Any, Generic, NamedTuple, TypeVar, cast
 from functools import cache
 from collections import namedtuple
 from collections.abc import Callable, Sequence
@@ -30,6 +30,8 @@ _unpack_oidlen = cast(
     Callable[[abc.Buffer, int], "tuple[int, int]"], _struct_oidlen.unpack_from
 )
 
+T = TypeVar("T")
+
 
 class CompositeInfo(TypeInfo):
     """Manage information about a composite type."""
@@ -199,17 +201,23 @@ class RecordBinaryLoader(Loader):
 
     def load(self, data: abc.Buffer) -> tuple[Any, ...]:
         record, oids = _parse_binary_record(data)
-        key = tuple(oids)
-        try:
-            tx = self._txs[key]
-        except KeyError:
-            tx = self._txs[key] = Transformer(self._ctx)
-            tx.set_loader_types(oids, self.format)
+        if not record:
+            return ()
 
+        tx = self._get_transformer(tuple(oids))
         return tx.load_sequence(record)
 
+    def _get_transformer(self, key: tuple[int, ...]) -> abc.Transformer:
+        if key in self._txs:
+            return self._txs[key]
+
+        tx = Transformer(self._ctx)
+        tx.set_loader_types([*key], self.format)
+        self._txs[key] = tx
+        return tx
 
-class _CompositeLoader(Loader):
+
+class _CompositeLoader(Loader, Generic[T]):
     """
     Base class to create text loaders of specific composite types.
 
@@ -218,7 +226,7 @@ class _CompositeLoader(Loader):
     create a subclass of this class.
     """
 
-    factory: Callable[..., Any]
+    factory: Callable[..., T]
     fields_types: list[int]
 
     def __init__(self, oid: int, context: abc.AdaptContext | None = None):
@@ -229,29 +237,45 @@ class _CompositeLoader(Loader):
         self._tx = Transformer(context)
         self._tx.set_loader_types(self.fields_types, self.format)
 
-    def load(self, data: abc.Buffer) -> Any:
-        # Use `type(self).factory` instead of `self.factory` because, if
-        # `factory` is a function, `self.factory` will become bound and the
-        # first argument passed will become `self`.
+    def load(self, data: abc.Buffer) -> T:
         if data == b"()":
-            return type(self).factory()
+            args = ()
+        else:
+            args = self._tx.load_sequence(tuple(_parse_text_record(data[1:-1])))
+        return self._load_instance(args)
+
+    @classmethod
+    def _load_instance(cls, args: Sequence[Any]) -> T:
+        raise NotImplementedError
 
-        return type(self).factory(
-            *self._tx.load_sequence(_parse_text_record(data[1:-1]))
-        )
 
+class _ArgsCompositeLoader(_CompositeLoader[T]):
+
+    @classmethod
+    def _load_instance(cls, args: Sequence[Any]) -> T:
+        return cls.factory(*args)
 
-class _CompositeBinaryLoader(Loader):
+
+class _KwargsCompositeLoader(_CompositeLoader[T]):
+    fields_names: Sequence[str]
+
+    @classmethod
+    def _load_instance(cls, args: Sequence[Any]) -> T:
+        mapped = dict(zip(cls.fields_names, args))
+        return cls.factory(**mapped)
+
+
+class _CompositeBinaryLoader(Loader, Generic[T]):
     """
     Base class to create text loaders of specific composite types.
 
-    The class is complete but lack information about the fields types and
-    object factory. These will be added by register_composite(), which will
+    The class is complete but lack information about the fields types, names,
+    and object factory. These will be added by register_composite(), which will
     create a subclass of this class.
     """
 
     format = pq.Format.BINARY
-    factory: Callable[..., Any]
+    factory: Callable[..., T]
     fields_types: list[int]
 
     def __init__(self, oid: int, context: abc.AdaptContext | None = None):
@@ -259,18 +283,37 @@ class _CompositeBinaryLoader(Loader):
         self._tx = Transformer(context)
         self._tx.set_loader_types(self.fields_types, self.format)
 
-    def load(self, data: abc.Buffer) -> Any:
-        record, _ = _parse_binary_record(data)  # assume oids == self.fields_types
-        # Use `type(self).factory` instead of `self.factory` because, if
-        # `factory` is a function, `self.factory` will become bound and the
-        # first argument passed will become `self`.
-        return type(self).factory(*self._tx.load_sequence(record))
+    def load(self, data: abc.Buffer) -> T:
+        brecord, _ = _parse_binary_record(data)  # assume oids == self.fields_types
+        record = self._tx.load_sequence(brecord)
+        return self._load_instance(record)
+
+    @classmethod
+    def _load_instance(cls, args: Sequence[Any]) -> T:
+        raise NotImplementedError
+
+
+class _ArgsCompositeBinaryLoader(_CompositeBinaryLoader[T]):
+
+    @classmethod
+    def _load_instance(cls, args: Sequence[Any]) -> T:
+        return cls.factory(*args)
+
+
+class _KwargsCompositeBinaryLoader(_CompositeBinaryLoader[T]):
+    fields_names: Sequence[str]
+
+    @classmethod
+    def _load_instance(cls, args: Sequence[Any]) -> T:
+        mapped = dict(zip(cls.fields_names, args))
+        return cls.factory(**mapped)
 
 
 def register_composite(
     info: CompositeInfo,
     context: abc.AdaptContext | None = None,
-    factory: Callable[..., Any] | None = None,
+    factory: Callable[..., T] | None = None,
+    use_keywords: bool = False,
 ) -> None:
     """Register the adapters to load and dump a composite type.
 
@@ -279,6 +322,8 @@ def register_composite(
         register it globally.
     :param factory: Callable to convert the sequence of attributes read from
         the composite into a Python object.
+    :param use_keywords: If `True`, load composite types using field names as keyword
+        arguments.
 
     .. note::
 
@@ -297,23 +342,29 @@ def register_composite(
     info.register(context)
 
     if not factory:
-        factory = _nt_from_info(info)
+        factory = cast("Callable[..., T]", _nt_from_info(info))
 
     adapters = context.adapters if context else postgres.adapters
 
+    field_names = tuple(_as_python_identifier(n) for n in info.field_names)
+    field_types = tuple(info.field_types)
+
     # generate and register a customized text loader
-    loader: type[Loader]
-    loader = _make_loader(info.name, tuple(info.field_types), factory)
+    loader: type[_CompositeLoader[T]] = _make_loader(
+        info.name, factory, field_names, field_types, use_keywords
+    )
     adapters.register_loader(info.oid, loader)
 
     # generate and register a customized binary loader
-    loader = _make_binary_loader(info.name, tuple(info.field_types), factory)
-    adapters.register_loader(info.oid, loader)
+    binary_loader: type[_CompositeBinaryLoader[T]] = _make_binary_loader(
+        info.name, factory, field_names, field_types, use_keywords
+    )
+    adapters.register_loader(info.oid, binary_loader)
 
     # If the factory is a type, create and register dumpers for it
     if isinstance(factory, type):
         dumper: type[Dumper]
-        dumper = _make_binary_dumper(info.name, info.oid, tuple(info.field_types))
+        dumper = _make_binary_dumper(info.name, info.oid, field_types)
         adapters.register_dumper(factory, dumper)
 
         # Default to the text dumper because it is more flexible
@@ -406,26 +457,42 @@ def _make_nt(name: str, fields: tuple[str, ...]) -> type[NamedTuple]:
 
 @cache
 def _make_loader(
-    name: str, types: tuple[int, ...], factory: Callable[..., Any]
-) -> type[_CompositeLoader]:
+    name: str,
+    factory: Callable[..., T],
+    field_names: tuple[str, ...],
+    field_types: tuple[int, ...],
+    use_keywords: bool,
+) -> type[_CompositeLoader[T]]:
     doc = f"Text loader for the '{name}' composite."
-    return type(
-        f"{name.title()}Loader",
-        (_CompositeLoader,),
-        {"__doc__": doc, "factory": factory, "fields_types": list(types)},
-    )
+    base_cls = _KwargsCompositeLoader if use_keywords else _ArgsCompositeLoader
+    d = {
+        "__doc__": doc,
+        "factory": factory,
+        "fields_types": field_types,
+        "fields_names": field_names,
+    }
+    return type(f"{name.title()}Loader", (base_cls,), d)
 
 
 @cache
 def _make_binary_loader(
-    name: str, types: tuple[int, ...], factory: Callable[..., Any]
-) -> type[_CompositeBinaryLoader]:
+    name: str,
+    factory: Callable[..., T],
+    field_names: tuple[str, ...],
+    field_types: tuple[int, ...],
+    use_keywords: bool,
+) -> type[_CompositeBinaryLoader[T]]:
     doc = f"Binary loader for the '{name}' composite."
-    return type(
-        f"{name.title()}BinaryLoader",
-        (_CompositeBinaryLoader,),
-        {"__doc__": doc, "factory": factory, "fields_types": list(types)},
+    base_cls = (
+        _KwargsCompositeBinaryLoader if use_keywords else _ArgsCompositeBinaryLoader
     )
+    d = {
+        "__doc__": doc,
+        "factory": factory,
+        "fields_names": field_names,
+        "fields_types": field_types,
+    }
+    return type(f"{name.title()}BinaryLoader", (base_cls,), d)
 
 
 @cache
index b5dae9e228c1edab2767c4a29450b0575832d571..2276e4aa60f020fc1189a722c569c03a0030a757 100644 (file)
@@ -345,6 +345,29 @@ def test_load_composite_factory(conn, testcomp, fmt_out):
     assert isinstance(res[0].baz, float)
 
 
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_load_keyword_composite_factory(conn, testcomp, fmt_out):
+    info = CompositeInfo.fetch(conn, "testcomp")
+
+    class MyKeywordThing:
+        def __init__(self, *, foo, bar, baz):
+            self.foo, self.bar, self.baz = foo, bar, baz
+
+    register_composite(info, conn, factory=MyKeywordThing, use_keywords=True)
+    assert info.python_type is MyKeywordThing
+
+    cur = conn.cursor(binary=fmt_out)
+    res = cur.execute("select row('hello', 10, 20)::testcomp").fetchone()[0]
+    assert isinstance(res, MyKeywordThing)
+    assert res.baz == 20.0
+    assert isinstance(res.baz, float)
+
+    res = cur.execute("select array[row('hello', 10, 30)::testcomp]").fetchone()[0]
+    assert len(res) == 1
+    assert res[0].baz == 30.0
+    assert isinstance(res[0].baz, float)
+
+
 def test_register_scope(conn, testcomp):
     info = CompositeInfo.fetch(conn, "testcomp")
     register_composite(info)