]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Got basic backrefs going with dynamic attributes
authorJason Kirtland <jek@discorporate.us>
Sat, 28 Jul 2007 00:33:59 +0000 (00:33 +0000)
committerJason Kirtland <jek@discorporate.us>
Sat, 28 Jul 2007 00:33:59 +0000 (00:33 +0000)
lib/sqlalchemy/orm/dynamic.py
test/orm/dynamic.py

index 5b613686710376d9d68608787158ba9af89589a4..b1293148288d86be920e594634fe8f1218566d97 100644 (file)
@@ -9,7 +9,7 @@ class DynamicCollectionAttribute(attributes.InstrumentedAttribute):
     def __init__(self, class_, attribute_manager, key, typecallable, target_mapper, **kwargs):
         super(DynamicCollectionAttribute, self).__init__(class_, attribute_manager, key, typecallable, **kwargs)
         self.target_mapper = target_mapper
-        
+
     def get(self, obj, passive=False):
         if passive:
             return self.get_history(obj, passive=True).added_items()
@@ -30,7 +30,7 @@ class DynamicCollectionAttribute(attributes.InstrumentedAttribute):
 
         # TODO: emit events ???
         state['modified'] = True
-    
+
     def delete(self, *args, **kwargs):
         raise NotImplementedError()
         
@@ -40,6 +40,17 @@ class DynamicCollectionAttribute(attributes.InstrumentedAttribute):
         except KeyError:
             obj.__dict__[self.key] = c = CollectionHistory(self, obj)
             return c
+
+    def append(self, obj, value, initiator):
+        if initiator is not self:
+            self.get_history(obj)._added_items.append(value)
+            self.fire_append_event(obj, value, self)
+    
+    def remove(self, obj, value, initiator):
+        if initiator is not self:
+            self.get_history(obj)._deleted_items.append(value)
+            self.fire_remove_event(obj, value, self)
+
             
 class AppenderQuery(Query):
     def __init__(self, attr, instance):
@@ -61,15 +72,16 @@ class AppenderQuery(Query):
             return iter(self.attr.get_history(self.instance)._added_items)
         else:
             return iter(self._clone())
-    
+
     def __getitem__(self, index):
         if not has_identity(self.instance):
-            return iter(self.attr.get_history(self.instance)._added_items.__getitem__(index))
+            # TODO: hmm
+            return self.attr.get_history(self.instance)._added_items.__getitem__(index)
         else:
             return self._clone().__getitem__(index)
         
     def _clone(self):
-        # note we're returning an entirely new query class here
+        # note we're returning an entirely new Query class instance here
         # without any assignment capabilities;
         # the class of this query is determined by the session.
         sess = object_session(self.instance)
@@ -90,15 +102,15 @@ class AppenderQuery(Query):
         return oldlist
         
     def append(self, item):
-        self.attr.get_history(self.instance)._added_items.append(item)
-        self.attr.fire_append_event(self.instance, item, self.attr)
-    
+        self.attr.append(self.instance, item, self.attr)
+
+    # TODO:jek: I think this should probably be axed, time will tell.
     def remove(self, item):
-        self.attr.get_history(self.instance)._deleted_items.append(item)
-        self.attr.fire_remove_event(self.instance, item, self.attr)
+        self.attr.remove(self.instance, item, self.attr)
             
 class CollectionHistory(attributes.AttributeHistory): 
-    """override AttributeHistory to receive append/remove events directly"""
+    """Overrides AttributeHistory to receive append/remove events directly."""
+
     def __init__(self, attr, obj):
         self._deleted_items = []
         self._added_items = []
@@ -120,4 +132,4 @@ class CollectionHistory(attributes.AttributeHistory):
 
     def deleted_items(self):
         return self._deleted_items
-    
\ No newline at end of file
+    
index 434ac22963e53bef7bb8569c701f65cddaffcc92..2cd616c12d62c68d74b0a66803c0fd6b401e8098 100644 (file)
@@ -49,11 +49,81 @@ class FlushTest(FixtureTest):
                 User(name='ed', addresses=[Address(email_address='foo@bar.com')])
             ] == sess.query(User).all()
 
-        # one query for the query(User).all(), one query for each address iter(),
-        # also one query for a count() on each address (the count() is an artifact of the
-        # fixtures.Base class, its not intrinsic to the property)
+        # one query for the query(User).all(), one query for each address
+        # iter(), also one query for a count() on each address (the count()
+        # is an artifact of the fixtures.Base class, its not intrinsic to the
+        # property)
         self.assert_sql_count(testbase.db, go, 5)
+
+    def test_backref_unsaved_u(self):
+        mapper(User, users, properties={
+            'addresses':relation(mapper(Address, addresses), lazy='dynamic',
+                                 backref='user')
+        })
+        sess = create_session()
+
+        u = User(name='buffy')
+
+        a = Address(email_address='foo@bar.com')
+        a.user = u
+
+        sess.save(u)
+        sess.flush()
+
+    def test_backref_unsaved_a(self):
+        mapper(User, users, properties={
+            'addresses':relation(mapper(Address, addresses), lazy='dynamic',
+                                 backref='user')
+        })
+        sess = create_session()
+
+        u = User(name='buffy')
+
+        a = Address(email_address='foo@bar.com')
+        a.user = u
+
+        self.assert_(list(u.addresses) == [a])
+        self.assert_(u.addresses[0] == a)
+
+        sess.save(a)
+        sess.flush()
+        
+        self.assert_(list(u.addresses) == [a])
+
+        a.user = None
+        self.assert_(list(u.addresses) == [a])
+
+        sess.flush()
+        self.assert_(list(u.addresses) == [])
+        
+
+    def test_backref_unsaved_u(self):
+        mapper(User, users, properties={
+            'addresses':relation(mapper(Address, addresses), lazy='dynamic',
+                                 backref='user')
+        })
+        sess = create_session()
+
+        u = User(name='buffy')
+
+        a = Address(email_address='foo@bar.com')
+        a.user = u
+
+        self.assert_(list(u.addresses) == [a])
+        self.assert_(u.addresses[0] == a)
+
+        sess.save(u)
+        sess.flush()
+        
+        assert list(u.addresses) == [a]
+
+        a.user = None
+        self.assert_(list(u.addresses) == [a])
+
+        sess.flush()
+        self.assert_(list(u.addresses) == [])
+
         
 if __name__ == '__main__':
     testbase.main()
-    
\ No newline at end of file
+