From: Mike Bayer Date: Wed, 15 Oct 2025 18:47:38 +0000 (-0400) Subject: fully copy_internals for AnnotatedFromClause for straight cloned traverse X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=03116b8cc90986a2e597d5423c490babf49c9913;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git fully copy_internals for AnnotatedFromClause for straight cloned traverse Fixed issue where using :meth:`_sql.Select.params` to replace bound parameters in a query could fail for some cases where the parameters were embedded in subqueries or CTEs when ORM classes were involved, due to issues with internal query traversal for these cases. Fixes: #12915 Change-Id: Ib63bca786a541682f6b2144fd5dd43350411ae9d --- diff --git a/doc/build/changelog/unreleased_20/12915.rst b/doc/build/changelog/unreleased_20/12915.rst new file mode 100644 index 0000000000..399c474715 --- /dev/null +++ b/doc/build/changelog/unreleased_20/12915.rst @@ -0,0 +1,8 @@ +.. change:: + :tags: bug, sql + :tickets: 12915 + + Fixed issue where using :meth:`_sql.Select.params` to replace bound + parameters in a query could fail for some cases where the parameters + were embedded in subqueries or CTEs when ORM classes were involved, + due to issues with internal query traversal for these cases. diff --git a/lib/sqlalchemy/sql/annotation.py b/lib/sqlalchemy/sql/annotation.py index 74b0467ebd..fe951e74c0 100644 --- a/lib/sqlalchemy/sql/annotation.py +++ b/lib/sqlalchemy/sql/annotation.py @@ -468,7 +468,9 @@ def _deep_annotate( newelem = elem newelem._copy_internals( - clone=clone, ind_cols_on_fromclause=ind_cols_on_fromclause + clone=clone, + ind_cols_on_fromclause=ind_cols_on_fromclause, + _annotations_traversal=True, ) cloned_ids[id_] = newelem @@ -508,7 +510,7 @@ def _deep_deannotate( if key not in cloned: newelem = elem._deannotate(values=values, clone=True) - newelem._copy_internals(clone=clone) + newelem._copy_internals(clone=clone, _annotations_traversal=True) cloned[key] = newelem return newelem else: @@ -529,7 +531,7 @@ def _shallow_annotate(element: _SA, annotations: _AnnotationDict) -> _SA: structure wasting time. """ element = element._annotate(annotations) - element._copy_internals() + element._copy_internals(_annotations_traversal=True) return element diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 60f062ec95..fa2079500b 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -7293,11 +7293,28 @@ TextAsFrom = TextualSelect class AnnotatedFromClause(Annotated): - def _copy_internals(self, **kw: Any) -> None: + def _copy_internals( + self, + _annotations_traversal: bool = False, + ind_cols_on_fromclause: bool = False, + **kw: Any, + ) -> None: super()._copy_internals(**kw) - if kw.get("ind_cols_on_fromclause", False): + + # passed from annotations._shallow_annotate(), _deep_annotate(), etc. + # the traversals used by annotations for these cases are not currently + # designed around expecting that inner elements inside of + # AnnotatedFromClause's element are also deep copied, so skip for these + # cases. in other cases such as plain visitors.cloned_traverse(), we + # expect this to happen. see issue #12915 + if not _annotations_traversal: ee = self._Annotated__element # type: ignore + ee._copy_internals(**kw) + if ind_cols_on_fromclause: + # passed from annotations._deep_annotate(). See that function + # for notes + ee = self._Annotated__element # type: ignore self.c = ee.__class__.c.fget(self) # type: ignore @util.ro_memoized_property diff --git a/test/orm/test_core_compilation.py b/test/orm/test_core_compilation.py index b49d4286bf..57c0ae36b0 100644 --- a/test/orm/test_core_compilation.py +++ b/test/orm/test_core_compilation.py @@ -36,6 +36,7 @@ from sqlalchemy.orm import with_loader_criteria from sqlalchemy.orm import with_polymorphic from sqlalchemy.sql import and_ from sqlalchemy.sql import sqltypes +from sqlalchemy.sql import visitors from sqlalchemy.sql.selectable import Join as core_join from sqlalchemy.sql.selectable import LABEL_STYLE_DISAMBIGUATE_ONLY from sqlalchemy.sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL @@ -365,6 +366,48 @@ class SelectableTest(QueryTest, AssertsCompiledSQL): checkparams={"param_1": 6, "param_2": 5}, ) + @testing.variation("use_get_params", [True, False]) + def test_annotated_cte_params_traverse(self, use_get_params): + """test #12915 + + test that .params() applied to a statement that includes + an annotated CTE will traverse into the CTE's internal structures + to replace the bound parameters. + + """ + User = self.classes.User + + ids_param = bindparam("ids") + cte = select(User).where(User.id == ids_param).cte("cte") + + ca = cte._annotate({"foo": "bar"}) + + stmt = select(ca) + + if use_get_params: + stmt = stmt.params(ids=17) + else: + # test without using params(), in case the implementation + # for params() changes we still want to test cloned_traverse + def visit_bindparam(bind): + if bind.key == "ids": + bind.value = 17 + bind.required = False + + stmt = visitors.cloned_traverse( + stmt, + {"maintain_key": True, "detect_subquery_cols": True}, + {"bindparam": visit_bindparam}, + ) + + self.assert_compile( + stmt, + "WITH cte AS (SELECT users.id AS id, users.name AS name " + "FROM users WHERE users.id = :ids) " + "SELECT cte.id, cte.name FROM cte", + checkparams={"ids": 17}, + ) + class PropagateAttrsTest(QueryTest): __backend__ = True diff --git a/test/sql/test_selectable.py b/test/sql/test_selectable.py index 7451de6747..3172c970d2 100644 --- a/test/sql/test_selectable.py +++ b/test/sql/test_selectable.py @@ -3136,6 +3136,42 @@ class AnnotationsTest(fixtures.TestBase): t = Table("t", MetaData(), c1) is_(c1_a.table, t) + @testing.variation("use_get_params", [True, False]) + def test_annotated_cte_params_traverse(self, use_get_params): + """test #12915""" + user = Table("user", MetaData(), Column("id", Integer)) + + ids_param = bindparam("ids") + + cte = select(user).where(user.c.id == ids_param).cte("cte") + + ca = cte._annotate({"foo": "bar"}) + + stmt = select(ca) + + if use_get_params: + stmt = stmt.params(ids=17) + else: + # test without using params(), in case the implementation + # for params() changes we still want to test cloned_traverse + def visit_bindparam(bind): + if bind.key == "ids": + bind.value = 17 + bind.required = False + + stmt = visitors.cloned_traverse( + stmt, + {"maintain_key": True, "detect_subquery_cols": True}, + {"bindparam": visit_bindparam}, + ) + + eq_( + stmt.selected_columns.id.table.element._where_criteria[ + 0 + ].right.value, + 17, + ) + def test_basic_attrs(self): t = Table( "t",