]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
hardening against inappropriate multi-table updates
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 23 Jun 2025 13:21:59 +0000 (09:21 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 24 Jun 2025 14:19:48 +0000 (10:19 -0400)
Hardening of the compiler's actions for UPDATE statements that access
multiple tables to report more specifically when tables or aliases are
referenced in the SET clause; on cases where the backend does not support
secondary tables in the SET clause, an explicit error is raised, and on the
MySQL or similar backends that support such a SET clause, more specific
checking for not-properly-included tables is performed.  Overall the change
is preventing these erroneous forms of UPDATE statements from being
compiled, whereas previously it was relied on the database to raise an
error, which was not always guaranteed to happen, or to be non-ambiguous,
due to cases where the parent table included the same column name as the
secondary table column being updated.

Fixed bug where the ORM would pull in the wrong column into an UPDATE when
a key name inside of the :meth:`.ValuesBase.values` method could be located
from an ORM entity mentioned in the statement, but where that ORM entity
was not the actual table that the statement was inserting or updating.  An
extra check for this edge case is added to avoid this problem.

Fixes: #12692
Change-Id: I342832b09dda7ed494caaad0cbb81b93fc10fe18

doc/build/changelog/unreleased_20/12692.rst [new file with mode: 0644]
lib/sqlalchemy/orm/bulk_persistence.py
lib/sqlalchemy/sql/crud.py
lib/sqlalchemy/sql/selectable.py
test/engine/test_reflection.py
test/orm/dml/test_bulk_statements.py
test/sql/test_update.py

diff --git a/doc/build/changelog/unreleased_20/12692.rst b/doc/build/changelog/unreleased_20/12692.rst
new file mode 100644 (file)
index 0000000..b2a48b6
--- /dev/null
@@ -0,0 +1,26 @@
+.. change::
+    :tags: bug, sql
+    :tickets: 12692
+
+    Hardening of the compiler's actions for UPDATE statements that access
+    multiple tables to report more specifically when tables or aliases are
+    referenced in the SET clause; on cases where the backend does not support
+    secondary tables in the SET clause, an explicit error is raised, and on the
+    MySQL or similar backends that support such a SET clause, more specific
+    checking for not-properly-included tables is performed.  Overall the change
+    is preventing these erroneous forms of UPDATE statements from being
+    compiled, whereas previously it was relied on the database to raise an
+    error, which was not always guaranteed to happen, or to be non-ambiguous,
+    due to cases where the parent table included the same column name as the
+    secondary table column being updated.
+
+
+.. change::
+    :tags: bug, orm
+    :tickets: 12692
+
+    Fixed bug where the ORM would pull in the wrong column into an UPDATE when
+    a key name inside of the :meth:`.ValuesBase.values` method could be located
+    from an ORM entity mentioned in the statement, but where that ORM entity
+    was not the actual table that the statement was inserting or updating.  An
+    extra check for this edge case is added to avoid this problem.
index 2664c9f9798db1dfdc4b2c09b8b1a92670fc8e93..7918c3ba84af2f71b4f1b75d4b484a27d99cec73 100644 (file)
@@ -446,11 +446,24 @@ class _ORMDMLState(_AbstractORMCompileState):
                     ),
                 )
 
+    @classmethod
+    def _get_dml_plugin_subject(cls, statement):
+        plugin_subject = statement.table._propagate_attrs.get("plugin_subject")
+
+        if (
+            not plugin_subject
+            or not plugin_subject.mapper
+            or plugin_subject
+            is not statement._propagate_attrs["plugin_subject"]
+        ):
+            return None
+        return plugin_subject
+
     @classmethod
     def _get_multi_crud_kv_pairs(cls, statement, kv_iterator):
-        plugin_subject = statement._propagate_attrs["plugin_subject"]
+        plugin_subject = cls._get_dml_plugin_subject(statement)
 
-        if not plugin_subject or not plugin_subject.mapper:
+        if not plugin_subject:
             return UpdateDMLState._get_multi_crud_kv_pairs(
                 statement, kv_iterator
             )
@@ -470,13 +483,12 @@ class _ORMDMLState(_AbstractORMCompileState):
             needs_to_be_cacheable
         ), "no test coverage for needs_to_be_cacheable=False"
 
-        plugin_subject = statement._propagate_attrs["plugin_subject"]
+        plugin_subject = cls._get_dml_plugin_subject(statement)
 
-        if not plugin_subject or not plugin_subject.mapper:
+        if not plugin_subject:
             return UpdateDMLState._get_crud_kv_pairs(
                 statement, kv_iterator, needs_to_be_cacheable
             )
-
         return list(
             cls._get_orm_crud_kv_pairs(
                 plugin_subject.mapper,
index 265b15c1e9fb0e3f19149ff6e754b7c6284981ac..e75a3ea1c96ee68ad652c67c4ba5d3ffaf6fb4e2 100644 (file)
@@ -327,6 +327,52 @@ def _get_crud_params(
             .difference(check_columns)
         )
         if check:
+
+            if dml.isupdate(compile_state):
+                tables_mentioned = set(
+                    c.table
+                    for c, v in stmt_parameter_tuples
+                    if isinstance(c, ColumnClause) and c.table is not None
+                ).difference([compile_state.dml_table])
+
+                multi_not_in_from = tables_mentioned.difference(
+                    compile_state._extra_froms
+                )
+
+                if tables_mentioned and (
+                    not compile_state.is_multitable
+                    or not compiler.render_table_with_column_in_update_from
+                ):
+                    if not compiler.render_table_with_column_in_update_from:
+                        preamble = (
+                            "Backend does not support additional "
+                            "tables in the SET clause"
+                        )
+                    else:
+                        preamble = (
+                            "Statement is not a multi-table UPDATE statement"
+                        )
+
+                    raise exc.CompileError(
+                        f"{preamble}; cannot "
+                        f"""include columns from table(s) {
+                            ", ".join(f"'{t.description}'"
+                                      for t in tables_mentioned)
+                        } in SET clause"""
+                    )
+
+                elif multi_not_in_from:
+                    assert compiler.render_table_with_column_in_update_from
+                    raise exc.CompileError(
+                        f"Multi-table UPDATE statement does not include "
+                        "table(s) "
+                        f"""{
+                            ", ".join(
+                                f"'{t.description}'" for
+                                t in multi_not_in_from)
+                        }"""
+                    )
+
             raise exc.CompileError(
                 "Unconsumed column names: %s"
                 % (", ".join("%s" % (c,) for c in check))
@@ -1364,9 +1410,28 @@ def _get_update_multitable_params(
 
     affected_tables = set()
     for t in compile_state._extra_froms:
+        # extra gymnastics to support the probably-shouldnt-have-supported
+        # case of "UPDATE table AS alias SET table.foo = bar", but it's
+        # supported
+        we_shouldnt_be_here_if_columns_found = (
+            not include_table
+            and not compile_state.dml_table.is_derived_from(t)
+        )
+
         for c in t.c:
             if c in normalized_params:
+
+                if we_shouldnt_be_here_if_columns_found:
+                    raise exc.CompileError(
+                        "Backend does not support additional tables "
+                        "in the SET "
+                        "clause; cannot include columns from table(s) "
+                        f"'{t.description}' in "
+                        "SET clause"
+                    )
+
                 affected_tables.add(t)
+
                 check_columns[_getattr_col_key(c)] = c
                 value = normalized_params[c]
 
@@ -1392,6 +1457,7 @@ def _get_update_multitable_params(
                     value = compiler.process(value.self_group(), **kw)
                     accumulated_bind_names = ()
                 values.append((c, col_value, value, accumulated_bind_names))
+
     # determine tables which are actually to be updated - process onupdate
     # and server_onupdate for these
     for t in affected_tables:
index c7ca0ba795b46a2238a7ffa1dd6fc1652247875c..73b936d24fafe6a7efb104181113bb38b008e837 100644 (file)
@@ -1752,7 +1752,7 @@ class AliasedReturnsRows(NoInit, NamedFromClause):
     def description(self) -> str:
         name = self.name
         if isinstance(name, _anonymous_label):
-            name = "anon_1"
+            return "anon_1"
 
         return name
 
@@ -1792,6 +1792,14 @@ class AliasedReturnsRows(NoInit, NamedFromClause):
 class FromClauseAlias(AliasedReturnsRows):
     element: FromClause
 
+    @util.ro_non_memoized_property
+    def description(self) -> str:
+        name = self.name
+        if isinstance(name, _anonymous_label):
+            return f"Anonymous alias of {self.element.description}"
+
+        return name
+
 
 class Alias(roles.DMLTableRole, FromClauseAlias):
     """Represents an table or selectable alias (AS).
index adb4037065512c6f6d3637ffdb568e74e43941ad..6ba130add341816742f1f69a3845d8bdf123dac9 100644 (file)
@@ -2478,7 +2478,8 @@ class IncludeColsFksTest(AssertsCompiledSQL, fixtures.TestBase):
         # the existing alias doesn't know about it
         with expect_raises_message(
             sa.exc.InvalidRequestError,
-            "Foreign key associated with column 'anon_1.r' could not find "
+            "Foreign key associated with column 'Anonymous alias of b.r' "
+            "could not find "
             "table 'a' with which to generate a foreign key to target "
             "column 'x'",
         ):
index 6d69b2250c3fab6033b097800dfbc06d21df3e3b..fcc908377b98506d699f07d0e0c8831d88cc57b2 100644 (file)
@@ -1,6 +1,7 @@
 from __future__ import annotations
 
 import contextlib
+import datetime
 from typing import Any
 from typing import List
 from typing import Optional
@@ -39,18 +40,22 @@ from sqlalchemy.orm import relationship
 from sqlalchemy.orm import selectinload
 from sqlalchemy.orm import Session
 from sqlalchemy.orm import subqueryload
+from sqlalchemy.sql import coercions
+from sqlalchemy.sql import roles
 from sqlalchemy.testing import config
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import expect_deprecated
 from sqlalchemy.testing import expect_raises_message
 from sqlalchemy.testing import expect_warnings
 from sqlalchemy.testing import fixtures
+from sqlalchemy.testing import is_
 from sqlalchemy.testing import mock
 from sqlalchemy.testing import provision
 from sqlalchemy.testing.assertsql import CompiledSQL
 from sqlalchemy.testing.assertsql import Conditional
 from sqlalchemy.testing.entities import ComparableEntity
 from sqlalchemy.testing.fixtures import fixture_session
+from sqlalchemy.types import NullType
 
 
 class InsertStmtTest(testing.AssertsExecutionResults, fixtures.TestBase):
@@ -2736,3 +2741,76 @@ class EagerLoadTest(
                 b3 = sess.scalar(stmt, [{"a_id": 3}])
 
                 eq_({c.id for c in b3.a.cs}, {3, 4})
+
+
+class DMLCompileScenariosTest(testing.AssertsCompiledSQL, fixtures.TestBase):
+    __dialect__ = "default_enhanced"  # for UPDATE..FROM
+
+    @testing.variation("style", ["insert", "upsert"])
+    def test_insert_values_from_primary_table_only(self, decl_base, style):
+        """test for #12692"""
+
+        class A(decl_base):
+            __tablename__ = "a"
+            id: Mapped[int] = mapped_column(Identity(), primary_key=True)
+            data: Mapped[int]
+
+        class B(decl_base):
+            __tablename__ = "b"
+            id: Mapped[int] = mapped_column(Identity(), primary_key=True)
+            data: Mapped[str]
+
+        stmt = insert(A.__table__)
+
+        # we're trying to exercise codepaths in orm/bulk_persistence.py that
+        # would only apply to an insert() statement against the ORM entity,
+        # e.g. insert(A).  In the update() case, the WHERE clause can also
+        # pull in the ORM entity, which is how we found the issue here, but
+        # for INSERT there's no current method that does this; returning()
+        # could do this in theory but currently doesnt.  So for now, cheat,
+        # and pretend there's some conversion that's going to propagate
+        # from an ORM expression
+        coercions.expect(
+            roles.WhereHavingRole, B.id == 5, apply_propagate_attrs=stmt
+        )
+
+        if style.insert:
+            stmt = stmt.values(data=123)
+
+            # assert that the ORM did not get involved, putting B.data as the
+            # key in the dictionary
+            is_(stmt._values["data"].type._type_affinity, NullType)
+        elif style.upsert:
+            stmt = stmt.values([{"data": 123}, {"data": 456}])
+
+            # assert that the ORM did not get involved, putting B.data as the
+            # keys in the dictionaries
+            eq_(stmt._multi_values, ([{"data": 123}, {"data": 456}],))
+        else:
+            style.fail()
+
+    def test_update_values_from_primary_table_only(self, decl_base):
+        """test for #12692"""
+
+        class A(decl_base):
+            __tablename__ = "a"
+            id: Mapped[int] = mapped_column(Identity(), primary_key=True)
+            data: Mapped[str]
+            updated_at: Mapped[datetime.datetime] = mapped_column(
+                onupdate=func.now()
+            )
+
+        class B(decl_base):
+            __tablename__ = "b"
+            id: Mapped[int] = mapped_column(Identity(), primary_key=True)
+            data: Mapped[str]
+            updated_at: Mapped[datetime.datetime] = mapped_column(
+                onupdate=func.now()
+            )
+
+        stmt = update(A.__table__).where(B.id == 1).values(data="some data")
+        self.assert_compile(
+            stmt,
+            "UPDATE a SET data=:data, updated_at=now() "
+            "FROM b WHERE b.id = :id_1",
+        )
index b381cb010e8695ebf0cca050a2ffbee000c69503..9a533040e925440ba2314cfa70a7094d51247e4b 100644 (file)
@@ -3,6 +3,7 @@ import random
 
 from sqlalchemy import bindparam
 from sqlalchemy import column
+from sqlalchemy import DateTime
 from sqlalchemy import exc
 from sqlalchemy import exists
 from sqlalchemy import ForeignKey
@@ -41,6 +42,14 @@ class _UpdateFromTestBase:
             Column("name", String(30)),
             Column("description", String(50)),
         )
+        Table(
+            "mytable_with_onupdate",
+            metadata,
+            Column("myid", Integer),
+            Column("name", String(30)),
+            Column("description", String(50)),
+            Column("updated_at", DateTime, onupdate=func.now()),
+        )
         Table(
             "myothertable",
             metadata,
@@ -626,19 +635,36 @@ class UpdateTest(_UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL):
             t.update().values(x=5, z=5).compile,
         )
 
-    def test_unconsumed_names_values_dict(self):
+    @testing.variation("include_in_from", [True, False])
+    @testing.variation("use_mysql", [True, False])
+    def test_unconsumed_names_values_dict(self, include_in_from, use_mysql):
         t = table("t", column("x"), column("y"))
         t2 = table("t2", column("q"), column("z"))
 
-        assert_raises_message(
-            exc.CompileError,
-            "Unconsumed column names: j",
-            t.update()
-            .values(x=5, j=7)
-            .values({t2.c.z: 5})
-            .where(t.c.x == t2.c.q)
-            .compile,
-        )
+        stmt = t.update().values(x=5, j=7).values({t2.c.z: 5})
+        if include_in_from:
+            stmt = stmt.where(t.c.x == t2.c.q)
+
+        if use_mysql:
+            if not include_in_from:
+                msg = (
+                    "Statement is not a multi-table UPDATE statement; cannot "
+                    r"include columns from table\(s\) 't2' in SET clause"
+                )
+            else:
+                msg = "Unconsumed column names: j"
+        else:
+            msg = (
+                "Backend does not support additional tables in the SET "
+                r"clause; cannot include columns from table\(s\) 't2' in "
+                "SET clause"
+            )
+
+        with expect_raises_message(exc.CompileError, msg):
+            if use_mysql:
+                stmt.compile(dialect=mysql.dialect())
+            else:
+                stmt.compile()
 
     def test_unconsumed_names_kwargs_w_keys(self):
         t = table("t", column("x"), column("y"))
@@ -999,55 +1025,165 @@ class UpdateFromCompileTest(
 
     run_create_tables = run_inserts = run_deletes = None
 
-    def test_alias_one(self):
-        table1 = self.tables.mytable
+    @testing.variation("use_onupdate", [True, False])
+    def test_alias_one(self, use_onupdate):
+
+        if use_onupdate:
+            table1 = self.tables.mytable_with_onupdate
+            tname = "mytable_with_onupdate"
+        else:
+            table1 = self.tables.mytable
+            tname = "mytable"
         talias1 = table1.alias("t1")
 
         # this case is nonsensical.  the UPDATE is entirely
         # against the alias, but we name the table-bound column
-        # in values.   The behavior here isn't really defined
+        # in values.   The behavior here isn't really defined.
+        # onupdates get skipped.
         self.assert_compile(
             update(talias1)
             .where(talias1.c.myid == 7)
             .values({table1.c.name: "fred"}),
-            "UPDATE mytable AS t1 "
+            f"UPDATE {tname} AS t1 "
             "SET name=:name "
             "WHERE t1.myid = :myid_1",
         )
 
-    def test_alias_two(self):
-        table1 = self.tables.mytable
+    @testing.variation("use_onupdate", [True, False])
+    def test_alias_two(self, use_onupdate):
+        """test a multi-table UPDATE/SET is actually supported on SQLite, PG
+        if we are only using an alias of the main table
+
+        """
+        if use_onupdate:
+            table1 = self.tables.mytable_with_onupdate
+            tname = "mytable_with_onupdate"
+            onupdate = ", updated_at=now() "
+        else:
+            table1 = self.tables.mytable
+            tname = "mytable"
+            onupdate = " "
         talias1 = table1.alias("t1")
 
         # Here, compared to
         # test_alias_one(), here we actually have UPDATE..FROM,
         # which is causing the "table1.c.name" param to be handled
-        # as an "extra table", hence we see the full table name rendered.
+        # as an "extra table", hence we see the full table name rendered
+        # as well as ON UPDATEs coming in nicely.
         self.assert_compile(
             update(talias1)
             .where(table1.c.myid == 7)
             .values({table1.c.name: "fred"}),
-            "UPDATE mytable AS t1 "
-            "SET name=:mytable_name "
-            "FROM mytable "
-            "WHERE mytable.myid = :myid_1",
-            checkparams={"mytable_name": "fred", "myid_1": 7},
-        )
-
-    def test_alias_two_mysql(self):
-        table1 = self.tables.mytable
+            f"UPDATE {tname} AS t1 "
+            f"SET name=:{tname}_name{onupdate}"
+            f"FROM {tname} "
+            f"WHERE {tname}.myid = :myid_1",
+            checkparams={f"{tname}_name": "fred", "myid_1": 7},
+        )
+
+    @testing.variation("use_onupdate", [True, False])
+    def test_alias_two_mysql(self, use_onupdate):
+        if use_onupdate:
+            table1 = self.tables.mytable_with_onupdate
+            tname = "mytable_with_onupdate"
+            onupdate = ", mytable_with_onupdate.updated_at=now() "
+        else:
+            table1 = self.tables.mytable
+            tname = "mytable"
+            onupdate = " "
         talias1 = table1.alias("t1")
 
         self.assert_compile(
             update(talias1)
             .where(table1.c.myid == 7)
             .values({table1.c.name: "fred"}),
-            "UPDATE mytable AS t1, mytable SET mytable.name=%s "
-            "WHERE mytable.myid = %s",
-            checkparams={"mytable_name": "fred", "myid_1": 7},
+            f"UPDATE {tname} AS t1, {tname} SET {tname}.name=%s{onupdate}"
+            f"WHERE {tname}.myid = %s",
+            checkparams={f"{tname}_name": "fred", "myid_1": 7},
             dialect="mysql",
         )
 
+    @testing.variation("use_alias", [True, False])
+    @testing.variation("use_alias_in_set", [True, False])
+    @testing.variation("include_in_from", [True, False])
+    @testing.variation("use_mysql", [True, False])
+    def test_raise_if_totally_different_table(
+        self, use_alias, include_in_from, use_alias_in_set, use_mysql
+    ):
+        """test cases for #12962"""
+        table1 = self.tables.mytable
+        table2 = self.tables.myothertable
+
+        if use_alias:
+            target = table1.alias("t1")
+        else:
+            target = table1
+
+        stmt = update(target).where(table1.c.myid == 7)
+
+        if use_alias_in_set:
+            stmt = stmt.values({table2.alias().c.othername: "fred"})
+        else:
+            stmt = stmt.values({table2.c.othername: "fred"})
+
+        if include_in_from:
+            stmt = stmt.where(table2.c.otherid == 12)
+
+        if use_mysql and include_in_from and not use_alias_in_set:
+            if not use_alias:
+                self.assert_compile(
+                    stmt,
+                    "UPDATE mytable, myothertable "
+                    "SET myothertable.othername=%s "
+                    "WHERE mytable.myid = %s AND myothertable.otherid = %s",
+                    dialect="mysql",
+                )
+            else:
+                self.assert_compile(
+                    stmt,
+                    "UPDATE mytable AS t1, mytable, myothertable "
+                    "SET myothertable.othername=%s WHERE mytable.myid = %s "
+                    "AND myothertable.otherid = %s",
+                    dialect="mysql",
+                )
+            return
+
+        if use_alias_in_set:
+            tabledesc = "Anonymous alias of myothertable"
+        else:
+            tabledesc = "myothertable"
+
+        if use_mysql:
+            if include_in_from:
+                msg = (
+                    r"Multi-table UPDATE statement does not include "
+                    rf"table\(s\) '{tabledesc}'"
+                )
+            else:
+                if use_alias:
+                    msg = (
+                        rf"Multi-table UPDATE statement does not include "
+                        rf"table\(s\) '{tabledesc}'"
+                    )
+                else:
+                    msg = (
+                        rf"Statement is not a multi-table UPDATE statement; "
+                        r"cannot include columns from table\(s\) "
+                        rf"'{tabledesc}' in SET clause"
+                    )
+        else:
+            msg = (
+                r"Backend does not support additional tables in the "
+                r"SET clause; cannot include columns from table\(s\) "
+                rf"'{tabledesc}' in SET clause"
+            )
+
+        with expect_raises_message(exc.CompileError, msg):
+            if use_mysql:
+                stmt.compile(dialect=mysql.dialect())
+            else:
+                stmt.compile()
+
     def test_update_from_multitable_same_name_mysql(self):
         users, addresses = self.tables.users, self.tables.addresses