From: Mike Bayer Date: Wed, 13 Jun 2018 19:59:35 +0000 (-0400) Subject: Support JOIN in UPDATE..FROM X-Git-Tag: rel_1_3_0b1~160^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=58540ae93db30fb12f331587c32bb2d76db79ab3;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Support JOIN in UPDATE..FROM The :class:`.Update` construct now accommodates a :class:`.Join` object as supported by MySQL for UPDATE..FROM. As the construct already accepted an alias object for a similar purpose, the feature of UPDATE against a non-table was already implied so this has been added. Change-Id: I7b2bca627849384d5377abb0c94626463e4fad04 Fixes: #3645 --- diff --git a/doc/build/changelog/unreleased_12/3645.rst b/doc/build/changelog/unreleased_12/3645.rst new file mode 100644 index 0000000000..e750744b15 --- /dev/null +++ b/doc/build/changelog/unreleased_12/3645.rst @@ -0,0 +1,8 @@ +.. change:: + :tags: bug, mysql + :tickets: 3645 + + The :class:`.Update` construct now accommodates a :class:`.Join` object + as supported by MySQL for UPDATE..FROM. As the construct already + accepted an alias object for a similar purpose, the feature of UPDATE + against a non-table was already implied so this has been added. diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index a442c65fd6..f6cdebb16f 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -2178,6 +2178,16 @@ class SQLCompiler(Compiled): "selectable": update_stmt}) extra_froms = update_stmt._extra_froms + is_multitable = bool(extra_froms) + + if is_multitable: + # main table might be a JOIN + main_froms = set(selectable._from_objects(update_stmt.table)) + render_extra_froms = [ + f for f in extra_froms if f not in main_froms + ] + else: + render_extra_froms = [] text = "UPDATE " @@ -2186,8 +2196,7 @@ class SQLCompiler(Compiled): update_stmt._prefixes, **kw) table_text = self.update_tables_clause(update_stmt, update_stmt.table, - extra_froms, **kw) - + render_extra_froms, **kw) crud_params = crud._setup_crud_params( self, update_stmt, crud.ISUPDATE, **kw) @@ -2200,7 +2209,7 @@ class SQLCompiler(Compiled): text += table_text text += ' SET ' - include_table = extra_froms and \ + include_table = is_multitable and \ self.render_table_with_column_in_update_from text += ', '.join( c[0]._compiler_dispatch(self, @@ -2217,7 +2226,7 @@ class SQLCompiler(Compiled): extra_from_text = self.update_from_clause( update_stmt, update_stmt.table, - extra_froms, + render_extra_froms, dialect_hints, **kw) if extra_from_text: text += " " + extra_from_text diff --git a/test/sql/test_update.py b/test/sql/test_update.py index 9ebaddffd1..cc5b4962bf 100644 --- a/test/sql/test_update.py +++ b/test/sql/test_update.py @@ -481,6 +481,23 @@ class UpdateFromCompileTest(_UpdateFromTestBase, fixtures.TablesTest, dialect='mysql' ) + def test_update_from_join_mysql(self): + users, addresses = self.tables.users, self.tables.addresses + + j = users.join(addresses) + self.assert_compile( + update(j). + values(name='newname'). + where(addresses.c.email_address == 'e1'), + "" + 'UPDATE users ' + 'INNER JOIN addresses ON users.id = addresses.user_id ' + 'SET users.name=%s ' + 'WHERE ' + 'addresses.email_address = %s', + checkparams={'email_address_1': 'e1', 'name': 'newname'}, + dialect=mysql.dialect()) + def test_render_table(self): users, addresses = self.tables.users, self.tables.addresses @@ -669,6 +686,35 @@ class UpdateFromRoundTripTest(_UpdateFromTestBase, fixtures.TablesTest): (10, 'chuck')] self._assert_users(users, expected) + @testing.only_on('mysql', 'Multi table update') + def test_exec_join_multitable(self): + users, addresses = self.tables.users, self.tables.addresses + + values = { + addresses.c.email_address: 'updated', + users.c.name: 'ed2' + } + + testing.db.execute( + update(users.join(addresses)). + values(values). + where(users.c.name == 'ed')) + + expected = [ + (1, 7, 'x', 'jack@bean.com'), + (2, 8, 'x', 'updated'), + (3, 8, 'x', 'updated'), + (4, 8, 'x', 'updated'), + (5, 9, 'x', 'fred@fred.com')] + self._assert_addresses(addresses, expected) + + expected = [ + (7, 'jack'), + (8, 'ed2'), + (9, 'fred'), + (10, 'chuck')] + self._assert_users(users, expected) + @testing.only_on('mysql', 'Multi table update') def test_exec_multitable_same_name(self): users, addresses = self.tables.users, self.tables.addresses