From 9df8afc600cd69a87ece009beefa0108bb49b256 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Tue, 6 Apr 2010 01:23:54 -0400 Subject: [PATCH] - cleanup, factoring, had some heisenbugs. more test coverage will be needed overall as missing dependency rules lead to subtle bugs pretty easily --- lib/sqlalchemy/orm/dependency.py | 245 ++++++++++++++++++++++--------- lib/sqlalchemy/orm/mapper.py | 28 ++-- lib/sqlalchemy/orm/unitofwork.py | 92 ++++++++---- test/orm/test_unitofwork.py | 5 + test/orm/test_unitofworkv2.py | 16 +- 5 files changed, 266 insertions(+), 120 deletions(-) diff --git a/lib/sqlalchemy/orm/dependency.py b/lib/sqlalchemy/orm/dependency.py index ecea094fd1..aef297ee6c 100644 --- a/lib/sqlalchemy/orm/dependency.py +++ b/lib/sqlalchemy/orm/dependency.py @@ -74,11 +74,23 @@ class DependencyProcessor(object): after_save = unitofwork.ProcessAll(uow, self, False, True) before_delete = unitofwork.ProcessAll(uow, self, True, True) - parent_saves = unitofwork.SaveUpdateAll(uow, self.parent.primary_mapper().base_mapper) - child_saves = unitofwork.SaveUpdateAll(uow, self.mapper.primary_mapper().base_mapper) - - parent_deletes = unitofwork.DeleteAll(uow, self.parent.primary_mapper().base_mapper) - child_deletes = unitofwork.DeleteAll(uow, self.mapper.primary_mapper().base_mapper) + parent_saves = unitofwork.SaveUpdateAll( + uow, + self.parent.primary_base_mapper + ) + child_saves = unitofwork.SaveUpdateAll( + uow, + self.mapper.primary_base_mapper + ) + + parent_deletes = unitofwork.DeleteAll( + uow, + self.parent.primary_base_mapper + ) + child_deletes = unitofwork.DeleteAll( + uow, + self.mapper.primary_base_mapper + ) self.per_property_dependencies(uow, parent_saves, @@ -109,8 +121,11 @@ class DependencyProcessor(object): after_save.disabled = True # check if the "child" side is part of the cycle - child_saves = unitofwork.SaveUpdateAll(uow, self.mapper.base_mapper) - child_deletes = unitofwork.DeleteAll(uow, self.mapper.base_mapper) + + parent_base_mapper = self.parent.primary_base_mapper + child_base_mapper = self.mapper.primary_base_mapper + child_saves = unitofwork.SaveUpdateAll(uow, child_base_mapper) + child_deletes = unitofwork.DeleteAll(uow, child_base_mapper) if child_saves not in uow.cycles: # based on the current dependencies we use, the saves/ @@ -130,12 +145,16 @@ class DependencyProcessor(object): # check if the "parent" side is part of the cycle if not isdelete: - parent_saves = unitofwork.SaveUpdateAll(uow, self.parent.base_mapper) + parent_saves = unitofwork.SaveUpdateAll( + uow, + self.parent.base_mapper) parent_deletes = before_delete = None if parent_saves in uow.cycles: parent_in_cycles = True else: - parent_deletes = unitofwork.DeleteAll(uow, self.parent.base_mapper) + parent_deletes = unitofwork.DeleteAll( + uow, + self.parent.base_mapper) parent_saves = after_save = None if parent_deletes in uow.cycles: parent_in_cycles = True @@ -148,17 +167,26 @@ class DependencyProcessor(object): if isdelete: before_delete = unitofwork.ProcessState(uow, self, True, state) if parent_in_cycles: - parent_deletes = unitofwork.DeleteState(uow, state) + parent_deletes = unitofwork.DeleteState( + uow, + state, + parent_base_mapper) else: after_save = unitofwork.ProcessState(uow, self, False, state) if parent_in_cycles: - parent_saves = unitofwork.SaveUpdateState(uow, state) + parent_saves = unitofwork.SaveUpdateState( + uow, + state, + parent_base_mapper) if child_in_cycles: # locate each child state associated with the parent action, # create dependencies for each. child_actions = [] - sum_ = uow.get_attribute_history(state, self.key, passive=True).sum() + sum_ = uow.get_attribute_history( + state, + self.key, + passive=True).sum() if not sum_: continue for child_state in sum_: @@ -169,9 +197,17 @@ class DependencyProcessor(object): else: (deleted, listonly) = uow.states[child_state] if deleted: - child_action = (unitofwork.DeleteState(uow, child_state), True) + child_action = ( + unitofwork.DeleteState( + uow, child_state, + child_base_mapper), + True) else: - child_action = (unitofwork.SaveUpdateState(uow, child_state), False) + child_action = ( + unitofwork.SaveUpdateState( + uow, child_state, + child_base_mapper), + False) child_actions.append(child_action) # establish dependencies between our possibly per-state @@ -204,25 +240,28 @@ class DependencyProcessor(object): not self.mapper._canload(state, allow_subtypes=not self.enable_typechecks): if self.mapper._canload(state, allow_subtypes=True): raise exc.FlushError( - "Attempting to flush an item of type %s on collection '%s', " - "which is not the expected type %s. Configure mapper '%s' to " - "load this subtype polymorphically, or set " - "enable_typechecks=False to allow subtypes. " - "Mismatched typeloading may cause bi-directional relationships " - "(backrefs) to not function properly." % - (state.class_, self.prop, self.mapper.class_, self.mapper)) + "Attempting to flush an item of type %s on collection '%s', " + "which is not the expected type %s. Configure mapper '%s' to " + "load this subtype polymorphically, or set " + "enable_typechecks=False to allow subtypes. " + "Mismatched typeloading may cause bi-directional relationships " + "(backrefs) to not function properly." % + (state.class_, self.prop, self.mapper.class_, self.mapper)) else: raise exc.FlushError( - "Attempting to flush an item of type %s on collection '%s', " - "whose mapper does not inherit from that of %s." % - (state.class_, self.prop, self.mapper.class_)) + "Attempting to flush an item of type %s on collection '%s', " + "whose mapper does not inherit from that of %s." % + (state.class_, self.prop, self.mapper.class_)) - def _synchronize(self, state, - child, associationrow, - clearkeys, uowcommit): + def _synchronize(self, state, child, associationrow, + clearkeys, uowcommit): raise NotImplementedError() def _check_reverse(self, uow): + """return True if a comparable dependency processor has + already set up on the "reverse" side of a relationship. + + """ for p in self.prop._reverse_property: if not p.viewonly and p._dependency_processor and \ (unitofwork.ProcessAll, @@ -365,11 +404,15 @@ class OneToManyDP(DependencyProcessor): if self._pks_changed(uowcommit, state): if not history: history = uowcommit.get_attribute_history( - state, self.key, passive=self.passive_updates) + state, self.key, + passive=self.passive_updates) if history: for child in history.unchanged: if child is not None: - uowcommit.register_object(child, False, self.passive_updates) + uowcommit.register_object( + child, + False, + self.passive_updates) def process_deletes(self, uowcommit, states): # head object is being deleted, and we manage its list of @@ -385,14 +428,27 @@ class OneToManyDP(DependencyProcessor): passive=self.passive_deletes) if history: for child in history.deleted: - if child is not None and self.hasparent(child) is False: - self._synchronize(state, child, None, True, uowcommit) - self._conditional_post_update(child, uowcommit, [state]) + if child is not None and \ + self.hasparent(child) is False: + self._synchronize( + state, + child, + None, True, uowcommit) + self._conditional_post_update( + child, + uowcommit, + [state]) if self.post_update or not self.cascade.delete: for child in history.unchanged: if child is not None: - self._synchronize(state, child, None, True, uowcommit) - self._conditional_post_update(child, uowcommit, [state]) + self._synchronize( + state, + child, + None, True, uowcommit) + self._conditional_post_update( + child, + uowcommit, + [state]) def process_saves(self, uowcommit, states): for state in states: @@ -401,20 +457,26 @@ class OneToManyDP(DependencyProcessor): for child in history.added: self._synchronize(state, child, None, False, uowcommit) if child is not None: - self._conditional_post_update(child, uowcommit, [state]) + self._conditional_post_update( + child, + uowcommit, + [state]) for child in history.deleted: - if not self.cascade.delete_orphan and not self.hasparent(child): + if not self.cascade.delete_orphan and \ + not self.hasparent(child): self._synchronize(state, child, None, True, uowcommit) if self._pks_changed(uowcommit, state): for child in history.unchanged: self._synchronize(state, child, None, False, uowcommit) - def _synchronize(self, state, child, associationrow, clearkeys, uowcommit): + def _synchronize(self, state, child, associationrow, + clearkeys, uowcommit): source = state dest = child - if dest is None or (not self.post_update and uowcommit.is_deleted(dest)): + if dest is None or \ + (not self.post_update and uowcommit.is_deleted(dest)): return self._verify_canload(child) if clearkeys: @@ -517,7 +579,8 @@ class ManyToOneDP(DependencyProcessor): if child is None: continue uowcommit.register_object(child, isdelete=True) - for c, m in self.mapper.cascade_iterator('delete', child): + for c, m in self.mapper.cascade_iterator( + 'delete', child): uowcommit.register_object( attributes.instance_state(c), isdelete=True) @@ -533,7 +596,8 @@ class ManyToOneDP(DependencyProcessor): for child in history.deleted: if self.hasparent(child) is False: uowcommit.register_object(child, isdelete=True) - for c, m in self.mapper.cascade_iterator('delete', child): + for c, m in self.mapper.cascade_iterator( + 'delete', child): uowcommit.register_object( attributes.instance_state(c), isdelete=True) @@ -552,7 +616,10 @@ class ManyToOneDP(DependencyProcessor): self.key, passive=self.passive_deletes) if history: - self._conditional_post_update(state, uowcommit, history.sum()) + self._conditional_post_update( + state, + uowcommit, + history.sum()) def process_saves(self, uowcommit, states): for state in states: @@ -561,7 +628,9 @@ class ManyToOneDP(DependencyProcessor): for child in history.added: self._synchronize(state, child, None, False, uowcommit) - self._conditional_post_update(state, uowcommit, history.sum()) + self._conditional_post_update( + state, + uowcommit, history.sum()) def _synchronize(self, state, child, associationrow, clearkeys, uowcommit): if state is None or (not self.post_update and uowcommit.is_deleted(state)): @@ -603,7 +672,9 @@ class DetectKeySwitch(DependencyProcessor): # so that we avoid ManyToOneDP's registering the object without # the listonly flag in its own preprocess stage (results in UPDATE) # statements being emitted - parent_saves = unitofwork.SaveUpdateAll(uow, self.parent.base_mapper) + parent_saves = unitofwork.SaveUpdateAll( + uow, + self.parent.base_mapper) after_save = unitofwork.ProcessAll(uow, self, False, False) uow.dependencies.update([ (parent_saves, after_save) @@ -631,8 +702,8 @@ class DetectKeySwitch(DependencyProcessor): if switchers: # if primary key values have actually changed somewhere, perform # a linear search through the UOW in search of a parent. - # note that this handler isn't used if the many-to-one relationship - # has a backref. + # note that this handler isn't used if the many-to-one + # relationship has a backref. for state in uowcommit.session.identity_map.all_states(): if not issubclass(state.class_, self.parent.class_): continue @@ -661,7 +732,7 @@ class ManyToManyDP(DependencyProcessor): def per_property_flush_actions(self, uow): if self._check_reverse(uow): - return + unitofwork.GetDependentObjects(uow, self, False, True) else: DependencyProcessor.per_property_flush_actions(self, uow) @@ -682,9 +753,18 @@ class ManyToManyDP(DependencyProcessor): uow.dependencies.update([ (parent_saves, after_save), (child_saves, after_save), + (after_save, child_deletes), + + # a rowswitch on the parent from deleted to saved + # can make this one occur, as the "save" may remove + # an element from the + # "deleted" list before we have a chance to + # process its child rows + (before_delete, parent_saves), (before_delete, parent_deletes), (before_delete, child_deletes), + (before_delete, child_saves), ]) def per_state_dependencies(self, uow, @@ -709,13 +789,21 @@ class ManyToManyDP(DependencyProcessor): pass def presort_saves(self, uowcommit, states): + if not self.cascade.delete_orphan: + return + for state in states: - history = uowcommit.get_attribute_history(state, self.key, passive=True) + history = uowcommit.get_attribute_history( + state, + self.key, + passive=True) if history: for child in history.deleted: - if self.cascade.delete_orphan and self.hasparent(child) is False: + if self.hasparent(child) is False: uowcommit.register_object(child, isdelete=True) - for c, m in self.mapper.cascade_iterator('delete', child): + for c, m in self.mapper.cascade_iterator( + 'delete', + child): uowcommit.register_object( attributes.instance_state(c), isdelete=True) @@ -733,7 +821,11 @@ class ManyToManyDP(DependencyProcessor): if child is None: continue associationrow = {} - self._synchronize(state, child, associationrow, False, uowcommit) + self._synchronize( + state, + child, + associationrow, + False, uowcommit) secondary_delete.append(associationrow) self._run_crud(uowcommit, secondary_insert, @@ -751,29 +843,37 @@ class ManyToManyDP(DependencyProcessor): if child is None: continue associationrow = {} - self._synchronize(state, child, associationrow, False, uowcommit) + self._synchronize(state, + child, + associationrow, + False, uowcommit) secondary_insert.append(associationrow) for child in history.deleted: if child is None: continue associationrow = {} - self._synchronize(state, child, associationrow, False, uowcommit) + self._synchronize(state, + child, + associationrow, + False, uowcommit) secondary_delete.append(associationrow) - if not self.passive_updates and self._pks_changed(uowcommit, state): + if not self.passive_updates and \ + self._pks_changed(uowcommit, state): if not history: - history = uowcommit.get_attribute_history(state, self.key, passive=False) + history = uowcommit.get_attribute_history( + state, + self.key, + passive=False) for child in history.unchanged: associationrow = {} - sync.update( - state, + sync.update(state, self.parent, associationrow, "old_", self.prop.synchronize_pairs) - sync.update( - child, + sync.update(child, self.mapper, associationrow, "old_", @@ -784,35 +884,48 @@ class ManyToManyDP(DependencyProcessor): self._run_crud(uowcommit, secondary_insert, secondary_update, secondary_delete) - def _run_crud(self, uowcommit, secondary_insert, secondary_update, secondary_delete): + def _run_crud(self, uowcommit, secondary_insert, + secondary_update, secondary_delete): connection = uowcommit.transaction.connection(self.mapper) if secondary_delete: associationrow = secondary_delete[0] statement = self.secondary.delete(sql.and_(*[ c == sql.bindparam(c.key, type_=c.type) - for c in self.secondary.c if c.key in associationrow + for c in self.secondary.c + if c.key in associationrow ])) result = connection.execute(statement, secondary_delete) + if result.supports_sane_multi_rowcount() and \ result.rowcount != len(secondary_delete): raise exc.ConcurrentModificationError( "Deleted rowcount %d does not match number of " "secondary table rows deleted from table '%s': %d" % - (result.rowcount, self.secondary.description, len(secondary_delete))) + ( + result.rowcount, + self.secondary.description, + len(secondary_delete)) + ) if secondary_update: associationrow = secondary_update[0] statement = self.secondary.update(sql.and_(*[ - c == sql.bindparam("old_" + c.key, type_=c.type) - for c in self.secondary.c if c.key in associationrow - ])) + c == sql.bindparam("old_" + c.key, type_=c.type) + for c in self.secondary.c + if c.key in associationrow + ])) result = connection.execute(statement, secondary_update) - if result.supports_sane_multi_rowcount() and result.rowcount != len(secondary_update): + if result.supports_sane_multi_rowcount() and \ + result.rowcount != len(secondary_update): raise exc.ConcurrentModificationError( "Updated rowcount %d does not match number of " "secondary table rows updated from table '%s': %d" % - (result.rowcount, self.secondary.description, len(secondary_update))) + ( + result.rowcount, + self.secondary.description, + len(secondary_update)) + ) if secondary_insert: statement = self.secondary.insert() diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 64b02d8412..c6b867f358 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -1069,6 +1069,10 @@ class Mapper(object): return self.class_manager.mapper + @property + def primary_base_mapper(self): + return self.class_manager.mapper.base_mapper + def identity_key_from_row(self, row, adapter=None): """Return an identity-map key for use in storing/retrieving an item from the identity map. @@ -1248,7 +1252,7 @@ class Mapper(object): ret[t] = table_to_mapper[t] return ret - def per_mapper_flush_actions(self, uow): + def _per_mapper_flush_actions(self, uow): saves = unitofwork.SaveUpdateAll(uow, self.base_mapper) deletes = unitofwork.DeleteAll(uow, self.base_mapper) uow.dependencies.add((saves, deletes)) @@ -1277,22 +1281,24 @@ class Mapper(object): for prop in mapper._props.values(): if prop not in props: props.add(prop) - yield prop, [m for m in mappers if m._props.get(prop.key) is prop] + yield prop, [m for m in mappers + if m._props.get(prop.key) is prop] - def per_state_flush_actions(self, uow, states, isdelete): + def _per_state_flush_actions(self, uow, states, isdelete): mappers_to_states = util.defaultdict(set) - save_all = unitofwork.SaveUpdateAll(uow, self.base_mapper) - delete_all = unitofwork.DeleteAll(uow, self.base_mapper) + base_mapper = self.base_mapper + save_all = unitofwork.SaveUpdateAll(uow, base_mapper) + delete_all = unitofwork.DeleteAll(uow, base_mapper) for state in states: # keep saves before deletes - # this ensures 'row switch' operations work if isdelete: - action = unitofwork.DeleteState(uow, state) + action = unitofwork.DeleteState(uow, state, base_mapper) uow.dependencies.add((save_all, action)) else: - action = unitofwork.SaveUpdateState(uow, state) + action = unitofwork.SaveUpdateState(uow, state, base_mapper) uow.dependencies.add((action, delete_all)) mappers_to_states[state.manager.mapper].add(state) @@ -1362,10 +1368,12 @@ class Mapper(object): if 'before_update' in mapper.extension: mapper.extension.before_update(mapper, conn, state.obj()) - # detect if we have a "pending" instance (i.e. has no instance_key attached to it), - # and another instance with the same identity key already exists as persistent. + # detect if we have a "pending" instance (i.e. has + # no instance_key attached to it), and another instance + # with the same identity key already exists as persistent. # convert to an UPDATE if so. - if not has_identity and instance_key in uowtransaction.session.identity_map: + if not has_identity and \ + instance_key in uowtransaction.session.identity_map: instance = uowtransaction.session.identity_map[instance_key] existing = attributes.instance_state(instance) if not uowtransaction.is_deleted(existing): diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py index 0028202c57..088c71b6e5 100644 --- a/lib/sqlalchemy/orm/unitofwork.py +++ b/lib/sqlalchemy/orm/unitofwork.py @@ -60,15 +60,19 @@ class UOWEventHandler(interfaces.AttributeExtension): sess.expunge(item) def set(self, state, newvalue, oldvalue, initiator): - # process "save_update" cascade rules for when an instance is attached to another instance + # process "save_update" cascade rules for when an instance + # is attached to another instance if oldvalue is newvalue: return newvalue sess = _state_session(state) if sess: prop = _state_mapper(state).get_property(self.key) - if newvalue is not None and prop.cascade.save_update and newvalue not in sess: + if newvalue is not None and \ + prop.cascade.save_update and \ + newvalue not in sess: sess.add(newvalue) - if prop.cascade.delete_orphan and oldvalue in sess.new and \ + if prop.cascade.delete_orphan and \ + oldvalue in sess.new and \ prop.mapper._is_orphan(attributes.instance_state(oldvalue)): sess.expunge(oldvalue) return newvalue @@ -99,7 +103,8 @@ class UOWTransaction(object): return bool(self.states) def is_deleted(self, state): - """return true if the given state is marked as deleted within this UOWTransaction.""" + """return true if the given state is marked as deleted + within this UOWTransaction.""" return state in self.states and self.states[state][0] def remove_state_actions(self, state): @@ -130,8 +135,6 @@ class UOWTransaction(object): return history.as_state() def register_object(self, state, isdelete=False, listonly=False): - - # if object is not in the overall session, do nothing if not self.session._contains_state(state): return @@ -139,7 +142,7 @@ class UOWTransaction(object): mapper = _state_mapper(state) if mapper not in self.mappers: - mapper.per_mapper_flush_actions(self) + mapper._per_mapper_flush_actions(self) self.mappers[mapper].add(state) self.states[state] = (isdelete, listonly) @@ -199,11 +202,9 @@ class UOWTransaction(object): # the per-state actions for those per-mapper actions # that were broken up. for edge in list(self.dependencies): - if None in edge: - self.dependencies.remove(edge) - elif cycles.issuperset(edge): - self.dependencies.remove(edge) - elif edge[0].disabled or edge[1].disabled: + if None in edge or\ + cycles.issuperset(edge) or \ + edge[0].disabled or edge[1].disabled: self.dependencies.remove(edge) elif edge[0] in cycles: self.dependencies.remove(edge) @@ -220,14 +221,24 @@ class UOWTransaction(object): ] ).difference(cycles) - # execute actions + # execute if cycles: - for set_ in topological.sort_as_subsets(self.dependencies, postsort_actions): + for set_ in topological.sort_as_subsets( + self.dependencies, + postsort_actions): while set_: n = set_.pop() n.execute_aggregate(self, set_) else: - for rec in topological.sort(self.dependencies, postsort_actions): + r = list(topological.sort( + self.dependencies, + postsort_actions)) + print "-----------" + print self.dependencies + print r + for rec in topological.sort( + self.dependencies, + postsort_actions): rec.execute(self) @@ -254,7 +265,9 @@ class PreSortRec(object): if key in uow.presort_actions: return uow.presort_actions[key] else: - uow.presort_actions[key] = ret = object.__new__(cls) + uow.presort_actions[key] = \ + ret = \ + object.__new__(cls) return ret class PostSortRec(object): @@ -265,7 +278,9 @@ class PostSortRec(object): if key in uow.postsort_actions: return uow.postsort_actions[key] else: - uow.postsort_actions[key] = ret = object.__new__(cls) + uow.postsort_actions[key] = \ + ret = \ + object.__new__(cls) return ret def execute_aggregate(self, uow, recs): @@ -351,7 +366,7 @@ class SaveUpdateAll(PostSortRec): ) def per_state_flush_actions(self, uow): - for rec in self.mapper.per_state_flush_actions( + for rec in self.mapper._per_state_flush_actions( uow, uow.states_for_mapper_hierarchy(self.mapper, False, False), False): @@ -369,7 +384,7 @@ class DeleteAll(PostSortRec): ) def per_state_flush_actions(self, uow): - for rec in self.mapper.per_state_flush_actions( + for rec in self.mapper._per_state_flush_actions( uow, uow.states_for_mapper_hierarchy(self.mapper, True, False), True): @@ -396,26 +411,27 @@ class ProcessState(PostSortRec): ) class SaveUpdateState(PostSortRec): - def __init__(self, uow, state): + def __init__(self, uow, state, mapper): self.state = state - + self.mapper = mapper + def execute(self, uow): - mapper = self.state.manager.mapper.base_mapper - mapper._save_obj( + self.mapper._save_obj( [self.state], uow ) def execute_aggregate(self, uow, recs): cls_ = self.__class__ - # TODO: have 'mapper' be present on SaveUpdateState already - mapper = self.state.manager.mapper.base_mapper - + mapper = self.mapper our_recs = [r for r in recs if r.__class__ is cls_ and - r.state.manager.mapper.base_mapper is mapper] + r.mapper is mapper] recs.difference_update(our_recs) - mapper._save_obj([self.state] + [r.state for r in our_recs], uow) + mapper._save_obj( + [self.state] + + [r.state for r in our_recs], + uow) def __repr__(self): return "%s(%s)" % ( @@ -424,17 +440,29 @@ class SaveUpdateState(PostSortRec): ) class DeleteState(PostSortRec): - def __init__(self, uow, state): + def __init__(self, uow, state, mapper): self.state = state - + self.mapper = mapper + def execute(self, uow): - mapper = self.state.manager.mapper.base_mapper if uow.states[self.state][0]: - mapper._delete_obj( + self.mapper._delete_obj( [self.state], uow ) + def execute_aggregate(self, uow, recs): + cls_ = self.__class__ + mapper = self.mapper + our_recs = [r for r in recs + if r.__class__ is cls_ and + r.mapper is mapper] + recs.difference_update(our_recs) + states = [self.state] + [r.state for r in our_recs] + mapper._delete_obj( + [s for s in states if uow.states[s][0]], + uow) + def __repr__(self): return "%s(%s)" % ( self.__class__.__name__, diff --git a/test/orm/test_unitofwork.py b/test/orm/test_unitofwork.py index 9cc12cffe5..ba62ce07d1 100644 --- a/test/orm/test_unitofwork.py +++ b/test/orm/test_unitofwork.py @@ -2198,7 +2198,12 @@ class RowSwitchTest(_base.MappedTest): T7(data='third t7', id=3), T7(data='fourth t7', id=4), ]) + sess.delete(o5) + assert o5 in sess.deleted + assert o5.t7s[0] in sess.deleted + assert o5.t7s[1] in sess.deleted + sess.add(o6) sess.flush() diff --git a/test/orm/test_unitofworkv2.py b/test/orm/test_unitofworkv2.py index e532b14fe6..42d9fd90f3 100644 --- a/test/orm/test_unitofworkv2.py +++ b/test/orm/test_unitofworkv2.py @@ -244,12 +244,8 @@ class SingleCycleTest(UOWTest): self.assert_sql_execution( testing.db, sess.flush, - AllOf( - CompiledSQL("DELETE FROM nodes WHERE nodes.id = :id", - lambda ctx:{'id':n3.id}), - CompiledSQL("DELETE FROM nodes WHERE nodes.id = :id", - lambda ctx: {'id':n2.id}), - ), + CompiledSQL("DELETE FROM nodes WHERE nodes.id = :id", + lambda ctx:[{'id':n2.id}, {'id':n3.id}]), CompiledSQL("DELETE FROM nodes WHERE nodes.id = :id", lambda ctx: {'id':n1.id}) ) @@ -329,12 +325,8 @@ class SingleCycleTest(UOWTest): self.assert_sql_execution( testing.db, sess.flush, - AllOf( - CompiledSQL("DELETE FROM nodes WHERE nodes.id = :id", - lambda ctx:{'id':n3.id}), - CompiledSQL("DELETE FROM nodes WHERE nodes.id = :id", - lambda ctx: {'id':n2.id}), - ), + CompiledSQL("DELETE FROM nodes WHERE nodes.id = :id", + lambda ctx:[{'id':n2.id},{'id':n3.id}]), CompiledSQL("DELETE FROM nodes WHERE nodes.id = :id", lambda ctx: {'id':n1.id}) ) -- 2.47.3