From: Mike Bayer Date: Mon, 23 Jun 2025 13:21:59 +0000 (-0400) Subject: hardening against inappropriate multi-table updates X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=dc0d0817622435ea46b33575fd4f84d3959dc42d;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git hardening against inappropriate multi-table updates Hardening of the compiler's actions for UPDATE statements that access multiple tables to report more specifically when tables or aliases are referenced in the SET clause; on cases where the backend does not support secondary tables in the SET clause, an explicit error is raised, and on the MySQL or similar backends that support such a SET clause, more specific checking for not-properly-included tables is performed. Overall the change is preventing these erroneous forms of UPDATE statements from being compiled, whereas previously it was relied on the database to raise an error, which was not always guaranteed to happen, or to be non-ambiguous, due to cases where the parent table included the same column name as the secondary table column being updated. Fixed bug where the ORM would pull in the wrong column into an UPDATE when a key name inside of the :meth:`.ValuesBase.values` method could be located from an ORM entity mentioned in the statement, but where that ORM entity was not the actual table that the statement was inserting or updating. An extra check for this edge case is added to avoid this problem. Fixes: #12692 Change-Id: I342832b09dda7ed494caaad0cbb81b93fc10fe18 --- diff --git a/doc/build/changelog/unreleased_20/12692.rst b/doc/build/changelog/unreleased_20/12692.rst new file mode 100644 index 0000000000..b2a48b6cef --- /dev/null +++ b/doc/build/changelog/unreleased_20/12692.rst @@ -0,0 +1,26 @@ +.. change:: + :tags: bug, sql + :tickets: 12692 + + Hardening of the compiler's actions for UPDATE statements that access + multiple tables to report more specifically when tables or aliases are + referenced in the SET clause; on cases where the backend does not support + secondary tables in the SET clause, an explicit error is raised, and on the + MySQL or similar backends that support such a SET clause, more specific + checking for not-properly-included tables is performed. Overall the change + is preventing these erroneous forms of UPDATE statements from being + compiled, whereas previously it was relied on the database to raise an + error, which was not always guaranteed to happen, or to be non-ambiguous, + due to cases where the parent table included the same column name as the + secondary table column being updated. + + +.. change:: + :tags: bug, orm + :tickets: 12692 + + Fixed bug where the ORM would pull in the wrong column into an UPDATE when + a key name inside of the :meth:`.ValuesBase.values` method could be located + from an ORM entity mentioned in the statement, but where that ORM entity + was not the actual table that the statement was inserting or updating. An + extra check for this edge case is added to avoid this problem. diff --git a/lib/sqlalchemy/orm/bulk_persistence.py b/lib/sqlalchemy/orm/bulk_persistence.py index 2664c9f979..7918c3ba84 100644 --- a/lib/sqlalchemy/orm/bulk_persistence.py +++ b/lib/sqlalchemy/orm/bulk_persistence.py @@ -446,11 +446,24 @@ class _ORMDMLState(_AbstractORMCompileState): ), ) + @classmethod + def _get_dml_plugin_subject(cls, statement): + plugin_subject = statement.table._propagate_attrs.get("plugin_subject") + + if ( + not plugin_subject + or not plugin_subject.mapper + or plugin_subject + is not statement._propagate_attrs["plugin_subject"] + ): + return None + return plugin_subject + @classmethod def _get_multi_crud_kv_pairs(cls, statement, kv_iterator): - plugin_subject = statement._propagate_attrs["plugin_subject"] + plugin_subject = cls._get_dml_plugin_subject(statement) - if not plugin_subject or not plugin_subject.mapper: + if not plugin_subject: return UpdateDMLState._get_multi_crud_kv_pairs( statement, kv_iterator ) @@ -470,13 +483,12 @@ class _ORMDMLState(_AbstractORMCompileState): needs_to_be_cacheable ), "no test coverage for needs_to_be_cacheable=False" - plugin_subject = statement._propagate_attrs["plugin_subject"] + plugin_subject = cls._get_dml_plugin_subject(statement) - if not plugin_subject or not plugin_subject.mapper: + if not plugin_subject: return UpdateDMLState._get_crud_kv_pairs( statement, kv_iterator, needs_to_be_cacheable ) - return list( cls._get_orm_crud_kv_pairs( plugin_subject.mapper, diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index 265b15c1e9..e75a3ea1c9 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -327,6 +327,52 @@ def _get_crud_params( .difference(check_columns) ) if check: + + if dml.isupdate(compile_state): + tables_mentioned = set( + c.table + for c, v in stmt_parameter_tuples + if isinstance(c, ColumnClause) and c.table is not None + ).difference([compile_state.dml_table]) + + multi_not_in_from = tables_mentioned.difference( + compile_state._extra_froms + ) + + if tables_mentioned and ( + not compile_state.is_multitable + or not compiler.render_table_with_column_in_update_from + ): + if not compiler.render_table_with_column_in_update_from: + preamble = ( + "Backend does not support additional " + "tables in the SET clause" + ) + else: + preamble = ( + "Statement is not a multi-table UPDATE statement" + ) + + raise exc.CompileError( + f"{preamble}; cannot " + f"""include columns from table(s) { + ", ".join(f"'{t.description}'" + for t in tables_mentioned) + } in SET clause""" + ) + + elif multi_not_in_from: + assert compiler.render_table_with_column_in_update_from + raise exc.CompileError( + f"Multi-table UPDATE statement does not include " + "table(s) " + f"""{ + ", ".join( + f"'{t.description}'" for + t in multi_not_in_from) + }""" + ) + raise exc.CompileError( "Unconsumed column names: %s" % (", ".join("%s" % (c,) for c in check)) @@ -1364,9 +1410,28 @@ def _get_update_multitable_params( affected_tables = set() for t in compile_state._extra_froms: + # extra gymnastics to support the probably-shouldnt-have-supported + # case of "UPDATE table AS alias SET table.foo = bar", but it's + # supported + we_shouldnt_be_here_if_columns_found = ( + not include_table + and not compile_state.dml_table.is_derived_from(t) + ) + for c in t.c: if c in normalized_params: + + if we_shouldnt_be_here_if_columns_found: + raise exc.CompileError( + "Backend does not support additional tables " + "in the SET " + "clause; cannot include columns from table(s) " + f"'{t.description}' in " + "SET clause" + ) + affected_tables.add(t) + check_columns[_getattr_col_key(c)] = c value = normalized_params[c] @@ -1392,6 +1457,7 @@ def _get_update_multitable_params( value = compiler.process(value.self_group(), **kw) accumulated_bind_names = () values.append((c, col_value, value, accumulated_bind_names)) + # determine tables which are actually to be updated - process onupdate # and server_onupdate for these for t in affected_tables: diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index c7ca0ba795..73b936d24f 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -1752,7 +1752,7 @@ class AliasedReturnsRows(NoInit, NamedFromClause): def description(self) -> str: name = self.name if isinstance(name, _anonymous_label): - name = "anon_1" + return "anon_1" return name @@ -1792,6 +1792,14 @@ class AliasedReturnsRows(NoInit, NamedFromClause): class FromClauseAlias(AliasedReturnsRows): element: FromClause + @util.ro_non_memoized_property + def description(self) -> str: + name = self.name + if isinstance(name, _anonymous_label): + return f"Anonymous alias of {self.element.description}" + + return name + class Alias(roles.DMLTableRole, FromClauseAlias): """Represents an table or selectable alias (AS). diff --git a/test/engine/test_reflection.py b/test/engine/test_reflection.py index adb4037065..6ba130add3 100644 --- a/test/engine/test_reflection.py +++ b/test/engine/test_reflection.py @@ -2478,7 +2478,8 @@ class IncludeColsFksTest(AssertsCompiledSQL, fixtures.TestBase): # the existing alias doesn't know about it with expect_raises_message( sa.exc.InvalidRequestError, - "Foreign key associated with column 'anon_1.r' could not find " + "Foreign key associated with column 'Anonymous alias of b.r' " + "could not find " "table 'a' with which to generate a foreign key to target " "column 'x'", ): diff --git a/test/orm/dml/test_bulk_statements.py b/test/orm/dml/test_bulk_statements.py index 6d69b2250c..fcc908377b 100644 --- a/test/orm/dml/test_bulk_statements.py +++ b/test/orm/dml/test_bulk_statements.py @@ -1,6 +1,7 @@ from __future__ import annotations import contextlib +import datetime from typing import Any from typing import List from typing import Optional @@ -39,18 +40,22 @@ from sqlalchemy.orm import relationship from sqlalchemy.orm import selectinload from sqlalchemy.orm import Session from sqlalchemy.orm import subqueryload +from sqlalchemy.sql import coercions +from sqlalchemy.sql import roles from sqlalchemy.testing import config from sqlalchemy.testing import eq_ from sqlalchemy.testing import expect_deprecated from sqlalchemy.testing import expect_raises_message from sqlalchemy.testing import expect_warnings from sqlalchemy.testing import fixtures +from sqlalchemy.testing import is_ from sqlalchemy.testing import mock from sqlalchemy.testing import provision from sqlalchemy.testing.assertsql import CompiledSQL from sqlalchemy.testing.assertsql import Conditional from sqlalchemy.testing.entities import ComparableEntity from sqlalchemy.testing.fixtures import fixture_session +from sqlalchemy.types import NullType class InsertStmtTest(testing.AssertsExecutionResults, fixtures.TestBase): @@ -2736,3 +2741,76 @@ class EagerLoadTest( b3 = sess.scalar(stmt, [{"a_id": 3}]) eq_({c.id for c in b3.a.cs}, {3, 4}) + + +class DMLCompileScenariosTest(testing.AssertsCompiledSQL, fixtures.TestBase): + __dialect__ = "default_enhanced" # for UPDATE..FROM + + @testing.variation("style", ["insert", "upsert"]) + def test_insert_values_from_primary_table_only(self, decl_base, style): + """test for #12692""" + + class A(decl_base): + __tablename__ = "a" + id: Mapped[int] = mapped_column(Identity(), primary_key=True) + data: Mapped[int] + + class B(decl_base): + __tablename__ = "b" + id: Mapped[int] = mapped_column(Identity(), primary_key=True) + data: Mapped[str] + + stmt = insert(A.__table__) + + # we're trying to exercise codepaths in orm/bulk_persistence.py that + # would only apply to an insert() statement against the ORM entity, + # e.g. insert(A). In the update() case, the WHERE clause can also + # pull in the ORM entity, which is how we found the issue here, but + # for INSERT there's no current method that does this; returning() + # could do this in theory but currently doesnt. So for now, cheat, + # and pretend there's some conversion that's going to propagate + # from an ORM expression + coercions.expect( + roles.WhereHavingRole, B.id == 5, apply_propagate_attrs=stmt + ) + + if style.insert: + stmt = stmt.values(data=123) + + # assert that the ORM did not get involved, putting B.data as the + # key in the dictionary + is_(stmt._values["data"].type._type_affinity, NullType) + elif style.upsert: + stmt = stmt.values([{"data": 123}, {"data": 456}]) + + # assert that the ORM did not get involved, putting B.data as the + # keys in the dictionaries + eq_(stmt._multi_values, ([{"data": 123}, {"data": 456}],)) + else: + style.fail() + + def test_update_values_from_primary_table_only(self, decl_base): + """test for #12692""" + + class A(decl_base): + __tablename__ = "a" + id: Mapped[int] = mapped_column(Identity(), primary_key=True) + data: Mapped[str] + updated_at: Mapped[datetime.datetime] = mapped_column( + onupdate=func.now() + ) + + class B(decl_base): + __tablename__ = "b" + id: Mapped[int] = mapped_column(Identity(), primary_key=True) + data: Mapped[str] + updated_at: Mapped[datetime.datetime] = mapped_column( + onupdate=func.now() + ) + + stmt = update(A.__table__).where(B.id == 1).values(data="some data") + self.assert_compile( + stmt, + "UPDATE a SET data=:data, updated_at=now() " + "FROM b WHERE b.id = :id_1", + ) diff --git a/test/sql/test_update.py b/test/sql/test_update.py index b381cb010e..9a533040e9 100644 --- a/test/sql/test_update.py +++ b/test/sql/test_update.py @@ -3,6 +3,7 @@ import random from sqlalchemy import bindparam from sqlalchemy import column +from sqlalchemy import DateTime from sqlalchemy import exc from sqlalchemy import exists from sqlalchemy import ForeignKey @@ -41,6 +42,14 @@ class _UpdateFromTestBase: Column("name", String(30)), Column("description", String(50)), ) + Table( + "mytable_with_onupdate", + metadata, + Column("myid", Integer), + Column("name", String(30)), + Column("description", String(50)), + Column("updated_at", DateTime, onupdate=func.now()), + ) Table( "myothertable", metadata, @@ -626,19 +635,36 @@ class UpdateTest(_UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL): t.update().values(x=5, z=5).compile, ) - def test_unconsumed_names_values_dict(self): + @testing.variation("include_in_from", [True, False]) + @testing.variation("use_mysql", [True, False]) + def test_unconsumed_names_values_dict(self, include_in_from, use_mysql): t = table("t", column("x"), column("y")) t2 = table("t2", column("q"), column("z")) - assert_raises_message( - exc.CompileError, - "Unconsumed column names: j", - t.update() - .values(x=5, j=7) - .values({t2.c.z: 5}) - .where(t.c.x == t2.c.q) - .compile, - ) + stmt = t.update().values(x=5, j=7).values({t2.c.z: 5}) + if include_in_from: + stmt = stmt.where(t.c.x == t2.c.q) + + if use_mysql: + if not include_in_from: + msg = ( + "Statement is not a multi-table UPDATE statement; cannot " + r"include columns from table\(s\) 't2' in SET clause" + ) + else: + msg = "Unconsumed column names: j" + else: + msg = ( + "Backend does not support additional tables in the SET " + r"clause; cannot include columns from table\(s\) 't2' in " + "SET clause" + ) + + with expect_raises_message(exc.CompileError, msg): + if use_mysql: + stmt.compile(dialect=mysql.dialect()) + else: + stmt.compile() def test_unconsumed_names_kwargs_w_keys(self): t = table("t", column("x"), column("y")) @@ -999,55 +1025,165 @@ class UpdateFromCompileTest( run_create_tables = run_inserts = run_deletes = None - def test_alias_one(self): - table1 = self.tables.mytable + @testing.variation("use_onupdate", [True, False]) + def test_alias_one(self, use_onupdate): + + if use_onupdate: + table1 = self.tables.mytable_with_onupdate + tname = "mytable_with_onupdate" + else: + table1 = self.tables.mytable + tname = "mytable" talias1 = table1.alias("t1") # this case is nonsensical. the UPDATE is entirely # against the alias, but we name the table-bound column - # in values. The behavior here isn't really defined + # in values. The behavior here isn't really defined. + # onupdates get skipped. self.assert_compile( update(talias1) .where(talias1.c.myid == 7) .values({table1.c.name: "fred"}), - "UPDATE mytable AS t1 " + f"UPDATE {tname} AS t1 " "SET name=:name " "WHERE t1.myid = :myid_1", ) - def test_alias_two(self): - table1 = self.tables.mytable + @testing.variation("use_onupdate", [True, False]) + def test_alias_two(self, use_onupdate): + """test a multi-table UPDATE/SET is actually supported on SQLite, PG + if we are only using an alias of the main table + + """ + if use_onupdate: + table1 = self.tables.mytable_with_onupdate + tname = "mytable_with_onupdate" + onupdate = ", updated_at=now() " + else: + table1 = self.tables.mytable + tname = "mytable" + onupdate = " " talias1 = table1.alias("t1") # Here, compared to # test_alias_one(), here we actually have UPDATE..FROM, # which is causing the "table1.c.name" param to be handled - # as an "extra table", hence we see the full table name rendered. + # as an "extra table", hence we see the full table name rendered + # as well as ON UPDATEs coming in nicely. self.assert_compile( update(talias1) .where(table1.c.myid == 7) .values({table1.c.name: "fred"}), - "UPDATE mytable AS t1 " - "SET name=:mytable_name " - "FROM mytable " - "WHERE mytable.myid = :myid_1", - checkparams={"mytable_name": "fred", "myid_1": 7}, - ) - - def test_alias_two_mysql(self): - table1 = self.tables.mytable + f"UPDATE {tname} AS t1 " + f"SET name=:{tname}_name{onupdate}" + f"FROM {tname} " + f"WHERE {tname}.myid = :myid_1", + checkparams={f"{tname}_name": "fred", "myid_1": 7}, + ) + + @testing.variation("use_onupdate", [True, False]) + def test_alias_two_mysql(self, use_onupdate): + if use_onupdate: + table1 = self.tables.mytable_with_onupdate + tname = "mytable_with_onupdate" + onupdate = ", mytable_with_onupdate.updated_at=now() " + else: + table1 = self.tables.mytable + tname = "mytable" + onupdate = " " talias1 = table1.alias("t1") self.assert_compile( update(talias1) .where(table1.c.myid == 7) .values({table1.c.name: "fred"}), - "UPDATE mytable AS t1, mytable SET mytable.name=%s " - "WHERE mytable.myid = %s", - checkparams={"mytable_name": "fred", "myid_1": 7}, + f"UPDATE {tname} AS t1, {tname} SET {tname}.name=%s{onupdate}" + f"WHERE {tname}.myid = %s", + checkparams={f"{tname}_name": "fred", "myid_1": 7}, dialect="mysql", ) + @testing.variation("use_alias", [True, False]) + @testing.variation("use_alias_in_set", [True, False]) + @testing.variation("include_in_from", [True, False]) + @testing.variation("use_mysql", [True, False]) + def test_raise_if_totally_different_table( + self, use_alias, include_in_from, use_alias_in_set, use_mysql + ): + """test cases for #12962""" + table1 = self.tables.mytable + table2 = self.tables.myothertable + + if use_alias: + target = table1.alias("t1") + else: + target = table1 + + stmt = update(target).where(table1.c.myid == 7) + + if use_alias_in_set: + stmt = stmt.values({table2.alias().c.othername: "fred"}) + else: + stmt = stmt.values({table2.c.othername: "fred"}) + + if include_in_from: + stmt = stmt.where(table2.c.otherid == 12) + + if use_mysql and include_in_from and not use_alias_in_set: + if not use_alias: + self.assert_compile( + stmt, + "UPDATE mytable, myothertable " + "SET myothertable.othername=%s " + "WHERE mytable.myid = %s AND myothertable.otherid = %s", + dialect="mysql", + ) + else: + self.assert_compile( + stmt, + "UPDATE mytable AS t1, mytable, myothertable " + "SET myothertable.othername=%s WHERE mytable.myid = %s " + "AND myothertable.otherid = %s", + dialect="mysql", + ) + return + + if use_alias_in_set: + tabledesc = "Anonymous alias of myothertable" + else: + tabledesc = "myothertable" + + if use_mysql: + if include_in_from: + msg = ( + r"Multi-table UPDATE statement does not include " + rf"table\(s\) '{tabledesc}'" + ) + else: + if use_alias: + msg = ( + rf"Multi-table UPDATE statement does not include " + rf"table\(s\) '{tabledesc}'" + ) + else: + msg = ( + rf"Statement is not a multi-table UPDATE statement; " + r"cannot include columns from table\(s\) " + rf"'{tabledesc}' in SET clause" + ) + else: + msg = ( + r"Backend does not support additional tables in the " + r"SET clause; cannot include columns from table\(s\) " + rf"'{tabledesc}' in SET clause" + ) + + with expect_raises_message(exc.CompileError, msg): + if use_mysql: + stmt.compile(dialect=mysql.dialect()) + else: + stmt.compile() + def test_update_from_multitable_same_name_mysql(self): users, addresses = self.tables.users, self.tables.addresses