From 69ad2955bdb33eb45939a01d95bcff240a2d9fb6 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Fri, 9 Dec 2005 05:08:51 +0000 Subject: [PATCH] build in 'backref' property argument --- lib/sqlalchemy/attributes.py | 17 +++++----- lib/sqlalchemy/mapping/properties.py | 48 ++++++++++++++++++++++------ lib/sqlalchemy/util.py | 38 +++++++++++++++++++++- test/manytomany.py | 8 ++--- test/objectstore.py | 4 +-- 5 files changed, 90 insertions(+), 25 deletions(-) diff --git a/lib/sqlalchemy/attributes.py b/lib/sqlalchemy/attributes.py index 72b0263878..c42a54be59 100644 --- a/lib/sqlalchemy/attributes.py +++ b/lib/sqlalchemy/attributes.py @@ -123,7 +123,7 @@ class ListElement(util.HistoryArraySet): self.obj.__dict__[self.key] = value self.set_data(value) def delattr(self, value): - pass + pass def _setrecord(self, item): res = util.HistoryArraySet._setrecord(self, item) if res: @@ -194,34 +194,33 @@ class AttributeExtension(object): pass def set(self, obj, child, oldchild): pass - + class ListBackrefExtension(AttributeExtension): def __init__(self, key): self.key = key def append(self, obj, child): - getattr(child, self.key).append_nohistory(obj) + getattr(child, self.key).append(obj) def delete(self, obj, child): - getattr(child, self.key).remove_nohistory(obj) - + getattr(child, self.key).remove(obj) class OTMBackrefExtension(AttributeExtension): def __init__(self, key): self.key = key def append(self, obj, child): prop = child.__class__._attribute_manager.get_history(child, self.key) - prop.setattr_clean(obj) + prop.setattr(obj) # prop.setattr(obj) def delete(self, obj, child): prop = child.__class__._attribute_manager.get_history(child, self.key) - prop.delattr_clean() + prop.delattr() class MTOBackrefExtension(AttributeExtension): def __init__(self, key): self.key = key def set(self, obj, child, oldchild): if oldchild is not None: - getattr(oldchild, self.key).remove_nohistory(obj) + getattr(oldchild, self.key).remove(obj) if child is not None: - getattr(child, self.key).append_nohistory(obj) + getattr(child, self.key).append(obj) # getattr(child, self.key).append(obj) class AttributeManager(object): diff --git a/lib/sqlalchemy/mapping/properties.py b/lib/sqlalchemy/mapping/properties.py index 4b82157d33..08897f2e61 100644 --- a/lib/sqlalchemy/mapping/properties.py +++ b/lib/sqlalchemy/mapping/properties.py @@ -21,6 +21,7 @@ import sqlalchemy.sql as sql import sqlalchemy.schema as schema import sqlalchemy.engine as engine import sqlalchemy.util as util +import sqlalchemy.attributes as attributes import mapper import objectstore import random @@ -56,13 +57,13 @@ class ColumnProperty(MapperProperty): mapper.ColumnProperty = ColumnProperty class PropertyLoader(MapperProperty): - LEFT = 0 - RIGHT = 1 - CENTER = 2 + LEFT = 0 # one-to-many + RIGHT = 1 # many-to-one + CENTER = 2 # many-to-many """describes an object property that holds a single item or list of items that correspond to a related database table.""" - def __init__(self, argument, secondary, primaryjoin, secondaryjoin, foreignkey=None, uselist=None, private=False, live=False, isoption=False, association=None, selectalias=None, order_by=None, attributeext=None): + def __init__(self, argument, secondary, primaryjoin, secondaryjoin, foreignkey=None, uselist=None, private=False, live=False, isoption=False, association=None, selectalias=None, order_by=None, attributeext=None, backref=None, is_backref=False): self.uselist = uselist self.argument = argument self.secondary = secondary @@ -76,6 +77,8 @@ class PropertyLoader(MapperProperty): self.selectalias = selectalias self.order_by=util.to_list(order_by) self.attributeext=attributeext + self.backref = backref + self.is_backref = is_backref self._hash_key = "%s(%s, %s, %s, %s, %s, %s, %s, %s)" % (self.__class__.__name__, hash_key(self.argument), hash_key(secondary), hash_key(primaryjoin), hash_key(secondaryjoin), hash_key(foreignkey), repr(uselist), repr(private), hash_key(self.order_by)) def _copy(self): @@ -126,10 +129,31 @@ class PropertyLoader(MapperProperty): self.uselist = True self._compile_synchronizers() - + + # primary property handler, set up class attributes if self._is_primary(): + # if a backref name is defined, set up an extension to populate + # attributes in the other direction + if self.backref is not None: + if self.direction == PropertyLoader.LEFT: + self.attributeext = attributes.OTMBackrefExtension(self.backref) + elif self.direction == PropertyLoader.RIGHT: + self.attributeext = attributes.MTOBackrefExtension(self.backref) + else: + self.attributeext = attributes.ListBackrefExtension(self.backref) + + # set our class attribute self._set_class_attribute(parent.class_, key) - + + if self.backref is not None: + # try to set a LazyLoader on our mapper referencing the parent mapper + if not self.mapper.props.has_key(self.backref): + self.mapper.add_property(self.backref, LazyLoader(self.parent, self.secondary, self.primaryjoin, self.secondaryjoin, backref=self.key, is_backref=True)); + else: + # else set one of us as the "backreference" + if not self.mapper.props[self.backref].is_backref: + self.is_backref=True + def _is_primary(self): """a return value of True indicates we are the primary PropertyLoader for this loader's attribute on our mapper's class. It means we can set the object's attribute behavior @@ -277,6 +301,11 @@ class PropertyLoader(MapperProperty): # or delete any objects, but just marks a dependency on the two # related mappers. its dependency processor then populates the # association table. + + if self.is_backref: + # if we are the "backref" half of a two-way backref + # relationship, let the other mapper handle inserting the rows + return stub = PropertyLoader.MapperStub() uowcommit.register_dependency(self.parent, stub) uowcommit.register_dependency(self.mapper, stub) @@ -674,10 +703,11 @@ class EagerLazyOption(MapperOption): else: class_ = LazyLoader - for arg in ('primaryjoin', 'secondaryjoin', 'foreignkey', 'uselist', 'private', 'live', 'isoption', 'association', 'selectalias', 'order_by', 'attributeext'): - self.kwargs.setdefault(arg, getattr(oldprop, arg)) + # create a clone of the class using mostly the arguments from the original self.kwargs['isoption'] = True - mapper.set_property(key, class_(submapper, oldprop.secondary, **self.kwargs )) + self.kwargs['argument'] = submapper + kwargs = util.constructor_args(oldprop, **self.kwargs) + mapper.set_property(key, class_(**kwargs )) class Aliasizer(sql.ClauseVisitor): """converts a table instance within an expression to be an alias of that table.""" diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index d280368f69..bc23d6080a 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -16,7 +16,7 @@ # Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. __all__ = ['OrderedProperties', 'OrderedDict'] -import thread, weakref, UserList,string +import thread, weakref, UserList,string, inspect def to_list(x): if x is None: @@ -345,3 +345,39 @@ class ScopedRegistry(object): def _clear_application(self): self.application = createfunc() + + +def constructor_args(instance, **kwargs): + classobj = instance.__class__ + + argspec = inspect.getargspec(classobj.__init__.im_func) + + argnames = argspec[0] or [] + defaultvalues = argspec[3] or [] + + (requiredargs, namedargs) = ( + argnames[0:len(argnames) - len(defaultvalues)], + argnames[len(argnames) - len(defaultvalues):] + ) + + newparams = {} + + for arg in requiredargs: + if arg == 'self': + continue + elif kwargs.has_key(arg): + newparams[arg] = kwargs[arg] + else: + newparams[arg] = getattr(instance, arg) + + for arg in namedargs: + if kwargs.has_key(arg): + newparams[arg] = kwargs[arg] + else: + if hasattr(instance, arg): + newparams[arg] = getattr(instance, arg) + else: + raise "instance has no attribute '%s'" % arg + + return newparams + \ No newline at end of file diff --git a/test/manytomany.py b/test/manytomany.py index cd38f38c3b..f4498dd8f4 100644 --- a/test/manytomany.py +++ b/test/manytomany.py @@ -98,12 +98,12 @@ class ManyToManyTest(testbase.AssertMixin): "break off" a new "mapper stub" to indicate a third depedendent processor.""" Place.mapper = mapper(Place, place) Transition.mapper = mapper(Transition, transition, properties = dict( - inputs = relation(Place.mapper, place_output, lazy=True, attributeext=attr.ListBackrefExtension('inputs')), - outputs = relation(Place.mapper, place_input, lazy=True, attributeext=attr.ListBackrefExtension('outputs')), + inputs = relation(Place.mapper, place_output, lazy=True, backref='inputs'), + outputs = relation(Place.mapper, place_input, lazy=True, backref='outputs'), ) ) - Place.mapper.add_property('inputs', relation(Transition.mapper, place_output, lazy=True, attributeext=attr.ListBackrefExtension('inputs'))) - Place.mapper.add_property('outputs', relation(Transition.mapper, place_input, lazy=True, attributeext=attr.ListBackrefExtension('outputs'))) + #Place.mapper.add_property('inputs', relation(Transition.mapper, place_output, lazy=True, attributeext=attr.ListBackrefExtension('inputs'))) + #Place.mapper.add_property('outputs', relation(Transition.mapper, place_input, lazy=True, attributeext=attr.ListBackrefExtension('outputs'))) Place.eagermapper = Place.mapper.options( eagerload('inputs', selectalias='ip_alias'), diff --git a/test/objectstore.py b/test/objectstore.py index 56da630c37..979a4ceefd 100644 --- a/test/objectstore.py +++ b/test/objectstore.py @@ -57,9 +57,8 @@ class HistoryTest(AssertMixin): class Address(object):pass am = mapper(Address, addresses) m = mapper(User, users, properties = dict( - addresses = relation(am, attributeext=attributes.OTMBackrefExtension('user'))) + addresses = relation(am, backref='user')) ) - am.add_property('user', relation(m, attributeext=attributes.MTOBackrefExtension('addresses'))) u = User() a = Address() @@ -68,6 +67,7 @@ class HistoryTest(AssertMixin): #print repr(u.addresses.added_items()) self.assert_(u.addresses == [a]) objectstore.commit() + class PKTest(AssertMixin): -- 2.47.2