]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
build in 'backref' property argument
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 9 Dec 2005 05:08:51 +0000 (05:08 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 9 Dec 2005 05:08:51 +0000 (05:08 +0000)
lib/sqlalchemy/attributes.py
lib/sqlalchemy/mapping/properties.py
lib/sqlalchemy/util.py
test/manytomany.py
test/objectstore.py

index 72b02638784498dfa8b2be7803024fd5b1daa280..c42a54be59c0e678551a3449acdff1010f108e50 100644 (file)
@@ -123,7 +123,7 @@ class ListElement(util.HistoryArraySet):
         self.obj.__dict__[self.key] = value
         self.set_data(value)
     def delattr(self, value):
-        pass    
+        pass
     def _setrecord(self, item):
         res = util.HistoryArraySet._setrecord(self, item)
         if res:
@@ -194,34 +194,33 @@ class AttributeExtension(object):
         pass
     def set(self, obj, child, oldchild):
         pass
-
+        
 class ListBackrefExtension(AttributeExtension):
     def __init__(self, key):
         self.key = key
     def append(self, obj, child):
-        getattr(child, self.key).append_nohistory(obj)
+        getattr(child, self.key).append(obj)
     def delete(self, obj, child):
-        getattr(child, self.key).remove_nohistory(obj)
-
+        getattr(child, self.key).remove(obj)
 class OTMBackrefExtension(AttributeExtension):
     def __init__(self, key):
         self.key = key
     def append(self, obj, child):
         prop = child.__class__._attribute_manager.get_history(child, self.key)
-        prop.setattr_clean(obj)
+        prop.setattr(obj)
 #        prop.setattr(obj)
     def delete(self, obj, child):
         prop = child.__class__._attribute_manager.get_history(child, self.key)
-        prop.delattr_clean()
+        prop.delattr()
 
 class MTOBackrefExtension(AttributeExtension):
     def __init__(self, key):
         self.key = key
     def set(self, obj, child, oldchild):
         if oldchild is not None:
-            getattr(oldchild, self.key).remove_nohistory(obj)
+            getattr(oldchild, self.key).remove(obj)
         if child is not None:
-            getattr(child, self.key).append_nohistory(obj)
+            getattr(child, self.key).append(obj)
 #            getattr(child, self.key).append(obj)
             
 class AttributeManager(object):
index 4b82157d3309efce77f6154ec24e617a611eb631..08897f2e61b1bb059aca0e20bc1d23eafffd947e 100644 (file)
@@ -21,6 +21,7 @@ import sqlalchemy.sql as sql
 import sqlalchemy.schema as schema
 import sqlalchemy.engine as engine
 import sqlalchemy.util as util
+import sqlalchemy.attributes as attributes
 import mapper
 import objectstore
 import random
@@ -56,13 +57,13 @@ class ColumnProperty(MapperProperty):
 mapper.ColumnProperty = ColumnProperty
 
 class PropertyLoader(MapperProperty):
-    LEFT = 0
-    RIGHT = 1
-    CENTER = 2
+    LEFT = 0  # one-to-many
+    RIGHT = 1  # many-to-one
+    CENTER = 2  # many-to-many
 
     """describes an object property that holds a single item or list of items that correspond
     to a related database table."""
-    def __init__(self, argument, secondary, primaryjoin, secondaryjoin, foreignkey=None, uselist=None, private=False, live=False, isoption=False, association=None, selectalias=None, order_by=None, attributeext=None):
+    def __init__(self, argument, secondary, primaryjoin, secondaryjoin, foreignkey=None, uselist=None, private=False, live=False, isoption=False, association=None, selectalias=None, order_by=None, attributeext=None, backref=None, is_backref=False):
         self.uselist = uselist
         self.argument = argument
         self.secondary = secondary
@@ -76,6 +77,8 @@ class PropertyLoader(MapperProperty):
         self.selectalias = selectalias
         self.order_by=util.to_list(order_by)
         self.attributeext=attributeext
+        self.backref = backref
+        self.is_backref = is_backref
         self._hash_key = "%s(%s, %s, %s, %s, %s, %s, %s, %s)" % (self.__class__.__name__, hash_key(self.argument), hash_key(secondary), hash_key(primaryjoin), hash_key(secondaryjoin), hash_key(foreignkey), repr(uselist), repr(private), hash_key(self.order_by))
 
     def _copy(self):
@@ -126,10 +129,31 @@ class PropertyLoader(MapperProperty):
             self.uselist = True
 
         self._compile_synchronizers()
-                
+
+        # primary property handler, set up class attributes
         if self._is_primary():
+            # if a backref name is defined, set up an extension to populate 
+            # attributes in the other direction
+            if self.backref is not None:
+                if self.direction == PropertyLoader.LEFT:
+                    self.attributeext = attributes.OTMBackrefExtension(self.backref)
+                elif self.direction == PropertyLoader.RIGHT:
+                    self.attributeext = attributes.MTOBackrefExtension(self.backref)
+                else:
+                    self.attributeext = attributes.ListBackrefExtension(self.backref)
+        
+            # set our class attribute
             self._set_class_attribute(parent.class_, key)
-    
+
+            if self.backref is not None:
+                # try to set a LazyLoader on our mapper referencing the parent mapper
+                if not self.mapper.props.has_key(self.backref):
+                    self.mapper.add_property(self.backref, LazyLoader(self.parent, self.secondary, self.primaryjoin, self.secondaryjoin, backref=self.key, is_backref=True));
+                else:
+                    # else set one of us as the "backreference"
+                    if not self.mapper.props[self.backref].is_backref:
+                        self.is_backref=True
+                    
     def _is_primary(self):
         """a return value of True indicates we are the primary PropertyLoader for this loader's
         attribute on our mapper's class.  It means we can set the object's attribute behavior
@@ -277,6 +301,11 @@ class PropertyLoader(MapperProperty):
             # or delete any objects, but just marks a dependency on the two
             # related mappers.  its dependency processor then populates the
             # association table.
+            
+            if self.is_backref:
+                # if we are the "backref" half of a two-way backref 
+                # relationship, let the other mapper handle inserting the rows
+                return
             stub = PropertyLoader.MapperStub()
             uowcommit.register_dependency(self.parent, stub)
             uowcommit.register_dependency(self.mapper, stub)
@@ -674,10 +703,11 @@ class EagerLazyOption(MapperOption):
         else:
             class_ = LazyLoader
 
-        for arg in ('primaryjoin', 'secondaryjoin', 'foreignkey', 'uselist', 'private', 'live', 'isoption', 'association', 'selectalias', 'order_by', 'attributeext'):
-            self.kwargs.setdefault(arg, getattr(oldprop, arg))
+        # create a clone of the class using mostly the arguments from the original
         self.kwargs['isoption'] = True
-        mapper.set_property(key, class_(submapper, oldprop.secondary, **self.kwargs ))
+        self.kwargs['argument'] = submapper
+        kwargs = util.constructor_args(oldprop, **self.kwargs)
+        mapper.set_property(key, class_(**kwargs ))
 
 class Aliasizer(sql.ClauseVisitor):
     """converts a table instance within an expression to be an alias of that table."""
index d280368f69101ad3d028d181d44d29c0237a36bc..bc23d6080ae9d1ac159776aec5fcb5d75990e1e8 100644 (file)
@@ -16,7 +16,7 @@
 # Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.
 
 __all__ = ['OrderedProperties', 'OrderedDict']
-import thread, weakref, UserList,string
+import thread, weakref, UserList,string, inspect
 
 def to_list(x):
     if x is None:
@@ -345,3 +345,39 @@ class ScopedRegistry(object):
     def _clear_application(self):
         self.application = createfunc()
                 
+
+
+def constructor_args(instance, **kwargs):
+    classobj = instance.__class__
+        
+    argspec = inspect.getargspec(classobj.__init__.im_func)
+
+    argnames = argspec[0] or []
+    defaultvalues = argspec[3] or []
+
+    (requiredargs, namedargs) = (
+            argnames[0:len(argnames) - len(defaultvalues)], 
+            argnames[len(argnames) - len(defaultvalues):]
+            )
+
+    newparams = {}
+
+    for arg in requiredargs:
+        if arg == 'self': 
+            continue
+        elif kwargs.has_key(arg):
+            newparams[arg] = kwargs[arg]
+        else:
+            newparams[arg] = getattr(instance, arg)
+
+    for arg in namedargs:
+        if kwargs.has_key(arg):
+            newparams[arg] = kwargs[arg]
+        else:
+            if hasattr(instance, arg):
+                newparams[arg] = getattr(instance, arg)
+            else:
+                raise "instance has no attribute '%s'" % arg
+
+    return newparams
+    
\ No newline at end of file
index cd38f38c3beab3055bd53980d53dda239f88ccd5..f4498dd8f46500b0d95cc5bfc6491f8fae6eabc2 100644 (file)
@@ -98,12 +98,12 @@ class ManyToManyTest(testbase.AssertMixin):
         "break off" a new "mapper stub" to indicate a third depedendent processor."""
         Place.mapper = mapper(Place, place)
         Transition.mapper = mapper(Transition, transition, properties = dict(
-            inputs = relation(Place.mapper, place_output, lazy=True, attributeext=attr.ListBackrefExtension('inputs')),
-            outputs = relation(Place.mapper, place_input, lazy=True, attributeext=attr.ListBackrefExtension('outputs')),
+            inputs = relation(Place.mapper, place_output, lazy=True, backref='inputs'),
+            outputs = relation(Place.mapper, place_input, lazy=True, backref='outputs'),
             )
         )
-        Place.mapper.add_property('inputs', relation(Transition.mapper, place_output, lazy=True, attributeext=attr.ListBackrefExtension('inputs')))
-        Place.mapper.add_property('outputs', relation(Transition.mapper, place_input, lazy=True, attributeext=attr.ListBackrefExtension('outputs')))
+        #Place.mapper.add_property('inputs', relation(Transition.mapper, place_output, lazy=True, attributeext=attr.ListBackrefExtension('inputs')))
+        #Place.mapper.add_property('outputs', relation(Transition.mapper, place_input, lazy=True, attributeext=attr.ListBackrefExtension('outputs')))
 
         Place.eagermapper = Place.mapper.options(
             eagerload('inputs', selectalias='ip_alias'), 
index 56da630c37273984983bf757f193c8d307ea9640..979a4ceefd35b58b90857e8de95ea63f556eeaad 100644 (file)
@@ -57,9 +57,8 @@ class HistoryTest(AssertMixin):
         class Address(object):pass
         am = mapper(Address, addresses)
         m = mapper(User, users, properties = dict(
-            addresses = relation(am, attributeext=attributes.OTMBackrefExtension('user')))
+            addresses = relation(am, backref='user'))
         )
-        am.add_property('user', relation(m, attributeext=attributes.MTOBackrefExtension('addresses')))
         
         u = User()
         a = Address()
@@ -68,6 +67,7 @@ class HistoryTest(AssertMixin):
         #print repr(u.addresses.added_items())
         self.assert_(u.addresses == [a])
         objectstore.commit()
+
         
 
 class PKTest(AssertMixin):