]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
sort of muscling this out, mysql a PITA
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 22 Nov 2011 01:40:31 +0000 (20:40 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 22 Nov 2011 01:40:31 +0000 (20:40 -0500)
lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/sql/compiler.py
test/lib/requires.py
test/sql/test_compiler.py

index 1a30e15fd830d788ccaf315f2e4a4181ed69e408..72bc1d32f36c3fe02867c563b05b9ecf6ab9d345 100644 (file)
@@ -1315,25 +1315,22 @@ class MySQLCompiler(compiler.SQLCompiler):
             # No offset provided, so just use the limit
             return ' \n LIMIT %s' % (self.process(sql.literal(limit)),)
 
-    def visit_update(self, update_stmt):
-        self.stack.append({'from': set([update_stmt.table])})
-
-        self.isupdate = True
-        colparams = self._get_colparams(update_stmt)
-
-        text = "UPDATE " + self.preparer.format_table(update_stmt.table) + \
-                " SET " + ', '.join(["%s=%s" % (self.preparer.format_column(c[0]), c[1]) for c in colparams])
-
-        if update_stmt._whereclause is not None:
-            text += " WHERE " + self.process(update_stmt._whereclause)
-
+    def update_limit_clause(self, update_stmt):
         limit = update_stmt.kwargs.get('%s_limit' % self.dialect.name, None)
         if limit:
-            text += " LIMIT %s" % limit
+            return "LIMIT %s" % limit
+        else:
+            return None
 
-        self.stack.pop(-1)
+    def update_tables_clause(self, update_stmt, from_table, extra_froms, **kw):
+        return ', '.join(t._compiler_dispatch(self, asfrom=True, **kw) 
+                    for t in [from_table] + list(extra_froms))
 
-        return text
+    def update_from_clause(self, update_stmt, from_table, extra_froms, **kw):
+        return None
+
+    def should_render_table_with_col_in_update(self, update_stmt, from_table, extra_froms):
+        return bool(extra_froms)
 
 # ug.  "InnoDB needs indexes on foreign keys and referenced keys [...].
 #       Starting with MySQL 4.1.2, these indexes are created automatically.
index 8d7f2aab937fa5f9c7a2981654dacd84db960667..b775919122f6057c06815b7c4cf338d2fa9fb115 100644 (file)
@@ -985,15 +985,46 @@ class SQLCompiler(engine.Compiled):
 
         return text
 
-    def visit_update(self, update_stmt):
+    def update_limit_clause(self, update_stmt):
+        return None
+
+    def update_tables_clause(self, update_stmt, from_table, extra_froms, **kw):
+        return self.preparer.format_table(from_table)
+
+    def update_from_clause(self, update_stmt, from_table, extra_froms, **kw):
+        return "FROM " + ', '.join(t._compiler_dispatch(self, asfrom=True, **kw) for t in extra_froms)
+
+    def should_render_table_with_col_in_update(self, update_stmt, from_table, extra_froms):
+        return False
+
+    def visit_update(self, update_stmt, **kw):
         self.stack.append({'from': set([update_stmt.table])})
 
         self.isupdate = True
-        colparams = self._get_colparams(update_stmt)
 
-        text = "UPDATE " + self.preparer.format_table(update_stmt.table)
+        if update_stmt._whereclause is not None:
+            extra_froms = set(update_stmt._whereclause._from_objects).\
+                            difference([update_stmt.table])
+        else:
+            extra_froms = set()
+
+        colparams = self._get_colparams(update_stmt, extra_froms)
+
+        #for c in colparams:
+        #    if hasattr(c[1], '_from_objects'):
+        #        extra_froms.update(c[1]._from_objects)
 
-        text += ' SET ' + \
+        text = "UPDATE " + self.update_tables_clause(update_stmt, update_stmt.table, extra_froms, **kw)
+
+        if self.should_render_table_with_col_in_update(update_stmt, update_stmt.table, extra_froms):
+            text += ' SET ' + \
+                    ', '.join(
+                            self.visit_column(c[0]) + 
+                            '=' + c[1]
+                          for c in colparams
+                    )
+        else:
+            text += ' SET ' + \
                 ', '.join(
                         self.preparer.quote(c[0].name, c[0].quote) + 
                         '=' + c[1]
@@ -1006,9 +1037,18 @@ class SQLCompiler(engine.Compiled):
                 text += " " + self.returning_clause(
                                     update_stmt, update_stmt._returning)
 
+        if extra_froms:
+            extra_from_text = self.update_from_clause(update_stmt, update_stmt.table, extra_froms, **kw)
+            if extra_from_text:
+                text += " " + extra_from_text
+
         if update_stmt._whereclause is not None:
             text += " WHERE " + self.process(update_stmt._whereclause)
 
+        limit_clause = self.update_limit_clause(update_stmt)
+        if limit_clause:
+            text += " " + limit_clause
+
         if self.returning and not self.returning_precedes_values:
             text += " " + self.returning_clause(
                                     update_stmt, update_stmt._returning)
@@ -1024,7 +1064,7 @@ class SQLCompiler(engine.Compiled):
         return bindparam._compiler_dispatch(self)
 
 
-    def _get_colparams(self, stmt):
+    def _get_colparams(self, stmt, extra_tables=None):
         """create a set of tuples representing column/string pairs for use
         in an INSERT or UPDATE statement.
 
@@ -1100,7 +1140,7 @@ class SQLCompiler(engine.Compiled):
                     (
                         implicit_returning or 
                         not postfetch_lastrowid or 
-                        c is not stmt.table._autoincrement_column
+                        c is not t._autoincrement_column
                     ):
 
                     if implicit_returning:
@@ -1127,7 +1167,7 @@ class SQLCompiler(engine.Compiled):
                             self.returning.append(c)
                     else:
                         if c.default is not None or \
-                            c is stmt.table._autoincrement_column and (
+                            c is t._autoincrement_column and (
                                 self.dialect.supports_sequences or
                                 self.dialect.preexecute_autoincrement_sequences
                             ):
index e27d0193c6aedf6abc1e1112b85f64067d716d32..9a117b6b13dba38e827ad74f95e38438a100a22e 100644 (file)
@@ -124,6 +124,14 @@ def correlated_outer_joins(fn):
         no_support('oracle', 'Raises "ORA-01799: a column may not be outer-joined to a subquery"')
     )
 
+def update_from(fn):
+    """Target must support UPDATE..FROM syntax"""
+    return _chain_decorators_on(
+        fn,
+        only_on(('postgresql', 'mssql', 'mysql'), 
+            "Backend does not support UPDATE..FROM")
+    )
+
 def savepoints(fn):
     """Target database must support savepoints."""
     return _chain_decorators_on(
index 4e086c8cda666302490f4522b09dc8e2e70f4bd2..9a53dd89ccaab1c6009524acbf65ff5c217145a0 100644 (file)
@@ -2663,6 +2663,18 @@ class CRUDTest(fixtures.TestBase, AssertsCompiledSQL):
                 "(SELECT myothertable.othername FROM myothertable "
                 "WHERE myothertable.otherid = mytable.myid)")
 
+        # test correlated FROM implicit in WHERE and SET clauses
+        u = table1.update().values(name=table2.c.othername)\
+                  .where(table2.c.otherid == table1.c.myid)
+        self.assert_compile(u,
+                "UPDATE mytable SET name=myothertable.othername "
+                "FROM myothertable WHERE myothertable.otherid = mytable.myid")
+        u = table1.update().values(name='foo')\
+                  .where(table2.c.otherid == table1.c.myid)
+        self.assert_compile(u,
+                "UPDATE mytable SET name=:name "
+                "FROM myothertable WHERE myothertable.otherid = mytable.myid")
+
     def test_delete(self):
         self.assert_compile(
                         delete(table1, table1.c.myid == 7),