From: Mike Bayer Date: Wed, 18 Apr 2007 02:16:57 +0000 (+0000) Subject: - making progress with session.merge() as well as combining its X-Git-Tag: rel_0_3_7~61 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=518d91134ce27744a8276934c527d899c3080985;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - making progress with session.merge() as well as combining its usage with entity_name [ticket:543] --- diff --git a/CHANGES b/CHANGES index d285509b0b..5b65adfe89 100644 --- a/CHANGES +++ b/CHANGES @@ -71,6 +71,8 @@ methods on them during lazy loads) - fix to many-to-many relationships targeting polymorphic mappers [ticket:533] + - making progress with session.merge() as well as combining its + usage with entity_name [ticket:543] - sqlite: - removed silly behavior where sqlite would reflect UNIQUE indexes as part of the primary key (?!) diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 1b6203e063..51fb8fb2eb 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -62,9 +62,10 @@ class PropertyLoader(StrategizedProperty): of items that correspond to a related database table. """ - def __init__(self, argument, secondary, primaryjoin, secondaryjoin, foreign_keys=None, foreignkey=None, uselist=None, private=False, association=None, order_by=False, attributeext=None, backref=None, is_backref=False, post_update=False, cascade=None, viewonly=False, lazy=True, collection_class=None, passive_deletes=False, remote_side=None, enable_typechecks=True): + def __init__(self, argument, secondary, primaryjoin, secondaryjoin, entity_name=None, foreign_keys=None, foreignkey=None, uselist=None, private=False, association=None, order_by=False, attributeext=None, backref=None, is_backref=False, post_update=False, cascade=None, viewonly=False, lazy=True, collection_class=None, passive_deletes=False, remote_side=None, enable_typechecks=True): self.uselist = uselist self.argument = argument + self.entity_name = entity_name self.secondary = secondary self.primaryjoin = primaryjoin self.secondaryjoin = secondaryjoin @@ -120,24 +121,24 @@ class PropertyLoader(StrategizedProperty): return str(self.parent.class_.__name__) + "." + self.key + " (" + str(self.mapper.class_.__name__) + ")" def merge(self, session, source, dest, _recursive): - if not "merge" in self.cascade or source in _recursive: + if not "merge" in self.cascade or self.mapper in _recursive: return - _recursive.add(source) - try: - childlist = sessionlib.attribute_manager.get_history(source, self.key, passive=True) - if childlist is None: - return - if self.uselist: - # sets a blank list according to the correct list class - dest_list = getattr(self.parent.class_, self.key).initialize(dest) - for current in list(childlist): - dest_list.append(session.merge(current, _recursive=_recursive)) - else: - current = list(childlist)[0] - if current is not None: - setattr(dest, self.key, session.merge(current, _recursive=_recursive)) - finally: - _recursive.remove(source) + childlist = sessionlib.attribute_manager.get_history(source, self.key, passive=True) + if childlist is None: + return + if self.uselist: + # sets a blank list according to the correct list class + dest_list = getattr(self.parent.class_, self.key).initialize(dest) + for current in list(childlist): + obj = session.merge(current, entity_name=self.mapper.entity_name, _recursive=_recursive) + if obj is not None: + dest_list.append(obj) + else: + current = list(childlist)[0] + if current is not None: + obj = session.merge(current, entity_name=self.mapper.entity_name, _recursive=_recursive) + if obj is not None: + setattr(dest, self.key, obj) def cascade_iterator(self, type, object, recursive, halt_on=None): if not type in self.cascade: @@ -188,7 +189,7 @@ class PropertyLoader(StrategizedProperty): def _determine_targets(self): if isinstance(self.argument, type): - self.mapper = mapper.class_mapper(self.argument, compile=False)._check_compile() + self.mapper = mapper.class_mapper(self.argument, entity_name=self.entity_name, compile=False)._check_compile() elif isinstance(self.argument, mapper.Mapper): self.mapper = self.argument._check_compile() else: @@ -199,7 +200,7 @@ class PropertyLoader(StrategizedProperty): if self.association is not None: if isinstance(self.association, type): - self.association = mapper.class_mapper(self.association, compile=False)._check_compile() + self.association = mapper.class_mapper(self.association, entity_name=self.entity_name, compile=False)._check_compile() self.target = self.mapper.mapped_table self.select_mapper = self.mapper.get_select_mapper() diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 35b35201d8..1880e6062c 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -446,7 +446,7 @@ class Session(object): for c in [object] + list(_object_mapper(object).cascade_iterator('delete', object)): self.uow.register_deleted(c) - def merge(self, object,entity_name=None, _recursive=None): + def merge(self, object, entity_name=None, _recursive=None): """Copy the state of the given `object` onto the persistent object with the same identifier. @@ -462,21 +462,31 @@ class Session(object): if _recursive is None: _recursive = util.Set() - mapper = _object_mapper(object, entity_name=entity_name) - key = getattr(object, '_instance_key', None) - if key is None: - merged = mapper._create_instance(self) + if entity_name is not None: + mapper = _class_mapper(object.__class__, entity_name=entity_name) else: - if key in self.identity_map: - merged = self.identity_map[key] + mapper = _object_mapper(object) + if mapper in _recursive or object in _recursive: + return None + _recursive.add(mapper) + _recursive.add(object) + try: + key = getattr(object, '_instance_key', None) + if key is None: + merged = mapper._create_instance(self) else: - merged = self.get(mapper.class_, key[1]) - for prop in mapper.props.values(): - prop.merge(self, object, merged, _recursive) - if key is None: - self.save(merged) - return merged - + if key in self.identity_map: + merged = self.identity_map[key] + else: + merged = self.get(mapper.class_, key[1]) + for prop in mapper.props.values(): + prop.merge(self, object, merged, _recursive) + if key is None: + self.save(merged, entity_name=mapper.entity_name) + return merged + finally: + _recursive.remove(mapper) + def identity_key(self, *args, **kwargs): """Get an identity key.