]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Allow custom getter/setters to be specified for a standard AssociationProxy
authorJason Kirtland <jek@discorporate.us>
Thu, 23 Aug 2007 15:35:03 +0000 (15:35 +0000)
committerJason Kirtland <jek@discorporate.us>
Thu, 23 Aug 2007 15:35:03 +0000 (15:35 +0000)
lib/sqlalchemy/ext/associationproxy.py
lib/sqlalchemy/orm/collections.py
lib/sqlalchemy/util.py

index 2dd8072228988371136b460fd948e84e0ef2bba8..5f75bfeb7d630cc9bc61cf254cc32f213125f958 100644 (file)
@@ -11,6 +11,7 @@ import sqlalchemy.exceptions as exceptions
 import sqlalchemy.orm as orm
 import sqlalchemy.util as util
 
+
 def association_proxy(targetcollection, attr, **kw):
     """Convenience function for use in mapped classes.  Implements a Python
     property representing a relation as a collection of simpler values.  The
@@ -60,7 +61,7 @@ class AssociationProxy(object):
     on an object."""
 
     def __init__(self, targetcollection, attr, creator=None,
-                 proxy_factory=None, proxy_bulk_set=None):
+                 getset_factory=None, proxy_factory=None, proxy_bulk_set=None):
         """Arguments are:
 
           targetcollection
@@ -83,6 +84,17 @@ class AssociationProxy(object):
             If you want to construct instances differently, supply a 'creator'
             function that takes arguments as above and returns instances.
 
+          getset_factory
+            Optional.  Proxied attribute access is automatically handled
+            by routines that get and set values based on the `attr` argument
+            for this proxy.
+
+            If you would like to customize this behavior, you may supply a
+            `getset_factory` callable that produces a tuple of `getter` and
+            `setter` functions.  The factory is called with two arguments,
+            the abstract type of the underlying collection and this proxy
+            instance.
+
           proxy_factory
             Optional.  The type of collection to emulate is determined by
             sniffing the target collection.  If your collection type can't be
@@ -98,6 +110,7 @@ class AssociationProxy(object):
         self.target_collection = targetcollection # backwards compat name...
         self.value_attr = attr
         self.creator = creator
+        self.getset_factory = getset_factory
         self.proxy_factory = proxy_factory
         self.proxy_bulk_set = proxy_bulk_set
 
@@ -165,22 +178,31 @@ class AssociationProxy(object):
     def __delete__(self, obj):
         delattr(obj, self.key)
 
+    def _default_getset(self, collection_class):
+        attr = self.value_attr
+        getter = util.attrgetter(attr)
+        if collection_class is dict:
+            setter = lambda o, k, v: setattr(o, attr, v)
+        else:
+            setter = lambda o, v: setattr(o, attr, v)
+        return getter, setter
+
     def _new(self, lazy_collection):
         creator = self.creator and self.creator or self.target_class
         self.collection_class = util.duck_type_collection(lazy_collection())
 
         if self.proxy_factory:
             return self.proxy_factory(lazy_collection, creator, self.value_attr)
-
-        value_attr = self.value_attr
-        getter = lambda o: getattr(o, value_attr)
-        setter = lambda o, v: setattr(o, value_attr, v)
+        
+        if self.getset_factory:
+            getter, setter = self.getset_factory(self.collection_class, self)
+        else:
+            getter, setter = self._default_getset(self.collection_class)
         
         if self.collection_class is list:
             return _AssociationList(lazy_collection, creator, getter, setter)
         elif self.collection_class is dict:
-            kv_setter = lambda o, k, v: setattr(o, value_attr, v)
-            return _AssociationDict(lazy_collection, creator, getter, kv_setter)
+            return _AssociationDict(lazy_collection, creator, getter, setter)
         elif self.collection_class is util.Set:
             return _AssociationSet(lazy_collection, creator, getter, setter)
         else:
index 714535dccfbacdb2954a1f61d6d9ebc07e6194a9..2a02eb8ff20f7e0c02ad4ac6b156ad8725135866 100644 (file)
@@ -98,14 +98,9 @@ through the adapter, allowing for some very sophisticated behavior.
 import copy, inspect, sys, weakref
 
 from sqlalchemy import exceptions, schema, util as sautil
+from sqlalchemy.util import attrgetter
 from sqlalchemy.orm import mapper
 
-try:
-    from operator import attrgetter
-except:
-    def attrgetter(attribute):
-        return lambda value: getattr(value, attribute)
-
 
 __all__ = ['collection', 'collection_adapter',
            'mapped_collection', 'column_mapped_collection',
index ae8205f2109052bd395a3ef10373d33b37bf188e..44ff3a2c5cebdc09d12bd76da9d075beca8d8f07 100644 (file)
@@ -53,6 +53,12 @@ except ImportError:
     Decimal.warn = True
     decimal_type = float
 
+try:
+    from operator import attrgetter
+except:
+    def attrgetter(attribute):
+        return lambda value: getattr(value, attribute)
+
 if sys.version_info >= (2, 5):
     class PopulateDict(dict):
         """a dict which populates missing values via a creation function.