]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
include setup_joins targets when scanning for FROM objects to clone
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 20 Sep 2021 17:43:21 +0000 (13:43 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 20 Sep 2021 19:50:07 +0000 (15:50 -0400)
Fixed a two issues where combinations of ``select()`` and ``join()`` when
adapted to form a copy of the element would not completely copy the state
of all column objects associated with subqueries. A key problem this caused
is that usage of the :meth:`_sql.ClauseElement.params` method (which should
probably be moved into a legacy category as it is inefficient and error
prone) would leave copies of the old :class:`_sql.BindParameter` objects
around, leading to issues in correctly setting the parameters at execution
time.

Fixes: #7055
Change-Id: Ib822a978a99561b4402da3fb727b370f5c58210b

doc/build/changelog/unreleased_14/7055.rst [new file with mode: 0644]
lib/sqlalchemy/sql/selectable.py
test/sql/test_external_traversal.py

diff --git a/doc/build/changelog/unreleased_14/7055.rst b/doc/build/changelog/unreleased_14/7055.rst
new file mode 100644 (file)
index 0000000..50d0c4e
--- /dev/null
@@ -0,0 +1,14 @@
+.. change::
+    :tags: bug, sql
+    :tickets: 7055
+
+    Fixed a two issues where combinations of ``select()`` and ``join()`` when
+    adapted to form a copy of the element would not completely copy the state
+    of all column objects associated with subqueries. A key problem this caused
+    is that usage of the :meth:`_sql.ClauseElement.params` method (which should
+    probably be moved into a legacy category as it is inefficient and error
+    prone) would leave copies of the old :class:`_sql.BindParameter` objects
+    around, leading to issues in correctly setting the parameters at execution
+    time.
+
+
index aa218052eaab38084d221925d2325c3ac943ada3..970c7a0c567b636fff24345346065dbd5988c4b4 100644 (file)
@@ -1138,6 +1138,39 @@ class Join(roles.DMLTableRole, FromClause):
             itertools.chain(*[col.foreign_keys for col in columns])
         )
 
+    def _copy_internals(self, clone=_clone, **kw):
+        # see Select._copy_internals() for similar concept
+
+        # here we pre-clone "left" and "right" so that we can
+        # determine the new FROM clauses
+        all_the_froms = set(
+            itertools.chain(
+                _from_objects(self.left),
+                _from_objects(self.right),
+            )
+        )
+
+        # run the clone on those.  these will be placed in the
+        # cache used by the clone function
+        new_froms = {f: clone(f, **kw) for f in all_the_froms}
+
+        # set up a special replace function that will replace for
+        # ColumnClause with parent table referring to those
+        # replaced FromClause objects
+        def replace(obj, **kw):
+            if isinstance(obj, ColumnClause) and obj.table in new_froms:
+                newelem = new_froms[obj.table].corresponding_column(obj)
+                return newelem
+
+        kw["replace"] = replace
+
+        # run normal _copy_internals.  the clones for
+        # left and right will come from the clone function's
+        # cache
+        super(Join, self)._copy_internals(clone=clone, **kw)
+
+        self._reset_memoizations()
+
     def _refresh_for_new_column(self, column):
         super(Join, self)._refresh_for_new_column(column)
         self.left._refresh_for_new_column(column)
@@ -5519,6 +5552,7 @@ class Select(
             itertools.chain(
                 _from_objects(*self._raw_columns),
                 _from_objects(*self._where_criteria),
+                _from_objects(*[elem[0] for elem in self._setup_joins]),
             )
         )
 
index 764bcd6d4acd315540ebd17cef036663f23474e0..3d1b4fe85ec1b889bafcc574744e71cb45b2b782 100644 (file)
@@ -9,6 +9,7 @@ from sqlalchemy import extract
 from sqlalchemy import ForeignKey
 from sqlalchemy import func
 from sqlalchemy import Integer
+from sqlalchemy import join
 from sqlalchemy import literal
 from sqlalchemy import literal_column
 from sqlalchemy import MetaData
@@ -42,6 +43,7 @@ from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import is_
 from sqlalchemy.testing import is_not
+from sqlalchemy.testing.schema import eq_clause_element
 from sqlalchemy.util import pickle
 
 A = B = t1 = t2 = t3 = table1 = table2 = table3 = table4 = None
@@ -795,6 +797,109 @@ class ClauseTest(fixtures.TestBase, AssertsCompiledSQL):
         eq_(u2.compile().params, {"id_param": 7})
         eq_(u3.compile().params, {"id_param": 10})
 
+    def test_params_elements_in_setup_joins(self):
+        """test #7055"""
+
+        meta = MetaData()
+
+        X = Table("x", meta, Column("a", Integer), Column("b", Integer))
+        Y = Table("y", meta, Column("a", Integer), Column("b", Integer))
+        s1 = select(X.c.a).where(X.c.b == bindparam("xb")).alias("s1")
+        jj = (
+            select(Y)
+            .join(s1, Y.c.a == s1.c.a)
+            .where(Y.c.b == bindparam("yb"))
+            .alias("s2")
+        )
+
+        params = {"xb": 42, "yb": 33}
+        sel = select(Y).select_from(jj).params(params)
+
+        eq_(
+            [
+                eq_clause_element(bindparam("yb", value=33)),
+                eq_clause_element(bindparam("xb", value=42)),
+            ],
+            sel._generate_cache_key()[1],
+        )
+
+    def test_params_subqueries_in_joins_one(self):
+        """test #7055"""
+
+        meta = MetaData()
+
+        Pe = Table(
+            "pe",
+            meta,
+            Column("c", Integer),
+            Column("p", Integer),
+            Column("pid", Integer),
+        )
+        S = Table(
+            "s",
+            meta,
+            Column("c", Integer),
+            Column("p", Integer),
+            Column("sid", Integer),
+        )
+        Ps = Table("ps", meta, Column("c", Integer), Column("p", Integer))
+        params = {"pid": 42, "sid": 33}
+
+        pe_s = select(Pe).where(Pe.c.pid == bindparam("pid")).alias("pe_s")
+        s_s = select(S).where(S.c.sid == bindparam("sid")).alias("s_s")
+        jj = join(
+            Ps,
+            join(pe_s, s_s, and_(pe_s.c.c == s_s.c.c, pe_s.c.p == s_s.c.p)),
+            and_(Ps.c.c == pe_s.c.c, Ps.c.p == Ps.c.p),
+        ).params(params)
+
+        eq_(
+            [
+                eq_clause_element(bindparam("pid", value=42)),
+                eq_clause_element(bindparam("sid", value=33)),
+            ],
+            jj._generate_cache_key()[1],
+        )
+
+    def test_params_subqueries_in_joins_two(self):
+        """test #7055"""
+
+        meta = MetaData()
+
+        Pe = Table(
+            "pe",
+            meta,
+            Column("c", Integer),
+            Column("p", Integer),
+            Column("pid", Integer),
+        )
+        S = Table(
+            "s",
+            meta,
+            Column("c", Integer),
+            Column("p", Integer),
+            Column("sid", Integer),
+        )
+        Ps = Table("ps", meta, Column("c", Integer), Column("p", Integer))
+
+        params = {"pid": 42, "sid": 33}
+
+        pe_s = select(Pe).where(Pe.c.pid == bindparam("pid")).alias("pe_s")
+        s_s = select(S).where(S.c.sid == bindparam("sid")).alias("s_s")
+        jj = (
+            join(Ps, pe_s, and_(Ps.c.c == pe_s.c.c, Ps.c.p == Ps.c.p))
+            .join(s_s, and_(Ps.c.c == s_s.c.c, Ps.c.p == s_s.c.p))
+            .params(params)
+        )
+
+        eq_(
+            [
+                eq_clause_element(bindparam("pid", value=42)),
+                eq_clause_element(bindparam("sid", value=33)),
+            ],
+            jj._generate_cache_key()[1],
+        )
+
     def test_in(self):
         expr = t1.c.col1.in_(["foo", "bar"])
         expr2 = CloningVisitor().traverse(expr)