From: Mike Bayer Date: Fri, 30 Apr 2021 15:54:52 +0000 (-0400) Subject: track_on needs to be a fixed size, support sub-tuples X-Git-Tag: rel_1_4_13~9 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=a47c158a9a3b1104698fc0bff47ca58d67cb9191;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git track_on needs to be a fixed size, support sub-tuples Fixed regression in ``selectinload`` loader strategy that would cause it to cache its internal state incorrectly when handling relationships that join across more than one column, such as when using a composite foreign key. The invalid caching would then cause other loader operations to fail. Fixes: #6410 Change-Id: I9f95ccca3553e7fd5794c619be4cf85c02b04626 --- diff --git a/doc/build/changelog/unreleased_14/6410.rst b/doc/build/changelog/unreleased_14/6410.rst new file mode 100644 index 0000000000..be997c6b8c --- /dev/null +++ b/doc/build/changelog/unreleased_14/6410.rst @@ -0,0 +1,9 @@ +.. change:: + :tags: bug, regression, orm + :tickets: 6410 + + Fixed regression in ``selectinload`` loader strategy that would cause it to + cache its internal state incorrectly when handling relationships that join + across more than one column, such as when using a composite foreign key. + The invalid caching would then cause other loader operations to fail. + diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 0f68a3fef8..e43fa09a0b 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -2865,7 +2865,7 @@ class SelectInLoader(PostLoader, util.MemoizedSlots): ), lambda_cache=self._query_cache, global_track_bound_values=False, - track_on=(self, effective_entity) + tuple(pk_cols), + track_on=(self, effective_entity) + (tuple(pk_cols),), ) if not self.parent_property.bake_queries: diff --git a/lib/sqlalchemy/sql/lambdas.py b/lib/sqlalchemy/sql/lambdas.py index 06db8f95e3..ddc4774db8 100644 --- a/lib/sqlalchemy/sql/lambdas.py +++ b/lib/sqlalchemy/sql/lambdas.py @@ -772,7 +772,16 @@ class AnalyzedCode(object): from the "track_on" parameter passed to a :class:`.LambdaElement`. """ - if isinstance(elem, traversals.HasCacheKey): + + if isinstance(elem, tuple): + # tuple must contain hascachekey elements + def get(closure, opts, anon_map, bindparams): + return tuple( + tup_elem._gen_cache_key(anon_map, bindparams) + for tup_elem in opts.track_on[idx] + ) + + elif isinstance(elem, traversals.HasCacheKey): def get(closure, opts, anon_map, bindparams): return opts.track_on[idx]._gen_cache_key(anon_map, bindparams) diff --git a/test/orm/test_selectin_relations.py b/test/orm/test_selectin_relations.py index ec642a71ce..5ea259da37 100644 --- a/test/orm/test_selectin_relations.py +++ b/test/orm/test_selectin_relations.py @@ -30,6 +30,7 @@ from sqlalchemy.testing import mock from sqlalchemy.testing.assertsql import AllOf from sqlalchemy.testing.assertsql import assert_engine from sqlalchemy.testing.assertsql import CompiledSQL +from sqlalchemy.testing.fixtures import ComparableEntity from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -3588,3 +3589,66 @@ class TestBakedCancelsCorrectly(fixtures.DeclarativeMappedTest): self.assert_sql_count(testing.db, go, 2) self.assert_sql_count(testing.db, go, 2) self.assert_sql_count(testing.db, go, 2) + + +class TestCompositePlusNonComposite(fixtures.DeclarativeMappedTest): + @classmethod + def setup_classes(cls): + Base = cls.DeclarativeBasic + + from sqlalchemy.sql import lambdas + from sqlalchemy.orm import configure_mappers + + lambdas._closure_per_cache_key.clear() + lambdas.AnalyzedCode._fns.clear() + + class A(ComparableEntity, Base): + __tablename__ = "a" + + id = Column(Integer, primary_key=True) + bs = relationship("B", lazy="selectin") + + class B(ComparableEntity, Base): + __tablename__ = "b" + id = Column(Integer, primary_key=True) + a_id = Column(ForeignKey("a.id")) + + class A2(ComparableEntity, Base): + __tablename__ = "a2" + + id = Column(Integer, primary_key=True) + id2 = Column(Integer, primary_key=True) + bs = relationship("B2", lazy="selectin") + + class B2(ComparableEntity, Base): + __tablename__ = "b2" + id = Column(Integer, primary_key=True) + a_id = Column(Integer) + a_id2 = Column(Integer) + __table_args__ = ( + ForeignKeyConstraint(["a_id", "a_id2"], ["a2.id", "a2.id2"]), + ) + + configure_mappers() + + @classmethod + def insert_data(cls, connection): + A, B, A2, B2 = cls.classes("A", "B", "A2", "B2") + s = Session(connection) + + s.add(A(bs=[B()])) + s.add(A2(id=1, id2=1, bs=[B2()])) + + s.commit() + + def test_load_composite_then_non_composite(self): + + A, B, A2, B2 = self.classes("A", "B", "A2", "B2") + + s = fixture_session() + + a2 = s.query(A2).first() + a1 = s.query(A).first() + + eq_(a2.bs, [B2()]) + eq_(a1.bs, [B()]) diff --git a/test/sql/test_lambdas.py b/test/sql/test_lambdas.py index 897c60f003..cdfd92ece5 100644 --- a/test/sql/test_lambdas.py +++ b/test/sql/test_lambdas.py @@ -454,7 +454,13 @@ class LambdaElementTest( checkparams={"y_1": 18, "p_1": 12}, ) - def test_stmt_lambda_w_atonce_whereclause_customtrack_binds(self): + @testing.combinations( + (True,), + (False,), + ) + def test_stmt_lambda_w_atonce_whereclause_customtrack_binds( + self, use_tuple + ): c2 = column("y") # this pattern is *completely unnecessary*, and I would prefer @@ -463,14 +469,31 @@ class LambdaElementTest( # however I also can't come up with a reliable way to catch it. # so we will keep the use of "track_on" to be internal. - def go(col_expr, whereclause, p): - stmt = lambdas.lambda_stmt(lambda: select(col_expr)) - stmt = stmt.add_criteria( - lambda stmt: stmt.where(whereclause).order_by(col_expr > p), - track_on=(whereclause, whereclause.right.value), - ) + if use_tuple: - return stmt + def go(col_expr, whereclause, p): + stmt = lambdas.lambda_stmt(lambda: select(col_expr)) + stmt = stmt.add_criteria( + lambda stmt: stmt.where(whereclause).order_by( + col_expr > p + ), + track_on=((whereclause,), whereclause.right.value), + ) + + return stmt + + else: + + def go(col_expr, whereclause, p): + stmt = lambdas.lambda_stmt(lambda: select(col_expr)) + stmt = stmt.add_criteria( + lambda stmt: stmt.where(whereclause).order_by( + col_expr > p + ), + track_on=(whereclause, whereclause.right.value), + ) + + return stmt c1 = column("x") c2 = column("y")