]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
dont mutate bind_arguments incoming dictionary
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 7 Oct 2022 15:25:08 +0000 (11:25 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 7 Oct 2022 15:32:24 +0000 (11:32 -0400)
The :paramref:`_orm.Session.execute.bind_arguments` dictionary is no longer
mutated when passed to :meth:`_orm.Session.execute` and similar; instead,
it's copied to an internal dictionary for state changes. Among other
things, this fixes and issue where the "clause" passed to the
:meth:`_orm.Session.get_bind` method would be incorrectly referring to the
:class:`_sql.Select` construct used for the "fetch" synchronization
strategy, when the actual query being emitted was a :class:`_dml.Delete` or
:class:`_dml.Update`. This would interfere with recipes for "routing
sessions".

Fixes: #8614
Change-Id: I8d237449485c9bbf41db2b29a34b6136aa43b7bc

doc/build/changelog/unreleased_14/8614.rst [new file with mode: 0644]
lib/sqlalchemy/orm/bulk_persistence.py
lib/sqlalchemy/orm/session.py
test/orm/dml/test_update_delete_where.py
test/orm/test_bind.py

diff --git a/doc/build/changelog/unreleased_14/8614.rst b/doc/build/changelog/unreleased_14/8614.rst
new file mode 100644 (file)
index 0000000..b975dbc
--- /dev/null
@@ -0,0 +1,13 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 8614
+
+    The :paramref:`_orm.Session.execute.bind_arguments` dictionary is no longer
+    mutated when passed to :meth:`_orm.Session.execute` and similar; instead,
+    it's copied to an internal dictionary for state changes. Among other
+    things, this fixes and issue where the "clause" passed to the
+    :meth:`_orm.Session.get_bind` method would be incorrectly referring to the
+    :class:`_sql.Select` construct used for the "fetch" synchronization
+    strategy, when the actual query being emitted was a :class:`_dml.Delete` or
+    :class:`_dml.Update`. This would interfere with recipes for "routing
+    sessions".
index b407fcdca1feb05081153edf77f03c0c6bc903c4..af5bf6b6a193a1a7898e38ec77bb6eb9b6fafa0e 100644 (file)
@@ -599,7 +599,6 @@ class BulkUDCompileState(ORMDMLState):
             execution_options,
             statement._execution_options,
         )
-
         bind_arguments["clause"] = statement
         try:
             plugin_subject = statement._propagate_attrs["plugin_subject"]
index 9577b4d260d092aec787c7613347463091e198fb..324ab7b257406941603c683cd7fd9b327168f4e1 100644 (file)
@@ -1823,6 +1823,8 @@ class Session(_SessionClassMethods, EventTarget):
 
         if not bind_arguments:
             bind_arguments = {}
+        else:
+            bind_arguments = dict(bind_arguments)
 
         if (
             statement._propagate_attrs.get("compile_state_plugin", None)
index 3250cb3f92464d0dddc6f0aff23602cf692b505e..c8e56e3c117fa18d0cb6fcd6cd4fa18064e01879 100644 (file)
@@ -26,6 +26,9 @@ from sqlalchemy.orm import Session
 from sqlalchemy.orm import sessionmaker
 from sqlalchemy.orm import synonym
 from sqlalchemy.orm import with_loader_criteria
+from sqlalchemy.sql.dml import Delete
+from sqlalchemy.sql.dml import Update
+from sqlalchemy.sql.selectable import Select
 from sqlalchemy.testing import assert_raises
 from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import eq_
@@ -1886,6 +1889,50 @@ class UpdateDeleteTest(fixtures.MappedTest):
         ):
             session.execute(stmt)
 
+    @testing.combinations(("update",), ("delete",), argnames="stmt_type")
+    @testing.combinations(
+        ("evaluate",), ("fetch",), (None,), argnames="sync_type"
+    )
+    def test_routing_session(self, stmt_type, sync_type, connection):
+        User = self.classes.User
+
+        if stmt_type == "update":
+            stmt = update(User).values(age=123)
+            expected = [Update]
+        elif stmt_type == "delete":
+            stmt = delete(User)
+            expected = [Delete]
+        else:
+            assert False
+
+        received = []
+
+        class RoutingSession(Session):
+            def get_bind(self, **kw):
+                received.append(type(kw["clause"]))
+                return super(RoutingSession, self).get_bind(**kw)
+
+        stmt = stmt.execution_options(synchronize_session=sync_type)
+
+        if sync_type == "fetch":
+            expected.insert(0, Select)
+
+            if (
+                stmt_type == "update"
+                and not connection.dialect.update_returning
+            ):
+                expected.insert(0, Select)
+            elif (
+                stmt_type == "delete"
+                and not connection.dialect.delete_returning
+            ):
+                expected.insert(0, Select)
+
+        with RoutingSession(bind=connection) as sess:
+            sess.execute(stmt)
+
+        eq_(received, expected)
+
 
 class UpdateDeleteIgnoresLoadersTest(fixtures.MappedTest):
     @classmethod
index 2f392cf6e5ffeeb148905ccdf80103a0492c1f20..409c6244f00608cf4369b716bf0f003b72ec18b8 100644 (file)
@@ -291,6 +291,28 @@ class BindIntegrationTest(_fixtures.FixtureTest):
 
         sess.close()
 
+    @testing.combinations(True, False)
+    def test_dont_mutate_binds(self, empty_dict):
+        users, User = (
+            self.tables.users,
+            self.classes.User,
+        )
+
+        mp = self.mapper_registry.map_imperatively(User, users)
+
+        sess = fixture_session()
+
+        if empty_dict:
+            bind_arguments = {}
+        else:
+            bind_arguments = {"mapper": mp}
+        sess.execute(select(1), bind_arguments=bind_arguments)
+
+        if empty_dict:
+            eq_(bind_arguments, {})
+        else:
+            eq_(bind_arguments, {"mapper": mp})
+
     @testing.combinations(
         (
             lambda session, Address: session.query(Address).statement,