]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
first take at backreference handlers
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 6 Dec 2005 03:32:24 +0000 (03:32 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 6 Dec 2005 03:32:24 +0000 (03:32 +0000)
lib/sqlalchemy/attributes.py
lib/sqlalchemy/util.py
test/attributes.py

index aa768532de36143454be3d52c1e8872cc24b9798..9bdfcab6a1ec43358ccb672629c390c24a42dcda 100644 (file)
@@ -44,10 +44,11 @@ class PropHistory(object):
     """manages the value of a particular scalar attribute on a particular object instance."""
     # make our own NONE to distinguish from "None"
     NONE = object()
-    def __init__(self, obj, key, **kwargs):
+    def __init__(self, obj, key, backrefmanager=None, **kwargs):
         self.obj = obj
         self.key = key
         self.orig = PropHistory.NONE
+        self.backrefmanager = backrefmanager
     def gethistory(self, *args, **kwargs):
         return self
     def history_contains(self, obj):
@@ -61,9 +62,13 @@ class PropHistory(object):
             raise ("assigning a list to scalar property '%s' on '%s' instance %d" % (self.key, self.obj.__class__.__name__, id(self.obj)))
         self.orig = self.obj.__dict__.get(self.key, None)
         self.obj.__dict__[self.key] = value
+        if self.backrefmanager is not None and self.orig is not value:
+            self.backrefmanager.set(self.obj, value, self.orig)
     def delattr(self):
         self.orig = self.obj.__dict__.get(self.key, None)
         self.obj.__dict__[self.key] = None
+        if self.backrefmanager is not None:
+            self.backrefmanager.set(self.obj, None, self.orig)
     def rollback(self):
         if self.orig is not PropHistory.NONE:
             self.obj.__dict__[self.key] = self.orig
@@ -88,9 +93,10 @@ class PropHistory(object):
 
 class ListElement(util.HistoryArraySet):
     """manages the value of a particular list-based attribute on a particular object instance."""
-    def __init__(self, obj, key, data=None, **kwargs):
+    def __init__(self, obj, key, data=None, backrefmanager=None, **kwargs):
         self.obj = obj
         self.key = key
+        self.backrefmanager = backrefmanager
         # if we are given a list, try to behave nicely with an existing
         # list that might be set on the object already
         try:
@@ -120,11 +126,15 @@ class ListElement(util.HistoryArraySet):
         res = util.HistoryArraySet._setrecord(self, item)
         if res:
             self.list_value_changed(self.obj, self.key, item, self, False)
+            if self.backrefmanager is not None:
+                self.backrefmanager.append(self.obj, item)
         return res
     def _delrecord(self, item):
         res = util.HistoryArraySet._delrecord(self, item)
         if res:
             self.list_value_changed(self.obj, self.key, item, self, True)
+            if self.backrefmanager is not None:
+                self.backrefmanager.delete(self.obj, item)
         return res
 
 class CallableProp(object):
@@ -175,6 +185,38 @@ class CallableProp(object):
     def rollback(self):
         pass
 
+class BackrefManager(object):
+    def __init__(self, key):
+        self.key = key
+    def append(self, parent, child):
+        pass
+    def delete(self, parent, child):
+        pass
+    def set(self, parent, child, oldchild):
+        pass
+
+
+class ListBackrefManager(BackrefManager):
+    def append(self, parent, child):
+        getattr(child, self.key).append(parent)
+    def delete(self, parent, child):
+        getattr(child, self.key).remove(parent)
+
+class OneToManyBackrefManager(BackrefManager):
+    def append(self, parent, child):
+        setattr(child, self.key, parent)
+    def delete(self, parent, child):
+        setattr(child, self.key, None)
+
+class ManyToOneBackrefManager(BackrefManager):
+    def set(self, parent, child, oldchild):
+        if oldchild is not None:
+            try:
+                getattr(oldchild, self.key).remove(parent)
+            except:
+                print "wha? oldchild is ", repr(oldchild)
+        if child is not None:
+            getattr(child, self.key).append(parent)
             
 class AttributeManager(object):
     """maintains a set of per-attribute callable/history manager objects for a set of objects."""
index c5ac8b979ae352e97845985348dd254100aae0ce..443db2e3d5f87cb8d181b41a5653dc097ad69fb8 100644 (file)
@@ -213,9 +213,9 @@ class HistoryArraySet(UserList.UserList):
                 self.records[item] = False
             elif val is True:
                 del self.records[item]
+            return True    
         except KeyError:
-            pass
-        return True
+            return False
     def commit(self):
         for key in self.records.keys():
             value = self.records[key]
@@ -274,11 +274,11 @@ class HistoryArraySet(UserList.UserList):
             self.data.insert(i, item)
     def pop(self, i=-1):
         item = self.data[i]
-        self._delrecord(item) 
-        return self.data.pop(i)
+        if self._delrecord(item):
+            return self.data.pop(i)
     def remove(self, item): 
-        self._delrecord(item)
-        self.data.remove(item)
+        if self._delrecord(item):
+            self.data.remove(item)
     def __add__(self, other):
         raise NotImplementedError()
     def __radd__(self, other):
index a7dd3bf4e440ff5da8bb2b155e9b4129ba4628ab..b546de533fddd3e93dae27adb6d9cb5613be182d 100644 (file)
@@ -77,6 +77,49 @@ class AttributesTest(PersistTest):
         self.assert_(u.user_id == 7 and u.user_name == 'john' and u.addresses[0].email_address == 'lala@123.com')
         self.assert_(len(u.addresses.unchanged_items()) == 1)
 
+    def testbackref(self):
+        class Student(object):pass
+        class Course(object):pass
+        manager = attributes.AttributeManager()
+        manager.register_attribute(Student, 'courses', uselist=True, backrefmanager=attributes.ListBackrefManager('students'))
+        manager.register_attribute(Course, 'students', uselist=True, backrefmanager=attributes.ListBackrefManager('courses'))
+        
+        s = Student()
+        c = Course()
+        s.courses.append(c)
+        self.assert_(c.students == [s])
+        s.courses.remove(c)
+        self.assert_(c.students == [])
+        
+        (s1, s2, s3) = (Student(), Student(), Student())
+        c.students = [s1, s2, s3]
+        self.assert_(s2.courses == [c])
+        self.assert_(s1.courses == [c])
+        s1.courses.remove(c)
+        self.assert_(c.students == [s2,s3])
+        
+        
+        class Post(object):pass
+        class Blog(object):pass
+        
+        manager.register_attribute(Post, 'blog', uselist=False, backrefmanager=attributes.ManyToOneBackrefManager('posts'))
+        manager.register_attribute(Blog, 'posts', uselist=True, backrefmanager=attributes.OneToManyBackrefManager('blog'))
+        b = Blog()
+        (p1, p2, p3) = (Post(), Post(), Post())
+        b.posts.append(p1)
+        b.posts.append(p2)
+        b.posts.append(p3)
+        self.assert_(b.posts == [p1, p2, p3])
+        self.assert_(p2.blog is b)
+        
+        p3.blog = None
+        self.assert_(b.posts == [p1, p2])
+        p4 = Post()
+        p4.blog = b
+        self.assert_(b.posts == [p1, p2, p4])
+        
+        
+        
         
 if __name__ == "__main__":
     unittest.main()