--- /dev/null
+.. change::
+ :tags: bug, sql
+ :tickets: 12692
+
+ Hardening of the compiler's actions for UPDATE statements that access
+ multiple tables to report more specifically when tables or aliases are
+ referenced in the SET clause; on cases where the backend does not support
+ secondary tables in the SET clause, an explicit error is raised, and on the
+ MySQL or similar backends that support such a SET clause, more specific
+ checking for not-properly-included tables is performed. Overall the change
+ is preventing these erroneous forms of UPDATE statements from being
+ compiled, whereas previously it was relied on the database to raise an
+ error, which was not always guaranteed to happen, or to be non-ambiguous,
+ due to cases where the parent table included the same column name as the
+ secondary table column being updated.
+
+
+.. change::
+ :tags: bug, orm
+ :tickets: 12692
+
+ Fixed bug where the ORM would pull in the wrong column into an UPDATE when
+ a key name inside of the :meth:`.ValuesBase.values` method could be located
+ from an ORM entity mentioned in the statement, but where that ORM entity
+ was not the actual table that the statement was inserting or updating. An
+ extra check for this edge case is added to avoid this problem.
),
)
+ @classmethod
+ def _get_dml_plugin_subject(cls, statement):
+ plugin_subject = statement.table._propagate_attrs.get("plugin_subject")
+
+ if (
+ not plugin_subject
+ or not plugin_subject.mapper
+ or plugin_subject
+ is not statement._propagate_attrs["plugin_subject"]
+ ):
+ return None
+ return plugin_subject
+
@classmethod
def _get_multi_crud_kv_pairs(cls, statement, kv_iterator):
- plugin_subject = statement._propagate_attrs["plugin_subject"]
+ plugin_subject = cls._get_dml_plugin_subject(statement)
- if not plugin_subject or not plugin_subject.mapper:
+ if not plugin_subject:
return UpdateDMLState._get_multi_crud_kv_pairs(
statement, kv_iterator
)
needs_to_be_cacheable
), "no test coverage for needs_to_be_cacheable=False"
- plugin_subject = statement._propagate_attrs["plugin_subject"]
+ plugin_subject = cls._get_dml_plugin_subject(statement)
- if not plugin_subject or not plugin_subject.mapper:
+ if not plugin_subject:
return UpdateDMLState._get_crud_kv_pairs(
statement, kv_iterator, needs_to_be_cacheable
)
-
return list(
cls._get_orm_crud_kv_pairs(
plugin_subject.mapper,
.difference(check_columns)
)
if check:
+
+ if dml.isupdate(compile_state):
+ tables_mentioned = set(
+ c.table
+ for c, v in stmt_parameter_tuples
+ if isinstance(c, ColumnClause) and c.table is not None
+ ).difference([compile_state.dml_table])
+
+ multi_not_in_from = tables_mentioned.difference(
+ compile_state._extra_froms
+ )
+
+ if tables_mentioned and (
+ not compile_state.is_multitable
+ or not compiler.render_table_with_column_in_update_from
+ ):
+ if not compiler.render_table_with_column_in_update_from:
+ preamble = (
+ "Backend does not support additional "
+ "tables in the SET clause"
+ )
+ else:
+ preamble = (
+ "Statement is not a multi-table UPDATE statement"
+ )
+
+ raise exc.CompileError(
+ f"{preamble}; cannot "
+ f"""include columns from table(s) {
+ ", ".join(f"'{t.description}'"
+ for t in tables_mentioned)
+ } in SET clause"""
+ )
+
+ elif multi_not_in_from:
+ assert compiler.render_table_with_column_in_update_from
+ raise exc.CompileError(
+ f"Multi-table UPDATE statement does not include "
+ "table(s) "
+ f"""{
+ ", ".join(
+ f"'{t.description}'" for
+ t in multi_not_in_from)
+ }"""
+ )
+
raise exc.CompileError(
"Unconsumed column names: %s"
% (", ".join("%s" % (c,) for c in check))
affected_tables = set()
for t in compile_state._extra_froms:
+ # extra gymnastics to support the probably-shouldnt-have-supported
+ # case of "UPDATE table AS alias SET table.foo = bar", but it's
+ # supported
+ we_shouldnt_be_here_if_columns_found = (
+ not include_table
+ and not compile_state.dml_table.is_derived_from(t)
+ )
+
for c in t.c:
if c in normalized_params:
+
+ if we_shouldnt_be_here_if_columns_found:
+ raise exc.CompileError(
+ "Backend does not support additional tables "
+ "in the SET "
+ "clause; cannot include columns from table(s) "
+ f"'{t.description}' in "
+ "SET clause"
+ )
+
affected_tables.add(t)
+
check_columns[_getattr_col_key(c)] = c
value = normalized_params[c]
value = compiler.process(value.self_group(), **kw)
accumulated_bind_names = ()
values.append((c, col_value, value, accumulated_bind_names))
+
# determine tables which are actually to be updated - process onupdate
# and server_onupdate for these
for t in affected_tables:
def description(self) -> str:
name = self.name
if isinstance(name, _anonymous_label):
- name = "anon_1"
+ return "anon_1"
return name
class FromClauseAlias(AliasedReturnsRows):
element: FromClause
+ @util.ro_non_memoized_property
+ def description(self) -> str:
+ name = self.name
+ if isinstance(name, _anonymous_label):
+ return f"Anonymous alias of {self.element.description}"
+
+ return name
+
class Alias(roles.DMLTableRole, FromClauseAlias):
"""Represents an table or selectable alias (AS).
# the existing alias doesn't know about it
with expect_raises_message(
sa.exc.InvalidRequestError,
- "Foreign key associated with column 'anon_1.r' could not find "
+ "Foreign key associated with column 'Anonymous alias of b.r' "
+ "could not find "
"table 'a' with which to generate a foreign key to target "
"column 'x'",
):
from __future__ import annotations
import contextlib
+import datetime
from typing import Any
from typing import List
from typing import Optional
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from sqlalchemy.orm import subqueryload
+from sqlalchemy.sql import coercions
+from sqlalchemy.sql import roles
from sqlalchemy.testing import config
from sqlalchemy.testing import eq_
from sqlalchemy.testing import expect_deprecated
from sqlalchemy.testing import expect_raises_message
from sqlalchemy.testing import expect_warnings
from sqlalchemy.testing import fixtures
+from sqlalchemy.testing import is_
from sqlalchemy.testing import mock
from sqlalchemy.testing import provision
from sqlalchemy.testing.assertsql import CompiledSQL
from sqlalchemy.testing.assertsql import Conditional
from sqlalchemy.testing.entities import ComparableEntity
from sqlalchemy.testing.fixtures import fixture_session
+from sqlalchemy.types import NullType
class InsertStmtTest(testing.AssertsExecutionResults, fixtures.TestBase):
b3 = sess.scalar(stmt, [{"a_id": 3}])
eq_({c.id for c in b3.a.cs}, {3, 4})
+
+
+class DMLCompileScenariosTest(testing.AssertsCompiledSQL, fixtures.TestBase):
+ __dialect__ = "default_enhanced" # for UPDATE..FROM
+
+ @testing.variation("style", ["insert", "upsert"])
+ def test_insert_values_from_primary_table_only(self, decl_base, style):
+ """test for #12692"""
+
+ class A(decl_base):
+ __tablename__ = "a"
+ id: Mapped[int] = mapped_column(Identity(), primary_key=True)
+ data: Mapped[int]
+
+ class B(decl_base):
+ __tablename__ = "b"
+ id: Mapped[int] = mapped_column(Identity(), primary_key=True)
+ data: Mapped[str]
+
+ stmt = insert(A.__table__)
+
+ # we're trying to exercise codepaths in orm/bulk_persistence.py that
+ # would only apply to an insert() statement against the ORM entity,
+ # e.g. insert(A). In the update() case, the WHERE clause can also
+ # pull in the ORM entity, which is how we found the issue here, but
+ # for INSERT there's no current method that does this; returning()
+ # could do this in theory but currently doesnt. So for now, cheat,
+ # and pretend there's some conversion that's going to propagate
+ # from an ORM expression
+ coercions.expect(
+ roles.WhereHavingRole, B.id == 5, apply_propagate_attrs=stmt
+ )
+
+ if style.insert:
+ stmt = stmt.values(data=123)
+
+ # assert that the ORM did not get involved, putting B.data as the
+ # key in the dictionary
+ is_(stmt._values["data"].type._type_affinity, NullType)
+ elif style.upsert:
+ stmt = stmt.values([{"data": 123}, {"data": 456}])
+
+ # assert that the ORM did not get involved, putting B.data as the
+ # keys in the dictionaries
+ eq_(stmt._multi_values, ([{"data": 123}, {"data": 456}],))
+ else:
+ style.fail()
+
+ def test_update_values_from_primary_table_only(self, decl_base):
+ """test for #12692"""
+
+ class A(decl_base):
+ __tablename__ = "a"
+ id: Mapped[int] = mapped_column(Identity(), primary_key=True)
+ data: Mapped[str]
+ updated_at: Mapped[datetime.datetime] = mapped_column(
+ onupdate=func.now()
+ )
+
+ class B(decl_base):
+ __tablename__ = "b"
+ id: Mapped[int] = mapped_column(Identity(), primary_key=True)
+ data: Mapped[str]
+ updated_at: Mapped[datetime.datetime] = mapped_column(
+ onupdate=func.now()
+ )
+
+ stmt = update(A.__table__).where(B.id == 1).values(data="some data")
+ self.assert_compile(
+ stmt,
+ "UPDATE a SET data=:data, updated_at=now() "
+ "FROM b WHERE b.id = :id_1",
+ )
from sqlalchemy import bindparam
from sqlalchemy import column
+from sqlalchemy import DateTime
from sqlalchemy import exc
from sqlalchemy import exists
from sqlalchemy import ForeignKey
Column("name", String(30)),
Column("description", String(50)),
)
+ Table(
+ "mytable_with_onupdate",
+ metadata,
+ Column("myid", Integer),
+ Column("name", String(30)),
+ Column("description", String(50)),
+ Column("updated_at", DateTime, onupdate=func.now()),
+ )
Table(
"myothertable",
metadata,
t.update().values(x=5, z=5).compile,
)
- def test_unconsumed_names_values_dict(self):
+ @testing.variation("include_in_from", [True, False])
+ @testing.variation("use_mysql", [True, False])
+ def test_unconsumed_names_values_dict(self, include_in_from, use_mysql):
t = table("t", column("x"), column("y"))
t2 = table("t2", column("q"), column("z"))
- assert_raises_message(
- exc.CompileError,
- "Unconsumed column names: j",
- t.update()
- .values(x=5, j=7)
- .values({t2.c.z: 5})
- .where(t.c.x == t2.c.q)
- .compile,
- )
+ stmt = t.update().values(x=5, j=7).values({t2.c.z: 5})
+ if include_in_from:
+ stmt = stmt.where(t.c.x == t2.c.q)
+
+ if use_mysql:
+ if not include_in_from:
+ msg = (
+ "Statement is not a multi-table UPDATE statement; cannot "
+ r"include columns from table\(s\) 't2' in SET clause"
+ )
+ else:
+ msg = "Unconsumed column names: j"
+ else:
+ msg = (
+ "Backend does not support additional tables in the SET "
+ r"clause; cannot include columns from table\(s\) 't2' in "
+ "SET clause"
+ )
+
+ with expect_raises_message(exc.CompileError, msg):
+ if use_mysql:
+ stmt.compile(dialect=mysql.dialect())
+ else:
+ stmt.compile()
def test_unconsumed_names_kwargs_w_keys(self):
t = table("t", column("x"), column("y"))
run_create_tables = run_inserts = run_deletes = None
- def test_alias_one(self):
- table1 = self.tables.mytable
+ @testing.variation("use_onupdate", [True, False])
+ def test_alias_one(self, use_onupdate):
+
+ if use_onupdate:
+ table1 = self.tables.mytable_with_onupdate
+ tname = "mytable_with_onupdate"
+ else:
+ table1 = self.tables.mytable
+ tname = "mytable"
talias1 = table1.alias("t1")
# this case is nonsensical. the UPDATE is entirely
# against the alias, but we name the table-bound column
- # in values. The behavior here isn't really defined
+ # in values. The behavior here isn't really defined.
+ # onupdates get skipped.
self.assert_compile(
update(talias1)
.where(talias1.c.myid == 7)
.values({table1.c.name: "fred"}),
- "UPDATE mytable AS t1 "
+ f"UPDATE {tname} AS t1 "
"SET name=:name "
"WHERE t1.myid = :myid_1",
)
- def test_alias_two(self):
- table1 = self.tables.mytable
+ @testing.variation("use_onupdate", [True, False])
+ def test_alias_two(self, use_onupdate):
+ """test a multi-table UPDATE/SET is actually supported on SQLite, PG
+ if we are only using an alias of the main table
+
+ """
+ if use_onupdate:
+ table1 = self.tables.mytable_with_onupdate
+ tname = "mytable_with_onupdate"
+ onupdate = ", updated_at=now() "
+ else:
+ table1 = self.tables.mytable
+ tname = "mytable"
+ onupdate = " "
talias1 = table1.alias("t1")
# Here, compared to
# test_alias_one(), here we actually have UPDATE..FROM,
# which is causing the "table1.c.name" param to be handled
- # as an "extra table", hence we see the full table name rendered.
+ # as an "extra table", hence we see the full table name rendered
+ # as well as ON UPDATEs coming in nicely.
self.assert_compile(
update(talias1)
.where(table1.c.myid == 7)
.values({table1.c.name: "fred"}),
- "UPDATE mytable AS t1 "
- "SET name=:mytable_name "
- "FROM mytable "
- "WHERE mytable.myid = :myid_1",
- checkparams={"mytable_name": "fred", "myid_1": 7},
- )
-
- def test_alias_two_mysql(self):
- table1 = self.tables.mytable
+ f"UPDATE {tname} AS t1 "
+ f"SET name=:{tname}_name{onupdate}"
+ f"FROM {tname} "
+ f"WHERE {tname}.myid = :myid_1",
+ checkparams={f"{tname}_name": "fred", "myid_1": 7},
+ )
+
+ @testing.variation("use_onupdate", [True, False])
+ def test_alias_two_mysql(self, use_onupdate):
+ if use_onupdate:
+ table1 = self.tables.mytable_with_onupdate
+ tname = "mytable_with_onupdate"
+ onupdate = ", mytable_with_onupdate.updated_at=now() "
+ else:
+ table1 = self.tables.mytable
+ tname = "mytable"
+ onupdate = " "
talias1 = table1.alias("t1")
self.assert_compile(
update(talias1)
.where(table1.c.myid == 7)
.values({table1.c.name: "fred"}),
- "UPDATE mytable AS t1, mytable SET mytable.name=%s "
- "WHERE mytable.myid = %s",
- checkparams={"mytable_name": "fred", "myid_1": 7},
+ f"UPDATE {tname} AS t1, {tname} SET {tname}.name=%s{onupdate}"
+ f"WHERE {tname}.myid = %s",
+ checkparams={f"{tname}_name": "fred", "myid_1": 7},
dialect="mysql",
)
+ @testing.variation("use_alias", [True, False])
+ @testing.variation("use_alias_in_set", [True, False])
+ @testing.variation("include_in_from", [True, False])
+ @testing.variation("use_mysql", [True, False])
+ def test_raise_if_totally_different_table(
+ self, use_alias, include_in_from, use_alias_in_set, use_mysql
+ ):
+ """test cases for #12962"""
+ table1 = self.tables.mytable
+ table2 = self.tables.myothertable
+
+ if use_alias:
+ target = table1.alias("t1")
+ else:
+ target = table1
+
+ stmt = update(target).where(table1.c.myid == 7)
+
+ if use_alias_in_set:
+ stmt = stmt.values({table2.alias().c.othername: "fred"})
+ else:
+ stmt = stmt.values({table2.c.othername: "fred"})
+
+ if include_in_from:
+ stmt = stmt.where(table2.c.otherid == 12)
+
+ if use_mysql and include_in_from and not use_alias_in_set:
+ if not use_alias:
+ self.assert_compile(
+ stmt,
+ "UPDATE mytable, myothertable "
+ "SET myothertable.othername=%s "
+ "WHERE mytable.myid = %s AND myothertable.otherid = %s",
+ dialect="mysql",
+ )
+ else:
+ self.assert_compile(
+ stmt,
+ "UPDATE mytable AS t1, mytable, myothertable "
+ "SET myothertable.othername=%s WHERE mytable.myid = %s "
+ "AND myothertable.otherid = %s",
+ dialect="mysql",
+ )
+ return
+
+ if use_alias_in_set:
+ tabledesc = "Anonymous alias of myothertable"
+ else:
+ tabledesc = "myothertable"
+
+ if use_mysql:
+ if include_in_from:
+ msg = (
+ r"Multi-table UPDATE statement does not include "
+ rf"table\(s\) '{tabledesc}'"
+ )
+ else:
+ if use_alias:
+ msg = (
+ rf"Multi-table UPDATE statement does not include "
+ rf"table\(s\) '{tabledesc}'"
+ )
+ else:
+ msg = (
+ rf"Statement is not a multi-table UPDATE statement; "
+ r"cannot include columns from table\(s\) "
+ rf"'{tabledesc}' in SET clause"
+ )
+ else:
+ msg = (
+ r"Backend does not support additional tables in the "
+ r"SET clause; cannot include columns from table\(s\) "
+ rf"'{tabledesc}' in SET clause"
+ )
+
+ with expect_raises_message(exc.CompileError, msg):
+ if use_mysql:
+ stmt.compile(dialect=mysql.dialect())
+ else:
+ stmt.compile()
+
def test_update_from_multitable_same_name_mysql(self):
users, addresses = self.tables.users, self.tables.addresses