From: Mike Bayer Date: Sat, 8 Aug 2009 17:38:45 +0000 (+0000) Subject: clean up the way we detect MSSQL's form of RETURNING X-Git-Tag: rel_0_6beta1~348 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=cbdccb7fd26da432ddf43ae1820656505acad37e;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git clean up the way we detect MSSQL's form of RETURNING --- diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index cd031af401..07ed37c356 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -885,7 +885,8 @@ class MSExecutionContext(default.DefaultExecutionContext): return base.ResultProxy(self) class MSSQLCompiler(compiler.SQLCompiler): - + returning_precedes_values = True + extract_map = compiler.SQLCompiler.extract_map.copy() extract_map.update ({ 'doy': 'dayofyear', diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index a47922cc51..d6187bcde8 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -161,8 +161,17 @@ class SQLCompiler(engine.Compiled): # level to define if this Compiled instance represents # INSERT/UPDATE/DELETE isdelete = isinsert = isupdate = False + + # holds the "returning" collection of columns if + # the statement is CRUD and defines returning columns + # either implicitly or explicitly returning = None + # set to True classwide to generate RETURNING + # clauses before the VALUES or WHERE clause (i.e. MSSQL) + returning_precedes_values = False + + def __init__(self, dialect, statement, column_keys=None, inline=False, **kwargs): """Construct a new ``DefaultCompiler`` object. @@ -699,10 +708,8 @@ class SQLCompiler(engine.Compiled): self.returning = self.returning or insert_stmt._returning returning_clause = self.returning_clause(insert_stmt, self.returning) - # cheating - if returning_clause.startswith("OUTPUT"): + if self.returning_precedes_values: text += " " + returning_clause - returning_clause = None if not colparams and supports_default_values: text += " DEFAULT VALUES" @@ -710,7 +717,7 @@ class SQLCompiler(engine.Compiled): text += " VALUES (%s)" % \ ', '.join([c[1] for c in colparams]) - if self.returning and returning_clause: + if self.returning and not self.returning_precedes_values: text += " " + returning_clause return text @@ -732,14 +739,14 @@ class SQLCompiler(engine.Compiled): if update_stmt._returning: self.returning = update_stmt._returning returning_clause = self.returning_clause(update_stmt, update_stmt._returning) - if returning_clause.startswith("OUTPUT"): + + if self.returning_precedes_values: text += " " + returning_clause - returning_clause = None if update_stmt._whereclause: text += " WHERE " + self.process(update_stmt._whereclause) - if self.returning and returning_clause: + if self.returning and not self.returning_precedes_values: text += " " + returning_clause self.stack.pop(-1) @@ -755,6 +762,11 @@ class SQLCompiler(engine.Compiled): """create a set of tuples representing column/string pairs for use in an INSERT or UPDATE statement. + Also generates the Compiled object's postfetch, prefetch, and returning + column collections, used for default handling and ultimately + populating the ResultProxy's prefetch_cols() and postfetch_cols() + collections. + """ self.postfetch = [] @@ -880,14 +892,14 @@ class SQLCompiler(engine.Compiled): if delete_stmt._returning: self.returning = delete_stmt._returning returning_clause = self.returning_clause(delete_stmt, delete_stmt._returning) - if returning_clause.startswith("OUTPUT"): + + if self.returning_precedes_values: text += " " + returning_clause - returning_clause = None if delete_stmt._whereclause: text += " WHERE " + self.process(delete_stmt._whereclause) - if self.returning and returning_clause: + if self.returning and not self.returning_precedes_values: text += " " + returning_clause self.stack.pop(-1) diff --git a/test/dialect/test_mssql.py b/test/dialect/test_mssql.py index 423310db62..5b4978f9ef 100644 --- a/test/dialect/test_mssql.py +++ b/test/dialect/test_mssql.py @@ -179,6 +179,20 @@ class CompileTest(TestBase, AssertsCompiledSQL): u = update(table1, values=dict(name='foo')).returning(func.length(table1.c.name)) self.assert_compile(u, "UPDATE mytable SET name=:name OUTPUT LEN(inserted.name) AS length_1") + def test_delete_returning(self): + table1 = table('mytable', + column('myid', Integer), + column('name', String(128)), + column('description', String(128)), + ) + + d = delete(table1).returning(table1.c.myid, table1.c.name) + self.assert_compile(d, "DELETE FROM mytable OUTPUT deleted.myid, deleted.name") + + d = delete(table1).where(table1.c.name=='bar').returning(table1.c.myid, table1.c.name) + self.assert_compile(d, "DELETE FROM mytable OUTPUT deleted.myid, deleted.name WHERE mytable.name = :name_1") + + def test_insert_returning(self): table1 = table('mytable', column('myid', Integer),