def _assert_active(self, prepared_ok=False,
rollback_ok=False,
+ deactive_ok=False,
closed_msg="This transaction is closed"):
if self._state is COMMITTED:
raise sa_exc.InvalidRequestError(
"SQL can be emitted within this transaction."
)
elif self._state is DEACTIVE:
- if not rollback_ok:
+ if not deactive_ok and not rollback_ok:
if self._rollback_exception:
raise sa_exc.InvalidRequestError(
"This Session's transaction has been rolled back "
" Original exception was: %s"
% self._rollback_exception
)
- else:
+ elif not deactive_ok:
raise sa_exc.InvalidRequestError(
"This Session's transaction has been rolled back "
"by a nested rollback() call. To begin a new "
return self
def __exit__(self, type, value, traceback):
- self._assert_active(prepared_ok=True)
+ self._assert_active(deactive_ok=True, prepared_ok=True)
if self.session.transaction is None:
return
if type is None:
-
+from __future__ import with_statement
from sqlalchemy.testing import eq_, assert_raises, \
assert_raises_message, assert_warnings
from sqlalchemy import *
conn.close()
raise
+
+
@testing.requires.savepoints
def test_heavy_nesting(self):
users = self.tables.users
synchronize_session='fetch')
self._run_test(update_fn)
+class ContextManagerTest(FixtureTest):
+ run_inserts = None
+
+ @testing.requires.savepoints
+ @engines.close_open_connections
+ def test_contextmanager_nested_rollback(self):
+ users, User = self.tables.users, self.classes.User
+
+ mapper(User, users)
+
+ sess = Session()
+ def go():
+ with sess.begin_nested():
+ sess.add(User()) # name can't be null
+ sess.flush()
+
+ # and not InvalidRequestError
+ assert_raises(
+ sa_exc.DBAPIError,
+ go
+ )
+
+ with sess.begin_nested():
+ sess.add(User(name='u1'))
+
+ eq_(sess.query(User).count(), 1)
+
+ def test_contextmanager_commit(self):
+ users, User = self.tables.users, self.classes.User
+
+ mapper(User, users)
+
+ sess = Session(autocommit=True)
+ with sess.begin():
+ sess.add(User(name='u1'))
+
+ sess.rollback()
+ eq_(sess.query(User).count(), 1)
+
+ def test_contextmanager_rollback(self):
+ users, User = self.tables.users, self.classes.User
+
+ mapper(User, users)
+
+ sess = Session(autocommit=True)
+ def go():
+ with sess.begin():
+ sess.add(User()) # name can't be null
+ assert_raises(
+ sa_exc.DBAPIError,
+ go
+ )
+
+ eq_(sess.query(User).count(), 0)
+
+ with sess.begin():
+ sess.add(User(name='u1'))
+ eq_(sess.query(User).count(), 1)
+
+
class AutoExpireTest(_LocalFixture):
def test_expunge_pending_on_rollback(self):