From: Mike Bayer Date: Mon, 20 Sep 2021 17:43:21 +0000 (-0400) Subject: include setup_joins targets when scanning for FROM objects to clone X-Git-Tag: rel_1_4_24~4^2 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=5b05440b6778b8505988265dd49e968f30c900e0;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git include setup_joins targets when scanning for FROM objects to clone Fixed a two issues where combinations of ``select()`` and ``join()`` when adapted to form a copy of the element would not completely copy the state of all column objects associated with subqueries. A key problem this caused is that usage of the :meth:`_sql.ClauseElement.params` method (which should probably be moved into a legacy category as it is inefficient and error prone) would leave copies of the old :class:`_sql.BindParameter` objects around, leading to issues in correctly setting the parameters at execution time. Fixes: #7055 Change-Id: Ib822a978a99561b4402da3fb727b370f5c58210b --- diff --git a/doc/build/changelog/unreleased_14/7055.rst b/doc/build/changelog/unreleased_14/7055.rst new file mode 100644 index 0000000000..50d0c4e495 --- /dev/null +++ b/doc/build/changelog/unreleased_14/7055.rst @@ -0,0 +1,14 @@ +.. change:: + :tags: bug, sql + :tickets: 7055 + + Fixed a two issues where combinations of ``select()`` and ``join()`` when + adapted to form a copy of the element would not completely copy the state + of all column objects associated with subqueries. A key problem this caused + is that usage of the :meth:`_sql.ClauseElement.params` method (which should + probably be moved into a legacy category as it is inefficient and error + prone) would leave copies of the old :class:`_sql.BindParameter` objects + around, leading to issues in correctly setting the parameters at execution + time. + + diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index aa218052ea..970c7a0c56 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -1138,6 +1138,39 @@ class Join(roles.DMLTableRole, FromClause): itertools.chain(*[col.foreign_keys for col in columns]) ) + def _copy_internals(self, clone=_clone, **kw): + # see Select._copy_internals() for similar concept + + # here we pre-clone "left" and "right" so that we can + # determine the new FROM clauses + all_the_froms = set( + itertools.chain( + _from_objects(self.left), + _from_objects(self.right), + ) + ) + + # run the clone on those. these will be placed in the + # cache used by the clone function + new_froms = {f: clone(f, **kw) for f in all_the_froms} + + # set up a special replace function that will replace for + # ColumnClause with parent table referring to those + # replaced FromClause objects + def replace(obj, **kw): + if isinstance(obj, ColumnClause) and obj.table in new_froms: + newelem = new_froms[obj.table].corresponding_column(obj) + return newelem + + kw["replace"] = replace + + # run normal _copy_internals. the clones for + # left and right will come from the clone function's + # cache + super(Join, self)._copy_internals(clone=clone, **kw) + + self._reset_memoizations() + def _refresh_for_new_column(self, column): super(Join, self)._refresh_for_new_column(column) self.left._refresh_for_new_column(column) @@ -5519,6 +5552,7 @@ class Select( itertools.chain( _from_objects(*self._raw_columns), _from_objects(*self._where_criteria), + _from_objects(*[elem[0] for elem in self._setup_joins]), ) ) diff --git a/test/sql/test_external_traversal.py b/test/sql/test_external_traversal.py index 764bcd6d4a..3d1b4fe85e 100644 --- a/test/sql/test_external_traversal.py +++ b/test/sql/test_external_traversal.py @@ -9,6 +9,7 @@ from sqlalchemy import extract from sqlalchemy import ForeignKey from sqlalchemy import func from sqlalchemy import Integer +from sqlalchemy import join from sqlalchemy import literal from sqlalchemy import literal_column from sqlalchemy import MetaData @@ -42,6 +43,7 @@ from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ from sqlalchemy.testing import is_not +from sqlalchemy.testing.schema import eq_clause_element from sqlalchemy.util import pickle A = B = t1 = t2 = t3 = table1 = table2 = table3 = table4 = None @@ -795,6 +797,109 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL): eq_(u2.compile().params, {"id_param": 7}) eq_(u3.compile().params, {"id_param": 10}) + def test_params_elements_in_setup_joins(self): + """test #7055""" + + meta = MetaData() + + X = Table("x", meta, Column("a", Integer), Column("b", Integer)) + Y = Table("y", meta, Column("a", Integer), Column("b", Integer)) + s1 = select(X.c.a).where(X.c.b == bindparam("xb")).alias("s1") + jj = ( + select(Y) + .join(s1, Y.c.a == s1.c.a) + .where(Y.c.b == bindparam("yb")) + .alias("s2") + ) + + params = {"xb": 42, "yb": 33} + sel = select(Y).select_from(jj).params(params) + + eq_( + [ + eq_clause_element(bindparam("yb", value=33)), + eq_clause_element(bindparam("xb", value=42)), + ], + sel._generate_cache_key()[1], + ) + + def test_params_subqueries_in_joins_one(self): + """test #7055""" + + meta = MetaData() + + Pe = Table( + "pe", + meta, + Column("c", Integer), + Column("p", Integer), + Column("pid", Integer), + ) + S = Table( + "s", + meta, + Column("c", Integer), + Column("p", Integer), + Column("sid", Integer), + ) + Ps = Table("ps", meta, Column("c", Integer), Column("p", Integer)) + params = {"pid": 42, "sid": 33} + + pe_s = select(Pe).where(Pe.c.pid == bindparam("pid")).alias("pe_s") + s_s = select(S).where(S.c.sid == bindparam("sid")).alias("s_s") + jj = join( + Ps, + join(pe_s, s_s, and_(pe_s.c.c == s_s.c.c, pe_s.c.p == s_s.c.p)), + and_(Ps.c.c == pe_s.c.c, Ps.c.p == Ps.c.p), + ).params(params) + + eq_( + [ + eq_clause_element(bindparam("pid", value=42)), + eq_clause_element(bindparam("sid", value=33)), + ], + jj._generate_cache_key()[1], + ) + + def test_params_subqueries_in_joins_two(self): + """test #7055""" + + meta = MetaData() + + Pe = Table( + "pe", + meta, + Column("c", Integer), + Column("p", Integer), + Column("pid", Integer), + ) + S = Table( + "s", + meta, + Column("c", Integer), + Column("p", Integer), + Column("sid", Integer), + ) + Ps = Table("ps", meta, Column("c", Integer), Column("p", Integer)) + + params = {"pid": 42, "sid": 33} + + pe_s = select(Pe).where(Pe.c.pid == bindparam("pid")).alias("pe_s") + s_s = select(S).where(S.c.sid == bindparam("sid")).alias("s_s") + jj = ( + join(Ps, pe_s, and_(Ps.c.c == pe_s.c.c, Ps.c.p == Ps.c.p)) + .join(s_s, and_(Ps.c.c == s_s.c.c, Ps.c.p == s_s.c.p)) + .params(params) + ) + + eq_( + [ + eq_clause_element(bindparam("pid", value=42)), + eq_clause_element(bindparam("sid", value=33)), + ], + jj._generate_cache_key()[1], + ) + def test_in(self): expr = t1.c.col1.in_(["foo", "bar"]) expr2 = CloningVisitor().traverse(expr)