]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- cleanup, factoring, had some heisenbugs. more test coverage
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 6 Apr 2010 05:23:54 +0000 (01:23 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 6 Apr 2010 05:23:54 +0000 (01:23 -0400)
 will be needed overall as missing dependency rules lead
to subtle bugs pretty easily

lib/sqlalchemy/orm/dependency.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/unitofwork.py
test/orm/test_unitofwork.py
test/orm/test_unitofworkv2.py

index ecea094fd10ece9dba8b282e1fcea5e6fbe0610c..aef297ee6c1ac27102660f36f8bdd89c872f91e8 100644 (file)
@@ -74,11 +74,23 @@ class DependencyProcessor(object):
         after_save = unitofwork.ProcessAll(uow, self, False, True)
         before_delete = unitofwork.ProcessAll(uow, self, True, True)
 
-        parent_saves = unitofwork.SaveUpdateAll(uow, self.parent.primary_mapper().base_mapper)
-        child_saves = unitofwork.SaveUpdateAll(uow, self.mapper.primary_mapper().base_mapper)
-
-        parent_deletes = unitofwork.DeleteAll(uow, self.parent.primary_mapper().base_mapper)
-        child_deletes = unitofwork.DeleteAll(uow, self.mapper.primary_mapper().base_mapper)
+        parent_saves = unitofwork.SaveUpdateAll(
+                                        uow, 
+                                        self.parent.primary_base_mapper
+                                        )
+        child_saves = unitofwork.SaveUpdateAll(
+                                        uow, 
+                                        self.mapper.primary_base_mapper
+                                        )
+
+        parent_deletes = unitofwork.DeleteAll(
+                                        uow, 
+                                        self.parent.primary_base_mapper
+                                        )
+        child_deletes = unitofwork.DeleteAll(
+                                        uow, 
+                                        self.mapper.primary_base_mapper
+                                        )
         
         self.per_property_dependencies(uow, 
                                         parent_saves, 
@@ -109,8 +121,11 @@ class DependencyProcessor(object):
         after_save.disabled = True
 
         # check if the "child" side is part of the cycle
-        child_saves = unitofwork.SaveUpdateAll(uow, self.mapper.base_mapper)
-        child_deletes = unitofwork.DeleteAll(uow, self.mapper.base_mapper)
+        
+        parent_base_mapper = self.parent.primary_base_mapper
+        child_base_mapper = self.mapper.primary_base_mapper
+        child_saves = unitofwork.SaveUpdateAll(uow, child_base_mapper)
+        child_deletes = unitofwork.DeleteAll(uow, child_base_mapper)
         
         if child_saves not in uow.cycles:
             # based on the current dependencies we use, the saves/
@@ -130,12 +145,16 @@ class DependencyProcessor(object):
         
         # check if the "parent" side is part of the cycle
         if not isdelete:
-            parent_saves = unitofwork.SaveUpdateAll(uow, self.parent.base_mapper)
+            parent_saves = unitofwork.SaveUpdateAll(
+                                                uow, 
+                                                self.parent.base_mapper)
             parent_deletes = before_delete = None
             if parent_saves in uow.cycles:
                 parent_in_cycles = True
         else:
-            parent_deletes = unitofwork.DeleteAll(uow, self.parent.base_mapper)
+            parent_deletes = unitofwork.DeleteAll(
+                                                uow, 
+                                                self.parent.base_mapper)
             parent_saves = after_save = None
             if parent_deletes in uow.cycles:
                 parent_in_cycles = True
@@ -148,17 +167,26 @@ class DependencyProcessor(object):
             if isdelete:
                 before_delete = unitofwork.ProcessState(uow, self, True, state)
                 if parent_in_cycles:
-                    parent_deletes = unitofwork.DeleteState(uow, state)
+                    parent_deletes = unitofwork.DeleteState(
+                                                uow, 
+                                                state, 
+                                                parent_base_mapper)
             else:
                 after_save = unitofwork.ProcessState(uow, self, False, state)
                 if parent_in_cycles:
-                    parent_saves = unitofwork.SaveUpdateState(uow, state)
+                    parent_saves = unitofwork.SaveUpdateState(
+                                                uow, 
+                                                state, 
+                                                parent_base_mapper)
                 
             if child_in_cycles:
                 # locate each child state associated with the parent action,
                 # create dependencies for each.
                 child_actions = []
-                sum_ = uow.get_attribute_history(state, self.key, passive=True).sum()
+                sum_ = uow.get_attribute_history(
+                                                state, 
+                                                self.key, 
+                                                passive=True).sum()
                 if not sum_:
                     continue
                 for child_state in sum_:
@@ -169,9 +197,17 @@ class DependencyProcessor(object):
                     else:
                         (deleted, listonly) = uow.states[child_state]
                         if deleted:
-                            child_action = (unitofwork.DeleteState(uow, child_state), True)
+                            child_action = (
+                                            unitofwork.DeleteState(
+                                                        uow, child_state, 
+                                                        child_base_mapper), 
+                                            True)
                         else:
-                            child_action = (unitofwork.SaveUpdateState(uow, child_state), False)
+                            child_action = (
+                                            unitofwork.SaveUpdateState(
+                                                        uow, child_state, 
+                                                        child_base_mapper), 
+                                            False)
                     child_actions.append(child_action)
                     
             # establish dependencies between our possibly per-state
@@ -204,25 +240,28 @@ class DependencyProcessor(object):
             not self.mapper._canload(state, allow_subtypes=not self.enable_typechecks):
             if self.mapper._canload(state, allow_subtypes=True):
                 raise exc.FlushError(
-                    "Attempting to flush an item of type %s on collection '%s', "
-                    "which is not the expected type %s.  Configure mapper '%s' to "
-                    "load this subtype polymorphically, or set "
-                    "enable_typechecks=False to allow subtypes. "
-                    "Mismatched typeloading may cause bi-directional relationships "
-                    "(backrefs) to not function properly." % 
-                    (state.class_, self.prop, self.mapper.class_, self.mapper))
+                "Attempting to flush an item of type %s on collection '%s', "
+                "which is not the expected type %s.  Configure mapper '%s' to "
+                "load this subtype polymorphically, or set "
+                "enable_typechecks=False to allow subtypes. "
+                "Mismatched typeloading may cause bi-directional relationships "
+                "(backrefs) to not function properly." % 
+                (state.class_, self.prop, self.mapper.class_, self.mapper))
             else:
                 raise exc.FlushError(
-                    "Attempting to flush an item of type %s on collection '%s', "
-                    "whose mapper does not inherit from that of %s." % 
-                    (state.class_, self.prop, self.mapper.class_))
+                "Attempting to flush an item of type %s on collection '%s', "
+                "whose mapper does not inherit from that of %s." % 
+                (state.class_, self.prop, self.mapper.class_))
             
-    def _synchronize(self, state, 
-                            child, associationrow, 
-                            clearkeys, uowcommit):
+    def _synchronize(self, state, child, associationrow,
+                                            clearkeys, uowcommit):
         raise NotImplementedError()
 
     def _check_reverse(self, uow):
+        """return True if a comparable dependency processor has
+        already set up on the "reverse" side of a relationship.
+        
+        """
         for p in self.prop._reverse_property:
             if not p.viewonly and p._dependency_processor and \
                 (unitofwork.ProcessAll, 
@@ -365,11 +404,15 @@ class OneToManyDP(DependencyProcessor):
             if self._pks_changed(uowcommit, state):
                 if not history:
                     history = uowcommit.get_attribute_history(
-                                        state, self.key, passive=self.passive_updates)
+                                        state, self.key, 
+                                        passive=self.passive_updates)
                 if history:
                     for child in history.unchanged:
                         if child is not None:
-                            uowcommit.register_object(child, False, self.passive_updates)
+                            uowcommit.register_object(
+                                        child, 
+                                        False, 
+                                        self.passive_updates)
     
     def process_deletes(self, uowcommit, states):
         # head object is being deleted, and we manage its list of 
@@ -385,14 +428,27 @@ class OneToManyDP(DependencyProcessor):
                                             passive=self.passive_deletes)
                 if history:
                     for child in history.deleted:
-                        if child is not None and self.hasparent(child) is False:
-                            self._synchronize(state, child, None, True, uowcommit)
-                            self._conditional_post_update(child, uowcommit, [state])
+                        if child is not None and \
+                            self.hasparent(child) is False:
+                            self._synchronize(
+                                            state, 
+                                            child, 
+                                            None, True, uowcommit)
+                            self._conditional_post_update(
+                                            child, 
+                                            uowcommit, 
+                                            [state])
                     if self.post_update or not self.cascade.delete:
                         for child in history.unchanged:
                             if child is not None:
-                                self._synchronize(state, child, None, True, uowcommit)
-                                self._conditional_post_update(child, uowcommit, [state])
+                                self._synchronize(
+                                            state, 
+                                            child, 
+                                            None, True, uowcommit)
+                                self._conditional_post_update(
+                                            child, 
+                                            uowcommit, 
+                                            [state])
     
     def process_saves(self, uowcommit, states):
         for state in states:
@@ -401,20 +457,26 @@ class OneToManyDP(DependencyProcessor):
                 for child in history.added:
                     self._synchronize(state, child, None, False, uowcommit)
                     if child is not None:
-                        self._conditional_post_update(child, uowcommit, [state])
+                        self._conditional_post_update(
+                                            child, 
+                                            uowcommit, 
+                                            [state])
 
                 for child in history.deleted:
-                    if not self.cascade.delete_orphan and not self.hasparent(child):
+                    if not self.cascade.delete_orphan and \
+                        not self.hasparent(child):
                         self._synchronize(state, child, None, True, uowcommit)
 
                 if self._pks_changed(uowcommit, state):
                     for child in history.unchanged:
                         self._synchronize(state, child, None, False, uowcommit)
         
-    def _synchronize(self, state, child, associationrow, clearkeys, uowcommit):
+    def _synchronize(self, state, child, associationrow, 
+                                                clearkeys, uowcommit):
         source = state
         dest = child
-        if dest is None or (not self.post_update and uowcommit.is_deleted(dest)):
+        if dest is None or \
+                (not self.post_update and uowcommit.is_deleted(dest)):
             return
         self._verify_canload(child)
         if clearkeys:
@@ -517,7 +579,8 @@ class ManyToOneDP(DependencyProcessor):
                         if child is None:
                             continue
                         uowcommit.register_object(child, isdelete=True)
-                        for c, m in self.mapper.cascade_iterator('delete', child):
+                        for c, m in self.mapper.cascade_iterator(
+                                                            'delete', child):
                             uowcommit.register_object(
                                 attributes.instance_state(c), isdelete=True)
 
@@ -533,7 +596,8 @@ class ManyToOneDP(DependencyProcessor):
                     for child in history.deleted:
                         if self.hasparent(child) is False:
                             uowcommit.register_object(child, isdelete=True)
-                            for c, m in self.mapper.cascade_iterator('delete', child):
+                            for c, m in self.mapper.cascade_iterator(
+                                                            'delete', child):
                                 uowcommit.register_object(
                                     attributes.instance_state(c),
                                     isdelete=True)
@@ -552,7 +616,10 @@ class ManyToOneDP(DependencyProcessor):
                                             self.key, 
                                             passive=self.passive_deletes)
                 if history:
-                    self._conditional_post_update(state, uowcommit, history.sum())
+                    self._conditional_post_update(
+                                            state, 
+                                            uowcommit, 
+                                            history.sum())
 
     def process_saves(self, uowcommit, states):
         for state in states:
@@ -561,7 +628,9 @@ class ManyToOneDP(DependencyProcessor):
                 for child in history.added:
                     self._synchronize(state, child, None, False, uowcommit)
                 
-                self._conditional_post_update(state, uowcommit, history.sum())
+                self._conditional_post_update(
+                                            state, 
+                                            uowcommit, history.sum())
 
     def _synchronize(self, state, child, associationrow, clearkeys, uowcommit):
         if state is None or (not self.post_update and uowcommit.is_deleted(state)):
@@ -603,7 +672,9 @@ class DetectKeySwitch(DependencyProcessor):
             # so that we avoid ManyToOneDP's registering the object without
             # the listonly flag in its own preprocess stage (results in UPDATE)
             # statements being emitted
-            parent_saves = unitofwork.SaveUpdateAll(uow, self.parent.base_mapper)
+            parent_saves = unitofwork.SaveUpdateAll(
+                                                uow, 
+                                                self.parent.base_mapper)
             after_save = unitofwork.ProcessAll(uow, self, False, False)
             uow.dependencies.update([
                 (parent_saves, after_save)
@@ -631,8 +702,8 @@ class DetectKeySwitch(DependencyProcessor):
         if switchers:
             # if primary key values have actually changed somewhere, perform
             # a linear search through the UOW in search of a parent.
-            # note that this handler isn't used if the many-to-one relationship
-            # has a backref.
+            # note that this handler isn't used if the many-to-one 
+            # relationship has a backref.
             for state in uowcommit.session.identity_map.all_states():
                 if not issubclass(state.class_, self.parent.class_):
                     continue
@@ -661,7 +732,7 @@ class ManyToManyDP(DependencyProcessor):
         
     def per_property_flush_actions(self, uow):
         if self._check_reverse(uow):
-            return
+            unitofwork.GetDependentObjects(uow, self, False, True)
         else:
             DependencyProcessor.per_property_flush_actions(self, uow)
             
@@ -682,9 +753,18 @@ class ManyToManyDP(DependencyProcessor):
         uow.dependencies.update([
             (parent_saves, after_save),
             (child_saves, after_save),
+            (after_save, child_deletes),
+            
+            # a rowswitch on the parent from  deleted to saved 
+            # can make this one occur, as the "save" may remove 
+            # an element from the 
+            # "deleted" list before we have a chance to
+            # process its child rows
+            (before_delete, parent_saves),
             
             (before_delete, parent_deletes),
             (before_delete, child_deletes),
+            (before_delete, child_saves),
         ])
 
     def per_state_dependencies(self, uow, 
@@ -709,13 +789,21 @@ class ManyToManyDP(DependencyProcessor):
         pass
 
     def presort_saves(self, uowcommit, states):
+        if not self.cascade.delete_orphan:
+            return
+            
         for state in states:
-            history = uowcommit.get_attribute_history(state, self.key, passive=True)
+            history = uowcommit.get_attribute_history(
+                                        state, 
+                                        self.key, 
+                                        passive=True)
             if history:
                 for child in history.deleted:
-                    if self.cascade.delete_orphan and self.hasparent(child) is False:
+                    if self.hasparent(child) is False:
                         uowcommit.register_object(child, isdelete=True)
-                        for c, m in self.mapper.cascade_iterator('delete', child):
+                        for c, m in self.mapper.cascade_iterator(
+                                                    'delete', 
+                                                    child):
                             uowcommit.register_object(
                                 attributes.instance_state(c), isdelete=True)
 
@@ -733,7 +821,11 @@ class ManyToManyDP(DependencyProcessor):
                     if child is None:
                         continue
                     associationrow = {}
-                    self._synchronize(state, child, associationrow, False, uowcommit)
+                    self._synchronize(
+                                        state, 
+                                        child, 
+                                        associationrow, 
+                                        False, uowcommit)
                     secondary_delete.append(associationrow)
 
         self._run_crud(uowcommit, secondary_insert, 
@@ -751,29 +843,37 @@ class ManyToManyDP(DependencyProcessor):
                     if child is None:
                         continue
                     associationrow = {}
-                    self._synchronize(state, child, associationrow, False, uowcommit)
+                    self._synchronize(state, 
+                                        child, 
+                                        associationrow, 
+                                        False, uowcommit)
                     secondary_insert.append(associationrow)
                 for child in history.deleted:
                     if child is None:
                         continue
                     associationrow = {}
-                    self._synchronize(state, child, associationrow, False, uowcommit)
+                    self._synchronize(state, 
+                                        child, 
+                                        associationrow, 
+                                        False, uowcommit)
                     secondary_delete.append(associationrow)
 
-            if not self.passive_updates and self._pks_changed(uowcommit, state):
+            if not self.passive_updates and \
+                    self._pks_changed(uowcommit, state):
                 if not history:
-                    history = uowcommit.get_attribute_history(state, self.key, passive=False)
+                    history = uowcommit.get_attribute_history(
+                                        state, 
+                                        self.key, 
+                                        passive=False)
                 
                 for child in history.unchanged:
                     associationrow = {}
-                    sync.update(
-                                state, 
+                    sync.update(state, 
                                 self.parent, 
                                 associationrow, 
                                 "old_", 
                                 self.prop.synchronize_pairs)
-                    sync.update(
-                                child, 
+                    sync.update(child, 
                                 self.mapper, 
                                 associationrow, 
                                 "old_", 
@@ -784,35 +884,48 @@ class ManyToManyDP(DependencyProcessor):
         self._run_crud(uowcommit, secondary_insert, 
                         secondary_update, secondary_delete)
         
-    def _run_crud(self, uowcommit, secondary_insert, secondary_update, secondary_delete):
+    def _run_crud(self, uowcommit, secondary_insert, 
+                                        secondary_update, secondary_delete):
         connection = uowcommit.transaction.connection(self.mapper)
 
         if secondary_delete:
             associationrow = secondary_delete[0]
             statement = self.secondary.delete(sql.and_(*[
                                 c == sql.bindparam(c.key, type_=c.type) 
-                                for c in self.secondary.c if c.key in associationrow
+                                for c in self.secondary.c 
+                                if c.key in associationrow
                             ]))
             result = connection.execute(statement, secondary_delete)
+            
             if result.supports_sane_multi_rowcount() and \
                         result.rowcount != len(secondary_delete):
                 raise exc.ConcurrentModificationError(
                         "Deleted rowcount %d does not match number of "
                         "secondary table rows deleted from table '%s': %d" % 
-                        (result.rowcount, self.secondary.description, len(secondary_delete)))
+                        (
+                            result.rowcount, 
+                            self.secondary.description, 
+                            len(secondary_delete))
+                        )
 
         if secondary_update:
             associationrow = secondary_update[0]
             statement = self.secondary.update(sql.and_(*[
-                                c == sql.bindparam("old_" + c.key, type_=c.type) 
-                                for c in self.secondary.c if c.key in associationrow
-                            ]))
+                            c == sql.bindparam("old_" + c.key, type_=c.type) 
+                            for c in self.secondary.c 
+                            if c.key in associationrow
+                        ]))
             result = connection.execute(statement, secondary_update)
-            if result.supports_sane_multi_rowcount() and result.rowcount != len(secondary_update):
+            if result.supports_sane_multi_rowcount() and \
+                        result.rowcount != len(secondary_update):
                 raise exc.ConcurrentModificationError(
                             "Updated rowcount %d does not match number of "
                             "secondary table rows updated from table '%s': %d" % 
-                            (result.rowcount, self.secondary.description, len(secondary_update)))
+                            (
+                                result.rowcount, 
+                                self.secondary.description, 
+                                len(secondary_update))
+                            )
 
         if secondary_insert:
             statement = self.secondary.insert()
index 64b02d8412c0850cb45a6cfb96af27e81fc8405c..c6b867f3580774835ebeb64ef10f567a67e4f036 100644 (file)
@@ -1069,6 +1069,10 @@ class Mapper(object):
         
         return self.class_manager.mapper
 
+    @property
+    def primary_base_mapper(self):
+        return self.class_manager.mapper.base_mapper
+        
     def identity_key_from_row(self, row, adapter=None):
         """Return an identity-map key for use in storing/retrieving an
         item from the identity map.
@@ -1248,7 +1252,7 @@ class Mapper(object):
             ret[t] = table_to_mapper[t]
         return ret
 
-    def per_mapper_flush_actions(self, uow):
+    def _per_mapper_flush_actions(self, uow):
         saves = unitofwork.SaveUpdateAll(uow, self.base_mapper)
         deletes = unitofwork.DeleteAll(uow, self.base_mapper)
         uow.dependencies.add((saves, deletes))
@@ -1277,22 +1281,24 @@ class Mapper(object):
             for prop in mapper._props.values():
                 if prop not in props:
                     props.add(prop)
-                    yield prop, [m for m in mappers if m._props.get(prop.key) is prop]
+                    yield prop, [m for m in mappers 
+                                    if m._props.get(prop.key) is prop]
 
-    def per_state_flush_actions(self, uow, states, isdelete):
+    def _per_state_flush_actions(self, uow, states, isdelete):
         
         mappers_to_states = util.defaultdict(set)
         
-        save_all = unitofwork.SaveUpdateAll(uow, self.base_mapper)
-        delete_all = unitofwork.DeleteAll(uow, self.base_mapper)
+        base_mapper = self.base_mapper
+        save_all = unitofwork.SaveUpdateAll(uow, base_mapper)
+        delete_all = unitofwork.DeleteAll(uow, base_mapper)
         for state in states:
             # keep saves before deletes -
             # this ensures 'row switch' operations work
             if isdelete:
-                action = unitofwork.DeleteState(uow, state)
+                action = unitofwork.DeleteState(uow, state, base_mapper)
                 uow.dependencies.add((save_all, action))
             else:
-                action = unitofwork.SaveUpdateState(uow, state)
+                action = unitofwork.SaveUpdateState(uow, state, base_mapper)
                 uow.dependencies.add((action, delete_all))
             
             mappers_to_states[state.manager.mapper].add(state)
@@ -1362,10 +1368,12 @@ class Mapper(object):
                     if 'before_update' in mapper.extension:
                         mapper.extension.before_update(mapper, conn, state.obj())
 
-                # detect if we have a "pending" instance (i.e. has no instance_key attached to it),
-                # and another instance with the same identity key already exists as persistent. 
+                # detect if we have a "pending" instance (i.e. has 
+                # no instance_key attached to it), and another instance 
+                # with the same identity key already exists as persistent. 
                 # convert to an UPDATE if so.
-                if not has_identity and instance_key in uowtransaction.session.identity_map:
+                if not has_identity and \
+                    instance_key in uowtransaction.session.identity_map:
                     instance = uowtransaction.session.identity_map[instance_key]
                     existing = attributes.instance_state(instance)
                     if not uowtransaction.is_deleted(existing):
index 0028202c57bd3b6409c2fdfe4001e56ba7ad4011..088c71b6e53576b09c54cc18ec6bcc0ec04c0d65 100644 (file)
@@ -60,15 +60,19 @@ class UOWEventHandler(interfaces.AttributeExtension):
                     sess.expunge(item)
 
     def set(self, state, newvalue, oldvalue, initiator):
-        # process "save_update" cascade rules for when an instance is attached to another instance
+        # process "save_update" cascade rules for when an instance 
+        # is attached to another instance
         if oldvalue is newvalue:
             return newvalue
         sess = _state_session(state)
         if sess:
             prop = _state_mapper(state).get_property(self.key)
-            if newvalue is not None and prop.cascade.save_update and newvalue not in sess:
+            if newvalue is not None and \
+                prop.cascade.save_update and \
+                newvalue not in sess:
                 sess.add(newvalue)
-            if prop.cascade.delete_orphan and oldvalue in sess.new and \
+            if prop.cascade.delete_orphan and \
+                oldvalue in sess.new and \
                 prop.mapper._is_orphan(attributes.instance_state(oldvalue)):
                 sess.expunge(oldvalue)
         return newvalue
@@ -99,7 +103,8 @@ class UOWTransaction(object):
         return bool(self.states)
 
     def is_deleted(self, state):
-        """return true if the given state is marked as deleted within this UOWTransaction."""
+        """return true if the given state is marked as deleted 
+        within this UOWTransaction."""
         return state in self.states and self.states[state][0]
     
     def remove_state_actions(self, state):
@@ -130,8 +135,6 @@ class UOWTransaction(object):
             return history.as_state()
 
     def register_object(self, state, isdelete=False, listonly=False):
-        
-        # if object is not in the overall session, do nothing
         if not self.session._contains_state(state):
             return
 
@@ -139,7 +142,7 @@ class UOWTransaction(object):
             mapper = _state_mapper(state)
             
             if mapper not in self.mappers:
-                mapper.per_mapper_flush_actions(self)
+                mapper._per_mapper_flush_actions(self)
             
             self.mappers[mapper].add(state)
             self.states[state] = (isdelete, listonly)
@@ -199,11 +202,9 @@ class UOWTransaction(object):
             # the per-state actions for those per-mapper actions
             # that were broken up.
             for edge in list(self.dependencies):
-                if None in edge:
-                    self.dependencies.remove(edge)
-                elif cycles.issuperset(edge):
-                    self.dependencies.remove(edge)
-                elif edge[0].disabled or edge[1].disabled:
+                if None in edge or\
+                    cycles.issuperset(edge) or \
+                    edge[0].disabled or edge[1].disabled:
                     self.dependencies.remove(edge)
                 elif edge[0] in cycles:
                     self.dependencies.remove(edge)
@@ -220,14 +221,24 @@ class UOWTransaction(object):
                                 ]
                             ).difference(cycles)
         
-        # execute actions
+        # execute
         if cycles:
-            for set_ in topological.sort_as_subsets(self.dependencies, postsort_actions):
+            for set_ in topological.sort_as_subsets(
+                                            self.dependencies, 
+                                            postsort_actions):
                 while set_:
                     n = set_.pop()
                     n.execute_aggregate(self, set_)
         else:
-            for rec in topological.sort(self.dependencies, postsort_actions):
+            r = list(topological.sort(
+                                    self.dependencies, 
+                                    postsort_actions))
+            print "-----------"
+            print self.dependencies
+            print r
+            for rec in topological.sort(
+                                    self.dependencies, 
+                                    postsort_actions):
                 rec.execute(self)
             
 
@@ -254,7 +265,9 @@ class PreSortRec(object):
         if key in uow.presort_actions:
             return uow.presort_actions[key]
         else:
-            uow.presort_actions[key] = ret = object.__new__(cls)
+            uow.presort_actions[key] = \
+                                    ret = \
+                                    object.__new__(cls)
             return ret
 
 class PostSortRec(object):
@@ -265,7 +278,9 @@ class PostSortRec(object):
         if key in uow.postsort_actions:
             return uow.postsort_actions[key]
         else:
-            uow.postsort_actions[key] = ret = object.__new__(cls)
+            uow.postsort_actions[key] = \
+                                    ret = \
+                                    object.__new__(cls)
             return ret
     
     def execute_aggregate(self, uow, recs):
@@ -351,7 +366,7 @@ class SaveUpdateAll(PostSortRec):
         )
     
     def per_state_flush_actions(self, uow):
-        for rec in self.mapper.per_state_flush_actions(
+        for rec in self.mapper._per_state_flush_actions(
                             uow, 
                             uow.states_for_mapper_hierarchy(self.mapper, False, False), 
                             False):
@@ -369,7 +384,7 @@ class DeleteAll(PostSortRec):
         )
 
     def per_state_flush_actions(self, uow):
-        for rec in self.mapper.per_state_flush_actions(
+        for rec in self.mapper._per_state_flush_actions(
                             uow, 
                             uow.states_for_mapper_hierarchy(self.mapper, True, False), 
                             True):
@@ -396,26 +411,27 @@ class ProcessState(PostSortRec):
         )
         
 class SaveUpdateState(PostSortRec):
-    def __init__(self, uow, state):
+    def __init__(self, uow, state, mapper):
         self.state = state
-
+        self.mapper = mapper
+        
     def execute(self, uow):
-        mapper = self.state.manager.mapper.base_mapper
-        mapper._save_obj(
+        self.mapper._save_obj(
             [self.state],
             uow
         )
 
     def execute_aggregate(self, uow, recs):
         cls_ = self.__class__
-        # TODO: have 'mapper' be present on SaveUpdateState already
-        mapper = self.state.manager.mapper.base_mapper
-        
+        mapper = self.mapper
         our_recs = [r for r in recs 
                         if r.__class__ is cls_ and 
-                        r.state.manager.mapper.base_mapper is mapper]
+                        r.mapper is mapper]
         recs.difference_update(our_recs)
-        mapper._save_obj([self.state] + [r.state for r in our_recs], uow)
+        mapper._save_obj(
+                        [self.state] + 
+                        [r.state for r in our_recs], 
+                        uow)
 
     def __repr__(self):
         return "%s(%s)" % (
@@ -424,17 +440,29 @@ class SaveUpdateState(PostSortRec):
         )
 
 class DeleteState(PostSortRec):
-    def __init__(self, uow, state):
+    def __init__(self, uow, state, mapper):
         self.state = state
-
+        self.mapper = mapper
+        
     def execute(self, uow):
-        mapper = self.state.manager.mapper.base_mapper
         if uow.states[self.state][0]:
-            mapper._delete_obj(
+            self.mapper._delete_obj(
                 [self.state],
                 uow
             )
 
+    def execute_aggregate(self, uow, recs):
+        cls_ = self.__class__
+        mapper = self.mapper
+        our_recs = [r for r in recs 
+                        if r.__class__ is cls_ and 
+                        r.mapper is mapper]
+        recs.difference_update(our_recs)
+        states = [self.state] + [r.state for r in our_recs]
+        mapper._delete_obj(
+                        [s for s in states if uow.states[s][0]], 
+                        uow)
+
     def __repr__(self):
         return "%s(%s)" % (
             self.__class__.__name__,
index 9cc12cffe568ce131001f4ee17b6d8e835b31174..ba62ce07d182c748f42081a2d40f8c820010a0a7 100644 (file)
@@ -2198,7 +2198,12 @@ class RowSwitchTest(_base.MappedTest):
             T7(data='third t7', id=3),
             T7(data='fourth t7', id=4),
             ])
+
         sess.delete(o5)
+        assert o5 in sess.deleted
+        assert o5.t7s[0] in sess.deleted
+        assert o5.t7s[1] in sess.deleted
+        
         sess.add(o6)
         sess.flush()
 
index e532b14fe6bc63c655873960336cd5690830b615..42d9fd90f3a5dbd1bf53fc114bc1ccb89a98fce5 100644 (file)
@@ -244,12 +244,8 @@ class SingleCycleTest(UOWTest):
         self.assert_sql_execution(
                 testing.db,
                 sess.flush,
-                AllOf(
-                    CompiledSQL("DELETE FROM nodes WHERE nodes.id = :id", 
-                            lambda ctx:{'id':n3.id}),
-                    CompiledSQL("DELETE FROM nodes WHERE nodes.id = :id", 
-                            lambda ctx: {'id':n2.id}),
-                ),
+                CompiledSQL("DELETE FROM nodes WHERE nodes.id = :id", 
+                        lambda ctx:[{'id':n2.id}, {'id':n3.id}]),
                 CompiledSQL("DELETE FROM nodes WHERE nodes.id = :id", 
                         lambda ctx: {'id':n1.id})
         )
@@ -329,12 +325,8 @@ class SingleCycleTest(UOWTest):
         self.assert_sql_execution(
                 testing.db,
                 sess.flush,
-                AllOf(
-                    CompiledSQL("DELETE FROM nodes WHERE nodes.id = :id", 
-                            lambda ctx:{'id':n3.id}),
-                    CompiledSQL("DELETE FROM nodes WHERE nodes.id = :id", 
-                            lambda ctx: {'id':n2.id}),
-                ),
+                CompiledSQL("DELETE FROM nodes WHERE nodes.id = :id", 
+                        lambda ctx:[{'id':n2.id},{'id':n3.id}]),
                 CompiledSQL("DELETE FROM nodes WHERE nodes.id = :id", 
                         lambda ctx: {'id':n1.id})
         )