]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
fixes to actually get tests to pass
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 22 Nov 2011 23:05:05 +0000 (18:05 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 22 Nov 2011 23:05:05 +0000 (18:05 -0500)
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/expression.py
test/aaa_profiling/test_compiler.py
test/sql/test_update.py

index 24c3687e9bc12200b94f76e998b8e874f291f75d..4b1b9bd5d770914fe54cf6364f16a9bfc749ebf0 100644 (file)
@@ -1025,11 +1025,7 @@ class SQLCompiler(engine.Compiled):
 
         self.isupdate = True
 
-        if update_stmt._whereclause is not None:
-            extra_froms = set(update_stmt._whereclause._from_objects).\
-                            difference([update_stmt.table])
-        else:
-            extra_froms = None
+        extra_froms = update_stmt._extra_froms
 
         colparams = self._get_colparams(update_stmt, extra_froms)
 
@@ -1038,20 +1034,17 @@ class SQLCompiler(engine.Compiled):
                                         update_stmt.table, 
                                         extra_froms, **kw)
 
+        text += ' SET '
         if extra_froms and self.render_table_with_column_in_update_from:
-            text += ' SET ' + \
-                    ', '.join(
+            text += ', '.join(
                             self.visit_column(c[0]) + 
-                            '=' + c[1]
-                          for c in colparams
-                    )
+                            '=' + c[1] for c in colparams
+                            )
         else:
-            text += ' SET ' + \
-                ', '.join(
+            text += ', '.join(
                         self.preparer.quote(c[0].name, c[0].quote) + 
-                        '=' + c[1]
-                      for c in colparams
-                )
+                        '=' + c[1] for c in colparams
+                            )
 
         if update_stmt._returning:
             self.returning = update_stmt._returning
@@ -1144,6 +1137,8 @@ class SQLCompiler(engine.Compiled):
         postfetch_lastrowid = need_pks and self.dialect.postfetch_lastrowid
 
         check_columns = {}
+        # special logic that only occurs for multi-table UPDATE 
+        # statements
         if extra_tables and stmt.parameters:
             for t in extra_tables:
                 for c in t.c:
@@ -1186,7 +1181,7 @@ class SQLCompiler(engine.Compiled):
                     (
                         implicit_returning or 
                         not postfetch_lastrowid or 
-                        c is not t._autoincrement_column
+                        c is not stmt.table._autoincrement_column
                     ):
 
                     if implicit_returning:
@@ -1213,7 +1208,7 @@ class SQLCompiler(engine.Compiled):
                             self.returning.append(c)
                     else:
                         if c.default is not None or \
-                            c is t._autoincrement_column and (
+                            c is stmt.table._autoincrement_column and (
                                 self.dialect.supports_sequences or
                                 self.dialect.preexecute_autoincrement_sequences
                             ):
index 6520be202de667bf2f92754b81200661ccc5d501..6eb4367b3b72c387024534b57b000767d3931426 100644 (file)
@@ -5292,6 +5292,20 @@ class Update(ValuesBase):
         else:
             self._whereclause = _literal_as_text(whereclause)
 
+    @property
+    def _extra_froms(self):
+        # TODO: this could be made memoized
+        # if the memoization is reset on each generative call.
+        froms = []
+        seen = set([self.table])
+
+        if self._whereclause is not None:
+            for item in _from_objects(self._whereclause):
+                if not seen.intersection(item._cloned_set):
+                    froms.append(item)
+                seen.update(item._cloned_set)
+
+        return froms
 
 class Delete(UpdateBase):
     """Represent a DELETE construct.
index f949ce6ead739f20e722467d9e2ea25e1e23f243..a7ce7a70b18f641695a70c07fd29e6c51db887ee 100644 (file)
@@ -39,11 +39,11 @@ class CompileTest(fixtures.TestBase, AssertsExecutionResults):
     def test_insert(self):
         t1.insert().compile(dialect=self.dialect)
 
-    @profiling.function_call_count(versions={'2.6':53, '2.7':53})
+    @profiling.function_call_count(versions={'2.6':56, '2.7':56})
     def test_update(self):
         t1.update().compile(dialect=self.dialect)
 
-    @profiling.function_call_count(versions={'2.6':110, '2.7':110, '3':115})
+    @profiling.function_call_count(versions={'2.6':117, '2.7':117, '3':118})
     def test_update_whereclause(self):
         t1.update().where(t1.c.c2==12).compile(dialect=self.dialect)
 
index 87fd6ffd5eb26eb9afbf7f1892ccd4cbf43aac4a..2ea3d92a4d18b090540af12e9d407a20768119db 100644 (file)
@@ -7,9 +7,7 @@ from test.lib import *
 from test.lib.schema import Table, Column
 from sqlalchemy.dialects import mysql
 
-class UpdateFromTest(fixtures.TablesTest, AssertsCompiledSQL):
-    __dialect__ = 'default'
-
+class _UpdateFromTestBase(object):
     @classmethod
     def define_tables(cls, metadata):
         Table('users', metadata,
@@ -65,6 +63,12 @@ class UpdateFromTest(fixtures.TablesTest, AssertsCompiledSQL):
             ),
         )
 
+
+class UpdateFromCompileTest(_UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL):
+    __dialect__ = 'default'
+
+    run_create_tables = run_inserts = run_deletes = None
+
     def test_render_table(self):
         users, addresses = self.tables.users, self.tables.addresses
         self.assert_compile(
@@ -134,6 +138,8 @@ class UpdateFromTest(fixtures.TablesTest, AssertsCompiledSQL):
                             u'id_1': 7, 'name': 'newname'}
         )
 
+class UpdateFromRoundTripTest(_UpdateFromTestBase, fixtures.TablesTest):
+
     @testing.requires.update_from
     def test_exec_two_table(self):
         users, addresses = self.tables.users, self.tables.addresses