From fe49ba2f8c962cf56a58b0e4e4a8547e36c11e4f Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sat, 4 Feb 2006 17:05:09 +0000 Subject: [PATCH] one-to-one support: rolled the BackrefExtensions into a single GenericBackrefExtension to handle all combinations of list/nonlist properties (such as one-to-one) tweak to properties.py which may receive "None" as "added_items()", in the case of a scalar property instead of a list PropHistory masquerades as a List on the setattr/append delattr/remove side to make one-to-one's automatically work --- lib/sqlalchemy/attributes.py | 42 +++++++++++----------------- lib/sqlalchemy/mapping/properties.py | 16 ++++------- test/attributes.py | 21 ++++++++++---- 3 files changed, 39 insertions(+), 40 deletions(-) diff --git a/lib/sqlalchemy/attributes.py b/lib/sqlalchemy/attributes.py index 87b09507f9..f6abe6b31d 100644 --- a/lib/sqlalchemy/attributes.py +++ b/lib/sqlalchemy/attributes.py @@ -24,7 +24,7 @@ AttributeManager can also assign a "callable" history container to an object's a which is invoked when first accessed, to provide the object's "committed" value. The package includes functions for managing "bi-directional" object relationships as well -via the ListBackrefExtension, OTMBackrefExtension, and MTOBackrefExtension objects. +via the GenericBackrefExtension object. """ import sqlalchemy.util as util @@ -80,15 +80,18 @@ class PropHistory(object): orig = self.obj.__dict__.get(self.key, None) if orig is value: return - self.orig = orig + if self.orig is PropHistory.NONE: + self.orig = orig self.obj.__dict__[self.key] = value - if self.extension is not None and self.orig is not value: - self.extension.set(self.obj, value, self.orig) + if self.extension is not None: + self.extension.set(self.obj, value, orig) def delattr(self): - self.orig = self.obj.__dict__.get(self.key, None) + orig = self.obj.__dict__.get(self.key, None) + if self.orig is PropHistory.NONE: + self.orig = orig self.obj.__dict__[self.key] = None if self.extension is not None: - self.extension.set(self.obj, None, self.orig) + self.extension.set(self.obj, None, orig) def append(self, obj): self.setattr(obj) def remove(self, obj): @@ -223,25 +226,7 @@ class AttributeExtension(object): def set(self, obj, child, oldchild): pass -class ListBackrefExtension(AttributeExtension): - def __init__(self, key): - self.key = key - def append(self, obj, child): - getattr(child, self.key).append(obj) - def delete(self, obj, child): - getattr(child, self.key).remove(obj) -class OTMBackrefExtension(AttributeExtension): - def __init__(self, key): - self.key = key - def append(self, obj, child): - prop = child.__class__._attribute_manager.get_history(child, self.key) - prop.setattr(obj) -# prop.setattr(obj) - def delete(self, obj, child): - prop = child.__class__._attribute_manager.get_history(child, self.key) - prop.delattr() - -class MTOBackrefExtension(AttributeExtension): +class GenericBackrefExtension(AttributeExtension): def __init__(self, key): self.key = key def set(self, obj, child, oldchild): @@ -251,6 +236,13 @@ class MTOBackrefExtension(AttributeExtension): if child is not None: prop = child.__class__._attribute_manager.get_history(child, self.key) prop.append(obj) + def append(self, obj, child): + prop = child.__class__._attribute_manager.get_history(child, self.key) + prop.append(obj) + def delete(self, obj, child): + prop = child.__class__._attribute_manager.get_history(child, self.key) + prop.remove(obj) + class AttributeManager(object): """maintains a set of per-attribute history container objects for a set of objects.""" diff --git a/lib/sqlalchemy/mapping/properties.py b/lib/sqlalchemy/mapping/properties.py index 5f15a8d2c6..9eb97aa0c1 100644 --- a/lib/sqlalchemy/mapping/properties.py +++ b/lib/sqlalchemy/mapping/properties.py @@ -196,12 +196,7 @@ class PropertyLoader(MapperProperty): # if a backref name is defined, set up an extension to populate # attributes in the other direction if self.backref is not None: - if self.direction == PropertyLoader.ONETOMANY: - self.attributeext = attributes.OTMBackrefExtension(self.backref) - elif self.direction == PropertyLoader.MANYTOONE: - self.attributeext = attributes.MTOBackrefExtension(self.backref) - else: - self.attributeext = attributes.ListBackrefExtension(self.backref) + self.attributeext = attributes.GenericBackrefExtension(self.backref) # set our class attribute self._set_class_attribute(parent.class_, key) @@ -521,10 +516,11 @@ class PropertyLoader(MapperProperty): childlist = getlist(obj, passive=True) if childlist is None: continue for child in childlist.added_items(): - self._synchronize(obj, child, None, False) - if self.direction == PropertyLoader.ONETOMANY: - # for a cyclical task, this registration is handled by the objectstore - uowcommit.register_object(child) + if child is not None: + self._synchronize(obj, child, None, False) + if self.direction == PropertyLoader.ONETOMANY: + # for a cyclical task, this registration is handled by the objectstore + uowcommit.register_object(child) if self.direction != PropertyLoader.MANYTOONE or len(childlist.added_items()) == 0: for child in childlist.deleted_items(): if not self.private: diff --git a/test/attributes.py b/test/attributes.py index 074a8d8391..ca34bdfc7a 100644 --- a/test/attributes.py +++ b/test/attributes.py @@ -81,8 +81,8 @@ class AttributesTest(PersistTest): class Student(object):pass class Course(object):pass manager = attributes.AttributeManager() - manager.register_attribute(Student, 'courses', uselist=True, extension=attributes.ListBackrefExtension('students')) - manager.register_attribute(Course, 'students', uselist=True, extension=attributes.ListBackrefExtension('courses')) + manager.register_attribute(Student, 'courses', uselist=True, extension=attributes.GenericBackrefExtension('students')) + manager.register_attribute(Course, 'students', uselist=True, extension=attributes.GenericBackrefExtension('courses')) s = Student() c = Course() @@ -102,8 +102,8 @@ class AttributesTest(PersistTest): class Post(object):pass class Blog(object):pass - manager.register_attribute(Post, 'blog', uselist=False, extension=attributes.MTOBackrefExtension('posts')) - manager.register_attribute(Blog, 'posts', uselist=True, extension=attributes.OTMBackrefExtension('blog')) + manager.register_attribute(Post, 'blog', uselist=False, extension=attributes.GenericBackrefExtension('posts')) + manager.register_attribute(Blog, 'posts', uselist=True, extension=attributes.GenericBackrefExtension('blog')) b = Blog() (p1, p2, p3) = (Post(), Post(), Post()) b.posts.append(p1) @@ -118,8 +118,19 @@ class AttributesTest(PersistTest): p4.blog = b self.assert_(b.posts == [p1, p2, p4]) + + class Port(object):pass + class Jack(object):pass + manager.register_attribute(Port, 'jack', uselist=False, extension=attributes.GenericBackrefExtension('port')) + manager.register_attribute(Jack, 'port', uselist=False, extension=attributes.GenericBackrefExtension('jack')) + p = Port() + j = Jack() + p.jack = j + self.assert_(j.port is p) + self.assert_(p.jack is not None) - + j.port = None + self.assert_(p.jack is None) if __name__ == "__main__": unittest.main() -- 2.47.2