]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
clean up the way we detect MSSQL's form of RETURNING
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 8 Aug 2009 17:38:45 +0000 (17:38 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 8 Aug 2009 17:38:45 +0000 (17:38 +0000)
lib/sqlalchemy/dialects/mssql/base.py
lib/sqlalchemy/sql/compiler.py
test/dialect/test_mssql.py

index cd031af4011db162be75e718e410c9ebbfa78b46..07ed37c356bbfbd2af921e23d101558e3cb6b017 100644 (file)
@@ -885,7 +885,8 @@ class MSExecutionContext(default.DefaultExecutionContext):
             return base.ResultProxy(self)
 
 class MSSQLCompiler(compiler.SQLCompiler):
-
+    returning_precedes_values = True
+    
     extract_map = compiler.SQLCompiler.extract_map.copy()
     extract_map.update ({
         'doy': 'dayofyear',
index a47922cc517a959639aa8904c8b142ad7e9e6976..d6187bcde822b0c1c961eb5e2a0eb94c1cf5da27 100644 (file)
@@ -161,8 +161,17 @@ class SQLCompiler(engine.Compiled):
     # level to define if this Compiled instance represents
     # INSERT/UPDATE/DELETE
     isdelete = isinsert = isupdate = False
+    
+    # holds the "returning" collection of columns if
+    # the statement is CRUD and defines returning columns
+    # either implicitly or explicitly
     returning = None
     
+    # set to True classwide to generate RETURNING
+    # clauses before the VALUES or WHERE clause (i.e. MSSQL)
+    returning_precedes_values = False
+    
+    
     def __init__(self, dialect, statement, column_keys=None, inline=False, **kwargs):
         """Construct a new ``DefaultCompiler`` object.
 
@@ -699,10 +708,8 @@ class SQLCompiler(engine.Compiled):
             self.returning = self.returning or insert_stmt._returning
             returning_clause = self.returning_clause(insert_stmt, self.returning)
             
-            # cheating
-            if returning_clause.startswith("OUTPUT"):
+            if self.returning_precedes_values:
                 text += " " + returning_clause
-                returning_clause = None
 
         if not colparams and supports_default_values:
             text += " DEFAULT VALUES"
@@ -710,7 +717,7 @@ class SQLCompiler(engine.Compiled):
             text += " VALUES (%s)" % \
                      ', '.join([c[1] for c in colparams])
         
-        if self.returning and returning_clause:
+        if self.returning and not self.returning_precedes_values:
             text += " " + returning_clause
         
         return text
@@ -732,14 +739,14 @@ class SQLCompiler(engine.Compiled):
         if update_stmt._returning:
             self.returning = update_stmt._returning
             returning_clause = self.returning_clause(update_stmt, update_stmt._returning)
-            if returning_clause.startswith("OUTPUT"):
+            
+            if self.returning_precedes_values:
                 text += " " + returning_clause
-                returning_clause = None
                 
         if update_stmt._whereclause:
             text += " WHERE " + self.process(update_stmt._whereclause)
 
-        if self.returning and returning_clause:
+        if self.returning and not self.returning_precedes_values:
             text += " " + returning_clause
             
         self.stack.pop(-1)
@@ -755,6 +762,11 @@ class SQLCompiler(engine.Compiled):
         """create a set of tuples representing column/string pairs for use
         in an INSERT or UPDATE statement.
 
+        Also generates the Compiled object's postfetch, prefetch, and returning
+        column collections, used for default handling and ultimately
+        populating the ResultProxy's prefetch_cols() and postfetch_cols()
+        collections.
+
         """
 
         self.postfetch = []
@@ -880,14 +892,14 @@ class SQLCompiler(engine.Compiled):
         if delete_stmt._returning:
             self.returning = delete_stmt._returning
             returning_clause = self.returning_clause(delete_stmt, delete_stmt._returning)
-            if returning_clause.startswith("OUTPUT"):
+            
+            if self.returning_precedes_values:
                 text += " " + returning_clause
-                returning_clause = None
                 
         if delete_stmt._whereclause:
             text += " WHERE " + self.process(delete_stmt._whereclause)
 
-        if self.returning and returning_clause:
+        if self.returning and not self.returning_precedes_values:
             text += " " + returning_clause
             
         self.stack.pop(-1)
index 423310db62c274d0515de219dc3eaef73b6c7a09..5b4978f9ef2b4969cd174751ec161128b77e4fa5 100644 (file)
@@ -179,6 +179,20 @@ class CompileTest(TestBase, AssertsCompiledSQL):
         u = update(table1, values=dict(name='foo')).returning(func.length(table1.c.name))
         self.assert_compile(u, "UPDATE mytable SET name=:name OUTPUT LEN(inserted.name) AS length_1")
 
+    def test_delete_returning(self):
+        table1 = table('mytable',
+            column('myid', Integer),
+            column('name', String(128)),
+            column('description', String(128)),
+        )
+
+        d = delete(table1).returning(table1.c.myid, table1.c.name)
+        self.assert_compile(d, "DELETE FROM mytable OUTPUT deleted.myid, deleted.name")
+
+        d = delete(table1).where(table1.c.name=='bar').returning(table1.c.myid, table1.c.name)
+        self.assert_compile(d, "DELETE FROM mytable OUTPUT deleted.myid, deleted.name WHERE mytable.name = :name_1")
+
+
     def test_insert_returning(self):
         table1 = table('mytable',
             column('myid', Integer),