]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Fix parameter mutation in orm_pre_session_exec()
authorShamil <ashm.tech@proton.me>
Thu, 20 Nov 2025 12:44:34 +0000 (07:44 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 20 Nov 2025 19:18:35 +0000 (14:18 -0500)
The :meth:`_events.SessionEvents.do_orm_execute` event now allows direct
mutation or replacement of the :attr:`.ORMExecuteState.parameters`
dictionary or list, which will take effect when the the statement is
executed.  Previously, changes to this collection were not accommodated by
the event hook.  Pull request courtesy Shamil.

Fixes: #12921
Closes: #12989
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/12989
Pull-request-sha: 86b64e06178f8722f1c1700fd9fcca53ca572e78

Change-Id: I04874b6ca720eb2be1470067ced4afd79896e267

doc/build/changelog/unreleased_21/12921.rst [new file with mode: 0644]
lib/sqlalchemy/orm/bulk_persistence.py
lib/sqlalchemy/orm/context.py
lib/sqlalchemy/orm/session.py
test/orm/test_events.py

diff --git a/doc/build/changelog/unreleased_21/12921.rst b/doc/build/changelog/unreleased_21/12921.rst
new file mode 100644 (file)
index 0000000..1d66588
--- /dev/null
@@ -0,0 +1,10 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 12921
+
+    The :meth:`_events.SessionEvents.do_orm_execute` event now allows direct
+    mutation or replacement of the :attr:`.ORMExecuteState.parameters`
+    dictionary or list, which will take effect when the the statement is
+    executed.  Previously, changes to this collection were not accommodated by
+    the event hook.  Pull request courtesy Shamil.
+
index 99b97ccf4ca3c1e977f34198d35f06dc6d75385f..d661a99f7845420459e71ac9b42345f7eea6e3b1 100644 (file)
@@ -809,6 +809,7 @@ class _BulkUDCompileState(_ORMDMLState):
             util.immutabledict(execution_options).union(
                 {"_sa_orm_update_options": update_options}
             ),
+            params,
         )
 
     @classmethod
@@ -1256,6 +1257,7 @@ class _BulkORMInsert(_ORMDMLState, InsertDMLState):
             util.immutabledict(execution_options).union(
                 {"_sa_orm_insert_options": insert_options}
             ),
+            params,
         )
 
     @classmethod
index 8f26eb2c5d4160371b536a89bbdc121a94afff1e..3ac216babfceead2e81591a94f64624b126a1504 100644 (file)
@@ -368,7 +368,7 @@ class _AutoflushOnlyORMCompileState(_AbstractORMCompileState):
         if not is_pre_event and load_options._autoflush:
             session._autoflush()
 
-        return statement, execution_options
+        return statement, execution_options, params
 
     @classmethod
     def orm_setup_cursor_result(
@@ -590,7 +590,7 @@ class _ORMCompileState(_AbstractORMCompileState):
         if not is_pre_event and load_options._autoflush:
             session._autoflush()
 
-        return statement, execution_options
+        return statement, execution_options, params
 
     @classmethod
     def orm_setup_cursor_result(
index 56b690aa4be6c7bcb5d1f879238231ba21ee5b3a..658adbe9230cac730419c213b5575d89ca4e891c 100644 (file)
@@ -293,8 +293,16 @@ class ORMExecuteState(util.MemoizedSlots):
     """
 
     parameters: Optional[_CoreAnyExecuteParams]
-    """Dictionary of parameters that was passed to
-    :meth:`_orm.Session.execute`."""
+    """Optional mapping or list of mappings of parameters that was passed to
+    :meth:`_orm.Session.execute`.
+
+    May be mutated or re-assigned in place, which will take effect as the
+    effective parameters passed to the method.
+
+    .. versionchanged:: 2.1 :attr:`.ORMExecuteState.parameters` may now be
+       mutated or replaced.
+
+    """
 
     execution_options: _ExecuteOptions
     """The complete dictionary of current execution options.
@@ -2212,6 +2220,7 @@ class Session(_SessionClassMethods, EventTarget):
                 (
                     statement,
                     combined_execution_options,
+                    params,
                 ) = compile_state_cls.orm_pre_session_exec(
                     self,
                     statement,
@@ -2243,6 +2252,7 @@ class Session(_SessionClassMethods, EventTarget):
 
             statement = orm_exec_state.statement
             combined_execution_options = orm_exec_state.local_execution_options
+            params = orm_exec_state.parameters
 
         if compile_state_cls is not None:
             # now run orm_pre_session_exec() "for real".   if there were
@@ -2253,6 +2263,7 @@ class Session(_SessionClassMethods, EventTarget):
             (
                 statement,
                 combined_execution_options,
+                params,
             ) = compile_state_cls.orm_pre_session_exec(
                 self,
                 statement,
index 536aa654f7e8118b866ae091a2482fcd0da8441d..437f7e8c58a44adc38793d161274e262e9557c67 100644 (file)
@@ -898,6 +898,179 @@ class ORMExecuteTest(RemoveORMEventsGlobally, _fixtures.FixtureTest):
         else:
             eq_(m1.mock_calls, [])
 
+    @testing.combinations(
+        (
+            lambda User: select(User).where(User.id == bindparam("id")),
+            {"id": 18},
+            {"id": 7},
+            "SELECT users.id, users.name FROM users WHERE users.id = :id",
+        ),
+        (
+            lambda User: select(User.__table__).where(
+                User.__table__.c.id == bindparam("id")
+            ),
+            {"id": 18},
+            {"id": 7},
+            "SELECT users.id, users.name FROM users WHERE users.id = :id",
+        ),
+        (
+            lambda User: update(User).where(User.id == 7),
+            {"name": "original_name"},
+            {"name": "mutated_name"},
+            "UPDATE users SET name=:name WHERE users.id = :id_1",
+        ),
+        (
+            lambda User: update(User.__table__).where(
+                User.__table__.c.id == 7
+            ),
+            {"name": "original_name"},
+            {"name": "mutated_name"},
+            "UPDATE users SET name=:name WHERE users.id = :id_1",
+        ),
+        (
+            lambda User: delete(User).where(User.id == bindparam("id_param")),
+            {"id_param": 18},
+            {"id_param": 10},  # row 10 does not have a related item
+            "DELETE FROM users WHERE users.id = :id_param",
+        ),
+        (
+            lambda User: insert(User),
+            {"id": 99, "name": "original_name"},
+            {"name": "mutated_name"},
+            "INSERT INTO users (id, name) VALUES (:id, :name)",
+        ),
+        (
+            lambda User: insert(User),
+            [
+                {"id": 100, "name": "name1"},
+                {"id": 101, "name": "name2"},
+                {"id": 102, "name": "name3"},
+            ],
+            [
+                {"id": 100, "name": "mutated_name1"},
+                {"id": 101, "name": "mutated_name2"},
+                {"id": 102, "name": "mutated_name3"},
+            ],
+            "INSERT INTO users (id, name) VALUES (:id, :name)",
+        ),
+        (
+            lambda User: insert(User.__table__),
+            [
+                {"id": 100, "name": "name1"},
+                {"id": 101, "name": "name2"},
+                {"id": 102, "name": "name3"},
+            ],
+            [
+                {"id": 100, "name": "mutated_name1"},
+                {"id": 101, "name": "mutated_name2"},
+                {"id": 102, "name": "mutated_name3"},
+            ],
+            "INSERT INTO users (id, name) VALUES (:id, :name)",
+        ),
+        argnames="stmt_callable,params,new_params,compiled_sql",
+    )
+    @testing.variation("param_op", ["mutate", "replace"])
+    def test_mutate_parameters(
+        self, stmt_callable, params, new_params, compiled_sql, param_op
+    ):
+        """test for #12921"""
+
+        User = self.classes.User
+
+        sess = Session(testing.db)
+        if param_op.mutate:
+
+            @event.listens_for(sess, "do_orm_execute")
+            def mutate_params(ctx):
+                # ensure change in place works
+                if isinstance(new_params, dict):
+                    ctx.parameters.update(new_params)
+                elif isinstance(new_params, list):
+                    ctx.parameters[:] = new_params
+
+        elif param_op.replace:
+
+            @event.listens_for(sess, "do_orm_execute")
+            def replace_params(ctx):
+                # ensure replace works
+                if isinstance(params, dict):
+                    replaced_params = dict(params)
+                    replaced_params.update(new_params)
+                    ctx.parameters = replaced_params
+                else:
+                    ctx.parameters = new_params
+
+        stmt = testing.resolve_lambda(stmt_callable, User=User)
+
+        # since we are doing mutate in place changes,
+        # uniquify the params dict so the combinations fixtures are
+        # not polluted
+        if isinstance(params, dict):
+            our_local_params = dict(params)
+            assert isinstance(new_params, dict)
+            expected_params = dict(params)
+            expected_params.update(new_params)
+        else:
+            assert isinstance(params, list)
+            our_local_params = [dict(p) for p in params]
+            expected_params = new_params
+
+        with self.sql_execution_asserter() as asserter:
+            sess.execute(
+                stmt,
+                our_local_params,
+            )
+
+        asserter.assert_(
+            CompiledSQL(
+                compiled_sql,
+                expected_params,
+            )
+        )
+
+    def test_mutate_parameters_selectinload(self, decl_base):
+        """test #12921 where we modify params for a relationship load"""
+
+        class A(decl_base):
+            __tablename__ = "a"
+            id: Mapped[int] = mapped_column(primary_key=True)
+            bs: Mapped[list["B"]] = relationship(
+                primaryjoin=lambda: (A.id == B.a_id)
+                & (B.status == bindparam("b_status"))
+            )
+
+        class B(decl_base):
+            __tablename__ = "b"
+            id: Mapped[int] = mapped_column(primary_key=True)
+            a_id: Mapped[int] = mapped_column(ForeignKey("a.id"))
+            status: Mapped[str]
+
+        decl_base.metadata.create_all(testing.db)
+
+        sess = Session(testing.db)
+
+        sess.add_all([A(id=1, bs=[B(status="x"), B(status="y")])])
+        sess.commit()
+
+        with Session(testing.db) as sess:
+
+            @event.listens_for(sess, "do_orm_execute")
+            def do_orm_execute(ctx):
+                if ctx.is_relationship_load:
+                    ctx.parameters["b_status"] = SELECT_STATUS
+
+            SELECT_STATUS = "x"
+            a1 = sess.scalars(select(A).options(selectinload(A.bs))).one()
+            eq_([b.status for b in a1.bs], ["x"])
+
+            SELECT_STATUS = "y"
+            a1 = sess.scalars(
+                select(A)
+                .execution_options(populate_existing=True)
+                .options(selectinload(A.bs))
+            ).one()
+            eq_([b.status for b in a1.bs], ["y"])
+
 
 class MapperEventsTest(RemoveORMEventsGlobally, _fixtures.FixtureTest):
     run_inserts = None