]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
unitofwork more Set oriented now
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 27 May 2006 05:45:21 +0000 (05:45 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 27 May 2006 05:45:21 +0000 (05:45 +0000)
MapperProperty now has "localparent" and "parent" attributes, which in the case of
inheritance represent the mapper the property is attached to, and the original mapper it was created on.
the unitofwork now keeps the dependency processors derived from those properties unique so inheritance
structures dont register redundant dependency processors.

lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/orm/unitofwork.py

index 9b636f244e54f5f2b2b233f756367a5a47cd0fd6..3e7e941f9691d1a2739d19171eba5e9b91e533d1 100644 (file)
@@ -932,6 +932,7 @@ class MapperProperty(object):
         """called when the MapperProperty is first attached to a new parent Mapper."""
         self.key = key
         self.parent = parent
+        self.localparent = parent
         self.do_init(key, parent)
     def adapt(self, newparent):
         """adapts this MapperProperty to a new parent, assuming the new parent is an inheriting
@@ -939,7 +940,8 @@ class MapperProperty(object):
         False if this MapperProperty cannot be adapted to the new parent (the case for this is,
         the parent mapper has a polymorphic select, and this property represents a column that is not
         represented in the new mapper's mapped table)"""
-        self.parent = newparent
+        #self.parent = newparent
+        self.localparent = newparent
         return True
     def do_init(self, key, parent):
         """template method for subclasses"""
index 00e8e2a67eb37abf539403bfb1845c81c16126c5..59b85c0578b3841e1e50bb4c2fc04cff285e1f9f 100644 (file)
@@ -38,8 +38,6 @@ class ColumnProperty(mapper.MapperProperty):
             else:
                 statement.append_column(c)
     def do_init(self, key, parent):
-        self.key = key
-        self.parent = parent
         # establish a SmartProperty property manager on the object for this key
         if parent._is_primary_mapper():
             #print "regiser col on class %s key %s" % (parent.class_.__name__, key)
@@ -60,14 +58,12 @@ class DeferredColumnProperty(ColumnProperty):
     def copy(self):
         return DeferredColumnProperty(*self.columns)
     def do_init(self, key, parent):
-        self.key = key
-        self.parent = parent
         # establish a SmartProperty property manager on the object for this key, 
         # containing a callable to load in the attribute
         if self.is_primary():
             sessionlib.global_attributes.register_attribute(parent.class_, key, uselist=False, callable_=lambda i:self.setup_loader(i))
     def setup_loader(self, instance):
-        if not self.parent.is_assigned(instance):
+        if not self.localparent.is_assigned(instance):
             return mapper.object_mapper(instance).props[self.key].setup_loader(instance)
         def lazyload():
             session = sessionlib.object_session(instance)
@@ -85,7 +81,7 @@ class DeferredColumnProperty(ColumnProperty):
             
             try:
                 if self.group is not None:
-                    groupcols = [p for p in self.parent.props.values() if isinstance(p, DeferredColumnProperty) and p.group==self.group]
+                    groupcols = [p for p in self.localparent.props.values() if isinstance(p, DeferredColumnProperty) and p.group==self.group]
                     row = connection.execute(sql.select([g.columns[0] for g in groupcols], clause, use_labels=True), None).fetchone()
                     for prop in groupcols:
                         if prop is self:
@@ -193,8 +189,6 @@ class PropertyLoader(mapper.MapperProperty):
                 self.association = mapper.class_mapper(self.association)
         
         self.target = self.mapper.mapped_table
-        self.key = key
-        self.parent = parent
 
         if self.secondaryjoin is not None and self.secondary is None:
             raise exceptions.ArgumentError("Property '" + self.key + "' specified with secondary join condition but no secondary argument")
@@ -253,7 +247,6 @@ class PropertyLoader(mapper.MapperProperty):
         
     def _get_direction(self):
         """determines our 'direction', i.e. do we represent one to many, many to many, etc."""
-        #print self.key, repr(self.parent.mapped_table.name), repr(self.parent.primarytable.name), repr(self.foreignkey.table.name), repr(self.target), repr(self.foreigntable.name)
         
         if self.secondaryjoin is not None:
             return sync.MANYTOMANY
@@ -323,7 +316,6 @@ class PropertyLoader(mapper.MapperProperty):
 
         self.syncrules = sync.ClauseSynchronizer(self.parent, self.mapper, self.direction)
         if self.direction == sync.MANYTOMANY:
-            #print "COMPILING p/c", self.parent, self.mapper
             self.syncrules.compile(self.primaryjoin, parent_tables, [self.secondary], False)
             self.syncrules.compile(self.secondaryjoin, target_tables, [self.secondary], True)
         else:
@@ -342,7 +334,7 @@ class LazyLoader(PropertyLoader):
         sessionlib.global_attributes.register_attribute(class_, key, uselist = self.uselist, callable_=lambda i: self.setup_loader(i), extension=self.attributeext, cascade=self.cascade, trackparent=True)
 
     def setup_loader(self, instance):
-        if not self.parent.is_assigned(instance):
+        if not self.localparent.is_assigned(instance):
             return mapper.object_mapper(instance).props[self.key].setup_loader(instance)
         def lazyload():
             params = {}
@@ -476,7 +468,7 @@ class EagerLoader(LazyLoader):
             if isinstance(prop, EagerLoader):
                 eagerprops.append(prop)
         if len(eagerprops):
-            recursion_stack[self.parent.mapped_table] = True
+            recursion_stack[self.localparent.mapped_table] = True
             self.mapper = self.mapper.copy()
             try:
                 for prop in eagerprops:
@@ -495,7 +487,7 @@ class EagerLoader(LazyLoader):
                     p.eagerprimary.accept_visitor(self.aliasizer)
                     #print "new eagertqarget", p.eagertarget.name, (p.secondary and p.secondary.name or "none"), p.parent.mapped_table.name
             finally:
-                del recursion_stack[self.parent.mapped_table]
+                del recursion_stack[self.localparent.mapped_table]
 
         self._row_decorator = self._create_decorator_row()
         
@@ -522,7 +514,7 @@ class EagerLoader(LazyLoader):
         if hasattr(statement, '_outerjoin'):
             towrap = statement._outerjoin
         else:
-            towrap = self.parent.mapped_table
+            towrap = self.localparent.mapped_table
 
  #       print "hello, towrap", str(towrap)
         if self.secondaryjoin is not None:
index 710326473986152a6cc9f00fe95cf8a92d30ba46..d7e04b31e3a89c1fd9fc2551e2945c59f6ef35d1 100644 (file)
@@ -21,7 +21,7 @@ from sqlalchemy.exceptions import *
 import StringIO
 import weakref
 import topological
-from sets import *
+import sets
 
 # a global indicating if all flush() operations should have their plan
 # printed to standard output.  also can be affected by creating an engine
@@ -218,19 +218,19 @@ class UnitOfWork(object):
         flush_context = UOWTransaction(self, session)
 
         if objects is not None:
-            objset = util.HashSet(iter=objects)
+            objset = sets.Set(objects)
         else:
             objset = None
 
         for obj in [n for n in self.new] + [d for d in self.dirty]:
-            if objset is not None and not objset.contains(obj):
+            if objset is not None and not obj in objset:
                 continue
             if self.deleted.contains(obj):
                 continue
             flush_context.register_object(obj)
             
         for obj in self.deleted:
-            if objset is not None and not objset.contains(obj):
+            if objset is not None and not obj in objset:
                 continue
             flush_context.register_object(obj, isdelete=True)
         
@@ -265,7 +265,7 @@ class UOWTransaction(object):
         self.uow = uow
         self.session = session
         #  unique list of all the mappers we come across
-        self.mappers = util.HashSet()
+        self.mappers = sets.Set()
         self.dependencies = {}
         self.tasks = {}
         self.__modified = False
@@ -287,7 +287,7 @@ class UOWTransaction(object):
         self.uow._validate_obj(obj)
             
         mapper = object_mapper(obj)
-        self.mappers.append(mapper)
+        self.mappers.add(mapper)
         task = self.get_task_by_mapper(mapper)
 
         if postupdate:
@@ -356,7 +356,7 @@ class UOWTransaction(object):
         task = self.get_task_by_mapper(mapper)
         targettask = self.get_task_by_mapper(mapperfrom)
         up = UOWDependencyProcessor(processor, targettask)
-        task.dependencies.append(up)
+        task.dependencies.add(up)
         self._mark_modified()
 
     def execute(self, echo=False):
@@ -368,7 +368,7 @@ class UOWTransaction(object):
         while True:
             ret = False
             for task in self.tasks.values():
-                for up in task.dependencies:
+                for up in list(task.dependencies):
                     if up.preexecute(self):
                         ret = True
             if not ret:
@@ -423,9 +423,9 @@ class UOWTransaction(object):
                     task.childtasks.append(t)
             return task
             
-        mappers = util.HashSet()
+        mappers = sets.Set()
         for task in self.tasks.values():
-            mappers.append(task.mapper)
+            mappers.add(task.mapper)
 
         def inheriting_tasks(task):
             if task.mapper not in mappers:
@@ -436,12 +436,12 @@ class UOWTransaction(object):
                     continue
                 inheriting_tasks(inherit_task)
                 task.inheriting_tasks.append(inherit_task)
-                del mappers[mapper]
+                mappers.remove(mapper)
                 
         for task in self.tasks.values():
             inheriting_tasks(task)
                 
-        head = DependencySorter(self.dependencies, mappers).sort(allow_all_cycles=True)
+        head = DependencySorter(self.dependencies, list(mappers)).sort(allow_all_cycles=True)
         #print str(head)
         task = sort_hier(head)
         return task
@@ -496,6 +496,11 @@ class UOWDependencyProcessor(object):
     def __init__(self, processor, targettask):
         self.processor = processor
         self.targettask = targettask
+    
+    def __eq__(self, other):
+        return other.processor is self.processor and other.targettask is self.targettask
+    def __hash__(self):
+        return hash((self.processor, self.targettask))
         
     def preexecute(self, trans):
         """traverses all objects handled by this dependency processor and locates additional objects which should be 
@@ -537,14 +542,42 @@ class UOWTask(object):
     def __init__(self, uowtransaction, mapper):
         if uowtransaction is not None:
             uowtransaction.tasks[mapper] = self
+
+        # the transaction owning this UOWTask
         self.uowtransaction = uowtransaction
+        
+        # the Mapper which this UOWTask corresponds to
         self.mapper = mapper
+        
+        # a dictionary mapping object instances to a corresponding UOWTaskElement.
+        # Each UOWTaskElement represents one instance which is to be saved or 
+        # deleted by this UOWTask's Mapper.
+        # in the case of the row-based "circular sort", the UOWTaskElement may
+        # also reference further UOWTasks which are dependent on that UOWTaskElement.
         self.objects = util.OrderedDict()
-        self.dependencies = []
-        self.cyclical_dependencies = []
-        self.circular = None
+        
+        # a list of UOWDependencyProcessors which are executed after saves and
+        # before deletes, to synchronize data to dependent objects
+        self.dependencies = sets.Set()
+
+        # a list of UOWTasks that are dependent on this UOWTask, which 
+        # are to be executed after this UOWTask performs saves and post-save
+        # dependency processing, and before pre-delete processing and deletes
         self.childtasks = []
+        
+        # a list of UOWTasks that correspond to Mappers which are inheriting
+        # mappers of this UOWTask's Mapper
         self.inheriting_tasks = []
+
+        # whether this UOWTask is circular, meaning it holds a second
+        # UOWTask that contains a special row-based dependency structure
+        self.circular = None
+
+        # a list of UOWDependencyProcessors are derived from the main
+        # set of dependencies, referencing sub-UOWTasks attached to this
+        # one which represent portions of the total list of objects.
+        # this is used for the row-based "circular sort"
+        self.cyclical_dependencies = sets.Set()
         
     def is_empty(self):
         return len(self.objects) == 0 and len(self.dependencies) == 0 and len(self.childtasks) == 0
@@ -569,8 +602,6 @@ class UOWTask(object):
             rec.childtasks.append(childtask)
         if isdelete:
             rec.isdelete = True
-        #if not childtask:
-        #    rec.preprocessed = False
         return retval
     
     def append_postupdate(self, obj):
@@ -666,16 +697,13 @@ class UOWTask(object):
         of its object list contain dependencies on each other.
         
         this is not the normal case; this logic only kicks in when something like 
-        a hierarchical tree is being represented.
-
-        """
-
+        a hierarchical tree is being represented."""
         allobjects = []
         for task in cycles:
             allobjects += task.objects.keys()
         tuples = []
         
-        cycles = Set(cycles)
+        cycles = sets.Set(cycles)
         
         #print "BEGIN CIRC SORT-------"
         #print "PRE-CIRC:"
@@ -733,11 +761,6 @@ class UOWTask(object):
                     # the task corresponding to the processor's objects
                     childtask = trans.get_task_by_mapper(processor.mapper)
                     
-#                    if isdelete:
-#                        childlist = childlist.unchanged_items() + childlist.deleted_items()
-#                    else:
-#                        childlist = childlist.added_items() 
-                        
                     childlist = childlist.added_items() + childlist.unchanged_items() + childlist.deleted_items()
                         
                     for o in childlist:
@@ -785,14 +808,12 @@ class UOWTask(object):
                 else:
                     t.append(node.item, original_task.objects[node.item].listonly, isdelete=original_task.objects[node.item].isdelete)
                     parenttask.append(None, listonly=False, isdelete=original_task.objects[node.item].isdelete, childtask=t)
-            #else:
-            #    parenttask.append(None, listonly=False, isdelete=original_task.objects[node.item].isdelete, childtask=t)
             if dependencies.has_key(node.item):
                 for depprocessor, deptask in dependencies[node.item].iteritems():
                     if can_add_to_parent:
-                        parenttask.cyclical_dependencies.append(depprocessor.branch(deptask))
+                        parenttask.cyclical_dependencies.add(depprocessor.branch(deptask))
                     else:
-                        t.cyclical_dependencies.append(depprocessor.branch(deptask))
+                        t.cyclical_dependencies.add(depprocessor.branch(deptask))
             return t
 
         # this is the new "circular" UOWTask which will execute in place of "self"
@@ -800,7 +821,7 @@ class UOWTask(object):
 
         # stick the non-circular dependencies and child tasks onto the new
         # circular UOWTask
-        t.dependencies += [d for d in extradeplist]
+        [t.dependencies.add(d) for d in extradeplist]
         t.childtasks = self.childtasks
         make_task_tree(head, t)
         #print t.dump()