From: Mike Bayer Date: Tue, 6 Dec 2005 06:45:44 +0000 (+0000) Subject: working the backref attributes thing. many-to-many unittest works now... X-Git-Tag: rel_0_1_0~253 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=478b0e15ed70ae109e76b696efe151b7acac036b;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git working the backref attributes thing. many-to-many unittest works now... --- diff --git a/lib/sqlalchemy/attributes.py b/lib/sqlalchemy/attributes.py index 9bdfcab6a1..72b0263878 100644 --- a/lib/sqlalchemy/attributes.py +++ b/lib/sqlalchemy/attributes.py @@ -44,17 +44,19 @@ class PropHistory(object): """manages the value of a particular scalar attribute on a particular object instance.""" # make our own NONE to distinguish from "None" NONE = object() - def __init__(self, obj, key, backrefmanager=None, **kwargs): + def __init__(self, obj, key, extension=None, **kwargs): self.obj = obj self.key = key self.orig = PropHistory.NONE - self.backrefmanager = backrefmanager + self.extension = extension def gethistory(self, *args, **kwargs): return self def history_contains(self, obj): return self.orig is obj or self.obj.__dict__[self.key] is obj def setattr_clean(self, value): self.obj.__dict__[self.key] = value + def delattr_clean(self): + del self.obj.__dict__[self.key] def getattr(self): return self.obj.__dict__[self.key] def setattr(self, value): @@ -62,13 +64,13 @@ class PropHistory(object): raise ("assigning a list to scalar property '%s' on '%s' instance %d" % (self.key, self.obj.__class__.__name__, id(self.obj))) self.orig = self.obj.__dict__.get(self.key, None) self.obj.__dict__[self.key] = value - if self.backrefmanager is not None and self.orig is not value: - self.backrefmanager.set(self.obj, value, self.orig) + if self.extension is not None and self.orig is not value: + self.extension.set(self.obj, value, self.orig) def delattr(self): self.orig = self.obj.__dict__.get(self.key, None) self.obj.__dict__[self.key] = None - if self.backrefmanager is not None: - self.backrefmanager.set(self.obj, None, self.orig) + if self.extension is not None: + self.extension.set(self.obj, None, self.orig) def rollback(self): if self.orig is not PropHistory.NONE: self.obj.__dict__[self.key] = self.orig @@ -93,10 +95,10 @@ class PropHistory(object): class ListElement(util.HistoryArraySet): """manages the value of a particular list-based attribute on a particular object instance.""" - def __init__(self, obj, key, data=None, backrefmanager=None, **kwargs): + def __init__(self, obj, key, data=None, extension=None, **kwargs): self.obj = obj self.key = key - self.backrefmanager = backrefmanager + self.extension = extension # if we are given a list, try to behave nicely with an existing # list that might be set on the object already try: @@ -126,15 +128,15 @@ class ListElement(util.HistoryArraySet): res = util.HistoryArraySet._setrecord(self, item) if res: self.list_value_changed(self.obj, self.key, item, self, False) - if self.backrefmanager is not None: - self.backrefmanager.append(self.obj, item) + if self.extension is not None: + self.extension.append(self.obj, item) return res def _delrecord(self, item): res = util.HistoryArraySet._delrecord(self, item) if res: self.list_value_changed(self.obj, self.key, item, self, True) - if self.backrefmanager is not None: - self.backrefmanager.delete(self.obj, item) + if self.extension is not None: + self.extension.delete(self.obj, item) return res class CallableProp(object): @@ -185,38 +187,42 @@ class CallableProp(object): def rollback(self): pass -class BackrefManager(object): - def __init__(self, key): - self.key = key - def append(self, parent, child): +class AttributeExtension(object): + def append(self, obj, child): pass - def delete(self, parent, child): + def delete(self, obj, child): pass - def set(self, parent, child, oldchild): + def set(self, obj, child, oldchild): pass +class ListBackrefExtension(AttributeExtension): + def __init__(self, key): + self.key = key + def append(self, obj, child): + getattr(child, self.key).append_nohistory(obj) + def delete(self, obj, child): + getattr(child, self.key).remove_nohistory(obj) -class ListBackrefManager(BackrefManager): - def append(self, parent, child): - getattr(child, self.key).append(parent) - def delete(self, parent, child): - getattr(child, self.key).remove(parent) - -class OneToManyBackrefManager(BackrefManager): - def append(self, parent, child): - setattr(child, self.key, parent) - def delete(self, parent, child): - setattr(child, self.key, None) +class OTMBackrefExtension(AttributeExtension): + def __init__(self, key): + self.key = key + def append(self, obj, child): + prop = child.__class__._attribute_manager.get_history(child, self.key) + prop.setattr_clean(obj) +# prop.setattr(obj) + def delete(self, obj, child): + prop = child.__class__._attribute_manager.get_history(child, self.key) + prop.delattr_clean() -class ManyToOneBackrefManager(BackrefManager): - def set(self, parent, child, oldchild): +class MTOBackrefExtension(AttributeExtension): + def __init__(self, key): + self.key = key + def set(self, obj, child, oldchild): if oldchild is not None: - try: - getattr(oldchild, self.key).remove(parent) - except: - print "wha? oldchild is ", repr(oldchild) + getattr(oldchild, self.key).remove_nohistory(obj) if child is not None: - getattr(child, self.key).append(parent) + getattr(child, self.key).append_nohistory(obj) +# getattr(child, self.key).append(obj) class AttributeManager(object): """maintains a set of per-attribute callable/history manager objects for a set of objects.""" @@ -305,6 +311,7 @@ class AttributeManager(object): except AttributeError: attr = {} class_._class_managed_attributes = attr + class_._attribute_manager = self return attr diff --git a/lib/sqlalchemy/mapping/__init__.py b/lib/sqlalchemy/mapping/__init__.py index ebc35c06b7..15624af24e 100644 --- a/lib/sqlalchemy/mapping/__init__.py +++ b/lib/sqlalchemy/mapping/__init__.py @@ -53,10 +53,15 @@ def _relation_loader(mapper, secondary=None, primaryjoin=None, secondaryjoin=Non def _relation_mapper(class_, table=None, secondary=None, primaryjoin=None, secondaryjoin=None, - foreignkey=None, uselist=None, private=False, live=False, association=None, lazy=True, selectalias=None, order_by=None, **kwargs): + foreignkey=None, uselist=None, private=False, + live=False, association=None, lazy=True, + selectalias=None, order_by=None, attributeext=None, **kwargs): - return _relation_loader(mapper(class_, table, **kwargs), secondary, primaryjoin, secondaryjoin, - foreignkey=foreignkey, uselist=uselist, private=private, live=live, association=association, lazy=lazy, selectalias=selectalias, order_by=order_by) + return _relation_loader(mapper(class_, table, **kwargs), + secondary, primaryjoin, secondaryjoin, + foreignkey=foreignkey, uselist=uselist, private=private, + live=live, association=association, lazy=lazy, + selectalias=selectalias, order_by=order_by, attributeext=attributeext) class assignmapper(object): """provides a property object that will instantiate a Mapper for a given class the first diff --git a/lib/sqlalchemy/mapping/objectstore.py b/lib/sqlalchemy/mapping/objectstore.py index 0efd54aea5..b86f7a325a 100644 --- a/lib/sqlalchemy/mapping/objectstore.py +++ b/lib/sqlalchemy/mapping/objectstore.py @@ -666,7 +666,14 @@ class UOWTask(object): return s def __repr__(self): - return ("UOWTask/%d Table: '%s'" % (id(self), self.mapper and self.mapper.primarytable.name or '(none)')) + if self.mapper is not None: + if self.mapper.__class__.__name__ == 'Mapper': + name = self.mapper.primarytable.name + else: + name = repr(self.mapper) + else: + name = '(none)' + return ("UOWTask/%d Table: '%s'" % (id(self), name)) class DependencySorter(topological.QueueDependencySorter): pass diff --git a/lib/sqlalchemy/mapping/properties.py b/lib/sqlalchemy/mapping/properties.py index 5c037f0c55..d1fd2315ed 100644 --- a/lib/sqlalchemy/mapping/properties.py +++ b/lib/sqlalchemy/mapping/properties.py @@ -62,7 +62,7 @@ class PropertyLoader(MapperProperty): """describes an object property that holds a single item or list of items that correspond to a related database table.""" - def __init__(self, argument, secondary, primaryjoin, secondaryjoin, foreignkey=None, uselist=None, private=False, live=False, isoption=False, association=None, selectalias=None, order_by=None, **kwargs): + def __init__(self, argument, secondary, primaryjoin, secondaryjoin, foreignkey=None, uselist=None, private=False, live=False, isoption=False, association=None, selectalias=None, order_by=None, attributeext=None): self.uselist = uselist self.argument = argument self.secondary = secondary @@ -75,6 +75,7 @@ class PropertyLoader(MapperProperty): self.association = association self.selectalias = selectalias self.order_by=util.to_list(order_by) + self.attributeext=attributeext self._hash_key = "%s(%s, %s, %s, %s, %s, %s, %s, %s)" % (self.__class__.__name__, hash_key(self.argument), hash_key(secondary), hash_key(primaryjoin), hash_key(secondaryjoin), hash_key(foreignkey), repr(uselist), repr(private), hash_key(self.order_by)) def _copy(self): @@ -137,7 +138,7 @@ class PropertyLoader(MapperProperty): def _set_class_attribute(self, class_, key): """sets attribute behavior on our target class.""" - objectstore.uow().register_attribute(class_, key, uselist = self.uselist, deleteremoved = self.private) + objectstore.uow().register_attribute(class_, key, uselist = self.uselist, deleteremoved = self.private, extension=self.attributeext) def _get_direction(self): if self.parent.primarytable is self.target: @@ -454,7 +455,7 @@ class LazyLoader(PropertyLoader): def _set_class_attribute(self, class_, key): # establish a class-level lazy loader on our class #print "SETCLASSATTR LAZY", repr(class_), key - objectstore.global_attributes.register_attribute(class_, key, uselist = self.uselist, deleteremoved = self.private, live=self.live, callable_=lambda i: self.setup_loader(i)) + objectstore.global_attributes.register_attribute(class_, key, uselist = self.uselist, deleteremoved = self.private, live=self.live, callable_=lambda i: self.setup_loader(i), extension=self.attributeext) def setup_loader(self, instance): def lazyload(): @@ -612,7 +613,7 @@ class EagerLoader(PropertyLoader): if not self.uselist: if isnew: - h.setattr(self._instance(row, imap)) + h.setattr_clean(self._instance(row, imap)) return elif isnew: result_list = h diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index 443db2e3d5..c247c86edf 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -243,6 +243,10 @@ class HistoryArraySet(UserList.UserList): if not self.records.has_key(item): self.records[item] = None self.data.append(item) + def remove_nohistory(self, item): + if self.records.has_key(item): + del self.records[item] + self.data.remove(item) def has_item(self, item): return self.records.has_key(item) def __setitem__(self, i, item): diff --git a/test/attributes.py b/test/attributes.py index b546de533f..074a8d8391 100644 --- a/test/attributes.py +++ b/test/attributes.py @@ -81,8 +81,8 @@ class AttributesTest(PersistTest): class Student(object):pass class Course(object):pass manager = attributes.AttributeManager() - manager.register_attribute(Student, 'courses', uselist=True, backrefmanager=attributes.ListBackrefManager('students')) - manager.register_attribute(Course, 'students', uselist=True, backrefmanager=attributes.ListBackrefManager('courses')) + manager.register_attribute(Student, 'courses', uselist=True, extension=attributes.ListBackrefExtension('students')) + manager.register_attribute(Course, 'students', uselist=True, extension=attributes.ListBackrefExtension('courses')) s = Student() c = Course() @@ -102,8 +102,8 @@ class AttributesTest(PersistTest): class Post(object):pass class Blog(object):pass - manager.register_attribute(Post, 'blog', uselist=False, backrefmanager=attributes.ManyToOneBackrefManager('posts')) - manager.register_attribute(Blog, 'posts', uselist=True, backrefmanager=attributes.OneToManyBackrefManager('blog')) + manager.register_attribute(Post, 'blog', uselist=False, extension=attributes.MTOBackrefExtension('posts')) + manager.register_attribute(Blog, 'posts', uselist=True, extension=attributes.OTMBackrefExtension('blog')) b = Blog() (p1, p2, p3) = (Post(), Post(), Post()) b.posts.append(p1) diff --git a/test/manytomany.py b/test/manytomany.py index e4b599d6b5..cd38f38c3b 100644 --- a/test/manytomany.py +++ b/test/manytomany.py @@ -1,6 +1,7 @@ from sqlalchemy import * import testbase import string +import sqlalchemy.attributes as attr class Place(object): '''represents a place''' @@ -97,12 +98,12 @@ class ManyToManyTest(testbase.AssertMixin): "break off" a new "mapper stub" to indicate a third depedendent processor.""" Place.mapper = mapper(Place, place) Transition.mapper = mapper(Transition, transition, properties = dict( - inputs = relation(Place.mapper, place_output, lazy=True), - outputs = relation(Place.mapper, place_input, lazy=True), + inputs = relation(Place.mapper, place_output, lazy=True, attributeext=attr.ListBackrefExtension('inputs')), + outputs = relation(Place.mapper, place_input, lazy=True, attributeext=attr.ListBackrefExtension('outputs')), ) ) - Place.mapper.add_property('inputs', relation(Transition.mapper, place_output, lazy=True)) - Place.mapper.add_property('outputs', relation(Transition.mapper, place_input, lazy=True)) + Place.mapper.add_property('inputs', relation(Transition.mapper, place_output, lazy=True, attributeext=attr.ListBackrefExtension('inputs'))) + Place.mapper.add_property('outputs', relation(Transition.mapper, place_input, lazy=True, attributeext=attr.ListBackrefExtension('outputs'))) Place.eagermapper = Place.mapper.options( eagerload('inputs', selectalias='ip_alias'), @@ -125,10 +126,10 @@ class ManyToManyTest(testbase.AssertMixin): p1.outputs.append(t1) objectstore.commit() - - l = Place.eagermapper.select() - print repr(l) + self.assert_result([t1], Transition, {'outputs': (Place, [{'name':'place3'}, {'name':'place1'}])}) + self.assert_result([p2], Place, {'inputs': (Transition, [{'name':'transition1'},{'name':'transition2'}])}) + if __name__ == "__main__": testbase.main() diff --git a/test/objectstore.py b/test/objectstore.py index 0942e41eec..56da630c37 100644 --- a/test/objectstore.py +++ b/test/objectstore.py @@ -19,6 +19,9 @@ class HistoryTest(AssertMixin): addresses.drop() users.drop() db.echo = testbase.echo + def setUp(self): + objectstore.clear() + clear_mappers() def testattr(self): """tests the rolling back of scalar and list attributes. this kind of thing @@ -49,6 +52,24 @@ class HistoryTest(AssertMixin): ] self.assert_result([u], data[0], *data[1:]) + def testbackref(self): + class User(object):pass + class Address(object):pass + am = mapper(Address, addresses) + m = mapper(User, users, properties = dict( + addresses = relation(am, attributeext=attributes.OTMBackrefExtension('user'))) + ) + am.add_property('user', relation(m, attributeext=attributes.MTOBackrefExtension('addresses'))) + + u = User() + a = Address() + a.user = u + #print repr(a.__class__._attribute_manager.get_history(a, 'user').added_items()) + #print repr(u.addresses.added_items()) + self.assert_(u.addresses == [a]) + objectstore.commit() + + class PKTest(AssertMixin): def setUpAll(self): db.echo = False