]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
further refinement of the polymorphic UOWTask idea. circular dependency sort has...
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 30 May 2006 17:11:14 +0000 (17:11 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 30 May 2006 17:11:14 +0000 (17:11 +0000)
of any inheritance chain, as it now takes part in pretty much any two dependent classes who share the same inherited parent.

lib/sqlalchemy/orm/dependency.py
lib/sqlalchemy/orm/unitofwork.py
test/inheritance2.py

index 195aff2b08b11dbe11b28e0ee0fbadb9a300d747..8002998fb378e9721b6efa391120a068b53e776c 100644 (file)
@@ -91,7 +91,7 @@ class OneToManyDP(DependencyProcessor):
             uowcommit.register_dependency(self.parent, self.mapper)
             uowcommit.register_processor(self.parent, self, self.parent)
     def process_dependencies(self, task, deplist, uowcommit, delete = False):
-        #print self.mapper.table.name + " " + self.key + " " + repr(len(deplist)) + " process_dep isdelete " + repr(delete) + " direction " + repr(self.direction)
+        #print self.mapper.mapped_table.name + " " + self.key + " " + repr(len(deplist)) + " process_dep isdelete " + repr(delete) + " direction " + repr(self.direction)
         if delete:
             # head object is being deleted, and we manage its list of child objects
             # the child objects have to have their foreign key to the parent set to NULL
@@ -121,7 +121,7 @@ class OneToManyDP(DependencyProcessor):
                         self._synchronize(obj, child, None, True)
 
     def preprocess_dependencies(self, task, deplist, uowcommit, delete = False):
-        #print self.mapper.table.name + " " + self.key + " " + repr(len(deplist)) + " process_dep isdelete " + repr(delete) + " direction " + repr(self.direction)
+        #print self.mapper.mapped_table.name + " " + self.key + " " + repr(len(deplist)) + " preprocess_dep isdelete " + repr(delete) + " direction " + repr(self.direction)
 
         if delete:
             # head object is being deleted, and we manage its list of child objects
index 3927b949e851a05c152f0f211aaf2e2e164e4c0a..d03840cf3af5813c0a53e8f205d5b236cb95e3a9 100644 (file)
@@ -324,6 +324,8 @@ class UOWTransaction(object):
         this method returns or creates the single per-transaction instance of
         UOWTask that exists for that mapper."""
         try:
+            if isinstance(mapper, UOWTask):
+                raise "wha"
             return self.tasks[mapper]
         except KeyError:
             if dontcreate:
@@ -363,6 +365,14 @@ class UOWTransaction(object):
         self._mark_modified()
 
     def execute(self, echo=False):
+        #print "\n------------------\nEXECUTE"
+        #for task in self.tasks.values():
+        #    print "\nTASK:", task
+        #    for obj in task.objects:
+        #        print "TASK OBJ:", obj
+        #    for elem in task.get_elements(polymorphic=True):
+        #        print "POLYMORPHIC TASK OBJ:", elem.obj
+        #print "-----------------------------"
         # pre-execute dependency processors.  this process may 
         # result in new tasks, objects and/or dependency processors being added,
         # particularly with 'delete-orphan' cascade rules.
@@ -426,30 +436,24 @@ class UOWTransaction(object):
                     task.childtasks.append(t)
             return task
             
-        mappers = sets.Set()
-        for task in self.tasks.values():
-            mappers.add(task.mapper)
-
-        def inheriting_tasks(task):
-            if task.mapper not in mappers:
-                return
-            for mapper in task.mapper._inheriting_mappers:
-                inherit_task = self.tasks.get(mapper, None)
-                if inherit_task is None:
-                    continue
-                inheriting_tasks(inherit_task)
-                task.inheriting_tasks.append(inherit_task)
-                mappers.remove(mapper)
-                
-        for task in self.tasks.values():
-            inheriting_tasks(task)
-                
+        mappers = self._get_noninheriting_mappers()
         head = DependencySorter(self.dependencies, list(mappers)).sort(allow_all_cycles=True)
+        #print "-------------------------"
         #print str(head)
+        #print "---------------------------"
         task = sort_hier(head)
         return task
 
-
+    def _get_noninheriting_mappers(self):
+        """returns a list of UOWTasks whose mappers are not inheriting from the mapper of another UOWTask.
+        i.e., this returns the root UOWTasks for all the inheritance hierarchies represented in this UOWTransaction."""
+        mappers = sets.Set()
+        for task in self.tasks.values():
+            base = task.mapper.base_mapper()
+            mappers.add(base)
+        return mappers
+        
+        
 class UOWTaskElement(object):
     """an element within a UOWTask.  corresponds to a single object instance
     to be saved, deleted, or just part of the transaction as a placeholder for 
@@ -499,7 +503,6 @@ class UOWDependencyProcessor(object):
     def __init__(self, processor, targettask):
         self.processor = processor
         self.targettask = targettask
-    
     def __eq__(self, other):
         return other.processor is self.processor and other.targettask is self.targettask
     def __hash__(self):
@@ -512,7 +515,7 @@ class UOWDependencyProcessor(object):
         def getobj(elem):
             elem.mark_preprocessed(self)
             return elem.obj
-            
+        
         ret = False
         elements = [getobj(elem) for elem in self.targettask.polymorphic_tosave_elements if elem.obj is not None and not elem.is_preprocessed(self)]
         if len(elements):
@@ -542,10 +545,10 @@ class UOWDependencyProcessor(object):
 
 class UOWTask(object):
     """represents the full list of objects that are to be saved/deleted by a specific Mapper."""
-    def __init__(self, uowtransaction, mapper):
-        if uowtransaction is not None:
+    def __init__(self, uowtransaction, mapper, circular_parent=None):
+        if not circular_parent:
             uowtransaction.tasks[mapper] = self
-
+        
         # the transaction owning this UOWTask
         self.uowtransaction = uowtransaction
         
@@ -570,12 +573,16 @@ class UOWTask(object):
         
         # a list of UOWTasks that correspond to Mappers which are inheriting
         # mappers of this UOWTask's Mapper
-        self.inheriting_tasks = []
+        #self.inheriting_tasks = sets.Set()
 
         # whether this UOWTask is circular, meaning it holds a second
-        # UOWTask that contains a special row-based dependency structure
+        # UOWTask that contains a special row-based dependency structure.
         self.circular = None
 
+        # for a task thats part of that row-based dependency structure, points
+        # back to the "public facing" task.
+        self.circular_parent = circular_parent
+        
         # a list of UOWDependencyProcessors are derived from the main
         # set of dependencies, referencing sub-UOWTasks attached to this
         # one which represent portions of the total list of objects.
@@ -678,6 +685,27 @@ class UOWTask(object):
         self._execute_per_element_childtasks(trans, True)
         self._delete_objects(trans)
 
+    def _inheriting_tasks(self):
+        """returns an iterator of UOWTasks whos mappers inherit from this UOWTask's mapper.  Only
+        goes one level deep; i.e. for each UOWTask returned, you would call _inheriting_tasks on those
+        to get their inheriting tasks.  For a multilevel-inheritance chain, i.e. A->B->C, and there are
+        UOWTasks for A and C but not B, C will be returned when this method is called on A, otherwise B."""
+        if self.circular_parent is not None:
+            return
+        def _tasks_by_mapper(mapper):
+            for m in mapper._inheriting_mappers:
+                inherit_task = self.uowtransaction.tasks.get(m, None)
+                if inherit_task is not None:
+                    yield inherit_task
+                    #for t in inherit_task._inheriting_tasks():
+                    #    yield t
+                else:
+                    for t in _tasks_by_mapper(m):
+                        yield t
+        for t in _tasks_by_mapper(self.mapper):
+            yield t
+    inheriting_tasks = property(_inheriting_tasks)
+    
     def polymorphic_tasks(self):
         yield self
         for task in self.inheriting_tasks:
@@ -696,10 +724,10 @@ class UOWTask(object):
     def get_elements(self, polymorphic=False):
         for rec in self.objects.values():
             yield rec
-            if polymorphic:
-                for task in self.inheriting_tasks:
-                    for rec in task.get_elements(polymorphic=True):
-                        yield rec
+        if polymorphic:
+            for task in self.inheriting_tasks:
+                for rec in task.get_elements(polymorphic=True):
+                    yield rec
     
     polymorphic_tosave_elements = property(lambda self: [rec for rec in self.get_elements(polymorphic=True) if not rec.isdelete])
     polymorphic_todelete_elements = property(lambda self: [rec for rec in self.get_elements(polymorphic=True) if rec.isdelete])
@@ -746,13 +774,14 @@ class UOWTask(object):
             try:
                 l = dp[depprocessor]
             except KeyError:
-                l = UOWTask(None, depprocessor.targettask.mapper)
+                l = UOWTask(self.uowtransaction, depprocessor.targettask.mapper, circular_parent=self)
                 dp[depprocessor] = l
             return l
 
         def dependency_in_cycles(dep):
-            proctask = trans.get_task_by_mapper(dep.processor.mapper.primary_mapper(), True)
-            return dep.targettask in cycles and (proctask is not None and proctask in cycles)
+            proctask = trans.get_task_by_mapper(dep.processor.mapper.primary_mapper().base_mapper(), True)
+            targettask = trans.get_task_by_mapper(dep.targettask.mapper.base_mapper(), True)
+            return targettask in cycles and (proctask is not None and proctask in cycles)
             
         # organize all original UOWDependencyProcessors by their target task
         deps_by_targettask = {}
@@ -775,8 +804,9 @@ class UOWTask(object):
                     
                     for dep in deps_by_targettask.get(task, []):
                         # is this dependency involved in one of the cycles ?
-                        #print "DEP iterate", dep.processor.key    
+                        #print "DEP iterate", dep.processor.key, dep.processor.parent, dep.processor.mapper
                         if not dependency_in_cycles(dep):
+                            #print "NOT IN CYCLE"
                             continue
                         #print "DEP", dep.processor.key    
                         (processor, targettask) = (dep.processor, dep.targettask)
@@ -820,7 +850,7 @@ class UOWTask(object):
                 return hierarchical_tasks[obj]
             except KeyError:
                 originating_task = object_to_original_task[obj]
-                return hierarchical_tasks.setdefault(obj, UOWTask(None, originating_task.mapper))
+                return hierarchical_tasks.setdefault(obj, UOWTask(self.uowtransaction, originating_task.mapper, circular_parent=self))
 
         def make_task_tree(node, parenttask):
             """takes a dependency-sorted tree of objects and creates a tree of UOWTasks"""
@@ -852,7 +882,7 @@ class UOWTask(object):
             return t
 
         # this is the new "circular" UOWTask which will execute in place of "self"
-        t = UOWTask(None, self.mapper)
+        t = UOWTask(self.uowtransaction, self.mapper, circular_parent=self)
 
         # stick the non-circular dependencies and child tasks onto the new
         # circular UOWTask
index f01189cb2ff308cbc7df20615cfef28c9543e559..c6b9f01985cd224c2bc3ac28113f6d27d7733bb5 100644 (file)
@@ -227,6 +227,57 @@ class InheritTest(testbase.AssertMixin):
         print orig
         print new
         assert orig == new  == '<Assembly a1> specification=[<SpecLine 1.0 <Detail d1>>] documents=[<Document doc1>, <RasterDocument doc2>]'
+
+    def testfour(self):
+        """this tests the RasterDocument being attached to the Assembly, but *not* the Document.  this means only
+        a "sub-class" task, i.e. corresponding to an inheriting mapper but not the base mapper, is created. """
+        product_mapper = mapper(Product, products_table,
+            polymorphic_on=products_table.c.product_type,
+            polymorphic_identity='product')
+        detail_mapper = mapper(Detail, inherits=product_mapper,
+            polymorphic_identity='detail')
+        assembly_mapper = mapper(Assembly, inherits=product_mapper,
+            polymorphic_identity='assembly')
+
+        document_mapper = mapper(Document, documents_table,
+            polymorphic_on=documents_table.c.document_type,
+            polymorphic_identity='document',
+            properties=dict(
+                name=documents_table.c.name,
+                data=deferred(documents_table.c.data),
+                product=relation(Product, lazy=True, backref='documents'),
+                ),
+            )
+        raster_document_mapper = mapper(RasterDocument, inherits=document_mapper,
+            polymorphic_identity='raster_document')
+
+        product_mapper.add_property('documents',
+            relation(Document, lazy=True,
+                backref='product', cascade='all, delete-orphan'),
+            )
+
+        session = create_session(echo_uow=False)
+
+        a1 = Assembly(name='a1')
+        a1.documents.append(RasterDocument('doc2'))
+        session.save(a1)
+        orig = repr(a1)
+        session.flush()
+        session.clear()
+
+        a1 = session.query(Product).get_by(name='a1')
+        new = repr(a1)
+        print orig
+        print new
+        assert orig == new  == '<Assembly a1> specification=None documents=[<RasterDocument doc2>]'
+
+        del a1.documents[0]
+        session.save(a1)
+        session.flush()
+        session.clear()
+
+        a1 = session.query(Product).get_by(name='a1')
+        assert len(session.query(Document).select()) == 0
         
 if __name__ == "__main__":    
     testbase.main()