]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
subqueryload invokes compile() on _OverrideBinds - do robust replace of bp
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 19 Mar 2024 17:35:35 +0000 (13:35 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 22 Mar 2024 13:21:22 +0000 (09:21 -0400)
Fixed regression from version 2.0.28 caused by the fix for :ticket:`11085`
where the newer method of adjusting post-cache bound parameter values would
interefere with the implementation for the :func:`_orm.subqueryload` loader
option, which has some more legacy patterns in use internally, when
the additional loader criteria feature were used with this loader option.

Fixes: #11173
Change-Id: I88982fbcc809d516eb7c46a00fb807aab9c3a98e
(cherry picked from commit b6f63a57ed878c1e157ecf86cb35d8b15cd7ea3b)

doc/build/changelog/unreleased_20/11173.rst [new file with mode: 0644]
lib/sqlalchemy/sql/compiler.py
test/orm/test_relationship_criteria.py
test/orm/test_subquery_relations.py

diff --git a/doc/build/changelog/unreleased_20/11173.rst b/doc/build/changelog/unreleased_20/11173.rst
new file mode 100644 (file)
index 0000000..900c614
--- /dev/null
@@ -0,0 +1,9 @@
+.. change::
+    :tags: bug, orm, regression
+    :tickets: 11173
+
+    Fixed regression from version 2.0.28 caused by the fix for :ticket:`11085`
+    where the newer method of adjusting post-cache bound parameter values would
+    interefere with the implementation for the :func:`_orm.subqueryload` loader
+    option, which has some more legacy patterns in use internally, when
+    the additional loader criteria feature were used with this loader option.
index 813d3fa0a0556c67ca9ab45ecbff073206f04768..c354ba8386443820787202160bbb9ea8ee8bdb10 100644 (file)
@@ -2360,17 +2360,18 @@ class SQLCompiler(Compiled):
         the compilation was already performed, and only the bound params should
         be swapped in at execution time.
 
-        However, the test suite has some tests that exercise compilation
-        on individual elements without using the cache key version, so here we
-        modify the bound parameter collection for the given compiler based on
-        the translation.
+        However, there are test cases that exericise this object, and
+        additionally the ORM subquery loader is known to feed in expressions
+        which include this construct into new queries (discovered in #11173),
+        so it has to do the right thing at compile time as well.
 
         """
 
         # get SQL text first
         sqltext = override_binds.element._compiler_dispatch(self, **kw)
 
-        # then change binds after the fact.  note that we don't try to
+        # for a test compile that is not for caching, change binds after the
+        # fact.  note that we don't try to
         # swap the bindparam as we compile, because our element may be
         # elsewhere in the statement already (e.g. a subquery or perhaps a
         # CTE) and was already visited / compiled. See
@@ -2381,14 +2382,36 @@ class SQLCompiler(Compiled):
                 continue
             bp = self.binds[k]
 
+            # so this would work, just change the value of bp in place.
+            # but we dont want to mutate things outside.
+            # bp.value = override_binds.translate[bp.key]
+            # continue
+
+            # instead, need to replace bp with new_bp or otherwise accommodate
+            # in all internal collections
             new_bp = bp._with_value(
                 override_binds.translate[bp.key],
                 maintain_key=True,
                 required=False,
             )
+
             name = self.bind_names[bp]
             self.binds[k] = self.binds[name] = new_bp
             self.bind_names[new_bp] = name
+            self.bind_names.pop(bp, None)
+
+            if bp in self.post_compile_params:
+                self.post_compile_params |= {new_bp}
+            if bp in self.literal_execute_params:
+                self.literal_execute_params |= {new_bp}
+
+            ckbm_tuple = self._cache_key_bind_match
+            if ckbm_tuple:
+                ckbm, cksm = ckbm_tuple
+                for bp in bp._cloned_set:
+                    if bp.key in cksm:
+                        cb = cksm[bp.key]
+                        ckbm[cb].append(new_bp)
 
         return sqltext
 
index 4add92c1e725f5308d898128cd759b46e15355d7..96c178e5e22ec05aedcbbd3902c646f77b407e3f 100644 (file)
@@ -2068,6 +2068,55 @@ class RelationshipCriteriaTest(_Fixtures, testing.AssertsCompiledSQL):
                 ),
             )
 
+    @testing.combinations(
+        (selectinload,),
+        (subqueryload,),
+        (lazyload,),
+        (joinedload,),
+        argnames="opt",
+    )
+    @testing.variation("use_in", [True, False])
+    def test_opts_local_criteria_cachekey(
+        self, opt, user_address_fixture, use_in
+    ):
+        """test #11173"""
+        User, Address = user_address_fixture
+
+        s = Session(testing.db, future=True)
+
+        def go(value):
+            if use_in:
+                expr = ~Address.email_address.in_([value, "some_email"])
+            else:
+                expr = Address.email_address != value
+            stmt = (
+                select(User)
+                .options(
+                    opt(User.addresses.and_(expr)),
+                )
+                .order_by(User.id)
+            )
+            result = s.execute(stmt)
+            return result
+
+        for value in (
+            "ed@wood.com",
+            "ed@lala.com",
+            "ed@wood.com",
+            "ed@lala.com",
+        ):
+            s.close()
+            result = go(value)
+
+            eq_(
+                result.scalars().unique().all(),
+                (
+                    self._user_minus_edwood(*user_address_fixture)
+                    if value == "ed@wood.com"
+                    else self._user_minus_edlala(*user_address_fixture)
+                ),
+            )
+
     @testing.combinations(
         (joinedload, False),
         (lazyload, True),
index 00564cfb656fd5808d6ec7e5f1631c53a481955c..538c77c0cee338c1220b7c4e803a6a6f7cad4154 100644 (file)
@@ -3759,3 +3759,81 @@ class Issue6149Test(fixtures.DeclarativeMappedTest):
                 ),
             )
             s.close()
+
+
+class Issue11173Test(fixtures.DeclarativeMappedTest):
+    @classmethod
+    def setup_classes(cls):
+        Base = cls.DeclarativeBasic
+
+        class SubItem(Base):
+            __tablename__ = "sub_items"
+
+            id = Column(Integer, primary_key=True, autoincrement=True)
+            item_id = Column(Integer, ForeignKey("items.id"))
+            name = Column(String(50))
+            number = Column(Integer)
+
+        class Item(Base):
+            __tablename__ = "items"
+
+            id = Column(Integer, primary_key=True, autoincrement=True)
+            name = Column(String(50))
+            number = Column(Integer)
+            sub_items = relationship("SubItem", backref="item")
+
+    @classmethod
+    def insert_data(cls, connection):
+        Item, SubItem = cls.classes("Item", "SubItem")
+
+        with Session(connection) as sess:
+            number_of_items = 50
+            number_of_sub_items = 5
+
+            items = [
+                Item(name=f"Item:{i}", number=i)
+                for i in range(number_of_items)
+            ]
+            sess.add_all(items)
+            for item in items:
+                item.sub_items = [
+                    SubItem(name=f"SubItem:{item.id}:{i}", number=i)
+                    for i in range(number_of_sub_items)
+                ]
+            sess.commit()
+
+    @testing.variation("use_in", [True, False])
+    def test_multiple_queries(self, use_in):
+        Item, SubItem = self.classes("Item", "SubItem")
+
+        for sub_item_number in (1, 2, 3):
+            s = fixture_session()
+            base_query = s.query(Item)
+
+            base_query = base_query.filter(Item.number > 5, Item.number <= 10)
+
+            if use_in:
+                base_query = base_query.options(
+                    subqueryload(
+                        Item.sub_items.and_(
+                            SubItem.number.in_([sub_item_number, 18, 12])
+                        )
+                    )
+                )
+            else:
+                base_query = base_query.options(
+                    subqueryload(
+                        Item.sub_items.and_(SubItem.number == sub_item_number)
+                    )
+                )
+
+            items = list(base_query)
+
+            eq_(len(items), 5)
+
+            for item in items:
+                sub_items = list(item.sub_items)
+                eq_(len(sub_items), 1)
+
+                for sub_item in sub_items:
+                    eq_(sub_item.number, sub_item_number)