From: Mike Bayer Date: Tue, 22 Nov 2011 01:40:31 +0000 (-0500) Subject: sort of muscling this out, mysql a PITA X-Git-Tag: rel_0_7_4~60 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=0c3a53d433d0adddfd16831380f8aea5d1fad176;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git sort of muscling this out, mysql a PITA --- diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index 1a30e15fd8..72bc1d32f3 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -1315,25 +1315,22 @@ class MySQLCompiler(compiler.SQLCompiler): # No offset provided, so just use the limit return ' \n LIMIT %s' % (self.process(sql.literal(limit)),) - def visit_update(self, update_stmt): - self.stack.append({'from': set([update_stmt.table])}) - - self.isupdate = True - colparams = self._get_colparams(update_stmt) - - text = "UPDATE " + self.preparer.format_table(update_stmt.table) + \ - " SET " + ', '.join(["%s=%s" % (self.preparer.format_column(c[0]), c[1]) for c in colparams]) - - if update_stmt._whereclause is not None: - text += " WHERE " + self.process(update_stmt._whereclause) - + def update_limit_clause(self, update_stmt): limit = update_stmt.kwargs.get('%s_limit' % self.dialect.name, None) if limit: - text += " LIMIT %s" % limit + return "LIMIT %s" % limit + else: + return None - self.stack.pop(-1) + def update_tables_clause(self, update_stmt, from_table, extra_froms, **kw): + return ', '.join(t._compiler_dispatch(self, asfrom=True, **kw) + for t in [from_table] + list(extra_froms)) - return text + def update_from_clause(self, update_stmt, from_table, extra_froms, **kw): + return None + + def should_render_table_with_col_in_update(self, update_stmt, from_table, extra_froms): + return bool(extra_froms) # ug. "InnoDB needs indexes on foreign keys and referenced keys [...]. # Starting with MySQL 4.1.2, these indexes are created automatically. diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 8d7f2aab93..b775919122 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -985,15 +985,46 @@ class SQLCompiler(engine.Compiled): return text - def visit_update(self, update_stmt): + def update_limit_clause(self, update_stmt): + return None + + def update_tables_clause(self, update_stmt, from_table, extra_froms, **kw): + return self.preparer.format_table(from_table) + + def update_from_clause(self, update_stmt, from_table, extra_froms, **kw): + return "FROM " + ', '.join(t._compiler_dispatch(self, asfrom=True, **kw) for t in extra_froms) + + def should_render_table_with_col_in_update(self, update_stmt, from_table, extra_froms): + return False + + def visit_update(self, update_stmt, **kw): self.stack.append({'from': set([update_stmt.table])}) self.isupdate = True - colparams = self._get_colparams(update_stmt) - text = "UPDATE " + self.preparer.format_table(update_stmt.table) + if update_stmt._whereclause is not None: + extra_froms = set(update_stmt._whereclause._from_objects).\ + difference([update_stmt.table]) + else: + extra_froms = set() + + colparams = self._get_colparams(update_stmt, extra_froms) + + #for c in colparams: + # if hasattr(c[1], '_from_objects'): + # extra_froms.update(c[1]._from_objects) - text += ' SET ' + \ + text = "UPDATE " + self.update_tables_clause(update_stmt, update_stmt.table, extra_froms, **kw) + + if self.should_render_table_with_col_in_update(update_stmt, update_stmt.table, extra_froms): + text += ' SET ' + \ + ', '.join( + self.visit_column(c[0]) + + '=' + c[1] + for c in colparams + ) + else: + text += ' SET ' + \ ', '.join( self.preparer.quote(c[0].name, c[0].quote) + '=' + c[1] @@ -1006,9 +1037,18 @@ class SQLCompiler(engine.Compiled): text += " " + self.returning_clause( update_stmt, update_stmt._returning) + if extra_froms: + extra_from_text = self.update_from_clause(update_stmt, update_stmt.table, extra_froms, **kw) + if extra_from_text: + text += " " + extra_from_text + if update_stmt._whereclause is not None: text += " WHERE " + self.process(update_stmt._whereclause) + limit_clause = self.update_limit_clause(update_stmt) + if limit_clause: + text += " " + limit_clause + if self.returning and not self.returning_precedes_values: text += " " + self.returning_clause( update_stmt, update_stmt._returning) @@ -1024,7 +1064,7 @@ class SQLCompiler(engine.Compiled): return bindparam._compiler_dispatch(self) - def _get_colparams(self, stmt): + def _get_colparams(self, stmt, extra_tables=None): """create a set of tuples representing column/string pairs for use in an INSERT or UPDATE statement. @@ -1100,7 +1140,7 @@ class SQLCompiler(engine.Compiled): ( implicit_returning or not postfetch_lastrowid or - c is not stmt.table._autoincrement_column + c is not t._autoincrement_column ): if implicit_returning: @@ -1127,7 +1167,7 @@ class SQLCompiler(engine.Compiled): self.returning.append(c) else: if c.default is not None or \ - c is stmt.table._autoincrement_column and ( + c is t._autoincrement_column and ( self.dialect.supports_sequences or self.dialect.preexecute_autoincrement_sequences ): diff --git a/test/lib/requires.py b/test/lib/requires.py index e27d0193c6..9a117b6b13 100644 --- a/test/lib/requires.py +++ b/test/lib/requires.py @@ -124,6 +124,14 @@ def correlated_outer_joins(fn): no_support('oracle', 'Raises "ORA-01799: a column may not be outer-joined to a subquery"') ) +def update_from(fn): + """Target must support UPDATE..FROM syntax""" + return _chain_decorators_on( + fn, + only_on(('postgresql', 'mssql', 'mysql'), + "Backend does not support UPDATE..FROM") + ) + def savepoints(fn): """Target database must support savepoints.""" return _chain_decorators_on( diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py index 4e086c8cda..9a53dd89cc 100644 --- a/test/sql/test_compiler.py +++ b/test/sql/test_compiler.py @@ -2663,6 +2663,18 @@ class CRUDTest(fixtures.TestBase, AssertsCompiledSQL): "(SELECT myothertable.othername FROM myothertable " "WHERE myothertable.otherid = mytable.myid)") + # test correlated FROM implicit in WHERE and SET clauses + u = table1.update().values(name=table2.c.othername)\ + .where(table2.c.otherid == table1.c.myid) + self.assert_compile(u, + "UPDATE mytable SET name=myothertable.othername " + "FROM myothertable WHERE myothertable.otherid = mytable.myid") + u = table1.update().values(name='foo')\ + .where(table2.c.otherid == table1.c.myid) + self.assert_compile(u, + "UPDATE mytable SET name=:name " + "FROM myothertable WHERE myothertable.otherid = mytable.myid") + def test_delete(self): self.assert_compile( delete(table1, table1.c.myid == 7),