]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Implemented experimental savepoint support in mssql. There are still some failing...
authorMichael Trier <mtrier@gmail.com>
Thu, 11 Dec 2008 19:24:22 +0000 (19:24 +0000)
committerMichael Trier <mtrier@gmail.com>
Thu, 11 Dec 2008 19:24:22 +0000 (19:24 +0000)
CHANGES
lib/sqlalchemy/databases/mssql.py
test/orm/session.py
test/testlib/requires.py
test/testlib/testing.py

diff --git a/CHANGES b/CHANGES
index f0b3b18f12d93a1eb541a657f0a0b66c7f651719..5f85c4fa8457eeddd894a9f78004917895a6c9e1 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -161,6 +161,10 @@ CHANGES
     - Documented `comparator_factory` kwarg, added
       new doc section "Custom Comparators".
     
+- mssql
+    - Added experimental support of savepoints. It
+      currently does not work fully with sessions.
+
 - postgres
     - Calling alias.execute() in conjunction with
       server_side_cursors won't raise AttributeError.
index ac803cfd778606d9d012c1d8669642587cae822d..23ad925f2bcc04f46c22b3ad252881ce0f2dbc2e 100644 (file)
@@ -1000,6 +1000,16 @@ class MSSQLCompiler(compiler.DefaultCompiler):
         kwargs['mssql_aliased'] = True
         return super(MSSQLCompiler, self).visit_alias(alias, **kwargs)
 
+    def visit_savepoint(self, savepoint_stmt):
+        util.warn("Savepoint support in mssql is experimental and may lead to data loss.")
+        return "SAVE TRANSACTION %s" % self.preparer.format_savepoint(savepoint_stmt)
+
+    def visit_rollback_to_savepoint(self, savepoint_stmt):
+        return "ROLLBACK TRANSACTION %s" % self.preparer.format_savepoint(savepoint_stmt)
+
+    def visit_release_savepoint(self, savepoint_stmt):
+        pass
+
     def visit_column(self, column, result_map=None, **kwargs):
         if column.table is not None and \
             (not self.isupdate and not self.isdelete) or self.is_subquery():
index d72514a33446a07d2515e4166810b26e40a15ab0..57780e4e5b0ac0eb0cd7934cb1319e0ad3731139 100644 (file)
@@ -199,7 +199,7 @@ class SessionTest(_fixtures.FixtureTest):
         u2 = sess.query(User).filter_by(name='ed').one()
         assert u2 is u
         eq_(conn1.execute("select count(1) from users").scalar(), 1)
-        eq_(conn2.execute("select count(1) from users").scalar(),  0)
+        eq_(conn2.execute("select count(1) from users").scalar(), 0)
         sess.commit()
         eq_(conn1.execute("select count(1) from users").scalar(), 1)
         eq_(bind.connect().execute("select count(1) from users").scalar(), 1)
index 7b2d33beb501175c1d69501052090da5edccb96d..13d4cdf11ab5242603dddd3a1e21f9a0a37e2b78 100644 (file)
@@ -8,7 +8,8 @@ target database.
 from testlib.testing import \
      _block_unconditionally as no_support, \
      _chain_decorators_on, \
-     exclude
+     exclude, \
+     emits_warning_on
 
 
 def deferrable_constraints(fn):
@@ -66,8 +67,8 @@ def savepoints(fn):
     """Target database must support savepoints."""
     return _chain_decorators_on(
         fn,
+        emits_warning_on('mssql', 'Savepoint support in mssql is experimental and may lead to data loss.'),
         no_support('access', 'FIXME: guessing, needs confirmation'),
-        no_support('mssql', 'FIXME: guessing, needs confirmation'),
         no_support('sqlite', 'not supported by database'),
         no_support('sybase', 'FIXME: guessing, needs confirmation'),
         exclude('mysql', '<', (5, 0, 3), 'not supported by database'),
index 0bf083bbd0a524c161f4ea3f78635ee4fc5da94b..ed7669be99a03d629dc3cee19262984fc0956f3d 100644 (file)
@@ -322,6 +322,30 @@ def emits_warning(*messages):
         return _function_named(safe, fn.__name__)
     return decorate
 
+def emits_warning_on(db, *warnings):
+    """Mark a test as emitting a warning on a specific dialect.
+
+    With no arguments, squelches all SAWarning failures.  Or pass one or more
+    strings; these will be matched to the root of the warning description by
+    warnings.filterwarnings().
+    """
+    def decorate(fn):
+        def maybe(*args, **kw):
+            if isinstance(db, basestring):
+                if config.db.name != db:
+                    return fn(*args, **kw)
+                else:
+                    wrapped = emits_warning(*warnings)(fn)
+                    return wrapped(*args, **kw)
+            else:
+                if not _is_excluded(*db):
+                    return fn(*args, **kw)
+                else:
+                    wrapped = emits_warning(*warnings)(fn)
+                    return wrapped(*args, **kw)
+        return _function_named(maybe, fn.__name__)
+    return decorate
+
 def uses_deprecated(*messages):
     """Mark a test as immune from fatal deprecation warnings.