]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- MySQL's update does work. add some logic to compiler to convert from ORM column...
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 20 Aug 2012 22:28:32 +0000 (18:28 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 20 Aug 2012 22:28:32 +0000 (18:28 -0400)
CHANGES
lib/sqlalchemy/sql/compiler.py
test/orm/test_update_delete.py

diff --git a/CHANGES b/CHANGES
index 53705a527e7603fcdf085488aab47cd994c6ba9a..ac489bc8e53bc6d2daf88d8bbd79d9823c65f2cc 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -142,10 +142,9 @@ underneath "0.7.xx".
     and if the parent table is referenced in the
     WHERE clause, the compiler will call upon
     UPDATE..FROM syntax as allowed by the dialect
-    to satisfy the WHERE clause.  Target columns
-    must still be in the target table i.e.
-    does not support MySQL's multi-table update
-    feature (even though this is in Core).
+    to satisfy the WHERE clause.  MySQL's multi-table
+    update feature is also supported if columns
+    are specified by object in the "values" dicitionary.
     PG's DELETE..USING is also not available
     in Core yet.
 
index c56b7fc37755f4527a267de9a631f985bac27ef4..fd9718f1f2c18ddcca44669b7e772a161a3b60e5 100644 (file)
@@ -1446,14 +1446,18 @@ class SQLCompiler(engine.Compiled):
         # special logic that only occurs for multi-table UPDATE
         # statements
         if extra_tables and stmt.parameters:
+            normalized_params = dict(
+                (sql._clause_element_as_expr(c), param)
+                for c, param in stmt.parameters.items()
+            )
             assert self.isupdate
             affected_tables = set()
             for t in extra_tables:
                 for c in t.c:
-                    if c in stmt.parameters:
+                    if c in normalized_params:
                         affected_tables.add(t)
                         check_columns[c.key] = c
-                        value = stmt.parameters[c]
+                        value = normalized_params[c]
                         if sql._is_literal(value):
                             value = self._create_crud_bind_param(
                                             c, value, required=value is required)
@@ -1466,7 +1470,7 @@ class SQLCompiler(engine.Compiled):
             # server_onupdate for these
             for t in affected_tables:
                 for c in t.c:
-                    if c in stmt.parameters:
+                    if c in normalized_params:
                         continue
                     elif c.onupdate is not None and not c.onupdate.is_sequence:
                         if c.onupdate.is_clause_element:
index e6a429c90b4a5427e351288afad4edc150e0aec2..e259c52295e3eb4dabffc7e9e19d59d38e2380cf 100644 (file)
@@ -642,3 +642,16 @@ class InheritTest(fixtures.DeclarativeMappedTest):
             set([('e1', 'e1', ), ('e2', 'e5')])
         )
 
+    @testing.only_on('mysql', 'Multi table update')
+    def test_update_from_multitable(self):
+        Engineer = self.classes.Engineer
+        Person = self.classes.Person
+        s = Session(testing.db)
+        s.query(Engineer).filter(Engineer.id == Person.id).\
+            filter(Person.name == 'e2').update({Person.name: 'e22',
+                                Engineer.engineer_name: 'e55'})
+
+        eq_(
+            set(s.query(Person.name, Engineer.engineer_name)),
+            set([('e1', 'e1', ), ('e22', 'e55')])
+        )