]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
one-to-one support:
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 4 Feb 2006 17:05:09 +0000 (17:05 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 4 Feb 2006 17:05:09 +0000 (17:05 +0000)
rolled the BackrefExtensions into a single GenericBackrefExtension to handle
all combinations of list/nonlist properties (such as one-to-one)
tweak to properties.py which may receive "None" as "added_items()", in the case of a scalar property
instead of a list
PropHistory masquerades as a List on the setattr/append delattr/remove side to make one-to-one's automatically
work

lib/sqlalchemy/attributes.py
lib/sqlalchemy/mapping/properties.py
test/attributes.py

index 87b09507f9b8bea099d224c51d04b8269fb5bd57..f6abe6b31db2f775dc81f4864eb2953d664a94ec 100644 (file)
@@ -24,7 +24,7 @@ AttributeManager can also assign a "callable" history container to an object's a
 which is invoked when first accessed, to provide the object's "committed" value.  
 
 The package includes functions for managing "bi-directional" object relationships as well
-via the ListBackrefExtension, OTMBackrefExtension, and MTOBackrefExtension objects.
+via the GenericBackrefExtension object.
 """
 
 import sqlalchemy.util as util
@@ -80,15 +80,18 @@ class PropHistory(object):
         orig = self.obj.__dict__.get(self.key, None)
         if orig is value:
             return
-        self.orig = orig
+        if self.orig is PropHistory.NONE:
+            self.orig = orig
         self.obj.__dict__[self.key] = value
-        if self.extension is not None and self.orig is not value:
-            self.extension.set(self.obj, value, self.orig)
+        if self.extension is not None:
+            self.extension.set(self.obj, value, orig)
     def delattr(self):
-        self.orig = self.obj.__dict__.get(self.key, None)
+        orig = self.obj.__dict__.get(self.key, None)
+        if self.orig is PropHistory.NONE:
+            self.orig = orig
         self.obj.__dict__[self.key] = None
         if self.extension is not None:
-            self.extension.set(self.obj, None, self.orig)
+            self.extension.set(self.obj, None, orig)
     def append(self, obj):
         self.setattr(obj)
     def remove(self, obj):
@@ -223,25 +226,7 @@ class AttributeExtension(object):
     def set(self, obj, child, oldchild):
         pass
         
-class ListBackrefExtension(AttributeExtension):
-    def __init__(self, key):
-        self.key = key
-    def append(self, obj, child):
-        getattr(child, self.key).append(obj)
-    def delete(self, obj, child):
-        getattr(child, self.key).remove(obj)
-class OTMBackrefExtension(AttributeExtension):
-    def __init__(self, key):
-        self.key = key
-    def append(self, obj, child):
-        prop = child.__class__._attribute_manager.get_history(child, self.key)
-        prop.setattr(obj)
-#        prop.setattr(obj)
-    def delete(self, obj, child):
-        prop = child.__class__._attribute_manager.get_history(child, self.key)
-        prop.delattr()
-
-class MTOBackrefExtension(AttributeExtension):
+class GenericBackrefExtension(AttributeExtension):
     def __init__(self, key):
         self.key = key
     def set(self, obj, child, oldchild):
@@ -251,6 +236,13 @@ class MTOBackrefExtension(AttributeExtension):
         if child is not None:
             prop = child.__class__._attribute_manager.get_history(child, self.key)
             prop.append(obj)
+    def append(self, obj, child):
+        prop = child.__class__._attribute_manager.get_history(child, self.key)
+        prop.append(obj)
+    def delete(self, obj, child):
+        prop = child.__class__._attribute_manager.get_history(child, self.key)
+        prop.remove(obj)
+
             
 class AttributeManager(object):
     """maintains a set of per-attribute history container objects for a set of objects."""
index 5f15a8d2c68850021dd97cda91935adb6b0da109..9eb97aa0c1afe17fc08bdea8cdc3beedd569d798 100644 (file)
@@ -196,12 +196,7 @@ class PropertyLoader(MapperProperty):
             # if a backref name is defined, set up an extension to populate 
             # attributes in the other direction
             if self.backref is not None:
-                if self.direction == PropertyLoader.ONETOMANY:
-                    self.attributeext = attributes.OTMBackrefExtension(self.backref)
-                elif self.direction == PropertyLoader.MANYTOONE:
-                    self.attributeext = attributes.MTOBackrefExtension(self.backref)
-                else:
-                    self.attributeext = attributes.ListBackrefExtension(self.backref)
+                self.attributeext = attributes.GenericBackrefExtension(self.backref)
         
             # set our class attribute
             self._set_class_attribute(parent.class_, key)
@@ -521,10 +516,11 @@ class PropertyLoader(MapperProperty):
                 childlist = getlist(obj, passive=True)
                 if childlist is None: continue
                 for child in childlist.added_items():
-                    self._synchronize(obj, child, None, False)
-                    if self.direction == PropertyLoader.ONETOMANY:
-                        # for a cyclical task, this registration is handled by the objectstore
-                        uowcommit.register_object(child)
+                    if child is not None:
+                        self._synchronize(obj, child, None, False)
+                        if self.direction == PropertyLoader.ONETOMANY:
+                            # for a cyclical task, this registration is handled by the objectstore
+                            uowcommit.register_object(child)
                 if self.direction != PropertyLoader.MANYTOONE or len(childlist.added_items()) == 0:
                     for child in childlist.deleted_items():
                         if not self.private:
index 074a8d8391912de08d9d17c9c2f142718f980f5e..ca34bdfc7af05af24b90b4f0c3cc6e26bdc1919a 100644 (file)
@@ -81,8 +81,8 @@ class AttributesTest(PersistTest):
         class Student(object):pass
         class Course(object):pass
         manager = attributes.AttributeManager()
-        manager.register_attribute(Student, 'courses', uselist=True, extension=attributes.ListBackrefExtension('students'))
-        manager.register_attribute(Course, 'students', uselist=True, extension=attributes.ListBackrefExtension('courses'))
+        manager.register_attribute(Student, 'courses', uselist=True, extension=attributes.GenericBackrefExtension('students'))
+        manager.register_attribute(Course, 'students', uselist=True, extension=attributes.GenericBackrefExtension('courses'))
         
         s = Student()
         c = Course()
@@ -102,8 +102,8 @@ class AttributesTest(PersistTest):
         class Post(object):pass
         class Blog(object):pass
         
-        manager.register_attribute(Post, 'blog', uselist=False, extension=attributes.MTOBackrefExtension('posts'))
-        manager.register_attribute(Blog, 'posts', uselist=True, extension=attributes.OTMBackrefExtension('blog'))
+        manager.register_attribute(Post, 'blog', uselist=False, extension=attributes.GenericBackrefExtension('posts'))
+        manager.register_attribute(Blog, 'posts', uselist=True, extension=attributes.GenericBackrefExtension('blog'))
         b = Blog()
         (p1, p2, p3) = (Post(), Post(), Post())
         b.posts.append(p1)
@@ -118,8 +118,19 @@ class AttributesTest(PersistTest):
         p4.blog = b
         self.assert_(b.posts == [p1, p2, p4])
         
+
+        class Port(object):pass
+        class Jack(object):pass
+        manager.register_attribute(Port, 'jack', uselist=False, extension=attributes.GenericBackrefExtension('port'))
+        manager.register_attribute(Jack, 'port', uselist=False, extension=attributes.GenericBackrefExtension('jack'))
+        p = Port()
+        j = Jack()
+        p.jack = j
+        self.assert_(j.port is p)
+        self.assert_(p.jack is not None)
         
-        
+        j.port = None
+        self.assert_(p.jack is None)
         
 if __name__ == "__main__":
     unittest.main()