]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Ensure ORMInsert sets up bind state
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 14 Apr 2022 16:01:16 +0000 (12:01 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 14 Apr 2022 16:01:16 +0000 (12:01 -0400)
Fixed regression where the change in #7861, released in version 1.4.33,
that brought the :class:`.Insert` construct to be partially recognized as
an ORM-enabled statement did not properly transfer the correct mapper /
mapped table state to the :class:`.Session`, causing the
:meth:`.Session.get_bind` method to fail for a :class:`.Session` that was
bound to engines and/or connections using the :paramref:`.Session.binds`
parameter.

Fixes: #7936
Change-Id: If19edef8e2dd68335465429eb3d2f0bfdade4a4c

doc/build/changelog/unreleased_14/7936.rst [new file with mode: 0644]
lib/sqlalchemy/orm/persistence.py
test/orm/test_bind.py
test/orm/test_update_delete.py

diff --git a/doc/build/changelog/unreleased_14/7936.rst b/doc/build/changelog/unreleased_14/7936.rst
new file mode 100644 (file)
index 0000000..bcad142
--- /dev/null
@@ -0,0 +1,11 @@
+.. change::
+    :tags: bug, orm, regression
+    :tickets: 7936
+
+    Fixed regression where the change in #7861, released in version 1.4.33,
+    that brought the :class:`.Insert` construct to be partially recognized as
+    an ORM-enabled statement did not properly transfer the correct mapper /
+    mapped table state to the :class:`.Session`, causing the
+    :meth:`.Session.get_bind` method to fail for a :class:`.Session` that was
+    bound to engines and/or connections using the :paramref:`.Session.binds`
+    parameter.
index 7298d3630ef4f75675f123a76723717968c14849..3229453e7aa88db6e9c516c0e3dc0f70d005e604 100644 (file)
@@ -2209,6 +2209,14 @@ class ORMInsert(ORMDMLState, InsertDMLState):
         bind_arguments,
         is_reentrant_invoke,
     ):
+        bind_arguments["clause"] = statement
+        try:
+            plugin_subject = statement._propagate_attrs["plugin_subject"]
+        except KeyError:
+            assert False, "statement had 'orm' plugin but no plugin_subject"
+        else:
+            bind_arguments["mapper"] = plugin_subject.mapper
+
         return (
             statement,
             util.immutabledict(execution_options),
index bf39ee44c5f6de5aef96078e3dfb84f872b37973..a6480365d0b4eecdec1d6c4f155fc8717e997eb0 100644 (file)
@@ -324,6 +324,21 @@ class BindIntegrationTest(_fixtures.FixtureTest):
             lambda User: {"clause": mock.ANY, "mapper": inspect(User)},
             "e1",
         ),
+        (
+            lambda User: update(User)
+            .values(name="not ed")
+            .where(User.name == "ed"),
+            lambda User: {"clause": mock.ANY, "mapper": inspect(User)},
+            "e1",
+        ),
+        (
+            lambda User: insert(User).values(name="not ed"),
+            lambda User: {
+                "clause": mock.ANY,
+                "mapper": inspect(User),
+            },
+            "e1",
+        ),
     )
     def test_bind_through_execute(
         self, statement, expected_get_bind_args, expected_engine_name
index b2743024ab4f62e880202dee30419441b9757939..427e49e5e695c658ccc7f9684a3ee4bdbd75af1b 100644 (file)
@@ -96,6 +96,39 @@ class UpdateDeleteTest(fixtures.MappedTest):
         )
         cls.mapper_registry.map_imperatively(Address, addresses)
 
+    @testing.combinations("table", "mapper", "both", argnames="bind_type")
+    @testing.combinations(
+        "update", "insert", "delete", argnames="statement_type"
+    )
+    def test_get_bind_scenarios(self, connection, bind_type, statement_type):
+        """test for #7936"""
+
+        User = self.classes.User
+
+        if statement_type == "insert":
+            stmt = insert(User).values(
+                {User.id: 5, User.age: 25, User.name: "spongebob"}
+            )
+        elif statement_type == "update":
+            stmt = (
+                update(User)
+                .where(User.id == 2)
+                .values({User.name: "spongebob"})
+            )
+        elif statement_type == "delete":
+            stmt = delete(User)
+
+        binds = {}
+        if bind_type == "both":
+            binds = {User: connection, User.__table__: connection}
+        elif bind_type == "mapper":
+            binds = {User: connection}
+        elif bind_type == "table":
+            binds = {User.__table__: connection}
+
+        with Session(binds=binds) as sess:
+            sess.execute(stmt)
+
     def test_illegal_eval(self):
         User = self.classes.User
         s = fixture_session()