From f8314ef9ff08af5f104731de402d6e6bd8c043f3 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Mon, 5 Jun 2006 02:31:53 +0000 Subject: [PATCH] improvements/fixes to session cascade iteration, fixes to entity_name propigation --- examples/polymorph/polymorph.py | 1 + lib/sqlalchemy/orm/mapper.py | 36 ++++++++++++++++------- lib/sqlalchemy/orm/properties.py | 49 +++++++++++++++++++++++--------- lib/sqlalchemy/orm/session.py | 49 ++++++++++++++------------------ lib/sqlalchemy/orm/unitofwork.py | 18 ++++++++++-- test/entity.py | 42 +++++++++++++++++++++++++++ test/inheritance3.py | 2 +- test/objectstore.py | 2 ++ 8 files changed, 144 insertions(+), 55 deletions(-) diff --git a/examples/polymorph/polymorph.py b/examples/polymorph/polymorph.py index 76a03b99d5..f2f568726b 100644 --- a/examples/polymorph/polymorph.py +++ b/examples/polymorph/polymorph.py @@ -80,6 +80,7 @@ c.employees.append(Person(name='joesmith', status='HHH')) c.employees.append(Engineer(name='wally', status='CGG', engineer_name='engineer2', primary_language='python')) c.employees.append(Manager(name='jsmith', status='ABA', manager_name='manager2')) session.save(c) + print session.new session.flush() #sys.exit() diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 4e80aceeb6..eba2203843 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -387,25 +387,31 @@ class Mapper(object): and assocites this Mapper with its class via the mapper_registry.""" oldinit = self.class_.__init__ def init(self, *args, **kwargs): - self._entity_name = kwargs.pop('_sa_entity_name', None) # this gets the AttributeManager to do some pre-initialization, # in order to save on KeyErrors later on sessionlib.global_attributes.init_attr(self) + entity_name = kwargs.pop('_sa_entity_name', None) if kwargs.has_key('_sa_session'): session = kwargs.pop('_sa_session') else: # works for whatever mapper the class is associated with - mapper = mapper_registry.get(ClassKey(self.__class__, self._entity_name)) + mapper = mapper_registry.get(ClassKey(self.__class__, entity_name)) if mapper is not None: session = mapper.extension.get_session() if session is EXT_PASS: session = None else: session = None + # if a session was found, either via _sa_session or via mapper extension, + # save() this instance to the session, and give it an associated entity_name. + # otherwise, this instance will not have a session or mapper association until it is + # save()d to some session. if session is not None: + self._entity_name = entity_name session._register_new(self) + if oldinit is not None: oldinit(self, *args, **kwargs) # override oldinit, insuring that its not already a Mapper-decorated init method @@ -748,16 +754,19 @@ class Mapper(object): for prop in self.props.values(): prop.register_dependencies(uowcommit, *args, **kwargs) - def cascade_iterator(self, type, object, recursive=None): + def cascade_iterator(self, type, object, callable_=None, recursive=None): if recursive is None: recursive=sets.Set() - if object not in recursive: - recursive.add(object) - yield object for prop in self.props.values(): for c in prop.cascade_iterator(type, object, recursive): yield c + def cascade_callable(self, type, object, callable_, recursive=None): + if recursive is None: + recursive=sets.Set() + for prop in self.props.values(): + prop.cascade_callable(type, object, callable_, recursive) + def _row_identity_key(self, row): return sessionlib.get_row_key(row, self.class_, self.pks_by_table[self.mapped_table], self.entity_name) @@ -929,6 +938,8 @@ class MapperProperty(object): raise NotImplementedError() def cascade_iterator(self, type, object, recursive=None): return [] + def cascade_callable(self, type, object, callable_, recursive=None): + return [] def copy(self): raise NotImplementedError() def get_criterion(self, query, key, value): @@ -1157,13 +1168,16 @@ def hash_key(obj): return obj.hash_key() else: return repr(obj) + +def has_mapper(object): + """returns True if the given object has a mapper association""" + return hasattr(object, '_entity_name') -def object_mapper(object, raiseerror=True, entity_name=None): - """given an object, returns the primary Mapper associated with the object - or the object's class.""" +def object_mapper(object, raiseerror=True): + """given an object, returns the primary Mapper associated with the object instance""" try: - return mapper_registry[ClassKey(object.__class__, getattr(object, '_entity_name', entity_name))] - except KeyError: + return mapper_registry[ClassKey(object.__class__, getattr(object, '_entity_name'))] + except (KeyError, AttributeError): if raiseerror: raise exceptions.InvalidRequestError("Class '%s' entity name '%s' has no mapper associated with it" % (object.__class__.__name__, getattr(object, '_entity_name', None))) else: diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 8c38b897fe..8cdaf39409 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -154,14 +154,25 @@ class PropertyLoader(mapper.MapperProperty): if not type in self.cascade: return childlist = sessionlib.global_attributes.get_history(object, self.key, passive=True) - + + mapper = self.mapper.primary_mapper() for c in childlist.added_items() + childlist.deleted_items() + childlist.unchanged_items(): - if c is not None: - if c not in recursive: - recursive.add(c) - yield c - for c2 in self.mapper.primary_mapper().cascade_iterator(type, c, recursive): - yield c2 + if c is not None and c not in recursive: + recursive.add(c) + yield c + for c2 in mapper.cascade_iterator(type, c, recursive): + yield c2 + + def cascade_callable(self, type, object, callable_, recursive): + if not type in self.cascade: + return + childlist = sessionlib.global_attributes.get_history(object, self.key, passive=True) + mapper = self.mapper.primary_mapper() + for c in childlist.added_items() + childlist.deleted_items() + childlist.unchanged_items(): + if c is not None and c not in recursive: + recursive.add(c) + callable_(c, mapper.entity_name) + mapper.cascade_callable(type, c, callable_, recursive) def copy(self): x = self.__class__.__new__(self.__class__) @@ -237,10 +248,16 @@ class PropertyLoader(mapper.MapperProperty): raise exceptions.ArgumentError("Attempting to assign a new relation '%s' to a non-primary mapper on class '%s'. New relations can only be added to the primary mapper, i.e. the very first mapper created for class '%s' " % (key, parent.class_.__name__, parent.class_.__name__)) self.do_init_subclass(key, parent) + + def _register_attribute(self, class_, callable_=None): + sessionlib.global_attributes.register_attribute(class_, self.key, uselist = self.uselist, extension=self.attributeext, cascade=self.cascade, trackparent=True, callable_=callable_) + + def _create_history(self, instance, callable_=None): + return sessionlib.global_attributes.create_history(instance, self.key, self.uselist, cascade=self.cascade, trackparent=True, callable_=callable_) def _set_class_attribute(self, class_, key): """sets attribute behavior on our target class.""" - sessionlib.global_attributes.register_attribute(class_, key, uselist = self.uselist, extension=self.attributeext, cascade=self.cascade, trackparent=True) + self._register_attribute(class_) def _get_direction(self): """determines our 'direction', i.e. do we represent one to many, many to many, etc.""" @@ -295,7 +312,7 @@ class PropertyLoader(mapper.MapperProperty): if self.is_primary(): return #print "PLAIN PROPLOADER EXEC NON-PRIAMRY", repr(id(self)), repr(self.mapper.class_), self.key - sessionlib.global_attributes.create_history(instance, self.key, self.uselist, cascade=self.cascade, trackparent=True) + self._create_history(instance) def register_dependencies(self, uowcommit): self._dependency_processor.register_dependencies(uowcommit) @@ -327,11 +344,17 @@ class LazyLoader(PropertyLoader): def _set_class_attribute(self, class_, key): # establish a class-level lazy loader on our class #print "SETCLASSATTR LAZY", repr(class_), key - sessionlib.global_attributes.register_attribute(class_, key, uselist = self.uselist, callable_=lambda i: self.setup_loader(i), extension=self.attributeext, cascade=self.cascade, trackparent=True) + self._register_attribute(class_, callable_=lambda i: self.setup_loader(i)) def setup_loader(self, instance): + # make sure our parent mapper is the one thats assigned to this instance, else call that one if not self.localparent.is_assigned(instance): - return mapper.object_mapper(instance).props[self.key].setup_loader(instance) + # if no mapper association with this instance (i.e. not in a session, not loaded by a mapper), + # then we cant set up a lazy loader + if not mapper.has_mapper(instance): + return None + else: + return mapper.object_mapper(instance).props[self.key].setup_loader(instance) def lazyload(): params = {} allparams = True @@ -379,7 +402,7 @@ class LazyLoader(PropertyLoader): #print "EXEC NON-PRIAMRY", repr(self.mapper.class_), self.key # we are not the primary manager for this attribute on this class - set up a per-instance lazyloader, # which will override the class-level behavior - sessionlib.global_attributes.create_history(instance, self.key, self.uselist, callable_=self.setup_loader(instance), cascade=self.cascade, trackparent=True) + self._create_history(instance, callable_=self.setup_loader(instance)) else: #print "EXEC PRIMARY", repr(self.mapper.class_), self.key # we are the primary manager for this attribute on this class - reset its per-instance attribute state, @@ -548,7 +571,7 @@ class EagerLoader(LazyLoader): if isnew: # new row loaded from the database. initialize a blank container on the instance. # this will override any per-class lazyloading type of stuff. - h = sessionlib.global_attributes.create_history(instance, self.key, self.uselist, cascade=self.cascade, trackparent=True) + h = self._create_history(instance) if not self.uselist: if isnew: diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 56d699cc66..bd17501653 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -275,12 +275,8 @@ class Session(object): The 'entity_name' keyword argument will further qualify the specific Mapper used to handle this instance. """ - for c in object_mapper(object, entity_name=entity_name).cascade_iterator('save-update', object): - if c is object: - self._save_impl(c, entity_name=entity_name) - else: - # TODO: this is running the cascade rules twice - self.save_or_update(c, entity_name=entity_name) + self._save_impl(object, entity_name=entity_name) + object_mapper(object).cascade_callable('save-update', object, lambda c, e:self._save_or_update_impl(c, e)) def update(self, object, entity_name=None): """Brings the given detached (saved) instance into this Session. @@ -288,30 +284,31 @@ class Session(object): Session), an exception is thrown. This operation cascades the "save_or_update" method to associated instances if the relation is mapped with cascade="save-update".""" - for c in object_mapper(object, entity_name=entity_name).cascade_iterator('save-update', object): - if c is object: - self._update_impl(c, entity_name=entity_name) - else: - self.save_or_update(c, entity_name=entity_name) + self._update_impl(object, entity_name=entity_name) + object_mapper(object).cascade_callable('save-update', object, lambda c, e:self._save_or_update_impl(c, e)) def save_or_update(self, object, entity_name=None): - for c in object_mapper(object, entity_name=entity_name).cascade_iterator('save-update', object): - key = getattr(object, '_instance_key', None) - if key is None: - self._save_impl(c, entity_name=entity_name) - else: - self._update_impl(c, entity_name=entity_name) - + self._save_or_update_impl(object, entity_name=entity_name) + object_mapper(object).cascade_callable('save-update', object, lambda c, e:self._save_or_update_impl(c, e)) + + def _save_or_update_impl(self, object, entity_name=None): + key = getattr(object, '_instance_key', None) + if key is None: + self._save_impl(object, entity_name=entity_name) + else: + self._update_impl(object, entity_name=entity_name) + def delete(self, object, entity_name=None): - for c in object_mapper(object, entity_name=entity_name).cascade_iterator('delete', object): + #self.uow.register_deleted(object) + for c in [object] + list(object_mapper(object).cascade_iterator('delete', object)): self.uow.register_deleted(c) def merge(self, object, entity_name=None): instance = None - for obj in object_mapper(object, entity_name=entity_name).cascade_iterator('merge', object): + for obj in [object] + list(object_mapper(object).cascade_iterator('merge', object)): key = getattr(obj, '_instance_key', None) if key is None: - mapper = object_mapper(object, entity_name=entity_name) + mapper = object_mapper(object) ident = mapper.identity(object) for k in ident: if k is None: @@ -333,10 +330,8 @@ class Session(object): if not self.uow.has_key(object._instance_key): raise exceptions.InvalidRequestError("Instance '%s' is already persistent in a different Session" % repr(object)) else: - entity_name = kwargs.get('entity_name', None) - if entity_name is not None: - m = class_mapper(object.__class__, entity_name=entity_name) - m._assign_entity_name(object) + m = class_mapper(object.__class__, entity_name=kwargs.get('entity_name', None)) + m._assign_entity_name(object) self._register_new(object) def _update_impl(self, object, **kwargs): @@ -422,8 +417,8 @@ def get_id_key(ident, class_, entity_name=None): def get_row_key(row, class_, primary_key, entity_name=None): return Session.get_row_key(row, class_, primary_key, entity_name) -def object_mapper(obj, **kwargs): - return sqlalchemy.orm.object_mapper(obj, **kwargs) +def object_mapper(obj): + return sqlalchemy.orm.object_mapper(obj) def class_mapper(class_, **kwargs): return sqlalchemy.orm.class_mapper(class_, **kwargs) diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py index c33f344fbe..9e9778cad9 100644 --- a/lib/sqlalchemy/orm/unitofwork.py +++ b/lib/sqlalchemy/orm/unitofwork.py @@ -49,7 +49,13 @@ class UOWListElement(attributes.ListAttribute): if self.cascade is not None: if not isdelete: if self.cascade.save_update: - sess.save_or_update(item) + # cascade the save_update operation onto the child object, + # relative to the mapper handling the parent object + # TODO: easier way to do this ? + mapper = object_mapper(obj) + prop = mapper.props[self.key] + ename = prop.mapper.entity_name + sess.save_or_update(item, entity_name=ename) def append(self, item, _mapper_nohistory = False): if _mapper_nohistory: self.append_nohistory(item) @@ -67,13 +73,19 @@ class UOWScalarElement(attributes.ScalarAttribute): sess._register_changed(obj) if newvalue is not None and self.cascade is not None: if self.cascade.save_update: - sess.save_or_update(newvalue) + # cascade the save_update operation onto the child object, + # relative to the mapper handling the parent object + # TODO: easier way to do this ? + mapper = object_mapper(obj) + prop = mapper.props[self.key] + ename = prop.mapper.entity_name + sess.save_or_update(newvalue, entity_name=ename) class UOWAttributeManager(attributes.AttributeManager): """overrides AttributeManager to provide unit-of-work "dirty" hooks when scalar attribues are modified, plus factory methods for UOWProperrty/UOWListElement.""" def __init__(self): attributes.AttributeManager.__init__(self) - + def create_prop(self, class_, key, uselist, callable_, **kwargs): return UOWProperty(class_, self, key, uselist, callable_, **kwargs) diff --git a/test/entity.py b/test/entity.py index 22e74f1716..f425c5c8ee 100644 --- a/test/entity.py +++ b/test/entity.py @@ -82,6 +82,48 @@ class EntityTest(AssertMixin): assert u1list[0] is not u2list[0] assert len(u1list[0].addresses) == len(u2list[0].addresses) == 1 + def testcascade(self): + """same as testbasic but relies on session cascading""" + class User(object):pass + class Address(object):pass + + a1mapper = mapper(Address, address1, entity_name='address1') + a2mapper = mapper(Address, address2, entity_name='address2') + u1mapper = mapper(User, user1, entity_name='user1', properties ={ + 'addresses':relation(a1mapper) + }) + u2mapper =mapper(User, user2, entity_name='user2', properties={ + 'addresses':relation(a2mapper) + }) + + sess = create_session() + u1 = User() + u1.name = 'this is user 1' + sess.save(u1, entity_name='user1') + a1 = Address() + a1.email='a1@foo.com' + u1.addresses.append(a1) + + u2 = User() + u2.name='this is user 2' + a2 = Address() + a2.email='a2@foo.com' + u2.addresses.append(a2) + sess.save(u2, entity_name='user2') + + sess.flush() + assert user1.select().execute().fetchall() == [(u1.user_id, u1.name)] + assert user2.select().execute().fetchall() == [(u2.user_id, u2.name)] + assert address1.select().execute().fetchall() == [(u1.user_id, a1.user_id, 'a1@foo.com')] + assert address2.select().execute().fetchall() == [(u2.user_id, a2.user_id, 'a2@foo.com')] + + sess.clear() + u1list = sess.query(User, entity_name='user1').select() + u2list = sess.query(User, entity_name='user2').select() + assert len(u1list) == len(u2list) == 1 + assert u1list[0] is not u2list[0] + assert len(u1list[0].addresses) == len(u2list[0].addresses) == 1 + def testpolymorphic(self): """tests that entity_name can be used to have two kinds of relations on the same class.""" class User(object):pass diff --git a/test/inheritance3.py b/test/inheritance3.py index 15c93e7b21..2f029e64f4 100644 --- a/test/inheritance3.py +++ b/test/inheritance3.py @@ -14,7 +14,7 @@ class Issue(BaseObject): class Location(BaseObject): def __repr__(self): - return "%s(%s, %s)" % (self.__class__.__name__, repr(self.issue_id), repr(str(self._name.name))) + return "%s(%s, %s)" % (self.__class__.__name__, str(self.issue_id), repr(str(self._name.name))) def _get_name(self): return self._name diff --git a/test/objectstore.py b/test/objectstore.py index ef20bbdbea..442ee9f4da 100644 --- a/test/objectstore.py +++ b/test/objectstore.py @@ -288,7 +288,9 @@ class PrivateAttrTest(SessionTest): ctx.current.flush([a]) ctx.current.delete(a) + print ctx.current.deleted ctx.current.flush([a]) +# ctx.current.flush() assert b_table.count().scalar() == 0 -- 2.47.2