]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
(no commit message)
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 16 Sep 2005 03:52:21 +0000 (03:52 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 16 Sep 2005 03:52:21 +0000 (03:52 +0000)
lib/sqlalchemy/objectstore.py

index cd6cbb9bf88cfdb64ac44327e1d01ba024df31dc..cb12a3dfca906baf805fa9a9373265e853918983 100644 (file)
@@ -37,20 +37,6 @@ def get_id_key(ident, class_, table):
     return value: a tuple object which is used as an identity key.
     """
     return (class_, table, tuple(ident))
-def get_instance_key(object, class_, table, primary_keys, mapper):
-    """returns an identity-map key for use in storing/retrieving an item from the identity map, given
-    the object instance itself.
-    
-    object - the object to be stored.  it is assumed that the object's primary key attributes are
-    populated.
-    class_ - a reference to the object's class
-    table - a Table object where the object's primary fields are stored.
-    selectable - a Selectable object which represents all the object's column-based fields.  this Selectable
-    may be synonymous with the table argument or can be a larger construct containing that table.
-    return value: a tuple object which is used as an identity key.
-    """
-    # TODO: clean this up, too many args, too confusing
-    return (class_, table, tuple([mapper._getattrbycolumn(object, column) for column in primary_keys]))
 def get_row_key(row, class_, table, primary_keys):
     """returns an identity-map key for use in storing/retrieving an item from the identity map, given
     a result set row.
@@ -111,34 +97,31 @@ def has_key(key):
 
 class UOWListElement(util.HistoryArraySet):
     """overrides HistoryArraySet to mark the parent object as dirty when changes occur"""
-    class listpointer(object): pass
         
     def __init__(self, obj, items = None):
         util.HistoryArraySet.__init__(self, items)
         self.obj = weakref.ref(obj)
-        
-        # cant hash a UserList, so make a bullshit pointer to us
-        self.listpointer = UOWListElement.listpointer()
-        self.listpointer.list = self
+        self.list = UOWListElement        
         
     def _setrecord(self, item):
         res = util.HistoryArraySet._setrecord(self, item)
         if res:
-            uow().modified_lists.append(self.listpointer)
+            uow().modified_lists.append(self.list)
         return res
     def _delrecord(self, item):
         res = util.HistoryArraySet._delrecord(self, item)
         if res:
-            uow().modified_lists.append(self.listpointer)
+            uow().modified_lists.append(self.list)
         return res
     
 class UnitOfWork(object):
-    def __init__(self):
+    def __init__(self, parent = None):
         self.new = util.HashSet()
         self.dirty = util.HashSet()
         self.modified_lists = util.HashSet()
         self.deleted = util.HashSet()
         self.attribute_history = weakref.WeakKeyDictionary()
+        self.parent = parent
         
     def attribute_set_callable(self, obj, key, func):
         obj.__dict__[key] = func
@@ -223,31 +206,31 @@ class UnitOfWork(object):
     def register_deleted(self, obj):
         pass   
 
-    # TODO: add begin().  tie in register_new/register_dirty with table transaction begins ?
-    
-    # TODO: add optional args for "i only want to save/delete these objects, not the whole thing"
-    def commit(self):
-        import sqlalchemy.mapper
+    # TODO: tie in register_new/register_dirty with table transaction begins ?
+    def begin(self):
+        u = UnitOfWork(self)
+        uow.set(u)
         
-        self.dependencies = {}
-        self.tasks = {}
+    def commit(self, *objects):
+        import sqlalchemy.mapper
+
+        self.commit_context = UOWTransaction()
         
-        for obj in [n for n in self.new] + [d for d in self.dirty]:
-            mapper = sqlalchemy.mapper.object_mapper(obj)
-            task = self.get_task_by_mapper(mapper)
-            task.objects.append(obj)
+        if len(objects):
+            for obj in objects:
+                self.commit_context.append_task(obj)
+        else:
+            for obj in [n for n in self.new] + [d for d in self.dirty]:
+                self.commit_context.append_task(obj)
 
-        for item in self.modified_lists:
-            item = item.list
-            obj = item.obj()
-            mapper = sqlalchemy.mapper.object_mapper(obj)
-            task = self.get_task_by_mapper(mapper)
-            task.lists.append(obj)
+            for item in self.modified_lists:
+                obj = item.obj()
+                self.commit_context.append_task(obj)
             
-        for task in self.tasks.values():
+        for task in self.commit_context.tasks.values():
             task.mapper.register_dependencies(util.HashSet(task.objects + task.lists), self)
             
-        mapperlist = self.tasks.values()
+        mapperlist = self.commit_context.tasks.values()
         def compare(a, b):
             if self.dependencies.has_key((a.mapper, b.mapper)):
                 return -1
@@ -256,41 +239,68 @@ class UnitOfWork(object):
             else:
                 return 0
         mapperlist.sort(compare)
-        
-        for task in mapperlist:
-            obj_list = task.objects
-            task.mapper.save_obj(obj_list)
-            for dep in task.dependencies:
-                (processor, stuff_to_process) = dep
-                processor.process_dependencies(stuff_to_process, self)
 
-        for obj in self.new:
-            mapper = sqlalchemy.mapper.object_mapper(obj)
-            mapper.put(obj)
-        self.new.clear()
-        self.dirty.clear()
-        for item in self.modified_lists:
-            item = item.list
-            item.clear_history()
-        self.modified_lists.clear()
+        try:
+            # TODO: db tranasction boundary
+            for task in mapperlist:
+                obj_list = task.objects
+                task.mapper.save_obj(obj_list, self)
+                for dep in task.dependencies:
+                    (processor, stuff_to_process) = dep
+                    processor.process_dependencies(stuff_to_process, self)
+        except:
+            raise
+            
+        for obj in self.commit_context.saved_objects:
+            if self.new.contains(obj):
+                mapper = sqlalchemy.mapper.object_mapper(obj)
+                mapper.put(obj)
+                del self.new[obj]
+            elif self.dirty.contains(obj):
+                del self.dirty[obj]
+
+        for obj in self.commit_context.saved_lists:
+            del self.modified_lists[obj]
 
-        self.tasks.clear()
-        self.dependencies.clear()
+        self.commit_context = None
         # TODO: deleted stuff
+        
+        if self.parent:
+            uow.set(self.parent)
+            
+    def register_saved_object(self, obj):
+        self.commit_context.saved_objects.append(obj)
 
+    def register_saved_list(self, listobj):
+        self.commit_context.saved_lists.append(listobj)
+        
     # TODO: better interface for tasks with no object save, or multiple dependencies
     def register_dependency(self, mapper, dependency, processor, stuff_to_process):
-        self.dependencies[(mapper, dependency)] = True
-        task = self.get_task_by_mapper(mapper)
+        self.commit_context.dependencies[(mapper, dependency)] = True
+        task = self.commit_context.get_task_by_mapper(mapper)
         if processor is not None:
             task.dependencies.append((processor, stuff_to_process))
         
+            
+class UOWTransaction(object):
+    def __init__(self):
+        self.dependencies = {}
+        self.tasks = {}
+        self.saved_objects = util.HashSet()
+        self.saved_lists = util.HashSet()
+
+    def append_task(self, obj):
+        import sqlalchemy.mapper
+        mapper = sqlalchemy.mapper.object_mapper(obj)
+        task = self.get_task_by_mapper(mapper)
+        task.objects.append(obj)
+
     def get_task_by_mapper(self, mapper):
         try:
             return self.tasks[mapper]
         except KeyError:
             return self.tasks.setdefault(mapper, UOWTask(mapper))
-    
+        
 class UOWTask(object):
     def __init__(self, mapper):
         self.mapper = mapper