]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
use _extra_criteria to store with_expression() expression
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 31 Oct 2023 18:56:37 +0000 (14:56 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 31 Oct 2023 22:48:07 +0000 (18:48 -0400)
this is an alternate version of the first patch, which adds extra
handling for the "expression" in local_opts.  this patch has
with_expression() use _extra_criteria directly, as this attribute
is currently unpurposed for column-based attributes.

Fixed caching bug where using the :func:`_orm.with_expression` construct in
conjunction with loader options :func:`_orm.selectinload`,
:func:`_orm.lazyload` would fail to substitute bound parameter values
correctly on subsequent caching runs.

Fixes: #10570
Change-Id: If6c74755580fe5b108056eebcb461d984410ff46

doc/build/changelog/unreleased_20/10570.rst [new file with mode: 0644]
lib/sqlalchemy/orm/strategies.py
lib/sqlalchemy/orm/strategy_options.py
lib/sqlalchemy/sql/cache_key.py
test/orm/test_cache_key.py

diff --git a/doc/build/changelog/unreleased_20/10570.rst b/doc/build/changelog/unreleased_20/10570.rst
new file mode 100644 (file)
index 0000000..0043e08
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 10570
+
+    Fixed caching bug where using the :func:`_orm.with_expression` construct in
+    conjunction with loader options :func:`_orm.selectinload`,
+    :func:`_orm.lazyload` would fail to substitute bound parameter values
+    correctly on subsequent caching runs.
index a0e092988386c95630f1610bbc201e3c89724884..1e58f4091a68422d7bb85f475d9c129edf07b2ba 100644 (file)
@@ -306,8 +306,9 @@ class ExpressionColumnLoader(ColumnLoader):
         **kwargs,
     ):
         columns = None
-        if loadopt and "expression" in loadopt.local_opts:
-            columns = [loadopt.local_opts["expression"]]
+        if loadopt and loadopt._extra_criteria:
+            columns = loadopt._extra_criteria
+
         elif self._have_default_expression:
             columns = self.parent_property.columns
 
@@ -343,8 +344,8 @@ class ExpressionColumnLoader(ColumnLoader):
     ):
         # look through list of columns represented here
         # to see which, if any, is present in the row.
-        if loadopt and "expression" in loadopt.local_opts:
-            columns = [loadopt.local_opts["expression"]]
+        if loadopt and loadopt._extra_criteria:
+            columns = loadopt._extra_criteria
 
             for col in columns:
                 if adapter:
index c9ccd9cde92cd7afc1a8c21517885956de01329e..6c81e8fe7370e48c512c69641dd46cc40f010c2d 100644 (file)
@@ -747,7 +747,7 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
         )
 
         return self._set_column_strategy(
-            (key,), {"query_expression": True}, opts={"expression": expression}
+            (key,), {"query_expression": True}, extra_criteria=(expression,)
         )
 
     def selectin_polymorphic(self, classes: Iterable[Type[Any]]) -> Self:
@@ -819,6 +819,7 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
         attrs: Tuple[_AttrType, ...],
         strategy: Optional[_StrategySpec],
         opts: Optional[_OptsType] = None,
+        extra_criteria: Optional[Tuple[Any, ...]] = None,
     ) -> Self:
         strategy_key = self._coerce_strat(strategy)
 
@@ -828,6 +829,7 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
             _COLUMN_TOKEN,
             opts=opts,
             attr_group=attrs,
+            extra_criteria=extra_criteria,
         )
         return self
 
@@ -884,6 +886,7 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption):
         attr_group: Optional[_AttrGroupType] = None,
         propagate_to_loaders: bool = True,
         reconcile_to_other: Optional[bool] = None,
+        extra_criteria: Optional[Tuple[Any, ...]] = None,
     ) -> Self:
         raise NotImplementedError()
 
@@ -1052,16 +1055,10 @@ class Load(_AbstractLoad):
         found_crit = False
 
         def process(opt: _LoadElement) -> _LoadElement:
-            if not opt._extra_criteria:
-                return opt
-
             nonlocal orig_cache_key, replacement_cache_key, found_crit
 
             found_crit = True
 
-            # avoid generating cache keys for the queries if we don't
-            # actually have any extra_criteria options, which is the
-            # common case
             if orig_cache_key is None or replacement_cache_key is None:
                 orig_cache_key = orig_query._generate_cache_key()
                 replacement_cache_key = context.query._generate_cache_key()
@@ -1075,8 +1072,12 @@ class Load(_AbstractLoad):
                 )
                 for crit in opt._extra_criteria
             )
+
             return opt
 
+        # avoid generating cache keys for the queries if we don't
+        # actually have any extra_criteria options, which is the
+        # common case
         new_context = tuple(
             process(value._clone()) if value._extra_criteria else value
             for value in self.context
@@ -1220,6 +1221,7 @@ class Load(_AbstractLoad):
         attr_group: Optional[_AttrGroupType] = None,
         propagate_to_loaders: bool = True,
         reconcile_to_other: Optional[bool] = None,
+        extra_criteria: Optional[Tuple[Any, ...]] = None,
     ) -> Self:
         # for individual strategy that needs to propagate, set the whole
         # Load container to also propagate, so that it shows up in
@@ -1253,6 +1255,7 @@ class Load(_AbstractLoad):
                 propagate_to_loaders,
                 attr_group=attr_group,
                 reconcile_to_other=reconcile_to_other,
+                extra_criteria=extra_criteria,
             )
             if load_element:
                 self.context += (load_element,)
@@ -1273,6 +1276,7 @@ class Load(_AbstractLoad):
                         propagate_to_loaders,
                         attr_group=attr_group,
                         reconcile_to_other=reconcile_to_other,
+                        extra_criteria=extra_criteria,
                     )
                 else:
                     load_element = _AttributeStrategyLoad.create(
@@ -1284,6 +1288,7 @@ class Load(_AbstractLoad):
                         propagate_to_loaders,
                         attr_group=attr_group,
                         reconcile_to_other=reconcile_to_other,
+                        extra_criteria=extra_criteria,
                     )
 
                 if load_element:
@@ -1347,6 +1352,7 @@ class _WildcardLoad(_AbstractLoad):
         attr_group=None,
         propagate_to_loaders=True,
         reconcile_to_other=None,
+        extra_criteria=None,
     ):
         assert attrs is not None
         attr = attrs[0]
@@ -1363,6 +1369,8 @@ class _WildcardLoad(_AbstractLoad):
         if opts:
             self.local_opts = util.immutabledict(opts)
 
+        assert extra_criteria is None
+
     def options(self, *opts: _AbstractLoad) -> Self:
         raise NotImplementedError("Star option does not support sub-options")
 
@@ -1649,7 +1657,9 @@ class _LoadElement(
 
         return effective_path
 
-    def _init_path(self, path, attr, wildcard_key, attr_group, raiseerr):
+    def _init_path(
+        self, path, attr, wildcard_key, attr_group, raiseerr, extra_criteria
+    ):
         """Apply ORM attributes and/or wildcard to an existing path, producing
         a new path.
 
@@ -1709,6 +1719,7 @@ class _LoadElement(
         raiseerr: bool = True,
         attr_group: Optional[_AttrGroupType] = None,
         reconcile_to_other: Optional[bool] = None,
+        extra_criteria: Optional[Tuple[Any, ...]] = None,
     ) -> _LoadElement:
         """Create a new :class:`._LoadElement` object."""
 
@@ -1728,7 +1739,9 @@ class _LoadElement(
         else:
             opt._reconcile_to_other = None
 
-        path = opt._init_path(path, attr, wildcard_key, attr_group, raiseerr)
+        path = opt._init_path(
+            path, attr, wildcard_key, attr_group, raiseerr, extra_criteria
+        )
 
         if not path:
             return None  # type: ignore
@@ -1828,7 +1841,7 @@ class _LoadElement(
             replacement.local_opts = replacement.local_opts.union(
                 existing.local_opts
             )
-            replacement._extra_criteria += replacement._extra_criteria
+            replacement._extra_criteria += existing._extra_criteria
             return replacement
         elif replacement.path.is_token:
             # use 'last one wins' logic for wildcard options.  this is also
@@ -1870,7 +1883,9 @@ class _AttributeStrategyLoad(_LoadElement):
     is_class_strategy = False
     is_token_strategy = False
 
-    def _init_path(self, path, attr, wildcard_key, attr_group, raiseerr):
+    def _init_path(
+        self, path, attr, wildcard_key, attr_group, raiseerr, extra_criteria
+    ):
         assert attr is not None
         self._of_type = None
         self._path_with_polymorphic_path = None
@@ -1916,7 +1931,11 @@ class _AttributeStrategyLoad(_LoadElement):
         # from an attribute.   This appears to have been an artifact of how
         # _UnboundLoad / Load interacted together, which was opaque and
         # poorly defined.
-        self._extra_criteria = attr._extra_criteria
+        if extra_criteria:
+            assert not attr._extra_criteria
+            self._extra_criteria = extra_criteria
+        else:
+            self._extra_criteria = attr._extra_criteria
 
         if getattr(attr, "_of_type", None):
             ac = attr._of_type
@@ -2109,7 +2128,9 @@ class _TokenStrategyLoad(_LoadElement):
     is_class_strategy = False
     is_token_strategy = True
 
-    def _init_path(self, path, attr, wildcard_key, attr_group, raiseerr):
+    def _init_path(
+        self, path, attr, wildcard_key, attr_group, raiseerr, extra_criteria
+    ):
         # assert isinstance(attr, str) or attr is None
         if attr is not None:
             default_token = attr.endswith(_DEFAULT_TOKEN)
@@ -2195,7 +2216,9 @@ class _ClassStrategyLoad(_LoadElement):
 
     __visit_name__ = "class_strategy_load_element"
 
-    def _init_path(self, path, attr, wildcard_key, attr_group, raiseerr):
+    def _init_path(
+        self, path, attr, wildcard_key, attr_group, raiseerr, extra_criteria
+    ):
         return path
 
     def _prepare_for_compile_state(
index 8c21be1b414f6e7c3bb9545507200bad34847fe8..500e3e4dd72b9fb600d5cc8016b1cc5f81835bfd 100644 (file)
@@ -546,6 +546,9 @@ class CacheKey(NamedTuple):
     def _apply_params_to_element(
         self, original_cache_key: CacheKey, target_element: ClauseElement
     ) -> ClauseElement:
+        if target_element._is_immutable:
+            return target_element
+
         translate = {
             k.key: v.value
             for k, v in zip(original_cache_key.bindparams, self.bindparams)
index a54652f4dbed4e4a005e5fcf3cc1497dad08f841..209b6537ec173628dd6e6cbf3cf6e5336ac8e15c 100644 (file)
@@ -3,6 +3,7 @@ import random
 import sqlalchemy as sa
 from sqlalchemy import Column
 from sqlalchemy import column
+from sqlalchemy import ForeignKey
 from sqlalchemy import func
 from sqlalchemy import inspect
 from sqlalchemy import Integer
@@ -28,6 +29,7 @@ from sqlalchemy.orm import lazyload
 from sqlalchemy.orm import Load
 from sqlalchemy.orm import load_only
 from sqlalchemy.orm import Query
+from sqlalchemy.orm import query_expression
 from sqlalchemy.orm import relationship
 from sqlalchemy.orm import selectinload
 from sqlalchemy.orm import Session
@@ -44,6 +46,7 @@ from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import int_within_variance
 from sqlalchemy.testing import ne_
+from sqlalchemy.testing.entities import ComparableMixin
 from sqlalchemy.testing.fixtures import DeclarativeMappedTest
 from sqlalchemy.testing.fixtures import fixture_session
 from sqlalchemy.testing.util import count_cache_key_tuples
@@ -1171,3 +1174,66 @@ class EmbeddedSubqTest(
                 int_within_variance(29796, total_size(ck), 0.05)
             else:
                 testing.skip_test("python platform not available")
+
+
+class WithExpresionLoaderOptTest(DeclarativeMappedTest):
+    """test #10570"""
+
+    @classmethod
+    def setup_classes(cls):
+        Base = cls.DeclarativeBasic
+
+        class A(ComparableMixin, Base):
+            __tablename__ = "a"
+
+            id = Column(Integer, primary_key=True)
+            data = Column(String(30))
+            bs = relationship("B")
+
+        class B(ComparableMixin, Base):
+            __tablename__ = "b"
+            id = Column(Integer, primary_key=True)
+            a_id = Column(ForeignKey("a.id"))
+            boolean = query_expression()
+            data = Column(String(30))
+
+    @classmethod
+    def insert_data(cls, connection):
+        A, B = cls.classes("A", "B")
+
+        with Session(connection) as s:
+            s.add(A(bs=[B(data="a"), B(data="b"), B(data="c")]))
+            s.commit()
+
+    @testing.combinations(
+        joinedload, lazyload, defaultload, selectinload, subqueryload
+    )
+    def test_from_opt(self, loadopt):
+        A, B = self.classes("A", "B")
+
+        def go(value):
+            with Session(testing.db) as sess:
+                objects = sess.execute(
+                    select(A).options(
+                        loadopt(A.bs).options(
+                            with_expression(B.boolean, B.data == value)
+                        )
+                    )
+                ).scalars()
+                if loadopt is joinedload:
+                    objects = objects.unique()
+                eq_(
+                    objects.all(),
+                    [
+                        A(
+                            bs=[
+                                B(data="a", boolean=value == "a"),
+                                B(data="b", boolean=value == "b"),
+                                B(data="c", boolean=value == "c"),
+                            ]
+                        )
+                    ],
+                )
+
+        go("b")
+        go("c")