]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
added "nest_on" option for Session, so nested transactions can occur mostly at the...
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 22 Mar 2006 01:16:16 +0000 (01:16 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 22 Mar 2006 01:16:16 +0000 (01:16 +0000)
fixes [ticket:113]

lib/sqlalchemy/engine.py
lib/sqlalchemy/mapping/mapper.py
lib/sqlalchemy/mapping/objectstore.py
test/objectstore.py

index f176e5a18050f260ecd8b95d152665cd49725eca..5c14e7a3cd1af27280183a9e81c60777fb65f377 100644 (file)
@@ -180,6 +180,8 @@ class SQLSession(object):
         if parent is not None:
             self.__connection = self.engine._pool.unique_connection()
         self.__tcount = 0
+    def pop(self):
+        self.engine.pop_session(self)
     def _connection(self):
         try:
             return self.__transaction
@@ -448,11 +450,13 @@ class SQLEngine(schema.SchemaEngine):
         sess = SQLSession(self, self.context.session)
         self.context.session = sess
         return sess
-    def pop_session(self):
+    def pop_session(self, s = None):
         """restores the current thread's SQLSession to that before the last push_session.  Returns the restored SQLSession object.  Raises an exception if there is no SQLSession pushed onto the stack."""
         sess = self.context.session.parent
         if sess is None:
-            raise InvalidRequestError("No SQLSession is pushed onto the stack.")
+            raise exceptions.InvalidRequestError("No SQLSession is pushed onto the stack.")
+        elif s is not None and s is not self.context.session:
+            raise exceptions.InvalidRequestError("Given SQLSession is not the current session on the stack")
         self.context.session = sess
         return sess
         
index a46064e6f46b64aa31e4954c5f6f69f49402f03f..5e0f257386761d013ebc1d82e17571a623d78571 100644 (file)
@@ -350,6 +350,8 @@ class Mapper(object):
         """returns a proxying object to this mapper, which will execute methods on the mapper
         within the context of the given session.  The session is placed as the "current" session
         via the push_session/pop_session methods in the objectstore module."""
+        if objectstore.get_session() is session:
+            return self
         mapper = self
         class Proxy(object):
             def __getattr__(self, key):
index 9b575ce105905efeefe1ebcbe34e67637c774ac3..f978d16f727c6f68b8a703393e889221c88fc75a 100644 (file)
@@ -17,7 +17,7 @@ import sqlalchemy
 class Session(object):
     """Maintains a UnitOfWork instance, including transaction state."""
     
-    def __init__(self, nest_transactions=False, hash_key=None):
+    def __init__(self, nest_on=None, hash_key=None):
         """Initialize the objectstore with a UnitOfWork registry.  If called
         with no arguments, creates a single UnitOfWork for all operations.
         
@@ -29,13 +29,28 @@ class Session(object):
         self.uow = unitofwork.UnitOfWork()
         self.parent_uow = None
         self.begin_count = 0
-        self.nest_transactions = nest_transactions
+        self.nest_on = util.to_list(nest_on)
+        self.__pushed_count = 0
         if hash_key is None:
             self.hash_key = id(self)
         else:
             self.hash_key = hash_key
         _sessions[self.hash_key] = self
-        
+    
+    def was_pushed(self):
+        if self.nest_on is None:
+            return
+        self.__pushed_count += 1
+        if self.__pushed_count == 1:
+            for n in self.nest_on:
+                n.push_session()
+    def was_popped(self):
+        if self.nest_on is None or self.__pushed_count == 0:
+            return
+        self.__pushed_count -= 1
+        if self.__pushed_count == 0:
+            for n in self.nest_on:
+                n.pop_session()
     def get_id_key(ident, class_):
         """returns an identity-map key for use in storing/retrieving an item from the identity
         map, given a tuple of the object's primary key values.
@@ -108,7 +123,7 @@ class Session(object):
     def _trans_commit(self, trans):
         if trans.uow is self.uow and trans.isactive:
             try:
-                self.uow.commit()
+                self._commit_uow()
             finally:
                 self.uow = self.parent_uow
                 self.parent_uow = None
@@ -116,6 +131,13 @@ class Session(object):
         if trans.uow is self.uow:
             self.uow = self.parent_uow
             self.parent_uow = None
+
+    def _commit_uow(self, *obj):
+        self.was_pushed()
+        try:
+            self.uow.commit(*obj)
+        finally:
+            self.was_popped()
                         
     def commit(self, *objects):
         """commits the current UnitOfWork transaction.  called with
@@ -126,11 +148,12 @@ class Session(object):
         # if an object list is given, commit just those but dont
         # change begin/commit status
         if len(objects):
+            self._commit_uow(*objects)
             self.uow.commit(*objects)
             return
         if self.parent_uow is None:
-            self.uow.commit()
-
+            self._commit_uow()
+            
     def refresh(self, *obj):
         """reloads the attributes for the given objects from the database, clears
         any changes made."""
@@ -287,14 +310,18 @@ uow = get_session # deprecated
 
 def push_session(sess):
     old = get_session()
+    if getattr(sess, '_previous', None) is not None:
+        raise InvalidRequestError("Given Session is already pushed onto some thread's stack")
     sess._previous = old
     session_registry.set(sess)
+    sess.was_pushed()
     
 def pop_session():
     sess = get_session()
     old = sess._previous
     sess._previous = None
     session_registry.set(old)
+    sess.was_popped()
     return old
     
 def using_session(sess, func):
index 6a3a16f7786c522d4ffea2c098bba447663f9044..73bbc1da432f3cbdb399e6890b65038421a9d1a1 100644 (file)
@@ -89,7 +89,8 @@ class SessionTest(AssertMixin):
         tables.delete_user_data()
         
     def test_nested_begin_commit(self):
-        """test nested session.begin/commit"""
+        """tests that nesting objectstore transactions with multiple commits
+        affects only the outermost transaction"""
         class User(object):pass
         m = mapper(User, users)
         def name_of(id):
@@ -117,6 +118,8 @@ class SessionTest(AssertMixin):
         self.assert_(name_of(8) == name2, msg="user_name should be %s" % name2)
 
     def test_nested_rollback(self):
+        """tests that nesting objectstore transactions with a rollback inside
+        affects only the outermost transaction"""
         class User(object):pass
         m = mapper(User, users)
         def name_of(id):
@@ -141,6 +144,32 @@ class SessionTest(AssertMixin):
         self.assert_(name_of(7) != name1, msg="user_name should not be %s" % name1)
         self.assert_(name_of(8) != name2, msg="user_name should not be %s" % name2)
 
+    def test_true_nested(self):
+        """tests creating a new Session inside a database transaction, in 
+        conjunction with an engine-level nested transaction, which uses
+        a second connection in order to achieve a nested transaction that commits, inside
+        of another engine session that rolls back."""
+#        testbase.db.echo='debug'
+        class User(object):
+            pass
+        testbase.db.begin()
+        try:
+            m = mapper(User, users)
+            name1 = "Oliver Twist"
+            name2 = 'Mr. Bumble'
+            m.get(7).user_name = name1
+            s = objectstore.Session(nest_on=testbase.db)
+            m.using(s).get(8).user_name = name2
+            s.commit()
+            objectstore.commit()
+            testbase.db.rollback()
+        except:
+            testbase.db.rollback()
+            raise
+        objectstore.clear()
+        self.assert_(m.get(8).user_name == name2)
+        self.assert_(m.get(7).user_name != name1)
+        
 class UnicodeTest(AssertMixin):
     def setUpAll(self):
         global uni_table