]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Expand custom assocproxy getter/setter support to scalar proxies
authorJason Kirtland <jek@discorporate.us>
Thu, 23 Aug 2007 15:48:51 +0000 (15:48 +0000)
committerJason Kirtland <jek@discorporate.us>
Thu, 23 Aug 2007 15:48:51 +0000 (15:48 +0000)
lib/sqlalchemy/ext/associationproxy.py

index 5f75bfeb7d630cc9bc61cf254cc32f213125f958..0130721e13cc5d31a1c05b8b5ce34c74987c8dcc 100644 (file)
@@ -148,9 +148,11 @@ class AssociationProxy(object):
             return
         elif self.scalar is None:
             self.scalar = self._target_is_scalar()
+            if self.scalar:
+                self._initialize_scalar_accessors()
 
         if self.scalar:
-            return getattr(getattr(obj, self.target_collection), self.value_attr)
+            return self._scalar_get(getattr(obj, self.target_collection))
         else:
             try:
                 return getattr(obj, self.key)
@@ -162,6 +164,8 @@ class AssociationProxy(object):
     def __set__(self, obj, values):
         if self.scalar is None:
             self.scalar = self._target_is_scalar()
+            if self.scalar:
+                self._initialize_scalar_accessors()
 
         if self.scalar:
             creator = self.creator and self.creator or self.target_class
@@ -169,7 +173,7 @@ class AssociationProxy(object):
             if target is None:
                 setattr(obj, self.target_collection, creator(values))
             else:
-                setattr(target, self.value_attr, values)
+                self._scalar_set(target, values)
         else:
             proxy = self.__get__(obj, None)
             proxy.clear()
@@ -178,6 +182,13 @@ class AssociationProxy(object):
     def __delete__(self, obj):
         delattr(obj, self.key)
 
+    def _initialize_scalar_accessors(self):
+        if self.getset_factory:
+            get, set = self.getset_factory(None, self)
+        else:
+            get, set = self._default_getset(None)
+        self._scalar_get, self._scalar_set = get, set
+
     def _default_getset(self, collection_class):
         attr = self.value_attr
         getter = util.attrgetter(attr)