]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
(no commit message)
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 17 Sep 2005 07:49:31 +0000 (07:49 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 17 Sep 2005 07:49:31 +0000 (07:49 +0000)
lib/sqlalchemy/objectstore.py
test/mapper.py

index a78dafc4c9ec692c01ef58211e15d46d252bee67..dffa4afa217375e1ea7e60dd64c875f556590cca 100644 (file)
@@ -114,7 +114,8 @@ class UOWListElement(util.HistoryArraySet):
         return res
     
 class UnitOfWork(object):
-    def __init__(self, parent = None):
+    def __init__(self, parent = None, is_begun = False):
+        self.is_begun = is_begun
         self.new = util.HashSet()
         self.dirty = util.HashSet()
         self.modified_lists = util.HashSet()
@@ -135,7 +136,13 @@ class UnitOfWork(object):
             obj.__dict__[key] = v
             self.register_attribute(obj, key).setattr_clean(v)
         return v
-        
+    
+    def rollback_attribute(self, obj, key):
+        if self.attribute_history.has_key(obj):
+            h = self.attribute_history[obj][key]
+            h.rollback()
+            obj.__dict__[key] = h.current
+            
     def set_attribute(self, obj, key, value, usehistory = False):
         if usehistory:
             self.register_attribute(obj, key).setattr(value)
@@ -153,7 +160,18 @@ class UnitOfWork(object):
             self.register_dirty(obj)
         else:
             self.register_new(obj)
-        
+    
+    def rollback_obj(self, obj):
+        try:
+            attributes = self.attribute_history[obj]
+            for key, hist in attributes.iteritems():
+                hist.rollback()
+                obj.__dict__[key] = hist.current
+        except KeyError:
+            pass
+        for value in obj.__dict__.values():
+            if isinstance(value, util.HistoryArraySet):
+                value.rollback()
     def register_attribute(self, obj, key):
         try:
             attributes = self.attribute_history[obj]
@@ -166,24 +184,36 @@ class UnitOfWork(object):
 
     def register_list_attribute(self, obj, key, data = None):
         try:
-            childlist = obj.__dict__[key]
+            attributes = self.attribute_history[obj]
         except KeyError:
-            childlist = UOWListElement(obj)
-            obj.__dict__[key] = childlist
-        
-        if callable(childlist):
-            childlist = UOWListElement(obj, childlist())
-            obj.__dict__[key] = childlist                        
-        elif not isinstance(childlist, util.HistoryArraySet):
-            childlist = UOWListElement(obj, childlist)
-            obj.__dict__[key] = childlist
+            attributes = self.attribute_history.setdefault(obj, {})
+        try:
+            childlist = attributes[key]
+        except KeyError:
+            try:
+                list = obj.__dict__[key]
+                if callable(list):
+                    list = list()
+            except KeyError:
+                list = []
+                obj.__dict__[key] = list
+
+            childlist = UOWListElement(obj, list)
+            
         if data is not None and childlist.data != data:
             try:
                 childlist.set_data(data)
             except TypeError:
                 raise "object " + repr(data) + " is not an iterable object"
         return childlist
-        
+    
+    def rollback_list_attribute(self, obj, key):
+        try:
+            childlist = obj.__dict__[key]
+            if isinstance(childlist, util.HistoryArraySet):
+                childlist.rollback()
+        except KeyError:
+            pass    
     def register_clean(self, obj, scope="thread"):
         try:
             del self.dirty[obj]
@@ -202,7 +232,7 @@ class UnitOfWork(object):
         
     def register_dirty(self, obj):
         self.dirty.append(obj)
-
+            
     def is_dirty(self, obj):
         if not self.dirty.contains(obj):
             return False
@@ -214,14 +244,14 @@ class UnitOfWork(object):
 
     # TODO: tie in register_new/register_dirty with table transaction begins ?
     def begin(self):
-        u = UnitOfWork(self)
+        u = UnitOfWork(self, True)
         uow.set(u)
         
     def commit(self, *objects):
         import sqlalchemy.mapper
 
         commit_context = UOWTransaction(self)
-        
+
         if len(objects):
             for obj in objects:
                 commit_context.append_task(obj)
@@ -232,9 +262,25 @@ class UnitOfWork(object):
                 obj = item.obj()
                 commit_context.append_task(obj)
 
-        commit_context.execute()                   
-
-        # TODO: deleted stuff
+        engines = util.HashSet()
+        for mapper in commit_context.mappers.keys():
+            for e in mapper.engines:
+                engines.append(e)
+                
+        for e in engines:
+            e.begin()
+        try:
+            commit_context.execute()
+        except:
+            for e in engines:
+                e.rollback()
+                if self.parent:
+                    uow.set(self.parent)
+            raise
+        for e in engines:
+            e.commit()
+            
+        commit_context.post_exec()
         
         if self.parent:
             uow.set(self.parent)
@@ -243,6 +289,7 @@ class UOWTransaction(object):
     def __init__(self, uow):
         self.uow = uow
         self.mappers = {}
+        self.engines = util.HashSet()
         self.dependencies = {}
         self.tasks = {}
         self.saved_objects = util.HashSet()
@@ -295,17 +342,14 @@ class UOWTransaction(object):
                 return 0
         mapperlist.sort(compare)
 
-        try:
-            # TODO: db tranasction boundary
-            for task in mapperlist:
-                obj_list = task.objects
-                task.mapper.save_obj(obj_list, self)
-                for dep in task.dependencies:
-                    (processor, stuff_to_process) = dep
-                    processor.process_dependencies(stuff_to_process, self)
-        except:
-            raise
+        for task in mapperlist:
+            obj_list = task.objects
+            task.mapper.save_obj(obj_list, self)
+            for dep in task.dependencies:
+                (processor, stuff_to_process) = dep
+                processor.process_dependencies(stuff_to_process, self)
 
+    def post_exec(self):
         for obj in self.saved_objects:
             mapper = self.object_mapper(obj)
             obj._instance_key = mapper.identity_key(obj)
index 2d457542d28599d1025c1aea949628952c1b9256..e2db30ca8c3a86540412ee1678f2d0f850afcb19 100644 (file)
@@ -5,6 +5,7 @@ import sqlalchemy.objectstore as objectstore
 
 #ECHO = True
 ECHO = False
+DATA = True
 execfile("test/tables.py")
 db.echo = True