From e89f31e0df9c59fcd04e81642410b3d2e21b1520 Mon Sep 17 00:00:00 2001 From: Paul Johnston Date: Sun, 25 Nov 2007 23:56:38 +0000 Subject: [PATCH] Fix: MSSQL set identity_insert and errors [ticket:538] --- lib/sqlalchemy/databases/mssql.py | 40 ++++++++++++++++++------------- 1 file changed, 23 insertions(+), 17 deletions(-) diff --git a/lib/sqlalchemy/databases/mssql.py b/lib/sqlalchemy/databases/mssql.py index d060458d4e..098bd33c89 100644 --- a/lib/sqlalchemy/databases/mssql.py +++ b/lib/sqlalchemy/databases/mssql.py @@ -333,20 +333,15 @@ class MSSQLExecutionContext(default.DefaultExecutionContext): one column). """ - if self.compiled.isinsert: - if self.IINSERT: - self.cursor.execute("SET IDENTITY_INSERT %s OFF" % self.dialect.identifier_preparer.format_table(self.compiled.statement.table)) - self.IINSERT = False - elif self.HASIDENT: - if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None: - if self.dialect.use_scope_identity: - self.cursor.execute("SELECT scope_identity() AS lastrowid") - else: - self.cursor.execute("SELECT @@identity AS lastrowid") - row = self.cursor.fetchone() - self._last_inserted_ids = [int(row[0])] + self._last_inserted_ids[1:] - # print "LAST ROW ID", self._last_inserted_ids - self.HASIDENT = False + if self.compiled.isinsert and self.HASIDENT and not self.IINSERT: + if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None: + if self.dialect.use_scope_identity: + self.cursor.execute("SELECT scope_identity() AS lastrowid") + else: + self.cursor.execute("SELECT @@identity AS lastrowid") + row = self.cursor.fetchone() + self._last_inserted_ids = [int(row[0])] + self._last_inserted_ids[1:] + # print "LAST ROW ID", self._last_inserted_ids super(MSSQLExecutionContext, self).post_exec() _ms_is_select = re.compile(r'\s*(?:SELECT|sp_columns)', @@ -367,7 +362,7 @@ class MSSQLExecutionContext_pyodbc (MSSQLExecutionContext): def post_exec(self): if self.compiled.isinsert and self.HASIDENT and (not self.IINSERT) and self.dialect.use_scope_identity: # do nothing - id was fetched in dialect.do_execute() - self.HASIDENT = False + pass else: super(MSSQLExecutionContext_pyodbc, self).post_exec() @@ -487,10 +482,21 @@ class MSSQLDialect(default.DefaultDialect): def last_inserted_ids(self): return self.context.last_inserted_ids - def do_execute(self, cursor, statement, params, **kwargs): + def do_execute(self, cursor, statement, params, context=None, **kwargs): if params == {}: params = () - super(MSSQLDialect, self).do_execute(cursor, statement, params, **kwargs) + try: + super(MSSQLDialect, self).do_execute(cursor, statement, params, context=context, **kwargs) + finally: + if context.IINSERT: + cursor.execute("SET IDENTITY_INSERT %s OFF" % self.identifier_preparer.format_table(context.compiled.statement.table)) + + def do_executemany(self, cursor, statement, params, context=None, **kwargs): + try: + super(MSSQLDialect, self).do_executemany(cursor, statement, params, context=context, **kwargs) + finally: + if context.IINSERT: + cursor.execute("SET IDENTITY_INSERT %s OFF" % self.identifier_preparer.format_table(context.compiled.statement.table)) def _execute(self, c, statement, parameters): try: -- 2.47.2