]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Support JOIN in UPDATE..FROM
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 13 Jun 2018 19:59:35 +0000 (15:59 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 14 Jun 2018 01:38:20 +0000 (21:38 -0400)
The :class:`.Update` construct now accommodates a :class:`.Join` object
as supported by MySQL for UPDATE..FROM.  As the construct already
accepted an alias object for a similar purpose, the feature of UPDATE
against a non-table was already implied so this has been added.

Change-Id: I7b2bca627849384d5377abb0c94626463e4fad04
Fixes: #3645
doc/build/changelog/unreleased_12/3645.rst [new file with mode: 0644]
lib/sqlalchemy/sql/compiler.py
test/sql/test_update.py

diff --git a/doc/build/changelog/unreleased_12/3645.rst b/doc/build/changelog/unreleased_12/3645.rst
new file mode 100644 (file)
index 0000000..e750744
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: bug, mysql
+    :tickets: 3645
+
+    The :class:`.Update` construct now accommodates a :class:`.Join` object
+    as supported by MySQL for UPDATE..FROM.  As the construct already
+    accepted an alias object for a similar purpose, the feature of UPDATE
+    against a non-table was already implied so this has been added.
index a442c65fd6d0fffe6eb878c167cdc55b7f502282..f6cdebb16f846a8e9f9463d36e4846bf76540262 100644 (file)
@@ -2178,6 +2178,16 @@ class SQLCompiler(Compiled):
              "selectable": update_stmt})
 
         extra_froms = update_stmt._extra_froms
+        is_multitable = bool(extra_froms)
+
+        if is_multitable:
+            # main table might be a JOIN
+            main_froms = set(selectable._from_objects(update_stmt.table))
+            render_extra_froms = [
+                f for f in extra_froms if f not in main_froms
+            ]
+        else:
+            render_extra_froms = []
 
         text = "UPDATE "
 
@@ -2186,8 +2196,7 @@ class SQLCompiler(Compiled):
                                             update_stmt._prefixes, **kw)
 
         table_text = self.update_tables_clause(update_stmt, update_stmt.table,
-                                               extra_froms, **kw)
-
+                                               render_extra_froms, **kw)
         crud_params = crud._setup_crud_params(
             self, update_stmt, crud.ISUPDATE, **kw)
 
@@ -2200,7 +2209,7 @@ class SQLCompiler(Compiled):
         text += table_text
 
         text += ' SET '
-        include_table = extra_froms and \
+        include_table = is_multitable and \
             self.render_table_with_column_in_update_from
         text += ', '.join(
             c[0]._compiler_dispatch(self,
@@ -2217,7 +2226,7 @@ class SQLCompiler(Compiled):
             extra_from_text = self.update_from_clause(
                 update_stmt,
                 update_stmt.table,
-                extra_froms,
+                render_extra_froms,
                 dialect_hints, **kw)
             if extra_from_text:
                 text += " " + extra_from_text
index 9ebaddffd1d96382317cabf7bd62cb11e13eaf16..cc5b4962bf00f48e8c652c7f714650a5df60e48d 100644 (file)
@@ -481,6 +481,23 @@ class UpdateFromCompileTest(_UpdateFromTestBase, fixtures.TablesTest,
             dialect='mysql'
         )
 
+    def test_update_from_join_mysql(self):
+        users, addresses = self.tables.users, self.tables.addresses
+
+        j = users.join(addresses)
+        self.assert_compile(
+            update(j).
+            values(name='newname').
+            where(addresses.c.email_address == 'e1'),
+            ""
+            'UPDATE users '
+            'INNER JOIN addresses ON users.id = addresses.user_id '
+            'SET users.name=%s '
+            'WHERE '
+            'addresses.email_address = %s',
+            checkparams={'email_address_1': 'e1', 'name': 'newname'},
+            dialect=mysql.dialect())
+
     def test_render_table(self):
         users, addresses = self.tables.users, self.tables.addresses
 
@@ -669,6 +686,35 @@ class UpdateFromRoundTripTest(_UpdateFromTestBase, fixtures.TablesTest):
             (10, 'chuck')]
         self._assert_users(users, expected)
 
+    @testing.only_on('mysql', 'Multi table update')
+    def test_exec_join_multitable(self):
+        users, addresses = self.tables.users, self.tables.addresses
+
+        values = {
+            addresses.c.email_address: 'updated',
+            users.c.name: 'ed2'
+        }
+
+        testing.db.execute(
+            update(users.join(addresses)).
+            values(values).
+            where(users.c.name == 'ed'))
+
+        expected = [
+            (1, 7, 'x', 'jack@bean.com'),
+            (2, 8, 'x', 'updated'),
+            (3, 8, 'x', 'updated'),
+            (4, 8, 'x', 'updated'),
+            (5, 9, 'x', 'fred@fred.com')]
+        self._assert_addresses(addresses, expected)
+
+        expected = [
+            (7, 'jack'),
+            (8, 'ed2'),
+            (9, 'fred'),
+            (10, 'chuck')]
+        self._assert_users(users, expected)
+
     @testing.only_on('mysql', 'Multi table update')
     def test_exec_multitable_same_name(self):
         users, addresses = self.tables.users, self.tables.addresses