From 4f39c956da1a877e05b2d201dd8e756d9d39b7c2 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Wed, 22 Mar 2006 01:16:16 +0000 Subject: [PATCH] added "nest_on" option for Session, so nested transactions can occur mostly at the Session level, fixes [ticket:113] --- lib/sqlalchemy/engine.py | 8 ++++-- lib/sqlalchemy/mapping/mapper.py | 2 ++ lib/sqlalchemy/mapping/objectstore.py | 39 ++++++++++++++++++++++----- test/objectstore.py | 31 ++++++++++++++++++++- 4 files changed, 71 insertions(+), 9 deletions(-) diff --git a/lib/sqlalchemy/engine.py b/lib/sqlalchemy/engine.py index f176e5a180..5c14e7a3cd 100644 --- a/lib/sqlalchemy/engine.py +++ b/lib/sqlalchemy/engine.py @@ -180,6 +180,8 @@ class SQLSession(object): if parent is not None: self.__connection = self.engine._pool.unique_connection() self.__tcount = 0 + def pop(self): + self.engine.pop_session(self) def _connection(self): try: return self.__transaction @@ -448,11 +450,13 @@ class SQLEngine(schema.SchemaEngine): sess = SQLSession(self, self.context.session) self.context.session = sess return sess - def pop_session(self): + def pop_session(self, s = None): """restores the current thread's SQLSession to that before the last push_session. Returns the restored SQLSession object. Raises an exception if there is no SQLSession pushed onto the stack.""" sess = self.context.session.parent if sess is None: - raise InvalidRequestError("No SQLSession is pushed onto the stack.") + raise exceptions.InvalidRequestError("No SQLSession is pushed onto the stack.") + elif s is not None and s is not self.context.session: + raise exceptions.InvalidRequestError("Given SQLSession is not the current session on the stack") self.context.session = sess return sess diff --git a/lib/sqlalchemy/mapping/mapper.py b/lib/sqlalchemy/mapping/mapper.py index a46064e6f4..5e0f257386 100644 --- a/lib/sqlalchemy/mapping/mapper.py +++ b/lib/sqlalchemy/mapping/mapper.py @@ -350,6 +350,8 @@ class Mapper(object): """returns a proxying object to this mapper, which will execute methods on the mapper within the context of the given session. The session is placed as the "current" session via the push_session/pop_session methods in the objectstore module.""" + if objectstore.get_session() is session: + return self mapper = self class Proxy(object): def __getattr__(self, key): diff --git a/lib/sqlalchemy/mapping/objectstore.py b/lib/sqlalchemy/mapping/objectstore.py index 9b575ce105..f978d16f72 100644 --- a/lib/sqlalchemy/mapping/objectstore.py +++ b/lib/sqlalchemy/mapping/objectstore.py @@ -17,7 +17,7 @@ import sqlalchemy class Session(object): """Maintains a UnitOfWork instance, including transaction state.""" - def __init__(self, nest_transactions=False, hash_key=None): + def __init__(self, nest_on=None, hash_key=None): """Initialize the objectstore with a UnitOfWork registry. If called with no arguments, creates a single UnitOfWork for all operations. @@ -29,13 +29,28 @@ class Session(object): self.uow = unitofwork.UnitOfWork() self.parent_uow = None self.begin_count = 0 - self.nest_transactions = nest_transactions + self.nest_on = util.to_list(nest_on) + self.__pushed_count = 0 if hash_key is None: self.hash_key = id(self) else: self.hash_key = hash_key _sessions[self.hash_key] = self - + + def was_pushed(self): + if self.nest_on is None: + return + self.__pushed_count += 1 + if self.__pushed_count == 1: + for n in self.nest_on: + n.push_session() + def was_popped(self): + if self.nest_on is None or self.__pushed_count == 0: + return + self.__pushed_count -= 1 + if self.__pushed_count == 0: + for n in self.nest_on: + n.pop_session() def get_id_key(ident, class_): """returns an identity-map key for use in storing/retrieving an item from the identity map, given a tuple of the object's primary key values. @@ -108,7 +123,7 @@ class Session(object): def _trans_commit(self, trans): if trans.uow is self.uow and trans.isactive: try: - self.uow.commit() + self._commit_uow() finally: self.uow = self.parent_uow self.parent_uow = None @@ -116,6 +131,13 @@ class Session(object): if trans.uow is self.uow: self.uow = self.parent_uow self.parent_uow = None + + def _commit_uow(self, *obj): + self.was_pushed() + try: + self.uow.commit(*obj) + finally: + self.was_popped() def commit(self, *objects): """commits the current UnitOfWork transaction. called with @@ -126,11 +148,12 @@ class Session(object): # if an object list is given, commit just those but dont # change begin/commit status if len(objects): + self._commit_uow(*objects) self.uow.commit(*objects) return if self.parent_uow is None: - self.uow.commit() - + self._commit_uow() + def refresh(self, *obj): """reloads the attributes for the given objects from the database, clears any changes made.""" @@ -287,14 +310,18 @@ uow = get_session # deprecated def push_session(sess): old = get_session() + if getattr(sess, '_previous', None) is not None: + raise InvalidRequestError("Given Session is already pushed onto some thread's stack") sess._previous = old session_registry.set(sess) + sess.was_pushed() def pop_session(): sess = get_session() old = sess._previous sess._previous = None session_registry.set(old) + sess.was_popped() return old def using_session(sess, func): diff --git a/test/objectstore.py b/test/objectstore.py index 6a3a16f778..73bbc1da43 100644 --- a/test/objectstore.py +++ b/test/objectstore.py @@ -89,7 +89,8 @@ class SessionTest(AssertMixin): tables.delete_user_data() def test_nested_begin_commit(self): - """test nested session.begin/commit""" + """tests that nesting objectstore transactions with multiple commits + affects only the outermost transaction""" class User(object):pass m = mapper(User, users) def name_of(id): @@ -117,6 +118,8 @@ class SessionTest(AssertMixin): self.assert_(name_of(8) == name2, msg="user_name should be %s" % name2) def test_nested_rollback(self): + """tests that nesting objectstore transactions with a rollback inside + affects only the outermost transaction""" class User(object):pass m = mapper(User, users) def name_of(id): @@ -141,6 +144,32 @@ class SessionTest(AssertMixin): self.assert_(name_of(7) != name1, msg="user_name should not be %s" % name1) self.assert_(name_of(8) != name2, msg="user_name should not be %s" % name2) + def test_true_nested(self): + """tests creating a new Session inside a database transaction, in + conjunction with an engine-level nested transaction, which uses + a second connection in order to achieve a nested transaction that commits, inside + of another engine session that rolls back.""" +# testbase.db.echo='debug' + class User(object): + pass + testbase.db.begin() + try: + m = mapper(User, users) + name1 = "Oliver Twist" + name2 = 'Mr. Bumble' + m.get(7).user_name = name1 + s = objectstore.Session(nest_on=testbase.db) + m.using(s).get(8).user_name = name2 + s.commit() + objectstore.commit() + testbase.db.rollback() + except: + testbase.db.rollback() + raise + objectstore.clear() + self.assert_(m.get(8).user_name == name2) + self.assert_(m.get(7).user_name != name1) + class UnicodeTest(AssertMixin): def setUpAll(self): global uni_table -- 2.47.2