From ef79d1ae3b404780d17e8615426eeb39be1ac670 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Mon, 21 Nov 2011 22:00:50 -0500 Subject: [PATCH] passes for all three, includes multi col system with mysql --- lib/sqlalchemy/dialects/mysql/base.py | 4 +-- lib/sqlalchemy/sql/compiler.py | 38 ++++++++++++++++++++++----- 2 files changed, 34 insertions(+), 8 deletions(-) diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index 72bc1d32f3..2433d24522 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -1329,8 +1329,8 @@ class MySQLCompiler(compiler.SQLCompiler): def update_from_clause(self, update_stmt, from_table, extra_froms, **kw): return None - def should_render_table_with_col_in_update(self, update_stmt, from_table, extra_froms): - return bool(extra_froms) + render_table_with_column_in_update = True + # ug. "InnoDB needs indexes on foreign keys and referenced keys [...]. # Starting with MySQL 4.1.2, these indexes are created automatically. diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index b775919122..92c0c7b38d 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -994,8 +994,7 @@ class SQLCompiler(engine.Compiled): def update_from_clause(self, update_stmt, from_table, extra_froms, **kw): return "FROM " + ', '.join(t._compiler_dispatch(self, asfrom=True, **kw) for t in extra_froms) - def should_render_table_with_col_in_update(self, update_stmt, from_table, extra_froms): - return False + render_table_with_column_in_update = False def visit_update(self, update_stmt, **kw): self.stack.append({'from': set([update_stmt.table])}) @@ -1014,9 +1013,12 @@ class SQLCompiler(engine.Compiled): # if hasattr(c[1], '_from_objects'): # extra_froms.update(c[1]._from_objects) - text = "UPDATE " + self.update_tables_clause(update_stmt, update_stmt.table, extra_froms, **kw) + text = "UPDATE " + self.update_tables_clause( + update_stmt, + update_stmt.table, + extra_froms, **kw) - if self.should_render_table_with_col_in_update(update_stmt, update_stmt.table, extra_froms): + if extra_froms and self.render_table_with_column_in_update: text += ' SET ' + \ ', '.join( self.visit_column(c[0]) + @@ -1038,7 +1040,10 @@ class SQLCompiler(engine.Compiled): update_stmt, update_stmt._returning) if extra_froms: - extra_from_text = self.update_from_clause(update_stmt, update_stmt.table, extra_froms, **kw) + extra_from_text = self.update_from_clause( + update_stmt, + update_stmt.table, + extra_froms, **kw) if extra_from_text: text += " " + extra_from_text @@ -1104,6 +1109,7 @@ class SQLCompiler(engine.Compiled): for k, v in stmt.parameters.iteritems(): parameters.setdefault(sql._column_as_key(k), v) + # create a list of column assignment clauses as tuples values = [] @@ -1117,11 +1123,31 @@ class SQLCompiler(engine.Compiled): postfetch_lastrowid = need_pks and self.dialect.postfetch_lastrowid + check_columns = {} + if extra_tables and stmt.parameters: + for t in extra_tables: + for c in t.c: + if c in stmt.parameters: + check_columns[c.key] = c + + for c in check_columns.values(): + value = stmt.parameters[c] + if sql._is_literal(value): + value = self._create_crud_bind_param( + c, value, required=value is required) + elif c.primary_key and implicit_returning: + self.returning.append(c) + value = self.process(value.self_group()) + else: + self.postfetch.append(c) + value = self.process(value.self_group()) + values.append((c, value)) + # iterating through columns at the top to maintain ordering. # otherwise we might iterate through individual sets of # "defaults", "primary key cols", etc. for c in stmt.table.columns: - if c.key in parameters: + if c.key in parameters and c.key not in check_columns: value = parameters[c.key] if sql._is_literal(value): value = self._create_crud_bind_param( -- 2.47.2