]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
mssql: preliminary support for using scope_identity() with pyodbc
authorPaul Johnston <paj@pajhome.org.uk>
Wed, 11 Jul 2007 18:51:44 +0000 (18:51 +0000)
committerPaul Johnston <paj@pajhome.org.uk>
Wed, 11 Jul 2007 18:51:44 +0000 (18:51 +0000)
CHANGES
lib/sqlalchemy/databases/mssql.py
test/testbase.py

diff --git a/CHANGES b/CHANGES
index 9fa41dfb2fd34ec8c41c4e08cf4a946c27d533db..9914f9295494e929360169aa3fc6ce4aa35851ee 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -97,6 +97,7 @@
 - mssql
     - fix port option handling for pyodbc [ticket:634]
     - now able to reflect start and increment values for identity columns
+    - preliminary support for using scope_identity() with pyodbc
 
 - extensions
     - added selectone_by() to assignmapper
index 934a644017fd56d013c8cfd75962df184a37c6dd..553e07a48aaa5da96bfbfe351984ad3e3d3e3a72 100644 (file)
@@ -291,6 +291,17 @@ class MSSQLExecutionContext_pyodbc (MSSQLExecutionContext):
             
         super(MSSQLExecutionContext_pyodbc, self).pre_exec()
 
+        # where appropriate, issue "select scope_identity()" in the same statement
+        if self.compiled.isinsert and self.HASIDENT and (not self.IINSERT) and self.dialect.use_scope_identity:
+            self.statement += "; select scope_identity()"
+
+    def post_exec(self):
+        if self.compiled.isinsert and self.HASIDENT and (not self.IINSERT) and self.dialect.use_scope_identity:
+            # do nothing - id was fetched in dialect.do_execute()
+            self.HASIDENT = False
+        else:
+            super(MSSQLExecutionContext_pyodbc, self).post_exec()
+
 
 class MSSQLDialect(ansisql.ANSIDialect):
     colspecs = {
@@ -709,11 +720,24 @@ class MSSQLDialect_pyodbc(MSSQLDialect):
         return [[";".join (connectors)], {}]
 
     def is_disconnect(self, e):
-        return isinstance(e, self.dbapi.Error) and '[08S01]' in e.args[1]
+        return isinstance(e, self.dbapi.Error) and '[08S01]' in str(e)
 
     def create_execution_context(self, *args, **kwargs):
         return MSSQLExecutionContext_pyodbc(self, *args, **kwargs)
 
+    def do_execute(self, cursor, statement, parameters, context=None, **kwargs):
+        super(MSSQLDialect_pyodbc, self).do_execute(cursor, statement, parameters, context=context, **kwargs)
+        if context and context.HASIDENT and (not context.IINSERT) and context.dialect.use_scope_identity:
+            import pyodbc
+            # fetch the last inserted id from the manipulated statement (pre_exec).
+            try:
+                row = cursor.fetchone()
+            except pyodbc.Error, e:
+                # if nocount OFF fetchone throws an exception and we have to jump over
+                # the rowcount to the resultset
+                cursor.nextset()
+                row = cursor.fetchone()
+            context._last_inserted_ids = [int(row[0])]
 
 class MSSQLDialect_adodbapi(MSSQLDialect):
     def import_dbapi(cls):
index 335a6953ddc0d6d856ccfcd649e789467598fa73..54fcd9db2361268b8218e2a59c30eeb38b8da30d 100644 (file)
@@ -344,6 +344,8 @@ class ExecutionContextWrapper(object):
                 parameters = [p.get_original_dict() for p in ctx.compiled_parameters]
                     
             query = self.convert_statement(query)
+            if db.engine.name == 'mssql' and statement.endswith('; select scope_identity()'):
+                statement = statement[:-25]
             testdata.unittest.assert_(statement == query and (params is None or params == parameters), "Testing for query '%s' params %s, received '%s' with params %s" % (query, repr(params), statement, repr(parameters)))
         testdata.sql_count += 1
         self.ctx.post_exec()