]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
improvements/fixes to session cascade iteration,
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 5 Jun 2006 02:31:53 +0000 (02:31 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 5 Jun 2006 02:31:53 +0000 (02:31 +0000)
fixes to entity_name propigation

examples/polymorph/polymorph.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/orm/unitofwork.py
test/entity.py
test/inheritance3.py
test/objectstore.py

index 76a03b99d5aabdfa8c651ff16681087570e9631d..f2f568726b23050de71a93942d8a1b926fa70a45 100644 (file)
@@ -80,6 +80,7 @@ c.employees.append(Person(name='joesmith', status='HHH'))
 c.employees.append(Engineer(name='wally', status='CGG', engineer_name='engineer2', primary_language='python'))
 c.employees.append(Manager(name='jsmith', status='ABA', manager_name='manager2'))
 session.save(c)
+
 print session.new
 session.flush()
 #sys.exit()
index 4e80aceeb68e104e87f61477b8e5fe3d711930cc..eba2203843d74ef052e7d96fd690500274ab2be0 100644 (file)
@@ -387,25 +387,31 @@ class Mapper(object):
         and assocites this Mapper with its class via the mapper_registry."""
         oldinit = self.class_.__init__
         def init(self, *args, **kwargs):
-            self._entity_name = kwargs.pop('_sa_entity_name', None)
 
             # this gets the AttributeManager to do some pre-initialization,
             # in order to save on KeyErrors later on
             sessionlib.global_attributes.init_attr(self)
             
+            entity_name = kwargs.pop('_sa_entity_name', None)
             if kwargs.has_key('_sa_session'):
                 session = kwargs.pop('_sa_session')
             else:
                 # works for whatever mapper the class is associated with
-                mapper = mapper_registry.get(ClassKey(self.__class__, self._entity_name))
+                mapper = mapper_registry.get(ClassKey(self.__class__, entity_name))
                 if mapper is not None:
                     session = mapper.extension.get_session()
                     if session is EXT_PASS:
                         session = None
                 else:
                     session = None
+            # if a session was found, either via _sa_session or via mapper extension,
+            # save() this instance to the session, and give it an associated entity_name.
+            # otherwise, this instance will not have a session or mapper association until it is
+            # save()d to some session.
             if session is not None:
+                self._entity_name = entity_name
                 session._register_new(self)
+
             if oldinit is not None:
                 oldinit(self, *args, **kwargs)
         # override oldinit, insuring that its not already a Mapper-decorated init method
@@ -748,16 +754,19 @@ class Mapper(object):
         for prop in self.props.values():
             prop.register_dependencies(uowcommit, *args, **kwargs)
     
-    def cascade_iterator(self, type, object, recursive=None):
+    def cascade_iterator(self, type, object, callable_=None, recursive=None):
         if recursive is None:
             recursive=sets.Set()
-        if object not in recursive:
-            recursive.add(object)
-            yield object
         for prop in self.props.values():
             for c in prop.cascade_iterator(type, object, recursive):
                 yield c
 
+    def cascade_callable(self, type, object, callable_, recursive=None):
+        if recursive is None:
+            recursive=sets.Set()
+        for prop in self.props.values():
+            prop.cascade_callable(type, object, callable_, recursive)
+            
     def _row_identity_key(self, row):
         return sessionlib.get_row_key(row, self.class_, self.pks_by_table[self.mapped_table], self.entity_name)
 
@@ -929,6 +938,8 @@ class MapperProperty(object):
         raise NotImplementedError()
     def cascade_iterator(self, type, object, recursive=None):
         return []
+    def cascade_callable(self, type, object, callable_, recursive=None):
+        return []
     def copy(self):
         raise NotImplementedError()
     def get_criterion(self, query, key, value):
@@ -1157,13 +1168,16 @@ def hash_key(obj):
         return obj.hash_key()
     else:
         return repr(obj)
+
+def has_mapper(object):
+    """returns True if the given object has a mapper association"""
+    return hasattr(object, '_entity_name')
         
-def object_mapper(object, raiseerror=True, entity_name=None):
-    """given an object, returns the primary Mapper associated with the object
-    or the object's class."""
+def object_mapper(object, raiseerror=True):
+    """given an object, returns the primary Mapper associated with the object instance"""
     try:
-        return mapper_registry[ClassKey(object.__class__, getattr(object, '_entity_name', entity_name))]
-    except KeyError:
+        return mapper_registry[ClassKey(object.__class__, getattr(object, '_entity_name'))]
+    except (KeyError, AttributeError):
         if raiseerror:
             raise exceptions.InvalidRequestError("Class '%s' entity name '%s' has no mapper associated with it" % (object.__class__.__name__, getattr(object, '_entity_name', None)))
         else:
index 8c38b897fef3afcb1c56bc7f29e920ad141ac839..8cdaf394096cb65b7676ca4cfc54731e9b195abb 100644 (file)
@@ -154,14 +154,25 @@ class PropertyLoader(mapper.MapperProperty):
         if not type in self.cascade:
             return
         childlist = sessionlib.global_attributes.get_history(object, self.key, passive=True)
-            
+        
+        mapper = self.mapper.primary_mapper()
         for c in childlist.added_items() + childlist.deleted_items() + childlist.unchanged_items():
-            if c is not None:
-                if c not in recursive:
-                    recursive.add(c)
-                    yield c
-                    for c2 in self.mapper.primary_mapper().cascade_iterator(type, c, recursive):
-                        yield c2
+            if c is not None and c not in recursive:
+                recursive.add(c)
+                yield c
+                for c2 in mapper.cascade_iterator(type, c, recursive):
+                    yield c2
+
+    def cascade_callable(self, type, object, callable_, recursive):
+        if not type in self.cascade:
+            return
+        childlist = sessionlib.global_attributes.get_history(object, self.key, passive=True)
+        mapper = self.mapper.primary_mapper()
+        for c in childlist.added_items() + childlist.deleted_items() + childlist.unchanged_items():
+            if c is not None and c not in recursive:
+                recursive.add(c)
+                callable_(c, mapper.entity_name)
+                mapper.cascade_callable(type, c, callable_, recursive)
 
     def copy(self):
         x = self.__class__.__new__(self.__class__)
@@ -237,10 +248,16 @@ class PropertyLoader(mapper.MapperProperty):
             raise exceptions.ArgumentError("Attempting to assign a new relation '%s' to a non-primary mapper on class '%s'.  New relations can only be added to the primary mapper, i.e. the very first mapper created for class '%s' " % (key, parent.class_.__name__, parent.class_.__name__))
 
         self.do_init_subclass(key, parent)
+
+    def _register_attribute(self, class_, callable_=None):
+        sessionlib.global_attributes.register_attribute(class_, self.key, uselist = self.uselist, extension=self.attributeext, cascade=self.cascade,  trackparent=True, callable_=callable_)
+
+    def _create_history(self, instance, callable_=None):
+        return sessionlib.global_attributes.create_history(instance, self.key, self.uselist, cascade=self.cascade,  trackparent=True, callable_=callable_)
         
     def _set_class_attribute(self, class_, key):
         """sets attribute behavior on our target class."""
-        sessionlib.global_attributes.register_attribute(class_, key, uselist = self.uselist, extension=self.attributeext, cascade=self.cascade, trackparent=True)
+        self._register_attribute(class_)
         
     def _get_direction(self):
         """determines our 'direction', i.e. do we represent one to many, many to many, etc."""
@@ -295,7 +312,7 @@ class PropertyLoader(mapper.MapperProperty):
         if self.is_primary():
             return
         #print "PLAIN PROPLOADER EXEC NON-PRIAMRY", repr(id(self)), repr(self.mapper.class_), self.key
-        sessionlib.global_attributes.create_history(instance, self.key, self.uselist, cascade=self.cascade, trackparent=True)
+        self._create_history(instance)
 
     def register_dependencies(self, uowcommit):
         self._dependency_processor.register_dependencies(uowcommit)
@@ -327,11 +344,17 @@ class LazyLoader(PropertyLoader):
     def _set_class_attribute(self, class_, key):
         # establish a class-level lazy loader on our class
         #print "SETCLASSATTR LAZY", repr(class_), key
-        sessionlib.global_attributes.register_attribute(class_, key, uselist = self.uselist, callable_=lambda i: self.setup_loader(i), extension=self.attributeext, cascade=self.cascade, trackparent=True)
+        self._register_attribute(class_, callable_=lambda i: self.setup_loader(i))
 
     def setup_loader(self, instance):
+        # make sure our parent mapper is the one thats assigned to this instance, else call that one
         if not self.localparent.is_assigned(instance):
-            return mapper.object_mapper(instance).props[self.key].setup_loader(instance)
+            # if no mapper association with this instance (i.e. not in a session, not loaded by a mapper),
+            # then we cant set up a lazy loader
+            if not mapper.has_mapper(instance):
+                return None
+            else:
+                return mapper.object_mapper(instance).props[self.key].setup_loader(instance)
         def lazyload():
             params = {}
             allparams = True
@@ -379,7 +402,7 @@ class LazyLoader(PropertyLoader):
                 #print "EXEC NON-PRIAMRY", repr(self.mapper.class_), self.key
                 # we are not the primary manager for this attribute on this class - set up a per-instance lazyloader,
                 # which will override the class-level behavior
-                sessionlib.global_attributes.create_history(instance, self.key, self.uselist, callable_=self.setup_loader(instance), cascade=self.cascade, trackparent=True)
+                self._create_history(instance, callable_=self.setup_loader(instance))
             else:
                 #print "EXEC PRIMARY", repr(self.mapper.class_), self.key
                 # we are the primary manager for this attribute on this class - reset its per-instance attribute state, 
@@ -548,7 +571,7 @@ class EagerLoader(LazyLoader):
         if isnew:
             # new row loaded from the database.  initialize a blank container on the instance.
             # this will override any per-class lazyloading type of stuff.
-            h = sessionlib.global_attributes.create_history(instance, self.key, self.uselist, cascade=self.cascade, trackparent=True)
+            h = self._create_history(instance)
             
         if not self.uselist:
             if isnew:
index 56d699cc6607f55ead57e271ad816da45812f33b..bd17501653c663975144f1de6a574b499cc2e5c8 100644 (file)
@@ -275,12 +275,8 @@ class Session(object):
         The 'entity_name' keyword argument will further qualify the specific Mapper used to handle this
         instance.
         """
-        for c in object_mapper(object, entity_name=entity_name).cascade_iterator('save-update', object):
-            if c is object:
-                self._save_impl(c, entity_name=entity_name)
-            else:
-                # TODO: this is running the cascade rules twice
-                self.save_or_update(c, entity_name=entity_name)
+        self._save_impl(object, entity_name=entity_name)
+        object_mapper(object).cascade_callable('save-update', object, lambda c, e:self._save_or_update_impl(c, e))
 
     def update(self, object, entity_name=None):
         """Brings the given detached (saved) instance into this Session.
@@ -288,30 +284,31 @@ class Session(object):
         Session), an exception is thrown. 
         This operation cascades the "save_or_update" method to associated instances if the relation is mapped 
         with cascade="save-update"."""
-        for c in object_mapper(object, entity_name=entity_name).cascade_iterator('save-update', object):
-            if c is object:
-                self._update_impl(c, entity_name=entity_name)
-            else:
-                self.save_or_update(c, entity_name=entity_name)
+        self._update_impl(object, entity_name=entity_name)
+        object_mapper(object).cascade_callable('save-update', object, lambda c, e:self._save_or_update_impl(c, e))
 
     def save_or_update(self, object, entity_name=None):
-        for c in object_mapper(object, entity_name=entity_name).cascade_iterator('save-update', object):
-            key = getattr(object, '_instance_key', None)
-            if key is None:
-                self._save_impl(c, entity_name=entity_name)
-            else:
-                self._update_impl(c, entity_name=entity_name)
-
+        self._save_or_update_impl(object, entity_name=entity_name)
+        object_mapper(object).cascade_callable('save-update', object, lambda c, e:self._save_or_update_impl(c, e))
+    
+    def _save_or_update_impl(self, object, entity_name=None):
+        key = getattr(object, '_instance_key', None)
+        if key is None:
+            self._save_impl(object, entity_name=entity_name)
+        else:
+            self._update_impl(object, entity_name=entity_name)
+        
     def delete(self, object, entity_name=None):
-        for c in object_mapper(object, entity_name=entity_name).cascade_iterator('delete', object):
+        #self.uow.register_deleted(object)
+        for c in [object] + list(object_mapper(object).cascade_iterator('delete', object)):
             self.uow.register_deleted(c)
 
     def merge(self, object, entity_name=None):
         instance = None
-        for obj in object_mapper(object, entity_name=entity_name).cascade_iterator('merge', object):
+        for obj in [object] + list(object_mapper(object).cascade_iterator('merge', object)):
             key = getattr(obj, '_instance_key', None)
             if key is None:
-                mapper = object_mapper(object, entity_name=entity_name)
+                mapper = object_mapper(object)
                 ident = mapper.identity(object)
                 for k in ident:
                     if k is None:
@@ -333,10 +330,8 @@ class Session(object):
             if not self.uow.has_key(object._instance_key):
                 raise exceptions.InvalidRequestError("Instance '%s' is already persistent in a different Session" % repr(object))
         else:
-            entity_name = kwargs.get('entity_name', None)
-            if entity_name is not None:
-                m = class_mapper(object.__class__, entity_name=entity_name)
-                m._assign_entity_name(object)
+            m = class_mapper(object.__class__, entity_name=kwargs.get('entity_name', None))
+            m._assign_entity_name(object)
             self._register_new(object)
 
     def _update_impl(self, object, **kwargs):
@@ -422,8 +417,8 @@ def get_id_key(ident, class_, entity_name=None):
 def get_row_key(row, class_, primary_key, entity_name=None):
     return Session.get_row_key(row, class_, primary_key, entity_name)
 
-def object_mapper(obj, **kwargs):
-    return sqlalchemy.orm.object_mapper(obj, **kwargs)
+def object_mapper(obj):
+    return sqlalchemy.orm.object_mapper(obj)
 
 def class_mapper(class_, **kwargs):
     return sqlalchemy.orm.class_mapper(class_, **kwargs)
index c33f344fbea9ea15bd0c71389332ac08e167e924..9e9778cad98c898697e5d2b6c0405db483763d32 100644 (file)
@@ -49,7 +49,13 @@ class UOWListElement(attributes.ListAttribute):
             if self.cascade is not None:
                 if not isdelete:
                     if self.cascade.save_update:
-                        sess.save_or_update(item)
+                        # cascade the save_update operation onto the child object,
+                        # relative to the mapper handling the parent object
+                        # TODO: easier way to do this ?
+                        mapper = object_mapper(obj)
+                        prop = mapper.props[self.key]
+                        ename = prop.mapper.entity_name
+                        sess.save_or_update(item, entity_name=ename)
     def append(self, item, _mapper_nohistory = False):
         if _mapper_nohistory:
             self.append_nohistory(item)
@@ -67,13 +73,19 @@ class UOWScalarElement(attributes.ScalarAttribute):
             sess._register_changed(obj)
             if newvalue is not None and self.cascade is not None:
                 if self.cascade.save_update:
-                    sess.save_or_update(newvalue)
+                    # cascade the save_update operation onto the child object,
+                    # relative to the mapper handling the parent object
+                    # TODO: easier way to do this ?
+                    mapper = object_mapper(obj)
+                    prop = mapper.props[self.key]
+                    ename = prop.mapper.entity_name
+                    sess.save_or_update(newvalue, entity_name=ename)
             
 class UOWAttributeManager(attributes.AttributeManager):
     """overrides AttributeManager to provide unit-of-work "dirty" hooks when scalar attribues are modified, plus factory methods for UOWProperrty/UOWListElement."""
     def __init__(self):
         attributes.AttributeManager.__init__(self)
-        
+
     def create_prop(self, class_, key, uselist, callable_, **kwargs):
         return UOWProperty(class_, self, key, uselist, callable_, **kwargs)
 
index 22e74f1716301f32ccbed724d8565630da196401..f425c5c8eedb22dce5a5439efb69946b59b7656b 100644 (file)
@@ -82,6 +82,48 @@ class EntityTest(AssertMixin):
         assert u1list[0] is not u2list[0]
         assert len(u1list[0].addresses) == len(u2list[0].addresses) == 1
 
+    def testcascade(self):
+        """same as testbasic but relies on session cascading"""
+        class User(object):pass
+        class Address(object):pass
+
+        a1mapper = mapper(Address, address1, entity_name='address1')
+        a2mapper = mapper(Address, address2, entity_name='address2')    
+        u1mapper = mapper(User, user1, entity_name='user1', properties ={
+            'addresses':relation(a1mapper)
+        })
+        u2mapper =mapper(User, user2, entity_name='user2', properties={
+            'addresses':relation(a2mapper)
+        })
+
+        sess = create_session()
+        u1 = User()
+        u1.name = 'this is user 1'
+        sess.save(u1, entity_name='user1')
+        a1 = Address()
+        a1.email='a1@foo.com'
+        u1.addresses.append(a1)
+
+        u2 = User()
+        u2.name='this is user 2'
+        a2 = Address()
+        a2.email='a2@foo.com'
+        u2.addresses.append(a2)
+        sess.save(u2, entity_name='user2')
+        
+        sess.flush()
+        assert user1.select().execute().fetchall() == [(u1.user_id, u1.name)]
+        assert user2.select().execute().fetchall() == [(u2.user_id, u2.name)]
+        assert address1.select().execute().fetchall() == [(u1.user_id, a1.user_id, 'a1@foo.com')]
+        assert address2.select().execute().fetchall() == [(u2.user_id, a2.user_id, 'a2@foo.com')]
+
+        sess.clear()
+        u1list = sess.query(User, entity_name='user1').select()
+        u2list = sess.query(User, entity_name='user2').select()
+        assert len(u1list) == len(u2list) == 1
+        assert u1list[0] is not u2list[0]
+        assert len(u1list[0].addresses) == len(u2list[0].addresses) == 1
+
     def testpolymorphic(self):
         """tests that entity_name can be used to have two kinds of relations on the same class."""
         class User(object):pass
index 15c93e7b213fbdf95d6ef284e489b760f57b4a82..2f029e64f40078fe3d916c5c59d3bb637f49f3be 100644 (file)
@@ -14,7 +14,7 @@ class Issue(BaseObject):
 
 class Location(BaseObject):
     def __repr__(self):
-        return "%s(%s, %s)" % (self.__class__.__name__, repr(self.issue_id), repr(str(self._name.name)))
+        return "%s(%s, %s)" % (self.__class__.__name__, str(self.issue_id), repr(str(self._name.name)))
 
     def _get_name(self):
         return self._name
index ef20bbdbead3dd8650ec9c62d6726cba849bf35b..442ee9f4da3dcd9b8e62cfea3d7b675a99b8ad91 100644 (file)
@@ -288,7 +288,9 @@ class PrivateAttrTest(SessionTest):
         ctx.current.flush([a])
     
         ctx.current.delete(a)
+        print ctx.current.deleted
         ctx.current.flush([a])
+#        ctx.current.flush()
         
         assert b_table.count().scalar() == 0