]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
expand column options for composites up front at the attribute level
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 15 May 2025 17:39:36 +0000 (13:39 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 15 May 2025 19:07:23 +0000 (15:07 -0400)
Implemented the :func:`_orm.defer`, :func:`_orm.undefer` and
:func:`_orm.load_only` loader options to work for composite attributes, a
use case that had never been supported previously.

Fixes: #12593
Change-Id: Ie7892a710f30b69c83f586f7492174a3b8198f80

doc/build/changelog/unreleased_20/12593.rst [new file with mode: 0644]
lib/sqlalchemy/orm/attributes.py
lib/sqlalchemy/orm/descriptor_props.py
lib/sqlalchemy/orm/strategy_options.py
test/orm/test_composites.py

diff --git a/doc/build/changelog/unreleased_20/12593.rst b/doc/build/changelog/unreleased_20/12593.rst
new file mode 100644 (file)
index 0000000..945e0d6
--- /dev/null
@@ -0,0 +1,7 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 12593
+
+    Implemented the :func:`_orm.defer`, :func:`_orm.undefer` and
+    :func:`_orm.load_only` loader options to work for composite attributes, a
+    use case that had never been supported previously.
index 1722de484859c03ce7f9a96aa419177b20de891c..952140575df697eea098125aad24b4a1653ffa6c 100644 (file)
@@ -463,6 +463,9 @@ class QueryableAttribute(
     ) -> bool:
         return self.impl.hasparent(state, optimistic=optimistic) is not False
 
+    def _column_strategy_attrs(self) -> Sequence[QueryableAttribute[Any]]:
+        return (self,)
+
     def __getattr__(self, key: str) -> Any:
         try:
             return util.MemoizedSlots.__getattr__(self, key)
@@ -596,7 +599,7 @@ def _create_proxied_attribute(
     # TODO: can move this to descriptor_props if the need for this
     # function is removed from ext/hybrid.py
 
-    class Proxy(QueryableAttribute[Any]):
+    class Proxy(QueryableAttribute[_T_co]):
         """Presents the :class:`.QueryableAttribute` interface as a
         proxy on top of a Python descriptor / :class:`.PropComparator`
         combination.
@@ -611,13 +614,13 @@ def _create_proxied_attribute(
 
         def __init__(
             self,
-            class_,
-            key,
-            descriptor,
-            comparator,
-            adapt_to_entity=None,
-            doc=None,
-            original_property=None,
+            class_: _ExternalEntityType[Any],
+            key: str,
+            descriptor: Any,
+            comparator: interfaces.PropComparator[_T_co],
+            adapt_to_entity: Optional[AliasedInsp[Any]] = None,
+            doc: Optional[str] = None,
+            original_property: Optional[QueryableAttribute[_T_co]] = None,
         ):
             self.class_ = class_
             self.key = key
@@ -642,6 +645,13 @@ def _create_proxied_attribute(
             ("_parententity", visitors.ExtendedInternalTraversal.dp_multi),
         ]
 
+        def _column_strategy_attrs(self) -> Sequence[QueryableAttribute[Any]]:
+            prop = self.original_property
+            if prop is None:
+                return ()
+            else:
+                return prop._column_strategy_attrs()
+
         @property
         def _impl_uses_objects(self):
             return (
index 6842cd149a435fd71e6457db82820bec8af486d4..d5f7bcc876413e728e9dc8c82d7ccf3ed182b8b8 100644 (file)
@@ -104,6 +104,11 @@ class DescriptorProperty(MapperProperty[_T]):
 
     descriptor: DescriptorReference[Any]
 
+    def _column_strategy_attrs(self) -> Sequence[QueryableAttribute[Any]]:
+        raise NotImplementedError(
+            "This MapperProperty does not implement column loader strategies"
+        )
+
     def get_history(
         self,
         state: InstanceState[Any],
@@ -509,6 +514,9 @@ class CompositeProperty(
             props.append(prop)
         return props
 
+    def _column_strategy_attrs(self) -> Sequence[QueryableAttribute[Any]]:
+        return self._comparable_elements
+
     @util.non_memoized_property
     @util.preload_module("orm.properties")
     def columns(self) -> Sequence[Column[Any]]:
@@ -1008,6 +1016,9 @@ class SynonymProperty(DescriptorProperty[_T]):
             )
         return attr.property
 
+    def _column_strategy_attrs(self) -> Sequence[QueryableAttribute[Any]]:
+        return (getattr(self.parent.class_, self.name),)
+
     def _comparator_factory(self, mapper: Mapper[Any]) -> SQLORMOperations[_T]:
         prop = self._proxied_object
 
index c2a44e899e8cd487be6e37a6cd02264d4a057d3f..d41eaec0b2b9c0aff23898c9e001f28360a3ffb8 100644 (file)
@@ -6,9 +6,7 @@
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
 # mypy: allow-untyped-defs, allow-untyped-calls
 
-"""
-
-"""
+""" """
 
 from __future__ import annotations
 
@@ -224,7 +222,7 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
 
         """
         cloned = self._set_column_strategy(
-            attrs,
+            _expand_column_strategy_attrs(attrs),
             {"deferred": False, "instrument": True},
         )
 
@@ -637,7 +635,9 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
         strategy = {"deferred": True, "instrument": True}
         if raiseload:
             strategy["raiseload"] = True
-        return self._set_column_strategy((key,), strategy)
+        return self._set_column_strategy(
+            _expand_column_strategy_attrs((key,)), strategy
+        )
 
     def undefer(self, key: _AttrType) -> Self:
         r"""Indicate that the given column-oriented attribute should be
@@ -676,7 +676,8 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
 
         """  # noqa: E501
         return self._set_column_strategy(
-            (key,), {"deferred": False, "instrument": True}
+            _expand_column_strategy_attrs((key,)),
+            {"deferred": False, "instrument": True},
         )
 
     def undefer_group(self, name: str) -> Self:
@@ -2387,6 +2388,23 @@ See :func:`_orm.{fn.__name__}` for usage examples.
     return fn
 
 
+def _expand_column_strategy_attrs(
+    attrs: Tuple[_AttrType, ...],
+) -> Tuple[_AttrType, ...]:
+    return cast(
+        "Tuple[_AttrType, ...]",
+        tuple(
+            a
+            for attr in attrs
+            for a in (
+                cast("QueryableAttribute[Any]", attr)._column_strategy_attrs()
+                if hasattr(attr, "_column_strategy_attrs")
+                else (attr,)
+            )
+        ),
+    )
+
+
 # standalone functions follow.  docstrings are filled in
 # by the ``@loader_unbound_fn`` decorator.
 
@@ -2400,6 +2418,7 @@ def contains_eager(*keys: _AttrType, **kw: Any) -> _AbstractLoad:
 def load_only(*attrs: _AttrType, raiseload: bool = False) -> _AbstractLoad:
     # TODO: attrs against different classes.  we likely have to
     # add some extra state to Load of some kind
+    attrs = _expand_column_strategy_attrs(attrs)
     _, lead_element, _ = _parse_attr_argument(attrs[0])
     return Load(lead_element).load_only(*attrs, raiseload=raiseload)
 
index f9a1ba386595684e404f5b0796c57f8b9ae83e37..cd205be5b48e1ff6b01397bfa34384882cf2beb5 100644 (file)
@@ -16,9 +16,13 @@ from sqlalchemy.orm import aliased
 from sqlalchemy.orm import Composite
 from sqlalchemy.orm import composite
 from sqlalchemy.orm import configure_mappers
+from sqlalchemy.orm import defer
+from sqlalchemy.orm import load_only
 from sqlalchemy.orm import mapped_column
 from sqlalchemy.orm import relationship
 from sqlalchemy.orm import Session
+from sqlalchemy.orm import undefer
+from sqlalchemy.orm import undefer_group
 from sqlalchemy.orm.attributes import LoaderCallableStatus
 from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import eq_
@@ -1470,7 +1474,7 @@ class ManyToOneTest(fixtures.MappedTest):
         eq_(sess.query(ae).filter(ae.c == C("a2b1", b2)).one(), a2)
 
 
-class ConfigurationTest(fixtures.MappedTest):
+class ConfigAndDeferralTest(fixtures.MappedTest):
     @classmethod
     def define_tables(cls, metadata):
         Table(
@@ -1508,7 +1512,7 @@ class ConfigurationTest(fixtures.MappedTest):
         class Edge(cls.Comparable):
             pass
 
-    def _test_roundtrip(self):
+    def _test_roundtrip(self, *, assert_deferred=False, options=()):
         Edge, Point = self.classes.Edge, self.classes.Point
 
         e1 = Edge(start=Point(3, 4), end=Point(5, 6))
@@ -1516,7 +1520,19 @@ class ConfigurationTest(fixtures.MappedTest):
         sess.add(e1)
         sess.commit()
 
-        eq_(sess.query(Edge).one(), Edge(start=Point(3, 4), end=Point(5, 6)))
+        stmt = select(Edge)
+        if options:
+            stmt = stmt.options(*options)
+        e1 = sess.execute(stmt).scalar_one()
+
+        names = ["start", "end", "x1", "x2", "y1", "y2"]
+        for name in names:
+            if assert_deferred:
+                assert name not in e1.__dict__
+            else:
+                assert name in e1.__dict__
+
+        eq_(e1, Edge(start=Point(3, 4), end=Point(5, 6)))
 
     def test_columns(self):
         edge, Edge, Point = (
@@ -1562,7 +1578,7 @@ class ConfigurationTest(fixtures.MappedTest):
 
         self._test_roundtrip()
 
-    def test_deferred(self):
+    def test_deferred_config(self):
         edge, Edge, Point = (
             self.tables.edge,
             self.classes.Edge,
@@ -1580,7 +1596,121 @@ class ConfigurationTest(fixtures.MappedTest):
                 ),
             },
         )
-        self._test_roundtrip()
+        self._test_roundtrip(assert_deferred=True)
+
+    def test_defer_option_on_cols(self):
+        edge, Edge, Point = (
+            self.tables.edge,
+            self.classes.Edge,
+            self.classes.Point,
+        )
+        self.mapper_registry.map_imperatively(
+            Edge,
+            edge,
+            properties={
+                "start": sa.orm.composite(
+                    Point,
+                    edge.c.x1,
+                    edge.c.y1,
+                ),
+                "end": sa.orm.composite(
+                    Point,
+                    edge.c.x2,
+                    edge.c.y2,
+                ),
+            },
+        )
+        self._test_roundtrip(
+            assert_deferred=True,
+            options=(
+                defer(Edge.x1),
+                defer(Edge.x2),
+                defer(Edge.y1),
+                defer(Edge.y2),
+            ),
+        )
+
+    def test_defer_option_on_composite(self):
+        edge, Edge, Point = (
+            self.tables.edge,
+            self.classes.Edge,
+            self.classes.Point,
+        )
+        self.mapper_registry.map_imperatively(
+            Edge,
+            edge,
+            properties={
+                "start": sa.orm.composite(
+                    Point,
+                    edge.c.x1,
+                    edge.c.y1,
+                ),
+                "end": sa.orm.composite(
+                    Point,
+                    edge.c.x2,
+                    edge.c.y2,
+                ),
+            },
+        )
+        self._test_roundtrip(
+            assert_deferred=True, options=(defer(Edge.start), defer(Edge.end))
+        )
+
+    @testing.variation("composite_only", [True, False])
+    def test_load_only_option_on_composite(self, composite_only):
+        edge, Edge, Point = (
+            self.tables.edge,
+            self.classes.Edge,
+            self.classes.Point,
+        )
+        self.mapper_registry.map_imperatively(
+            Edge,
+            edge,
+            properties={
+                "start": sa.orm.composite(
+                    Point, edge.c.x1, edge.c.y1, deferred=True
+                ),
+                "end": sa.orm.composite(
+                    Point,
+                    edge.c.x2,
+                    edge.c.y2,
+                ),
+            },
+        )
+
+        if composite_only:
+            self._test_roundtrip(
+                assert_deferred=False,
+                options=(load_only(Edge.start, Edge.end),),
+            )
+        else:
+            self._test_roundtrip(
+                assert_deferred=False,
+                options=(load_only(Edge.start, Edge.x2, Edge.y2),),
+            )
+
+    def test_defer_option_on_composite_via_group(self):
+        edge, Edge, Point = (
+            self.tables.edge,
+            self.classes.Edge,
+            self.classes.Point,
+        )
+        self.mapper_registry.map_imperatively(
+            Edge,
+            edge,
+            properties={
+                "start": sa.orm.composite(
+                    Point, edge.c.x1, edge.c.y1, deferred=True, group="s"
+                ),
+                "end": sa.orm.composite(
+                    Point, edge.c.x2, edge.c.y2, deferred=True
+                ),
+            },
+        )
+        self._test_roundtrip(
+            assert_deferred=False,
+            options=(undefer_group("s"), undefer(Edge.end)),
+        )
 
     def test_check_prop_type(self):
         edge, Edge, Point = (