]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
working the backref attributes thing. many-to-many unittest works now...
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 6 Dec 2005 06:45:44 +0000 (06:45 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 6 Dec 2005 06:45:44 +0000 (06:45 +0000)
lib/sqlalchemy/attributes.py
lib/sqlalchemy/mapping/__init__.py
lib/sqlalchemy/mapping/objectstore.py
lib/sqlalchemy/mapping/properties.py
lib/sqlalchemy/util.py
test/attributes.py
test/manytomany.py
test/objectstore.py

index 9bdfcab6a1ec43358ccb672629c390c24a42dcda..72b02638784498dfa8b2be7803024fd5b1daa280 100644 (file)
@@ -44,17 +44,19 @@ class PropHistory(object):
     """manages the value of a particular scalar attribute on a particular object instance."""
     # make our own NONE to distinguish from "None"
     NONE = object()
-    def __init__(self, obj, key, backrefmanager=None, **kwargs):
+    def __init__(self, obj, key, extension=None, **kwargs):
         self.obj = obj
         self.key = key
         self.orig = PropHistory.NONE
-        self.backrefmanager = backrefmanager
+        self.extension = extension
     def gethistory(self, *args, **kwargs):
         return self
     def history_contains(self, obj):
         return self.orig is obj or self.obj.__dict__[self.key] is obj
     def setattr_clean(self, value):
         self.obj.__dict__[self.key] = value
+    def delattr_clean(self):
+        del self.obj.__dict__[self.key]
     def getattr(self):
         return self.obj.__dict__[self.key]
     def setattr(self, value):
@@ -62,13 +64,13 @@ class PropHistory(object):
             raise ("assigning a list to scalar property '%s' on '%s' instance %d" % (self.key, self.obj.__class__.__name__, id(self.obj)))
         self.orig = self.obj.__dict__.get(self.key, None)
         self.obj.__dict__[self.key] = value
-        if self.backrefmanager is not None and self.orig is not value:
-            self.backrefmanager.set(self.obj, value, self.orig)
+        if self.extension is not None and self.orig is not value:
+            self.extension.set(self.obj, value, self.orig)
     def delattr(self):
         self.orig = self.obj.__dict__.get(self.key, None)
         self.obj.__dict__[self.key] = None
-        if self.backrefmanager is not None:
-            self.backrefmanager.set(self.obj, None, self.orig)
+        if self.extension is not None:
+            self.extension.set(self.obj, None, self.orig)
     def rollback(self):
         if self.orig is not PropHistory.NONE:
             self.obj.__dict__[self.key] = self.orig
@@ -93,10 +95,10 @@ class PropHistory(object):
 
 class ListElement(util.HistoryArraySet):
     """manages the value of a particular list-based attribute on a particular object instance."""
-    def __init__(self, obj, key, data=None, backrefmanager=None, **kwargs):
+    def __init__(self, obj, key, data=None, extension=None, **kwargs):
         self.obj = obj
         self.key = key
-        self.backrefmanager = backrefmanager
+        self.extension = extension
         # if we are given a list, try to behave nicely with an existing
         # list that might be set on the object already
         try:
@@ -126,15 +128,15 @@ class ListElement(util.HistoryArraySet):
         res = util.HistoryArraySet._setrecord(self, item)
         if res:
             self.list_value_changed(self.obj, self.key, item, self, False)
-            if self.backrefmanager is not None:
-                self.backrefmanager.append(self.obj, item)
+            if self.extension is not None:
+                self.extension.append(self.obj, item)
         return res
     def _delrecord(self, item):
         res = util.HistoryArraySet._delrecord(self, item)
         if res:
             self.list_value_changed(self.obj, self.key, item, self, True)
-            if self.backrefmanager is not None:
-                self.backrefmanager.delete(self.obj, item)
+            if self.extension is not None:
+                self.extension.delete(self.obj, item)
         return res
 
 class CallableProp(object):
@@ -185,38 +187,42 @@ class CallableProp(object):
     def rollback(self):
         pass
 
-class BackrefManager(object):
-    def __init__(self, key):
-        self.key = key
-    def append(self, parent, child):
+class AttributeExtension(object):
+    def append(self, obj, child):
         pass
-    def delete(self, parent, child):
+    def delete(self, obj, child):
         pass
-    def set(self, parent, child, oldchild):
+    def set(self, obj, child, oldchild):
         pass
 
+class ListBackrefExtension(AttributeExtension):
+    def __init__(self, key):
+        self.key = key
+    def append(self, obj, child):
+        getattr(child, self.key).append_nohistory(obj)
+    def delete(self, obj, child):
+        getattr(child, self.key).remove_nohistory(obj)
 
-class ListBackrefManager(BackrefManager):
-    def append(self, parent, child):
-        getattr(child, self.key).append(parent)
-    def delete(self, parent, child):
-        getattr(child, self.key).remove(parent)
-
-class OneToManyBackrefManager(BackrefManager):
-    def append(self, parent, child):
-        setattr(child, self.key, parent)
-    def delete(self, parent, child):
-        setattr(child, self.key, None)
+class OTMBackrefExtension(AttributeExtension):
+    def __init__(self, key):
+        self.key = key
+    def append(self, obj, child):
+        prop = child.__class__._attribute_manager.get_history(child, self.key)
+        prop.setattr_clean(obj)
+#        prop.setattr(obj)
+    def delete(self, obj, child):
+        prop = child.__class__._attribute_manager.get_history(child, self.key)
+        prop.delattr_clean()
 
-class ManyToOneBackrefManager(BackrefManager):
-    def set(self, parent, child, oldchild):
+class MTOBackrefExtension(AttributeExtension):
+    def __init__(self, key):
+        self.key = key
+    def set(self, obj, child, oldchild):
         if oldchild is not None:
-            try:
-                getattr(oldchild, self.key).remove(parent)
-            except:
-                print "wha? oldchild is ", repr(oldchild)
+            getattr(oldchild, self.key).remove_nohistory(obj)
         if child is not None:
-            getattr(child, self.key).append(parent)
+            getattr(child, self.key).append_nohistory(obj)
+#            getattr(child, self.key).append(obj)
             
 class AttributeManager(object):
     """maintains a set of per-attribute callable/history manager objects for a set of objects."""
@@ -305,6 +311,7 @@ class AttributeManager(object):
         except AttributeError:
             attr = {}
             class_._class_managed_attributes = attr
+            class_._attribute_manager = self
         return attr
 
 
index ebc35c06b7edc67685663488d77a818d81465ac6..15624af24eea1d85b85a920838488c1a31a21e55 100644 (file)
@@ -53,10 +53,15 @@ def _relation_loader(mapper, secondary=None, primaryjoin=None, secondaryjoin=Non
 
 def _relation_mapper(class_, table=None, secondary=None, 
                     primaryjoin=None, secondaryjoin=None, 
-                    foreignkey=None, uselist=None, private=False, live=False, association=None, lazy=True, selectalias=None, order_by=None, **kwargs):
+                    foreignkey=None, uselist=None, private=False, 
+                    live=False, association=None, lazy=True, 
+                    selectalias=None, order_by=None, attributeext=None, **kwargs):
 
-    return _relation_loader(mapper(class_, table, **kwargs), secondary, primaryjoin, secondaryjoin, 
-                    foreignkey=foreignkey, uselist=uselist, private=private, live=live, association=association, lazy=lazy, selectalias=selectalias, order_by=order_by)
+    return _relation_loader(mapper(class_, table, **kwargs), 
+                    secondary, primaryjoin, secondaryjoin, 
+                    foreignkey=foreignkey, uselist=uselist, private=private, 
+                    live=live, association=association, lazy=lazy, 
+                    selectalias=selectalias, order_by=order_by, attributeext=attributeext)
 
 class assignmapper(object):
     """provides a property object that will instantiate a Mapper for a given class the first
index 0efd54aea55cae8e09f807bfc6679c9e41f84f7f..b86f7a325ac29e4e87d3dc32d4a432f8eb0d8328 100644 (file)
@@ -666,7 +666,14 @@ class UOWTask(object):
         return s
         
     def __repr__(self):
-        return ("UOWTask/%d Table: '%s'" % (id(self), self.mapper and self.mapper.primarytable.name or '(none)'))
+        if self.mapper is not None:
+            if self.mapper.__class__.__name__ == 'Mapper':
+                name = self.mapper.primarytable.name
+            else:
+                name = repr(self.mapper)
+        else:
+            name = '(none)'
+        return ("UOWTask/%d Table: '%s'" % (id(self), name))
 
 class DependencySorter(topological.QueueDependencySorter):
     pass
index 5c037f0c552c1c6aaf8824c6c94f375609fd474b..d1fd2315ed787235799aa91f62fd487a1e66815f 100644 (file)
@@ -62,7 +62,7 @@ class PropertyLoader(MapperProperty):
 
     """describes an object property that holds a single item or list of items that correspond
     to a related database table."""
-    def __init__(self, argument, secondary, primaryjoin, secondaryjoin, foreignkey=None, uselist=None, private=False, live=False, isoption=False, association=None, selectalias=None, order_by=None, **kwargs):
+    def __init__(self, argument, secondary, primaryjoin, secondaryjoin, foreignkey=None, uselist=None, private=False, live=False, isoption=False, association=None, selectalias=None, order_by=None, attributeext=None):
         self.uselist = uselist
         self.argument = argument
         self.secondary = secondary
@@ -75,6 +75,7 @@ class PropertyLoader(MapperProperty):
         self.association = association
         self.selectalias = selectalias
         self.order_by=util.to_list(order_by)
+        self.attributeext=attributeext
         self._hash_key = "%s(%s, %s, %s, %s, %s, %s, %s, %s)" % (self.__class__.__name__, hash_key(self.argument), hash_key(secondary), hash_key(primaryjoin), hash_key(secondaryjoin), hash_key(foreignkey), repr(uselist), repr(private), hash_key(self.order_by))
 
     def _copy(self):
@@ -137,7 +138,7 @@ class PropertyLoader(MapperProperty):
         
     def _set_class_attribute(self, class_, key):
         """sets attribute behavior on our target class."""
-        objectstore.uow().register_attribute(class_, key, uselist = self.uselist, deleteremoved = self.private)
+        objectstore.uow().register_attribute(class_, key, uselist = self.uselist, deleteremoved = self.private, extension=self.attributeext)
         
     def _get_direction(self):
         if self.parent.primarytable is self.target:
@@ -454,7 +455,7 @@ class LazyLoader(PropertyLoader):
     def _set_class_attribute(self, class_, key):
         # establish a class-level lazy loader on our class
         #print "SETCLASSATTR LAZY", repr(class_), key
-        objectstore.global_attributes.register_attribute(class_, key, uselist = self.uselist, deleteremoved = self.private, live=self.live, callable_=lambda i: self.setup_loader(i))
+        objectstore.global_attributes.register_attribute(class_, key, uselist = self.uselist, deleteremoved = self.private, live=self.live, callable_=lambda i: self.setup_loader(i), extension=self.attributeext)
 
     def setup_loader(self, instance):
         def lazyload():
@@ -612,7 +613,7 @@ class EagerLoader(PropertyLoader):
             
         if not self.uselist:
             if isnew:
-                h.setattr(self._instance(row, imap))
+                h.setattr_clean(self._instance(row, imap))
             return
         elif isnew:
             result_list = h
index 443db2e3d5f87cb8d181b41a5653dc097ad69fb8..c247c86edfbd46060f52b55fb8fc2c7230f9279a 100644 (file)
@@ -243,6 +243,10 @@ class HistoryArraySet(UserList.UserList):
         if not self.records.has_key(item):
             self.records[item] = None
             self.data.append(item)
+    def remove_nohistory(self, item):
+        if self.records.has_key(item):
+            del self.records[item]
+            self.data.remove(item)
     def has_item(self, item):
         return self.records.has_key(item)
     def __setitem__(self, i, item): 
index b546de533fddd3e93dae27adb6d9cb5613be182d..074a8d8391912de08d9d17c9c2f142718f980f5e 100644 (file)
@@ -81,8 +81,8 @@ class AttributesTest(PersistTest):
         class Student(object):pass
         class Course(object):pass
         manager = attributes.AttributeManager()
-        manager.register_attribute(Student, 'courses', uselist=True, backrefmanager=attributes.ListBackrefManager('students'))
-        manager.register_attribute(Course, 'students', uselist=True, backrefmanager=attributes.ListBackrefManager('courses'))
+        manager.register_attribute(Student, 'courses', uselist=True, extension=attributes.ListBackrefExtension('students'))
+        manager.register_attribute(Course, 'students', uselist=True, extension=attributes.ListBackrefExtension('courses'))
         
         s = Student()
         c = Course()
@@ -102,8 +102,8 @@ class AttributesTest(PersistTest):
         class Post(object):pass
         class Blog(object):pass
         
-        manager.register_attribute(Post, 'blog', uselist=False, backrefmanager=attributes.ManyToOneBackrefManager('posts'))
-        manager.register_attribute(Blog, 'posts', uselist=True, backrefmanager=attributes.OneToManyBackrefManager('blog'))
+        manager.register_attribute(Post, 'blog', uselist=False, extension=attributes.MTOBackrefExtension('posts'))
+        manager.register_attribute(Blog, 'posts', uselist=True, extension=attributes.OTMBackrefExtension('blog'))
         b = Blog()
         (p1, p2, p3) = (Post(), Post(), Post())
         b.posts.append(p1)
index e4b599d6b51e075f125bdf36fb227b557e8e3884..cd38f38c3beab3055bd53980d53dda239f88ccd5 100644 (file)
@@ -1,6 +1,7 @@
 from sqlalchemy import *
 import testbase
 import string
+import sqlalchemy.attributes as attr
 
 class Place(object):
     '''represents a place'''
@@ -97,12 +98,12 @@ class ManyToManyTest(testbase.AssertMixin):
         "break off" a new "mapper stub" to indicate a third depedendent processor."""
         Place.mapper = mapper(Place, place)
         Transition.mapper = mapper(Transition, transition, properties = dict(
-            inputs = relation(Place.mapper, place_output, lazy=True),
-            outputs = relation(Place.mapper, place_input, lazy=True),
+            inputs = relation(Place.mapper, place_output, lazy=True, attributeext=attr.ListBackrefExtension('inputs')),
+            outputs = relation(Place.mapper, place_input, lazy=True, attributeext=attr.ListBackrefExtension('outputs')),
             )
         )
-        Place.mapper.add_property('inputs', relation(Transition.mapper, place_output, lazy=True))
-        Place.mapper.add_property('outputs', relation(Transition.mapper, place_input, lazy=True))
+        Place.mapper.add_property('inputs', relation(Transition.mapper, place_output, lazy=True, attributeext=attr.ListBackrefExtension('inputs')))
+        Place.mapper.add_property('outputs', relation(Transition.mapper, place_input, lazy=True, attributeext=attr.ListBackrefExtension('outputs')))
 
         Place.eagermapper = Place.mapper.options(
             eagerload('inputs', selectalias='ip_alias'), 
@@ -125,10 +126,10 @@ class ManyToManyTest(testbase.AssertMixin):
         p1.outputs.append(t1)
         
         objectstore.commit()
-
         
-        l = Place.eagermapper.select()
-        print repr(l)
+        self.assert_result([t1], Transition, {'outputs': (Place, [{'name':'place3'}, {'name':'place1'}])})
+        self.assert_result([p2], Place, {'inputs': (Transition, [{'name':'transition1'},{'name':'transition2'}])})
+        
 
 if __name__ == "__main__":    
     testbase.main()
index 0942e41eec8d83b6f3babd682638233e6f7b1728..56da630c37273984983bf757f193c8d307ea9640 100644 (file)
@@ -19,6 +19,9 @@ class HistoryTest(AssertMixin):
         addresses.drop()
         users.drop()
         db.echo = testbase.echo
+    def setUp(self):
+        objectstore.clear()
+        clear_mappers()
 
     def testattr(self):
         """tests the rolling back of scalar and list attributes.  this kind of thing
@@ -49,6 +52,24 @@ class HistoryTest(AssertMixin):
         ]
         self.assert_result([u], data[0], *data[1:])
 
+    def testbackref(self):
+        class User(object):pass
+        class Address(object):pass
+        am = mapper(Address, addresses)
+        m = mapper(User, users, properties = dict(
+            addresses = relation(am, attributeext=attributes.OTMBackrefExtension('user')))
+        )
+        am.add_property('user', relation(m, attributeext=attributes.MTOBackrefExtension('addresses')))
+        
+        u = User()
+        a = Address()
+        a.user = u
+        #print repr(a.__class__._attribute_manager.get_history(a, 'user').added_items())
+        #print repr(u.addresses.added_items())
+        self.assert_(u.addresses == [a])
+        objectstore.commit()
+        
+
 class PKTest(AssertMixin):
     def setUpAll(self):
         db.echo = False