From: Mike Bayer Date: Tue, 22 Nov 2011 23:05:05 +0000 (-0500) Subject: fixes to actually get tests to pass X-Git-Tag: rel_0_7_4~56 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=4de3b28abce67a09dfde1cffd8a244b6542ae8c1;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git fixes to actually get tests to pass --- diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 24c3687e9b..4b1b9bd5d7 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1025,11 +1025,7 @@ class SQLCompiler(engine.Compiled): self.isupdate = True - if update_stmt._whereclause is not None: - extra_froms = set(update_stmt._whereclause._from_objects).\ - difference([update_stmt.table]) - else: - extra_froms = None + extra_froms = update_stmt._extra_froms colparams = self._get_colparams(update_stmt, extra_froms) @@ -1038,20 +1034,17 @@ class SQLCompiler(engine.Compiled): update_stmt.table, extra_froms, **kw) + text += ' SET ' if extra_froms and self.render_table_with_column_in_update_from: - text += ' SET ' + \ - ', '.join( + text += ', '.join( self.visit_column(c[0]) + - '=' + c[1] - for c in colparams - ) + '=' + c[1] for c in colparams + ) else: - text += ' SET ' + \ - ', '.join( + text += ', '.join( self.preparer.quote(c[0].name, c[0].quote) + - '=' + c[1] - for c in colparams - ) + '=' + c[1] for c in colparams + ) if update_stmt._returning: self.returning = update_stmt._returning @@ -1144,6 +1137,8 @@ class SQLCompiler(engine.Compiled): postfetch_lastrowid = need_pks and self.dialect.postfetch_lastrowid check_columns = {} + # special logic that only occurs for multi-table UPDATE + # statements if extra_tables and stmt.parameters: for t in extra_tables: for c in t.c: @@ -1186,7 +1181,7 @@ class SQLCompiler(engine.Compiled): ( implicit_returning or not postfetch_lastrowid or - c is not t._autoincrement_column + c is not stmt.table._autoincrement_column ): if implicit_returning: @@ -1213,7 +1208,7 @@ class SQLCompiler(engine.Compiled): self.returning.append(c) else: if c.default is not None or \ - c is t._autoincrement_column and ( + c is stmt.table._autoincrement_column and ( self.dialect.supports_sequences or self.dialect.preexecute_autoincrement_sequences ): diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 6520be202d..6eb4367b3b 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -5292,6 +5292,20 @@ class Update(ValuesBase): else: self._whereclause = _literal_as_text(whereclause) + @property + def _extra_froms(self): + # TODO: this could be made memoized + # if the memoization is reset on each generative call. + froms = [] + seen = set([self.table]) + + if self._whereclause is not None: + for item in _from_objects(self._whereclause): + if not seen.intersection(item._cloned_set): + froms.append(item) + seen.update(item._cloned_set) + + return froms class Delete(UpdateBase): """Represent a DELETE construct. diff --git a/test/aaa_profiling/test_compiler.py b/test/aaa_profiling/test_compiler.py index f949ce6ead..a7ce7a70b1 100644 --- a/test/aaa_profiling/test_compiler.py +++ b/test/aaa_profiling/test_compiler.py @@ -39,11 +39,11 @@ class CompileTest(fixtures.TestBase, AssertsExecutionResults): def test_insert(self): t1.insert().compile(dialect=self.dialect) - @profiling.function_call_count(versions={'2.6':53, '2.7':53}) + @profiling.function_call_count(versions={'2.6':56, '2.7':56}) def test_update(self): t1.update().compile(dialect=self.dialect) - @profiling.function_call_count(versions={'2.6':110, '2.7':110, '3':115}) + @profiling.function_call_count(versions={'2.6':117, '2.7':117, '3':118}) def test_update_whereclause(self): t1.update().where(t1.c.c2==12).compile(dialect=self.dialect) diff --git a/test/sql/test_update.py b/test/sql/test_update.py index 87fd6ffd5e..2ea3d92a4d 100644 --- a/test/sql/test_update.py +++ b/test/sql/test_update.py @@ -7,9 +7,7 @@ from test.lib import * from test.lib.schema import Table, Column from sqlalchemy.dialects import mysql -class UpdateFromTest(fixtures.TablesTest, AssertsCompiledSQL): - __dialect__ = 'default' - +class _UpdateFromTestBase(object): @classmethod def define_tables(cls, metadata): Table('users', metadata, @@ -65,6 +63,12 @@ class UpdateFromTest(fixtures.TablesTest, AssertsCompiledSQL): ), ) + +class UpdateFromCompileTest(_UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL): + __dialect__ = 'default' + + run_create_tables = run_inserts = run_deletes = None + def test_render_table(self): users, addresses = self.tables.users, self.tables.addresses self.assert_compile( @@ -134,6 +138,8 @@ class UpdateFromTest(fixtures.TablesTest, AssertsCompiledSQL): u'id_1': 7, 'name': 'newname'} ) +class UpdateFromRoundTripTest(_UpdateFromTestBase, fixtures.TablesTest): + @testing.requires.update_from def test_exec_two_table(self): users, addresses = self.tables.users, self.tables.addresses