]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
passes for all three, includes multi col system with mysql
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 22 Nov 2011 03:00:50 +0000 (22:00 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 22 Nov 2011 03:00:50 +0000 (22:00 -0500)
lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/sql/compiler.py

index 72bc1d32f36c3fe02867c563b05b9ecf6ab9d345..2433d24522e1a1a458799d1cedc25ea831235af3 100644 (file)
@@ -1329,8 +1329,8 @@ class MySQLCompiler(compiler.SQLCompiler):
     def update_from_clause(self, update_stmt, from_table, extra_froms, **kw):
         return None
 
-    def should_render_table_with_col_in_update(self, update_stmt, from_table, extra_froms):
-        return bool(extra_froms)
+    render_table_with_column_in_update = True
+
 
 # ug.  "InnoDB needs indexes on foreign keys and referenced keys [...].
 #       Starting with MySQL 4.1.2, these indexes are created automatically.
index b775919122f6057c06815b7c4cf338d2fa9fb115..92c0c7b38de8771614e8e32ab814056f95681d87 100644 (file)
@@ -994,8 +994,7 @@ class SQLCompiler(engine.Compiled):
     def update_from_clause(self, update_stmt, from_table, extra_froms, **kw):
         return "FROM " + ', '.join(t._compiler_dispatch(self, asfrom=True, **kw) for t in extra_froms)
 
-    def should_render_table_with_col_in_update(self, update_stmt, from_table, extra_froms):
-        return False
+    render_table_with_column_in_update = False
 
     def visit_update(self, update_stmt, **kw):
         self.stack.append({'from': set([update_stmt.table])})
@@ -1014,9 +1013,12 @@ class SQLCompiler(engine.Compiled):
         #    if hasattr(c[1], '_from_objects'):
         #        extra_froms.update(c[1]._from_objects)
 
-        text = "UPDATE " + self.update_tables_clause(update_stmt, update_stmt.table, extra_froms, **kw)
+        text = "UPDATE " + self.update_tables_clause(
+                                        update_stmt, 
+                                        update_stmt.table, 
+                                        extra_froms, **kw)
 
-        if self.should_render_table_with_col_in_update(update_stmt, update_stmt.table, extra_froms):
+        if extra_froms and self.render_table_with_column_in_update:
             text += ' SET ' + \
                     ', '.join(
                             self.visit_column(c[0]) + 
@@ -1038,7 +1040,10 @@ class SQLCompiler(engine.Compiled):
                                     update_stmt, update_stmt._returning)
 
         if extra_froms:
-            extra_from_text = self.update_from_clause(update_stmt, update_stmt.table, extra_froms, **kw)
+            extra_from_text = self.update_from_clause(
+                                        update_stmt, 
+                                        update_stmt.table, 
+                                        extra_froms, **kw)
             if extra_from_text:
                 text += " " + extra_from_text
 
@@ -1104,6 +1109,7 @@ class SQLCompiler(engine.Compiled):
             for k, v in stmt.parameters.iteritems():
                 parameters.setdefault(sql._column_as_key(k), v)
 
+
         # create a list of column assignment clauses as tuples
         values = []
 
@@ -1117,11 +1123,31 @@ class SQLCompiler(engine.Compiled):
 
         postfetch_lastrowid = need_pks and self.dialect.postfetch_lastrowid
 
+        check_columns = {}
+        if extra_tables and stmt.parameters:
+            for t in extra_tables:
+                for c in t.c:
+                    if c in stmt.parameters:
+                        check_columns[c.key] = c
+
+            for c in check_columns.values():
+                value = stmt.parameters[c]
+                if sql._is_literal(value):
+                    value = self._create_crud_bind_param(
+                                    c, value, required=value is required)
+                elif c.primary_key and implicit_returning:
+                    self.returning.append(c)
+                    value = self.process(value.self_group())
+                else:
+                    self.postfetch.append(c)
+                    value = self.process(value.self_group())
+                values.append((c, value))
+
         # iterating through columns at the top to maintain ordering.
         # otherwise we might iterate through individual sets of 
         # "defaults", "primary key cols", etc.
         for c in stmt.table.columns:
-            if c.key in parameters:
+            if c.key in parameters and c.key not in check_columns:
                 value = parameters[c.key]
                 if sql._is_literal(value):
                     value = self._create_crud_bind_param(