From: Mike Bayer Date: Wed, 15 Oct 2025 18:47:38 +0000 (-0400) Subject: fully copy_internals for AnnotatedFromClause for straight cloned traverse X-Git-Tag: rel_2_0_45~52^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=64a6278ffe2658b0dab7d11ffeac77ea8a09fd8b;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git fully copy_internals for AnnotatedFromClause for straight cloned traverse Fixed issue where using :meth:`_sql.Select.params` to replace bound parameters in a query could fail for some cases where the parameters were embedded in subqueries or CTEs when ORM classes were involved, due to issues with internal query traversal for these cases. Fixes: #12915 Change-Id: Ib63bca786a541682f6b2144fd5dd43350411ae9d (cherry picked from commit 03116b8cc90986a2e597d5423c490babf49c9913) --- diff --git a/doc/build/changelog/unreleased_20/12915.rst b/doc/build/changelog/unreleased_20/12915.rst new file mode 100644 index 0000000000..399c474715 --- /dev/null +++ b/doc/build/changelog/unreleased_20/12915.rst @@ -0,0 +1,8 @@ +.. change:: + :tags: bug, sql + :tickets: 12915 + + Fixed issue where using :meth:`_sql.Select.params` to replace bound + parameters in a query could fail for some cases where the parameters + were embedded in subqueries or CTEs when ORM classes were involved, + due to issues with internal query traversal for these cases. diff --git a/lib/sqlalchemy/sql/annotation.py b/lib/sqlalchemy/sql/annotation.py index bf445ff330..27aa7cd2bd 100644 --- a/lib/sqlalchemy/sql/annotation.py +++ b/lib/sqlalchemy/sql/annotation.py @@ -465,7 +465,9 @@ def _deep_annotate( newelem = elem newelem._copy_internals( - clone=clone, ind_cols_on_fromclause=ind_cols_on_fromclause + clone=clone, + ind_cols_on_fromclause=ind_cols_on_fromclause, + _annotations_traversal=True, ) cloned_ids[id_] = newelem @@ -505,7 +507,7 @@ def _deep_deannotate( if key not in cloned: newelem = elem._deannotate(values=values, clone=True) - newelem._copy_internals(clone=clone) + newelem._copy_internals(clone=clone, _annotations_traversal=True) cloned[key] = newelem return newelem else: @@ -526,7 +528,7 @@ def _shallow_annotate(element: _SA, annotations: _AnnotationDict) -> _SA: structure wasting time. """ element = element._annotate(annotations) - element._copy_internals() + element._copy_internals(_annotations_traversal=True) return element diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 07f17570d7..8d1f60db12 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -7208,11 +7208,28 @@ TextAsFrom = TextualSelect class AnnotatedFromClause(Annotated): - def _copy_internals(self, **kw: Any) -> None: + def _copy_internals( + self, + _annotations_traversal: bool = False, + ind_cols_on_fromclause: bool = False, + **kw: Any, + ) -> None: super()._copy_internals(**kw) - if kw.get("ind_cols_on_fromclause", False): + + # passed from annotations._shallow_annotate(), _deep_annotate(), etc. + # the traversals used by annotations for these cases are not currently + # designed around expecting that inner elements inside of + # AnnotatedFromClause's element are also deep copied, so skip for these + # cases. in other cases such as plain visitors.cloned_traverse(), we + # expect this to happen. see issue #12915 + if not _annotations_traversal: ee = self._Annotated__element # type: ignore + ee._copy_internals(**kw) + if ind_cols_on_fromclause: + # passed from annotations._deep_annotate(). See that function + # for notes + ee = self._Annotated__element # type: ignore self.c = ee.__class__.c.fget(self) # type: ignore @util.ro_memoized_property diff --git a/test/orm/test_core_compilation.py b/test/orm/test_core_compilation.py index a961962d91..af72783aa1 100644 --- a/test/orm/test_core_compilation.py +++ b/test/orm/test_core_compilation.py @@ -35,6 +35,7 @@ from sqlalchemy.orm import with_loader_criteria from sqlalchemy.orm import with_polymorphic from sqlalchemy.sql import and_ from sqlalchemy.sql import sqltypes +from sqlalchemy.sql import visitors from sqlalchemy.sql.selectable import Join as core_join from sqlalchemy.sql.selectable import LABEL_STYLE_DISAMBIGUATE_ONLY from sqlalchemy.sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL @@ -363,6 +364,48 @@ class SelectableTest(QueryTest, AssertsCompiledSQL): checkparams={"param_1": 6, "param_2": 5}, ) + @testing.variation("use_get_params", [True, False]) + def test_annotated_cte_params_traverse(self, use_get_params): + """test #12915 + + test that .params() applied to a statement that includes + an annotated CTE will traverse into the CTE's internal structures + to replace the bound parameters. + + """ + User = self.classes.User + + ids_param = bindparam("ids") + cte = select(User).where(User.id == ids_param).cte("cte") + + ca = cte._annotate({"foo": "bar"}) + + stmt = select(ca) + + if use_get_params: + stmt = stmt.params(ids=17) + else: + # test without using params(), in case the implementation + # for params() changes we still want to test cloned_traverse + def visit_bindparam(bind): + if bind.key == "ids": + bind.value = 17 + bind.required = False + + stmt = visitors.cloned_traverse( + stmt, + {"maintain_key": True, "detect_subquery_cols": True}, + {"bindparam": visit_bindparam}, + ) + + self.assert_compile( + stmt, + "WITH cte AS (SELECT users.id AS id, users.name AS name " + "FROM users WHERE users.id = :ids) " + "SELECT cte.id, cte.name FROM cte", + checkparams={"ids": 17}, + ) + class PropagateAttrsTest(QueryTest): def propagate_cases(): diff --git a/test/sql/test_selectable.py b/test/sql/test_selectable.py index e69b09dfb8..49645e559b 100644 --- a/test/sql/test_selectable.py +++ b/test/sql/test_selectable.py @@ -3154,6 +3154,42 @@ class AnnotationsTest(fixtures.TestBase): t = Table("t", MetaData(), c1) is_(c1_a.table, t) + @testing.variation("use_get_params", [True, False]) + def test_annotated_cte_params_traverse(self, use_get_params): + """test #12915""" + user = Table("user", MetaData(), Column("id", Integer)) + + ids_param = bindparam("ids") + + cte = select(user).where(user.c.id == ids_param).cte("cte") + + ca = cte._annotate({"foo": "bar"}) + + stmt = select(ca) + + if use_get_params: + stmt = stmt.params(ids=17) + else: + # test without using params(), in case the implementation + # for params() changes we still want to test cloned_traverse + def visit_bindparam(bind): + if bind.key == "ids": + bind.value = 17 + bind.required = False + + stmt = visitors.cloned_traverse( + stmt, + {"maintain_key": True, "detect_subquery_cols": True}, + {"bindparam": visit_bindparam}, + ) + + eq_( + stmt.selected_columns.id.table.element._where_criteria[ + 0 + ].right.value, + 17, + ) + def test_basic_attrs(self): t = Table( "t",