]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
only use _DMLReturningColFilter for "bulk insert", not other DML
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 9 Feb 2025 23:09:21 +0000 (18:09 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 10 Feb 2025 01:40:30 +0000 (20:40 -0500)
Fixed bug in ORM enabled UPDATE (and theoretically DELETE) where using a
multi-table DML statement would not allow ORM mapped columns from mappers
other than the primary UPDATE mapper to be named in the RETURNING clause;
they would be omitted instead and cause a column not found exception.

Fixes: #12328
Change-Id: I2223ee506eec447823a3a545eecad1a7a03364a9

doc/build/changelog/unreleased_20/12328.rst [new file with mode: 0644]
lib/sqlalchemy/orm/context.py
lib/sqlalchemy/orm/query.py
test/orm/dml/test_update_delete_where.py
test/requirements.py

diff --git a/doc/build/changelog/unreleased_20/12328.rst b/doc/build/changelog/unreleased_20/12328.rst
new file mode 100644 (file)
index 0000000..9d9b709
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 12328
+
+    Fixed bug in ORM enabled UPDATE (and theoretically DELETE) where using a
+    multi-table DML statement would not allow ORM mapped columns from mappers
+    other than the primary UPDATE mapper to be named in the RETURNING clause;
+    they would be omitted instead and cause a column not found exception.
index d86f1d0ce578bc672728ee3b1ab6b106e7172598..fa57bcfae834318400f6ae4a76ceb048007d5818 100644 (file)
@@ -155,10 +155,12 @@ class QueryContext:
         statement: Union[
             Select[Unpack[TupleAny]],
             FromStatement[Unpack[TupleAny]],
+            UpdateBase,
         ],
         user_passed_query: Union[
             Select[Unpack[TupleAny]],
             FromStatement[Unpack[TupleAny]],
+            UpdateBase,
         ],
         params: _CoreSingleExecuteParams,
         session: Session,
@@ -420,7 +422,9 @@ class _ORMCompileState(_AbstractORMCompileState):
     attributes: Dict[Any, Any]
     global_attributes: Dict[Any, Any]
 
-    statement: Union[Select[Unpack[TupleAny]], FromStatement[Unpack[TupleAny]]]
+    statement: Union[
+        Select[Unpack[TupleAny]], FromStatement[Unpack[TupleAny]], UpdateBase
+    ]
     select_statement: Union[
         Select[Unpack[TupleAny]], FromStatement[Unpack[TupleAny]]
     ]
@@ -663,8 +667,14 @@ class _ORMCompileState(_AbstractORMCompileState):
         )
 
 
-class _DMLReturningColFilter:
-    """an adapter used for the DML RETURNING case.
+class _DMLBulkInsertReturningColFilter:
+    """an adapter used for the DML RETURNING case specifically
+    for ORM bulk insert (or any hypothetical DML that is splitting out a class
+    hierarchy among multiple DML statements....ORM bulk insert is the only
+    example right now)
+
+    its main job is to limit the columns in a RETURNING to only a specific
+    mapped table in a hierarchy.
 
     Has a subset of the interface used by
     :class:`.ORMAdapter` and is used for :class:`._QueryEntity`
@@ -860,14 +870,20 @@ class _ORMFromStatementCompileState(_ORMCompileState):
         return None
 
     def setup_dml_returning_compile_state(self, dml_mapper):
-        """used by BulkORMInsert (and Update / Delete?) to set up a handler
+        """used by BulkORMInsert, Update, Delete to set up a handler
         for RETURNING to return ORM objects and expressions
 
         """
         target_mapper = self.statement._propagate_attrs.get(
             "plugin_subject", None
         )
-        adapter = _DMLReturningColFilter(target_mapper, dml_mapper)
+
+        if self.statement.is_insert:
+            adapter = _DMLBulkInsertReturningColFilter(
+                target_mapper, dml_mapper
+            )
+        else:
+            adapter = None
 
         if self.compile_options._is_star and (len(self._entities) != 1):
             raise sa_exc.CompileError(
@@ -2544,7 +2560,7 @@ class _QueryEntity:
     def setup_dml_returning_compile_state(
         self,
         compile_state: _ORMCompileState,
-        adapter: _DMLReturningColFilter,
+        adapter: Optional[_DMLBulkInsertReturningColFilter],
     ) -> None:
         raise NotImplementedError()
 
@@ -2746,7 +2762,7 @@ class _MapperEntity(_QueryEntity):
     def setup_dml_returning_compile_state(
         self,
         compile_state: _ORMCompileState,
-        adapter: _DMLReturningColFilter,
+        adapter: Optional[_DMLBulkInsertReturningColFilter],
     ) -> None:
         loading._setup_entity_query(
             compile_state,
@@ -2905,7 +2921,7 @@ class _BundleEntity(_QueryEntity):
     def setup_dml_returning_compile_state(
         self,
         compile_state: _ORMCompileState,
-        adapter: _DMLReturningColFilter,
+        adapter: Optional[_DMLBulkInsertReturningColFilter],
     ) -> None:
         return self.setup_compile_state(compile_state)
 
@@ -3095,7 +3111,7 @@ class _RawColumnEntity(_ColumnEntity):
     def setup_dml_returning_compile_state(
         self,
         compile_state: _ORMCompileState,
-        adapter: _DMLReturningColFilter,
+        adapter: Optional[_DMLBulkInsertReturningColFilter],
     ) -> None:
         return self.setup_compile_state(compile_state)
 
@@ -3212,10 +3228,13 @@ class _ORMColumnEntity(_ColumnEntity):
     def setup_dml_returning_compile_state(
         self,
         compile_state: _ORMCompileState,
-        adapter: _DMLReturningColFilter,
+        adapter: Optional[_DMLBulkInsertReturningColFilter],
     ) -> None:
-        self._fetch_column = self.column
-        column = adapter(self.column, False)
+
+        self._fetch_column = column = self.column
+        if adapter:
+            column = adapter(column, False)
+
         if column is not None:
             compile_state.dedupe_columns.add(column)
             compile_state.primary_columns.append(column)
index 02a98fefe7c9cfb5df12c95ead2562494bc22996..ac6746adba941e3c68e95be33a6b566d34baaa4f 100644 (file)
@@ -137,6 +137,7 @@ if TYPE_CHECKING:
     from ..sql._typing import _TypedColumnClauseArgument as _TCCA
     from ..sql.base import CacheableOptions
     from ..sql.base import ExecutableOption
+    from ..sql.dml import UpdateBase
     from ..sql.elements import ColumnElement
     from ..sql.elements import Label
     from ..sql.selectable import _ForUpdateOfArgument
@@ -503,7 +504,7 @@ class Query(
         return cast("Select[_T]", self.statement)
 
     @property
-    def statement(self) -> Union[Select[_T], FromStatement[_T]]:
+    def statement(self) -> Union[Select[_T], FromStatement[_T], UpdateBase]:
         """The full SELECT statement represented by this Query.
 
         The statement by default will not have disambiguating labels
@@ -531,6 +532,8 @@ class Query(
         # from there, it starts to look much like Query itself won't be
         # passed into the execute process and won't generate its own cache
         # key; this will all occur in terms of the ORM-enabled Select.
+        stmt: Union[Select[_T], FromStatement[_T], UpdateBase]
+
         if not self._compile_options._set_base_alias:
             # if we don't have legacy top level aliasing features in use
             # then convert to a future select() directly
@@ -802,7 +805,7 @@ class Query(
         )
 
     @property
-    def selectable(self) -> Union[Select[_T], FromStatement[_T]]:
+    def selectable(self) -> Union[Select[_T], FromStatement[_T], UpdateBase]:
         """Return the :class:`_expression.Select` object emitted by this
         :class:`_query.Query`.
 
@@ -813,7 +816,9 @@ class Query(
         """
         return self.__clause_element__()
 
-    def __clause_element__(self) -> Union[Select[_T], FromStatement[_T]]:
+    def __clause_element__(
+        self,
+    ) -> Union[Select[_T], FromStatement[_T], UpdateBase]:
         return (
             self._with_compile_options(
                 _enable_eagerloads=False, _render_for_subquery=True
index 7d06a8618cd44440c9d45d4905ac4bfb1c66cca4..387ce161b867c5080e0734fd354716a0462fbd03 100644 (file)
@@ -78,6 +78,7 @@ class UpdateDeleteTest(fixtures.MappedTest):
             metadata,
             Column("id", Integer, primary_key=True),
             Column("user_id", ForeignKey("users.id")),
+            Column("email_address", String(50)),
         )
 
         m = MetaData()
@@ -118,6 +119,24 @@ class UpdateDeleteTest(fixtures.MappedTest):
             ],
         )
 
+    @testing.fixture
+    def addresses_data(
+        self,
+    ):
+        addresses = self.tables.addresses
+
+        with testing.db.begin() as connection:
+            connection.execute(
+                addresses.insert(),
+                [
+                    dict(id=1, user_id=1, email_address="jo1"),
+                    dict(id=2, user_id=1, email_address="jo2"),
+                    dict(id=3, user_id=2, email_address="ja1"),
+                    dict(id=4, user_id=3, email_address="ji1"),
+                    dict(id=5, user_id=4, email_address="jan1"),
+                ],
+            )
+
     @classmethod
     def setup_mappers(cls):
         User = cls.classes.User
@@ -1324,6 +1343,52 @@ class UpdateDeleteTest(fixtures.MappedTest):
             ),
         )
 
+    @testing.requires.update_from_returning
+    # can't use evaluate because it can't match the col->col in the WHERE
+    @testing.combinations("fetch", "auto", argnames="synchronize_session")
+    def test_update_from_multi_returning(
+        self, synchronize_session, addresses_data
+    ):
+        """test #12327"""
+        User = self.classes.User
+        Address = self.classes.Address
+
+        sess = fixture_session()
+
+        john, jack, jill, jane = sess.query(User).order_by(User.id).all()
+
+        with self.sql_execution_asserter() as asserter:
+            stmt = (
+                update(User)
+                .where(User.id == Address.user_id)
+                .filter(User.age > 29)
+                .values({"age": User.age - 10})
+                .returning(
+                    User.id, Address.email_address, func.char_length(User.name)
+                )
+                .execution_options(synchronize_session=synchronize_session)
+            )
+
+            rows = sess.execute(stmt).all()
+            eq_(set(rows), {(2, "ja1", 4), (4, "jan1", 4)})
+
+            # these are simple values, these are now evaluated even with
+            # the "fetch" strategy, new in 1.4, so there is no expiry
+            eq_([john.age, jack.age, jill.age, jane.age], [25, 37, 29, 27])
+
+        asserter.assert_(
+            CompiledSQL(
+                "UPDATE users SET age_int=(users.age_int - %(age_int_1)s) "
+                "FROM addresses "
+                "WHERE users.id = addresses.user_id AND "
+                "users.age_int > %(age_int_2)s "
+                "RETURNING users.id, addresses.email_address, "
+                "char_length(users.name) AS char_length_1",
+                [{"age_int_1": 10, "age_int_2": 29}],
+                dialect="postgresql",
+            ),
+        )
+
     @testing.requires.update_returning
     @testing.combinations("update", "delete", argnames="crud_type")
     def test_fetch_w_explicit_returning(self, crud_type):
index a37f51e8d3fb830d1d2bb445f62b9e12c912d05b..69b56423df6f92c82457dff1fd42e0a655981589 100644 (file)
@@ -493,6 +493,13 @@ class DefaultRequirements(SuiteRequirements):
             "Backend does not support UPDATE..FROM",
         )
 
+    @property
+    def update_from_returning(self):
+        """Target must support UPDATE..FROM syntax where RETURNING can
+        return columns from the non-primary FROM clause"""
+
+        return self.update_returning + self.update_from + skip_if("sqlite")
+
     @property
     def update_from_using_alias(self):
         """Target must support UPDATE..FROM syntax against an alias"""