]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
track_on needs to be a fixed size, support sub-tuples
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 30 Apr 2021 15:54:52 +0000 (11:54 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 30 Apr 2021 16:35:28 +0000 (12:35 -0400)
Fixed regression in ``selectinload`` loader strategy that would cause it to
cache its internal state incorrectly when handling relationships that join
across more than one column, such as when using a composite foreign key.
The invalid caching would then cause other loader operations to fail.

Fixes: #6410
Change-Id: I9f95ccca3553e7fd5794c619be4cf85c02b04626

doc/build/changelog/unreleased_14/6410.rst [new file with mode: 0644]
lib/sqlalchemy/orm/strategies.py
lib/sqlalchemy/sql/lambdas.py
test/orm/test_selectin_relations.py
test/sql/test_lambdas.py

diff --git a/doc/build/changelog/unreleased_14/6410.rst b/doc/build/changelog/unreleased_14/6410.rst
new file mode 100644 (file)
index 0000000..be997c6
--- /dev/null
@@ -0,0 +1,9 @@
+.. change::
+    :tags: bug, regression, orm
+    :tickets: 6410
+
+    Fixed regression in ``selectinload`` loader strategy that would cause it to
+    cache its internal state incorrectly when handling relationships that join
+    across more than one column, such as when using a composite foreign key.
+    The invalid caching would then cause other loader operations to fail.
+
index 0f68a3fef8f7826dac5df47c5f34fef09354752d..e43fa09a0bd6e20290d4f348bdbc7f485f471c23 100644 (file)
@@ -2865,7 +2865,7 @@ class SelectInLoader(PostLoader, util.MemoizedSlots):
             ),
             lambda_cache=self._query_cache,
             global_track_bound_values=False,
-            track_on=(self, effective_entity) + tuple(pk_cols),
+            track_on=(self, effective_entity) + (tuple(pk_cols),),
         )
 
         if not self.parent_property.bake_queries:
index 06db8f95e395ec86988caabe2d090a831500b95e..ddc4774db857dff0abdafa566af2c0bd0373f6a7 100644 (file)
@@ -772,7 +772,16 @@ class AnalyzedCode(object):
         from the "track_on" parameter passed to a :class:`.LambdaElement`.
 
         """
-        if isinstance(elem, traversals.HasCacheKey):
+
+        if isinstance(elem, tuple):
+            # tuple must contain hascachekey elements
+            def get(closure, opts, anon_map, bindparams):
+                return tuple(
+                    tup_elem._gen_cache_key(anon_map, bindparams)
+                    for tup_elem in opts.track_on[idx]
+                )
+
+        elif isinstance(elem, traversals.HasCacheKey):
 
             def get(closure, opts, anon_map, bindparams):
                 return opts.track_on[idx]._gen_cache_key(anon_map, bindparams)
index ec642a71ce5dc32ca80b5927dc19bff184a59642..5ea259da37c1f85ba6cd1b2d0fd38377e0c2984f 100644 (file)
@@ -30,6 +30,7 @@ from sqlalchemy.testing import mock
 from sqlalchemy.testing.assertsql import AllOf
 from sqlalchemy.testing.assertsql import assert_engine
 from sqlalchemy.testing.assertsql import CompiledSQL
+from sqlalchemy.testing.fixtures import ComparableEntity
 from sqlalchemy.testing.fixtures import fixture_session
 from sqlalchemy.testing.schema import Column
 from sqlalchemy.testing.schema import Table
@@ -3588,3 +3589,66 @@ class TestBakedCancelsCorrectly(fixtures.DeclarativeMappedTest):
         self.assert_sql_count(testing.db, go, 2)
         self.assert_sql_count(testing.db, go, 2)
         self.assert_sql_count(testing.db, go, 2)
+
+
+class TestCompositePlusNonComposite(fixtures.DeclarativeMappedTest):
+    @classmethod
+    def setup_classes(cls):
+        Base = cls.DeclarativeBasic
+
+        from sqlalchemy.sql import lambdas
+        from sqlalchemy.orm import configure_mappers
+
+        lambdas._closure_per_cache_key.clear()
+        lambdas.AnalyzedCode._fns.clear()
+
+        class A(ComparableEntity, Base):
+            __tablename__ = "a"
+
+            id = Column(Integer, primary_key=True)
+            bs = relationship("B", lazy="selectin")
+
+        class B(ComparableEntity, Base):
+            __tablename__ = "b"
+            id = Column(Integer, primary_key=True)
+            a_id = Column(ForeignKey("a.id"))
+
+        class A2(ComparableEntity, Base):
+            __tablename__ = "a2"
+
+            id = Column(Integer, primary_key=True)
+            id2 = Column(Integer, primary_key=True)
+            bs = relationship("B2", lazy="selectin")
+
+        class B2(ComparableEntity, Base):
+            __tablename__ = "b2"
+            id = Column(Integer, primary_key=True)
+            a_id = Column(Integer)
+            a_id2 = Column(Integer)
+            __table_args__ = (
+                ForeignKeyConstraint(["a_id", "a_id2"], ["a2.id", "a2.id2"]),
+            )
+
+        configure_mappers()
+
+    @classmethod
+    def insert_data(cls, connection):
+        A, B, A2, B2 = cls.classes("A", "B", "A2", "B2")
+        s = Session(connection)
+
+        s.add(A(bs=[B()]))
+        s.add(A2(id=1, id2=1, bs=[B2()]))
+
+        s.commit()
+
+    def test_load_composite_then_non_composite(self):
+
+        A, B, A2, B2 = self.classes("A", "B", "A2", "B2")
+
+        s = fixture_session()
+
+        a2 = s.query(A2).first()
+        a1 = s.query(A).first()
+
+        eq_(a2.bs, [B2()])
+        eq_(a1.bs, [B()])
index 897c60f0039a91afa6c3dc14585177e292ce192e..cdfd92ece579b0d152ec923e9bfd0553bcd87e1f 100644 (file)
@@ -454,7 +454,13 @@ class LambdaElementTest(
             checkparams={"y_1": 18, "p_1": 12},
         )
 
-    def test_stmt_lambda_w_atonce_whereclause_customtrack_binds(self):
+    @testing.combinations(
+        (True,),
+        (False,),
+    )
+    def test_stmt_lambda_w_atonce_whereclause_customtrack_binds(
+        self, use_tuple
+    ):
         c2 = column("y")
 
         # this pattern is *completely unnecessary*, and I would prefer
@@ -463,14 +469,31 @@ class LambdaElementTest(
         # however I also can't come up with a reliable way to catch it.
         # so we will keep the use of "track_on" to be internal.
 
-        def go(col_expr, whereclause, p):
-            stmt = lambdas.lambda_stmt(lambda: select(col_expr))
-            stmt = stmt.add_criteria(
-                lambda stmt: stmt.where(whereclause).order_by(col_expr > p),
-                track_on=(whereclause, whereclause.right.value),
-            )
+        if use_tuple:
 
-            return stmt
+            def go(col_expr, whereclause, p):
+                stmt = lambdas.lambda_stmt(lambda: select(col_expr))
+                stmt = stmt.add_criteria(
+                    lambda stmt: stmt.where(whereclause).order_by(
+                        col_expr > p
+                    ),
+                    track_on=((whereclause,), whereclause.right.value),
+                )
+
+                return stmt
+
+        else:
+
+            def go(col_expr, whereclause, p):
+                stmt = lambdas.lambda_stmt(lambda: select(col_expr))
+                stmt = stmt.add_criteria(
+                    lambda stmt: stmt.where(whereclause).order_by(
+                        col_expr > p
+                    ),
+                    track_on=(whereclause, whereclause.right.value),
+                )
+
+                return stmt
 
         c1 = column("x")
         c2 = column("y")