From: Mike Bayer Date: Sun, 21 Jan 2007 19:47:25 +0000 (+0000) Subject: added recursion check to merge X-Git-Tag: rel_0_3_4~13 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=5f33323bc639c612e79bf2f91d4e2e7c28cfbaa8;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git added recursion check to merge --- diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index db43a8e27a..4fd9a3e9b5 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -38,7 +38,7 @@ class SynonymProperty(MapperProperty): return s return getattr(obj, self.name) setattr(self.parent.class_, self.key, SynonymProp()) - def merge(self, session, source, dest): + def merge(self, session, source, dest, _recursive): pass class ColumnProperty(StrategizedProperty): @@ -61,7 +61,7 @@ class ColumnProperty(StrategizedProperty): setattr(object, self.key, value) def get_history(self, obj, passive=False): return sessionlib.attribute_manager.get_history(obj, self.key, passive=passive) - def merge(self, session, source, dest): + def merge(self, session, source, dest, _recursive): setattr(dest, self.key, getattr(source, self.key, None)) def compare(self, value): return self.columns[0] == value @@ -127,20 +127,26 @@ class PropertyLoader(StrategizedProperty): def __str__(self): return self.__class__.__name__ + " " + str(self.parent) + "->" + self.key + "->" + str(self.mapper) - def merge(self, session, source, dest): - if not "merge" in self.cascade: + def merge(self, session, source, dest, _recursive): + if not "merge" in self.cascade or source in _recursive: return - childlist = sessionlib.attribute_manager.get_history(source, self.key, passive=True) - if childlist is None: - return - if self.uselist: - # sets a blank list according to the correct list class - dest_list = getattr(self.parent.class_, self.key).initialize(dest) - for current in list(childlist): - dest_list.append(session.merge(current)) - else: - setattr(dest, self.key, session.merge(current)) - + _recursive.add(source) + try: + childlist = sessionlib.attribute_manager.get_history(source, self.key, passive=True) + if childlist is None: + return + if self.uselist: + # sets a blank list according to the correct list class + dest_list = getattr(self.parent.class_, self.key).initialize(dest) + for current in list(childlist): + dest_list.append(session.merge(current, _recursive=_recursive)) + else: + current = list(childlist)[0] + if current is not None: + setattr(dest, self.key, session.merge(current, _recursive=_recursive)) + finally: + _recursive.remove(source) + def cascade_iterator(self, type, object, recursive, halt_on=None): if not type in self.cascade: return diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 8292206885..f2a7181774 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -323,7 +323,7 @@ class Session(object): for c in [object] + list(_object_mapper(object).cascade_iterator('delete', object)): self.uow.register_deleted(c) - def merge(self, object, entity_name=None): + def merge(self, object, entity_name=None, _recursive=None): """copy the state of the given object onto the persistent object with the same identifier. If there is no persistent instance currently associated with the session, it will be loaded. @@ -331,6 +331,8 @@ class Session(object): a newly persistent instance. The given instance does not become associated with the session. This operation cascades to associated instances if the association is mapped with cascade="merge". """ + if _recursive is None: + _recursive = util.Set() mapper = _object_mapper(object) key = getattr(object, '_instance_key', None) if key is None: @@ -341,7 +343,7 @@ class Session(object): else: merged = self.get(mapper.class_, key[1]) for prop in mapper.props.values(): - prop.merge(self, object, merged) + prop.merge(self, object, merged, _recursive) if key is None: self.save(merged) return merged diff --git a/test/orm/merge.py b/test/orm/merge.py index 7a62b147c8..cb36cc3b57 100644 --- a/test/orm/merge.py +++ b/test/orm/merge.py @@ -60,7 +60,7 @@ class MergeTest(AssertMixin): def test_saved_cascade(self): """test merge of a persistent entity with two child persistent entities.""" mapper(User, users, properties={ - 'addresses':relation(mapper(Address, addresses)) + 'addresses':relation(mapper(Address, addresses), backref='user') }) sess = create_session() @@ -108,7 +108,7 @@ class MergeTest(AssertMixin): mapper(User, users, properties={ 'addresses':relation(mapper(Address, addresses)), - 'orders':relation(Order) + 'orders':relation(Order, backref='customer') }) sess = create_session() @@ -132,6 +132,12 @@ class MergeTest(AssertMixin): u.orders[0].items[1].item_name = 'item 2 modified' sess2.merge(u) assert u2.orders[0].items[1].item_name == 'item 2 modified' + + sess2 = create_session() + o2 = sess2.query(Order).get(o.order_id) + o.customer.user_name = 'also fred' + sess2.merge(o) + assert o2.customer.user_name == 'also fred' if __name__ == "__main__":