]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
(no commit message)
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 17 Sep 2005 19:27:48 +0000 (19:27 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 17 Sep 2005 19:27:48 +0000 (19:27 +0000)
lib/sqlalchemy/attributes.py
lib/sqlalchemy/mapper.py
lib/sqlalchemy/objectstore.py
lib/sqlalchemy/util.py
test/attributes.py
test/objectstore.py

index 1e1cab893c8b12068d647c5de312e4d83fe4cc1a..bf28111df7f0e9aef1f3310f7b578961584341ef 100644 (file)
@@ -33,7 +33,7 @@ class ListElement(util.HistoryArraySet):
         self.obj = obj
         self.key = key
         util.HistoryArraySet.__init__(self, items)
-        print "listelement init"
+        obj.__dict__[key] = self.data
 
     def list_value_changed(self, obj, key, listval):
         pass    
@@ -73,7 +73,7 @@ class PropHistory(object):
         if self.orig is not PropHistory.NONE:
             self.obj.__dict__[self.key] = self.orig
             self.orig = PropHistory.NONE
-    def clear_history(self):
+    def commit(self):
         self.orig = PropHistory.NONE
     def added_items(self):
         if self.orig is not PropHistory.NONE:
@@ -134,21 +134,31 @@ class AttributeManager(object):
     def delete_list_attribute(self, obj, key):
         pass
         
-    def rollback(self, obj):
-        try:
-            attributes = self.attribute_history[obj]
-            for hist in attributes.values():
-                hist.rollback()
-        except KeyError:
-            pass
+    def rollback(self, obj = None):
+        if obj is None:
+            for attr in self.attribute_history.values():
+                for hist in attr.values():
+                    hist.rollback()
+        else:
+            try:
+                attributes = self.attribute_history[obj]
+                for hist in attributes.values():
+                    hist.rollback()
+            except KeyError:
+                pass
 
-    def clear_history(self, obj):
-        try:
-            attributes = self.attribute_history[obj]
-            for hist in attributes.values():
-                hist.clear_history()
-        except KeyError:
-            pass
+    def commit(self, obj = None):
+        if obj is None:
+            for attr in self.attribute_history.values():
+                for hist in attr.values():
+                    hist.commit()
+        else:
+            try:
+                attributes = self.attribute_history[obj]
+                for hist in attributes.values():
+                    hist.commit()
+            except KeyError:
+                pass
 
     def get_history(self, obj, key):
         try:
index b7f5fd53e6d9b28b55b2040d6e6dff68f27a6b86..98f225403cee0121e50c3c18b23d0815e5be7b16 100644 (file)
@@ -155,7 +155,7 @@ class Mapper(object):
 
     def init(self):
         [prop.init(key, self) for key, prop in self.props.iteritems()]
-        print "well hi!"
+        # TODO: get some notion of "primary mapper" going so multiple mappers dont collide
         self.class_._mapper = self.hashkey
 
     def instances(self, cursor, db = None):
@@ -713,7 +713,7 @@ class EagerLoader(PropertyLoader):
             result_list = []
             setattr(instance, self.key, result_list)
             result_list = getattr(instance, self.key)
-            result_list.clear_history()
+            result_list.commit()
         else:
             result_list = getattr(instance, self.key)
             
index d18dca90082e8710f37d778eeed4f82755ef2d16..8d46fd6b6b2607dc98abe4fb27962c80e557a01c 100644 (file)
@@ -184,7 +184,6 @@ class UnitOfWork(object):
         else:
             for obj in [n for n in self.new] + [d for d in self.dirty]:
                 commit_context.append_task(obj)
-                print "COMMIT append  " + obj.__class__.__name__ + " " + repr(id(obj))
             for item in self.modified_lists:
                 obj = item.obj
                 commit_context.append_task(obj)
@@ -201,16 +200,23 @@ class UnitOfWork(object):
         except:
             for e in engines:
                 e.rollback()
-                if self.parent:
-                    uow.set(self.parent)
+            if self.parent:
+                self.rollback()
             raise
         for e in engines:
             e.commit()
             
         commit_context.post_exec()
+        self.attributes.commit()
         
         if self.parent:
             uow.set(self.parent)
+
+    def rollback(self):
+        if not self.is_begun:
+            raise "UOW transaction is not begun"
+        self.attributes.rollback()
+        uow.set(self.parent)
             
 class UOWTransaction(object):
     def __init__(self, uow):
index 1575c2858282795cc974ce18091a35760e335190..ae0442c16a44451eed7670b44fd8b6f46721c031 100644 (file)
@@ -175,7 +175,7 @@ class HistoryArraySet(UserList.UserList):
                 del self.records[item]
         except KeyError:
             pass
-    def clear_history(self):
+    def commit(self):
         for key in self.records.keys():
             value = self.records[key]
             if value is False:
index 718f30214392707f8b45375453770fa54ad6e09c..7b3c82d7b60e14928c7d7af7468ffa3c20db7e07 100644 (file)
@@ -22,7 +22,7 @@ class AttributesTest(PersistTest):
         
         print repr(u.__dict__)
         self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com')
-        manager.clear_history(u)
+        manager.commit(u)
         print repr(u.__dict__)
         self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com')
 
@@ -58,7 +58,7 @@ class AttributesTest(PersistTest):
 
         print repr(u.__dict__)
         self.assert_(u.user_id == 7 and u.user_name == 'john' and u.addresses[0].email_address == 'lala@123.com')
-        manager.clear_history(u)
+        manager.commit()
         print repr(u.__dict__)
         self.assert_(u.user_id == 7 and u.user_name == 'john' and u.addresses[0].email_address == 'lala@123.com')
 
@@ -70,8 +70,9 @@ class AttributesTest(PersistTest):
         print repr(u.__dict__)
         self.assert_(u.user_id == 7 and u.user_name == 'heythere' and u.addresses[0].email_address == 'lala@123.com' and u.addresses[1].email_address == 'foo@bar.com')
 
-        manager.rollback(u)
+        manager.rollback()
         print repr(u.__dict__)
+        print repr(u.addresses[0].__dict__)
         self.assert_(u.user_id == 7 and u.user_name == 'john' and u.addresses[0].email_address == 'lala@123.com')
 
 if __name__ == "__main__":
index 0684e302ee9d2603e2da4600e171a1f5dd74821d..96839acccdb2a31d6ca7d9b03ac9b6ba0de6f3d3 100644 (file)
@@ -32,6 +32,7 @@ class HistoryTest(AssertMixin):
         u.addresses.append(Address())
         u.addresses[1].email_address = 'there'
         print repr(u.__dict__)
+        print repr(u.addresses)
         m.rollback(u)
         print repr(u.__dict__)