From 052d7f36433a0c29ab20d0ea37933c03a488e12d Mon Sep 17 00:00:00 2001 From: Michael Trier Date: Thu, 11 Dec 2008 19:24:22 +0000 Subject: [PATCH] Implemented experimental savepoint support in mssql. There are still some failing savepoint related tests. --- CHANGES | 4 ++++ lib/sqlalchemy/databases/mssql.py | 10 ++++++++++ test/orm/session.py | 2 +- test/testlib/requires.py | 5 +++-- test/testlib/testing.py | 24 ++++++++++++++++++++++++ 5 files changed, 42 insertions(+), 3 deletions(-) diff --git a/CHANGES b/CHANGES index f0b3b18f12..5f85c4fa84 100644 --- a/CHANGES +++ b/CHANGES @@ -161,6 +161,10 @@ CHANGES - Documented `comparator_factory` kwarg, added new doc section "Custom Comparators". +- mssql + - Added experimental support of savepoints. It + currently does not work fully with sessions. + - postgres - Calling alias.execute() in conjunction with server_side_cursors won't raise AttributeError. diff --git a/lib/sqlalchemy/databases/mssql.py b/lib/sqlalchemy/databases/mssql.py index ac803cfd77..23ad925f2b 100644 --- a/lib/sqlalchemy/databases/mssql.py +++ b/lib/sqlalchemy/databases/mssql.py @@ -1000,6 +1000,16 @@ class MSSQLCompiler(compiler.DefaultCompiler): kwargs['mssql_aliased'] = True return super(MSSQLCompiler, self).visit_alias(alias, **kwargs) + def visit_savepoint(self, savepoint_stmt): + util.warn("Savepoint support in mssql is experimental and may lead to data loss.") + return "SAVE TRANSACTION %s" % self.preparer.format_savepoint(savepoint_stmt) + + def visit_rollback_to_savepoint(self, savepoint_stmt): + return "ROLLBACK TRANSACTION %s" % self.preparer.format_savepoint(savepoint_stmt) + + def visit_release_savepoint(self, savepoint_stmt): + pass + def visit_column(self, column, result_map=None, **kwargs): if column.table is not None and \ (not self.isupdate and not self.isdelete) or self.is_subquery(): diff --git a/test/orm/session.py b/test/orm/session.py index d72514a334..57780e4e5b 100644 --- a/test/orm/session.py +++ b/test/orm/session.py @@ -199,7 +199,7 @@ class SessionTest(_fixtures.FixtureTest): u2 = sess.query(User).filter_by(name='ed').one() assert u2 is u eq_(conn1.execute("select count(1) from users").scalar(), 1) - eq_(conn2.execute("select count(1) from users").scalar(), 0) + eq_(conn2.execute("select count(1) from users").scalar(), 0) sess.commit() eq_(conn1.execute("select count(1) from users").scalar(), 1) eq_(bind.connect().execute("select count(1) from users").scalar(), 1) diff --git a/test/testlib/requires.py b/test/testlib/requires.py index 7b2d33beb5..13d4cdf11a 100644 --- a/test/testlib/requires.py +++ b/test/testlib/requires.py @@ -8,7 +8,8 @@ target database. from testlib.testing import \ _block_unconditionally as no_support, \ _chain_decorators_on, \ - exclude + exclude, \ + emits_warning_on def deferrable_constraints(fn): @@ -66,8 +67,8 @@ def savepoints(fn): """Target database must support savepoints.""" return _chain_decorators_on( fn, + emits_warning_on('mssql', 'Savepoint support in mssql is experimental and may lead to data loss.'), no_support('access', 'FIXME: guessing, needs confirmation'), - no_support('mssql', 'FIXME: guessing, needs confirmation'), no_support('sqlite', 'not supported by database'), no_support('sybase', 'FIXME: guessing, needs confirmation'), exclude('mysql', '<', (5, 0, 3), 'not supported by database'), diff --git a/test/testlib/testing.py b/test/testlib/testing.py index 0bf083bbd0..ed7669be99 100644 --- a/test/testlib/testing.py +++ b/test/testlib/testing.py @@ -322,6 +322,30 @@ def emits_warning(*messages): return _function_named(safe, fn.__name__) return decorate +def emits_warning_on(db, *warnings): + """Mark a test as emitting a warning on a specific dialect. + + With no arguments, squelches all SAWarning failures. Or pass one or more + strings; these will be matched to the root of the warning description by + warnings.filterwarnings(). + """ + def decorate(fn): + def maybe(*args, **kw): + if isinstance(db, basestring): + if config.db.name != db: + return fn(*args, **kw) + else: + wrapped = emits_warning(*warnings)(fn) + return wrapped(*args, **kw) + else: + if not _is_excluded(*db): + return fn(*args, **kw) + else: + wrapped = emits_warning(*warnings)(fn) + return wrapped(*args, **kw) + return _function_named(maybe, fn.__name__) + return decorate + def uses_deprecated(*messages): """Mark a test as immune from fatal deprecation warnings. -- 2.47.3