From 19aae75e6aea39e59357f334039baf6861647e40 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Tue, 6 Dec 2005 03:32:24 +0000 Subject: [PATCH] first take at backreference handlers --- lib/sqlalchemy/attributes.py | 46 ++++++++++++++++++++++++++++++++++-- lib/sqlalchemy/util.py | 12 +++++----- test/attributes.py | 43 +++++++++++++++++++++++++++++++++ 3 files changed, 93 insertions(+), 8 deletions(-) diff --git a/lib/sqlalchemy/attributes.py b/lib/sqlalchemy/attributes.py index aa768532de..9bdfcab6a1 100644 --- a/lib/sqlalchemy/attributes.py +++ b/lib/sqlalchemy/attributes.py @@ -44,10 +44,11 @@ class PropHistory(object): """manages the value of a particular scalar attribute on a particular object instance.""" # make our own NONE to distinguish from "None" NONE = object() - def __init__(self, obj, key, **kwargs): + def __init__(self, obj, key, backrefmanager=None, **kwargs): self.obj = obj self.key = key self.orig = PropHistory.NONE + self.backrefmanager = backrefmanager def gethistory(self, *args, **kwargs): return self def history_contains(self, obj): @@ -61,9 +62,13 @@ class PropHistory(object): raise ("assigning a list to scalar property '%s' on '%s' instance %d" % (self.key, self.obj.__class__.__name__, id(self.obj))) self.orig = self.obj.__dict__.get(self.key, None) self.obj.__dict__[self.key] = value + if self.backrefmanager is not None and self.orig is not value: + self.backrefmanager.set(self.obj, value, self.orig) def delattr(self): self.orig = self.obj.__dict__.get(self.key, None) self.obj.__dict__[self.key] = None + if self.backrefmanager is not None: + self.backrefmanager.set(self.obj, None, self.orig) def rollback(self): if self.orig is not PropHistory.NONE: self.obj.__dict__[self.key] = self.orig @@ -88,9 +93,10 @@ class PropHistory(object): class ListElement(util.HistoryArraySet): """manages the value of a particular list-based attribute on a particular object instance.""" - def __init__(self, obj, key, data=None, **kwargs): + def __init__(self, obj, key, data=None, backrefmanager=None, **kwargs): self.obj = obj self.key = key + self.backrefmanager = backrefmanager # if we are given a list, try to behave nicely with an existing # list that might be set on the object already try: @@ -120,11 +126,15 @@ class ListElement(util.HistoryArraySet): res = util.HistoryArraySet._setrecord(self, item) if res: self.list_value_changed(self.obj, self.key, item, self, False) + if self.backrefmanager is not None: + self.backrefmanager.append(self.obj, item) return res def _delrecord(self, item): res = util.HistoryArraySet._delrecord(self, item) if res: self.list_value_changed(self.obj, self.key, item, self, True) + if self.backrefmanager is not None: + self.backrefmanager.delete(self.obj, item) return res class CallableProp(object): @@ -175,6 +185,38 @@ class CallableProp(object): def rollback(self): pass +class BackrefManager(object): + def __init__(self, key): + self.key = key + def append(self, parent, child): + pass + def delete(self, parent, child): + pass + def set(self, parent, child, oldchild): + pass + + +class ListBackrefManager(BackrefManager): + def append(self, parent, child): + getattr(child, self.key).append(parent) + def delete(self, parent, child): + getattr(child, self.key).remove(parent) + +class OneToManyBackrefManager(BackrefManager): + def append(self, parent, child): + setattr(child, self.key, parent) + def delete(self, parent, child): + setattr(child, self.key, None) + +class ManyToOneBackrefManager(BackrefManager): + def set(self, parent, child, oldchild): + if oldchild is not None: + try: + getattr(oldchild, self.key).remove(parent) + except: + print "wha? oldchild is ", repr(oldchild) + if child is not None: + getattr(child, self.key).append(parent) class AttributeManager(object): """maintains a set of per-attribute callable/history manager objects for a set of objects.""" diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index c5ac8b979a..443db2e3d5 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -213,9 +213,9 @@ class HistoryArraySet(UserList.UserList): self.records[item] = False elif val is True: del self.records[item] + return True except KeyError: - pass - return True + return False def commit(self): for key in self.records.keys(): value = self.records[key] @@ -274,11 +274,11 @@ class HistoryArraySet(UserList.UserList): self.data.insert(i, item) def pop(self, i=-1): item = self.data[i] - self._delrecord(item) - return self.data.pop(i) + if self._delrecord(item): + return self.data.pop(i) def remove(self, item): - self._delrecord(item) - self.data.remove(item) + if self._delrecord(item): + self.data.remove(item) def __add__(self, other): raise NotImplementedError() def __radd__(self, other): diff --git a/test/attributes.py b/test/attributes.py index a7dd3bf4e4..b546de533f 100644 --- a/test/attributes.py +++ b/test/attributes.py @@ -77,6 +77,49 @@ class AttributesTest(PersistTest): self.assert_(u.user_id == 7 and u.user_name == 'john' and u.addresses[0].email_address == 'lala@123.com') self.assert_(len(u.addresses.unchanged_items()) == 1) + def testbackref(self): + class Student(object):pass + class Course(object):pass + manager = attributes.AttributeManager() + manager.register_attribute(Student, 'courses', uselist=True, backrefmanager=attributes.ListBackrefManager('students')) + manager.register_attribute(Course, 'students', uselist=True, backrefmanager=attributes.ListBackrefManager('courses')) + + s = Student() + c = Course() + s.courses.append(c) + self.assert_(c.students == [s]) + s.courses.remove(c) + self.assert_(c.students == []) + + (s1, s2, s3) = (Student(), Student(), Student()) + c.students = [s1, s2, s3] + self.assert_(s2.courses == [c]) + self.assert_(s1.courses == [c]) + s1.courses.remove(c) + self.assert_(c.students == [s2,s3]) + + + class Post(object):pass + class Blog(object):pass + + manager.register_attribute(Post, 'blog', uselist=False, backrefmanager=attributes.ManyToOneBackrefManager('posts')) + manager.register_attribute(Blog, 'posts', uselist=True, backrefmanager=attributes.OneToManyBackrefManager('blog')) + b = Blog() + (p1, p2, p3) = (Post(), Post(), Post()) + b.posts.append(p1) + b.posts.append(p2) + b.posts.append(p3) + self.assert_(b.posts == [p1, p2, p3]) + self.assert_(p2.blog is b) + + p3.blog = None + self.assert_(b.posts == [p1, p2]) + p4 = Post() + p4.blog = b + self.assert_(b.posts == [p1, p2, p4]) + + + if __name__ == "__main__": unittest.main() -- 2.47.2