From: Mike Bayer Date: Wed, 18 Apr 2012 23:52:58 +0000 (-0400) Subject: - [bug] UPDATE..FROM syntax with SQL Server X-Git-Tag: rel_0_7_7~15 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=541e59c3d7c141cfe532b26b5fbf4b8a8d30b841;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - [bug] UPDATE..FROM syntax with SQL Server requires that the updated table be present in the FROM clause when an alias of that table is also present in the FROM clause. The updated table is now always present in the FROM, when FROM is present in the first place. Courtesy sayap. [ticket:2468] --- diff --git a/CHANGES b/CHANGES index 83408f68d5..1fb2557de5 100644 --- a/CHANGES +++ b/CHANGES @@ -68,6 +68,15 @@ CHANGES INSERT to get at the last inserted ID, for those tables which have "implicit_returning" set to False. + + - [bug] UPDATE..FROM syntax with SQL Server + requires that the updated table be present + in the FROM clause when an alias of that + table is also present in the FROM clause. + The updated table is now always present + in the FROM, when FROM is present + in the first place. Courtesy sayap. + [ticket:2468] - postgresql - [feature] Added new for_update/with_lockmode() diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index 78da18711a..3366d5fab7 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -985,6 +985,22 @@ class MSSQLCompiler(compiler.SQLCompiler): else: return "" + def update_from_clause(self, update_stmt, + from_table, extra_froms, + from_hints, + **kw): + """Render the UPDATE..FROM clause specific to MSSQL. + + In MSSQL, if the UPDATE statement involves an alias of the table to + be updated, then the table itself must be added to the FROM list as + well. Otherwise, it is optional. Here, we add it regardless. + + """ + return "FROM " + ', '.join( + t._compiler_dispatch(self, asfrom=True, + fromhints=from_hints, **kw) + for t in [from_table] + extra_froms) + class MSSQLStrictCompiler(MSSQLCompiler): """A subclass of MSSQLCompiler which disables the usage of bind parameters where not allowed natively by MS-SQL. diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index fdff99fb12..bf234fe5cc 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1115,7 +1115,7 @@ class SQLCompiler(engine.Compiled): """Provide a hook to override the generation of an UPDATE..FROM clause. - MySQL overrides this. + MySQL and MSSQL override this. """ return "FROM " + ', '.join( diff --git a/test/dialect/test_mssql.py b/test/dialect/test_mssql.py index 2b35ff57fb..74e96c8efa 100644 --- a/test/dialect/test_mssql.py +++ b/test/dialect/test_mssql.py @@ -136,7 +136,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): selectable=t2, dialect_name=darg), "UPDATE sometable SET somecolumn=:somecolumn " - "FROM othertable WITH (PAGLOCK) " + "FROM sometable, othertable WITH (PAGLOCK) " "WHERE sometable.somecolumn = othertable.somecolumn" ) diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py index c3cf001fa6..feb7405db5 100644 --- a/test/sql/test_compiler.py +++ b/test/sql/test_compiler.py @@ -2926,6 +2926,19 @@ class CRUDTest(fixtures.TestBase, AssertsCompiledSQL): "UPDATE mytable SET name=:name " "FROM myothertable WHERE myothertable.otherid = mytable.myid") + self.assert_compile(u, + "UPDATE mytable SET name=:name " + "FROM mytable, myothertable WHERE " + "myothertable.otherid = mytable.myid", + dialect=mssql.dialect()) + + self.assert_compile(u.where(table2.c.othername == mt.c.name), + "UPDATE mytable SET name=:name " + "FROM mytable, myothertable, mytable AS mytable_1 " + "WHERE myothertable.otherid = mytable.myid " + "AND myothertable.othername = mytable_1.name", + dialect=mssql.dialect()) + def test_delete(self): self.assert_compile( delete(table1, table1.c.myid == 7), diff --git a/test/sql/test_update.py b/test/sql/test_update.py index 8eccde999b..f900a164cf 100644 --- a/test/sql/test_update.py +++ b/test/sql/test_update.py @@ -156,6 +156,31 @@ class UpdateFromRoundTripTest(_UpdateFromTestBase, fixtures.TablesTest): ] ) + @testing.requires.update_from + def test_exec_two_table_plus_alias(self): + users, addresses = self.tables.users, self.tables.addresses + a1 = addresses.alias() + + testing.db.execute( + addresses.update().\ + values(email_address=users.c.name).\ + where(users.c.id==a1.c.user_id).\ + where(users.c.name=='ed').\ + where(a1.c.id==addresses.c.id) + ) + eq_( + testing.db.execute( + addresses.select().\ + order_by(addresses.c.id)).fetchall(), + [ + (1, 7, 'x', "jack@bean.com"), + (2, 8, 'x', "ed"), + (3, 8, 'x', "ed"), + (4, 8, 'x', "ed"), + (5, 9, 'x', "fred@fred.com") + ] + ) + @testing.requires.update_from def test_exec_three_table(self): users, addresses, dingalings = \