From: Mike Bayer Date: Sun, 9 Feb 2025 23:09:21 +0000 (-0500) Subject: only use _DMLReturningColFilter for "bulk insert", not other DML X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=1c7e3f9c94b2e6c441ba635a88573bc4cd88ad7d;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git only use _DMLReturningColFilter for "bulk insert", not other DML Fixed bug in ORM enabled UPDATE (and theoretically DELETE) where using a multi-table DML statement would not allow ORM mapped columns from mappers other than the primary UPDATE mapper to be named in the RETURNING clause; they would be omitted instead and cause a column not found exception. Fixes: #12328 Change-Id: I2223ee506eec447823a3a545eecad1a7a03364a9 --- diff --git a/doc/build/changelog/unreleased_20/12328.rst b/doc/build/changelog/unreleased_20/12328.rst new file mode 100644 index 0000000000..9d9b70965e --- /dev/null +++ b/doc/build/changelog/unreleased_20/12328.rst @@ -0,0 +1,8 @@ +.. change:: + :tags: bug, orm + :tickets: 12328 + + Fixed bug in ORM enabled UPDATE (and theoretically DELETE) where using a + multi-table DML statement would not allow ORM mapped columns from mappers + other than the primary UPDATE mapper to be named in the RETURNING clause; + they would be omitted instead and cause a column not found exception. diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py index d86f1d0ce5..fa57bcfae8 100644 --- a/lib/sqlalchemy/orm/context.py +++ b/lib/sqlalchemy/orm/context.py @@ -155,10 +155,12 @@ class QueryContext: statement: Union[ Select[Unpack[TupleAny]], FromStatement[Unpack[TupleAny]], + UpdateBase, ], user_passed_query: Union[ Select[Unpack[TupleAny]], FromStatement[Unpack[TupleAny]], + UpdateBase, ], params: _CoreSingleExecuteParams, session: Session, @@ -420,7 +422,9 @@ class _ORMCompileState(_AbstractORMCompileState): attributes: Dict[Any, Any] global_attributes: Dict[Any, Any] - statement: Union[Select[Unpack[TupleAny]], FromStatement[Unpack[TupleAny]]] + statement: Union[ + Select[Unpack[TupleAny]], FromStatement[Unpack[TupleAny]], UpdateBase + ] select_statement: Union[ Select[Unpack[TupleAny]], FromStatement[Unpack[TupleAny]] ] @@ -663,8 +667,14 @@ class _ORMCompileState(_AbstractORMCompileState): ) -class _DMLReturningColFilter: - """an adapter used for the DML RETURNING case. +class _DMLBulkInsertReturningColFilter: + """an adapter used for the DML RETURNING case specifically + for ORM bulk insert (or any hypothetical DML that is splitting out a class + hierarchy among multiple DML statements....ORM bulk insert is the only + example right now) + + its main job is to limit the columns in a RETURNING to only a specific + mapped table in a hierarchy. Has a subset of the interface used by :class:`.ORMAdapter` and is used for :class:`._QueryEntity` @@ -860,14 +870,20 @@ class _ORMFromStatementCompileState(_ORMCompileState): return None def setup_dml_returning_compile_state(self, dml_mapper): - """used by BulkORMInsert (and Update / Delete?) to set up a handler + """used by BulkORMInsert, Update, Delete to set up a handler for RETURNING to return ORM objects and expressions """ target_mapper = self.statement._propagate_attrs.get( "plugin_subject", None ) - adapter = _DMLReturningColFilter(target_mapper, dml_mapper) + + if self.statement.is_insert: + adapter = _DMLBulkInsertReturningColFilter( + target_mapper, dml_mapper + ) + else: + adapter = None if self.compile_options._is_star and (len(self._entities) != 1): raise sa_exc.CompileError( @@ -2544,7 +2560,7 @@ class _QueryEntity: def setup_dml_returning_compile_state( self, compile_state: _ORMCompileState, - adapter: _DMLReturningColFilter, + adapter: Optional[_DMLBulkInsertReturningColFilter], ) -> None: raise NotImplementedError() @@ -2746,7 +2762,7 @@ class _MapperEntity(_QueryEntity): def setup_dml_returning_compile_state( self, compile_state: _ORMCompileState, - adapter: _DMLReturningColFilter, + adapter: Optional[_DMLBulkInsertReturningColFilter], ) -> None: loading._setup_entity_query( compile_state, @@ -2905,7 +2921,7 @@ class _BundleEntity(_QueryEntity): def setup_dml_returning_compile_state( self, compile_state: _ORMCompileState, - adapter: _DMLReturningColFilter, + adapter: Optional[_DMLBulkInsertReturningColFilter], ) -> None: return self.setup_compile_state(compile_state) @@ -3095,7 +3111,7 @@ class _RawColumnEntity(_ColumnEntity): def setup_dml_returning_compile_state( self, compile_state: _ORMCompileState, - adapter: _DMLReturningColFilter, + adapter: Optional[_DMLBulkInsertReturningColFilter], ) -> None: return self.setup_compile_state(compile_state) @@ -3212,10 +3228,13 @@ class _ORMColumnEntity(_ColumnEntity): def setup_dml_returning_compile_state( self, compile_state: _ORMCompileState, - adapter: _DMLReturningColFilter, + adapter: Optional[_DMLBulkInsertReturningColFilter], ) -> None: - self._fetch_column = self.column - column = adapter(self.column, False) + + self._fetch_column = column = self.column + if adapter: + column = adapter(column, False) + if column is not None: compile_state.dedupe_columns.add(column) compile_state.primary_columns.append(column) diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 02a98fefe7..ac6746adba 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -137,6 +137,7 @@ if TYPE_CHECKING: from ..sql._typing import _TypedColumnClauseArgument as _TCCA from ..sql.base import CacheableOptions from ..sql.base import ExecutableOption + from ..sql.dml import UpdateBase from ..sql.elements import ColumnElement from ..sql.elements import Label from ..sql.selectable import _ForUpdateOfArgument @@ -503,7 +504,7 @@ class Query( return cast("Select[_T]", self.statement) @property - def statement(self) -> Union[Select[_T], FromStatement[_T]]: + def statement(self) -> Union[Select[_T], FromStatement[_T], UpdateBase]: """The full SELECT statement represented by this Query. The statement by default will not have disambiguating labels @@ -531,6 +532,8 @@ class Query( # from there, it starts to look much like Query itself won't be # passed into the execute process and won't generate its own cache # key; this will all occur in terms of the ORM-enabled Select. + stmt: Union[Select[_T], FromStatement[_T], UpdateBase] + if not self._compile_options._set_base_alias: # if we don't have legacy top level aliasing features in use # then convert to a future select() directly @@ -802,7 +805,7 @@ class Query( ) @property - def selectable(self) -> Union[Select[_T], FromStatement[_T]]: + def selectable(self) -> Union[Select[_T], FromStatement[_T], UpdateBase]: """Return the :class:`_expression.Select` object emitted by this :class:`_query.Query`. @@ -813,7 +816,9 @@ class Query( """ return self.__clause_element__() - def __clause_element__(self) -> Union[Select[_T], FromStatement[_T]]: + def __clause_element__( + self, + ) -> Union[Select[_T], FromStatement[_T], UpdateBase]: return ( self._with_compile_options( _enable_eagerloads=False, _render_for_subquery=True diff --git a/test/orm/dml/test_update_delete_where.py b/test/orm/dml/test_update_delete_where.py index 7d06a8618c..387ce161b8 100644 --- a/test/orm/dml/test_update_delete_where.py +++ b/test/orm/dml/test_update_delete_where.py @@ -78,6 +78,7 @@ class UpdateDeleteTest(fixtures.MappedTest): metadata, Column("id", Integer, primary_key=True), Column("user_id", ForeignKey("users.id")), + Column("email_address", String(50)), ) m = MetaData() @@ -118,6 +119,24 @@ class UpdateDeleteTest(fixtures.MappedTest): ], ) + @testing.fixture + def addresses_data( + self, + ): + addresses = self.tables.addresses + + with testing.db.begin() as connection: + connection.execute( + addresses.insert(), + [ + dict(id=1, user_id=1, email_address="jo1"), + dict(id=2, user_id=1, email_address="jo2"), + dict(id=3, user_id=2, email_address="ja1"), + dict(id=4, user_id=3, email_address="ji1"), + dict(id=5, user_id=4, email_address="jan1"), + ], + ) + @classmethod def setup_mappers(cls): User = cls.classes.User @@ -1324,6 +1343,52 @@ class UpdateDeleteTest(fixtures.MappedTest): ), ) + @testing.requires.update_from_returning + # can't use evaluate because it can't match the col->col in the WHERE + @testing.combinations("fetch", "auto", argnames="synchronize_session") + def test_update_from_multi_returning( + self, synchronize_session, addresses_data + ): + """test #12327""" + User = self.classes.User + Address = self.classes.Address + + sess = fixture_session() + + john, jack, jill, jane = sess.query(User).order_by(User.id).all() + + with self.sql_execution_asserter() as asserter: + stmt = ( + update(User) + .where(User.id == Address.user_id) + .filter(User.age > 29) + .values({"age": User.age - 10}) + .returning( + User.id, Address.email_address, func.char_length(User.name) + ) + .execution_options(synchronize_session=synchronize_session) + ) + + rows = sess.execute(stmt).all() + eq_(set(rows), {(2, "ja1", 4), (4, "jan1", 4)}) + + # these are simple values, these are now evaluated even with + # the "fetch" strategy, new in 1.4, so there is no expiry + eq_([john.age, jack.age, jill.age, jane.age], [25, 37, 29, 27]) + + asserter.assert_( + CompiledSQL( + "UPDATE users SET age_int=(users.age_int - %(age_int_1)s) " + "FROM addresses " + "WHERE users.id = addresses.user_id AND " + "users.age_int > %(age_int_2)s " + "RETURNING users.id, addresses.email_address, " + "char_length(users.name) AS char_length_1", + [{"age_int_1": 10, "age_int_2": 29}], + dialect="postgresql", + ), + ) + @testing.requires.update_returning @testing.combinations("update", "delete", argnames="crud_type") def test_fetch_w_explicit_returning(self, crud_type): diff --git a/test/requirements.py b/test/requirements.py index a37f51e8d3..69b56423df 100644 --- a/test/requirements.py +++ b/test/requirements.py @@ -493,6 +493,13 @@ class DefaultRequirements(SuiteRequirements): "Backend does not support UPDATE..FROM", ) + @property + def update_from_returning(self): + """Target must support UPDATE..FROM syntax where RETURNING can + return columns from the non-primary FROM clause""" + + return self.update_returning + self.update_from + skip_if("sqlite") + @property def update_from_using_alias(self): """Target must support UPDATE..FROM syntax against an alias"""