]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Fix: MSSQL set identity_insert and errors [ticket:538]
authorPaul Johnston <paj@pajhome.org.uk>
Sun, 25 Nov 2007 23:56:38 +0000 (23:56 +0000)
committerPaul Johnston <paj@pajhome.org.uk>
Sun, 25 Nov 2007 23:56:38 +0000 (23:56 +0000)
lib/sqlalchemy/databases/mssql.py

index d060458d4e9bc38d0e8e783c0dbb9b5074bcf148..098bd33c8972a4eb9bbd9510a8d447a2d5c1b667 100644 (file)
@@ -333,20 +333,15 @@ class MSSQLExecutionContext(default.DefaultExecutionContext):
         one column).
         """
         
-        if self.compiled.isinsert:
-            if self.IINSERT:
-                self.cursor.execute("SET IDENTITY_INSERT %s OFF" % self.dialect.identifier_preparer.format_table(self.compiled.statement.table))
-                self.IINSERT = False
-            elif self.HASIDENT:
-                if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None:
-                    if self.dialect.use_scope_identity:
-                        self.cursor.execute("SELECT scope_identity() AS lastrowid")
-                    else:
-                        self.cursor.execute("SELECT @@identity AS lastrowid")
-                    row = self.cursor.fetchone()
-                    self._last_inserted_ids = [int(row[0])] + self._last_inserted_ids[1:]
-                    # print "LAST ROW ID", self._last_inserted_ids
-            self.HASIDENT = False
+        if self.compiled.isinsert and self.HASIDENT and not self.IINSERT:
+            if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None:
+                if self.dialect.use_scope_identity:
+                    self.cursor.execute("SELECT scope_identity() AS lastrowid")
+                else:
+                    self.cursor.execute("SELECT @@identity AS lastrowid")
+                row = self.cursor.fetchone()
+                self._last_inserted_ids = [int(row[0])] + self._last_inserted_ids[1:]
+                # print "LAST ROW ID", self._last_inserted_ids
         super(MSSQLExecutionContext, self).post_exec()
     
     _ms_is_select = re.compile(r'\s*(?:SELECT|sp_columns)',
@@ -367,7 +362,7 @@ class MSSQLExecutionContext_pyodbc (MSSQLExecutionContext):
     def post_exec(self):
         if self.compiled.isinsert and self.HASIDENT and (not self.IINSERT) and self.dialect.use_scope_identity:
             # do nothing - id was fetched in dialect.do_execute()
-            self.HASIDENT = False
+            pass
         else:
             super(MSSQLExecutionContext_pyodbc, self).post_exec()
 
@@ -487,10 +482,21 @@ class MSSQLDialect(default.DefaultDialect):
     def last_inserted_ids(self):
         return self.context.last_inserted_ids
             
-    def do_execute(self, cursor, statement, params, **kwargs):
+    def do_execute(self, cursor, statement, params, context=None, **kwargs):
         if params == {}:
             params = ()
-        super(MSSQLDialect, self).do_execute(cursor, statement, params, **kwargs)
+        try:
+            super(MSSQLDialect, self).do_execute(cursor, statement, params, context=context, **kwargs)
+        finally:        
+            if context.IINSERT:
+                cursor.execute("SET IDENTITY_INSERT %s OFF" % self.identifier_preparer.format_table(context.compiled.statement.table))
+         
+    def do_executemany(self, cursor, statement, params, context=None, **kwargs):
+        try:
+            super(MSSQLDialect, self).do_executemany(cursor, statement, params, context=context, **kwargs)
+        finally:        
+            if context.IINSERT:
+                cursor.execute("SET IDENTITY_INSERT %s OFF" % self.identifier_preparer.format_table(context.compiled.statement.table))
 
     def _execute(self, c, statement, parameters):
         try: