]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Modified savepoint logic in mssql to ensure that it does not step on non-savepoint...
authorMichael Trier <mtrier@gmail.com>
Tue, 28 Apr 2009 03:35:35 +0000 (03:35 +0000)
committerMichael Trier <mtrier@gmail.com>
Tue, 28 Apr 2009 03:35:35 +0000 (03:35 +0000)
CHANGES
lib/sqlalchemy/databases/mssql.py

diff --git a/CHANGES b/CHANGES
index 22d4c98434719b59c84b959f3f18dd0fe0519e98..c3dc89916f4df464ce8521f9aecdfea972d7e8f1 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -69,6 +69,10 @@ CHANGES
       construct (i.e. declarative columns).  [ticket:1353]
 
 - mssql
+    - Modified how savepoint logic works to prevent it from
+      stepping on non-savepoint oriented routines. Savepoint
+      support is still very experimental.
+
     - Added in reserved words for MSSQL that covers version 2008
       and all prior versions. [ticket:1310]
 
index 0442ddfcaf5f97ee21525f8e5686af6be57f3e51..ce39df94c727bf7fa6d2e5e21bb088373f011c14 100644 (file)
@@ -1147,10 +1147,10 @@ class MSSQLDialect(default.DefaultDialect):
             newobj.dialect = self
         return newobj
 
-    def do_begin(self, connection):
-        cursor = connection.cursor()
-        cursor.execute("SET IMPLICIT_TRANSACTIONS OFF")
-        cursor.execute("BEGIN TRANSACTION")
+    def do_savepoint(self, connection, name):
+        util.warn("Savepoint support in mssql is experimental and may lead to data loss.")
+        connection.execute("IF @@TRANCOUNT = 0 BEGIN TRANSACTION")
+        connection.execute("SAVE TRANSACTION %s" % name)
 
     def do_release_savepoint(self, connection, name):
         pass
@@ -1627,10 +1627,6 @@ class MSSQLCompiler(compiler.DefaultCompiler):
         field = self.extract_map.get(extract.field, extract.field)
         return 'DATEPART("%s", %s)' % (field, self.process(extract.expr))
 
-    def visit_savepoint(self, savepoint_stmt):
-        util.warn("Savepoint support in mssql is experimental and may lead to data loss.")
-        return "SAVE TRANSACTION %s" % self.preparer.format_savepoint(savepoint_stmt)
-
     def visit_rollback_to_savepoint(self, savepoint_stmt):
         return "ROLLBACK TRANSACTION %s" % self.preparer.format_savepoint(savepoint_stmt)