From: Jason Kirtland Date: Thu, 23 Aug 2007 15:35:03 +0000 (+0000) Subject: Allow custom getter/setters to be specified for a standard AssociationProxy X-Git-Tag: rel_0_4beta6~79 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=51dc8b088d37b7132f207949a0a00cd3db651e37;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Allow custom getter/setters to be specified for a standard AssociationProxy --- diff --git a/lib/sqlalchemy/ext/associationproxy.py b/lib/sqlalchemy/ext/associationproxy.py index 2dd8072228..5f75bfeb7d 100644 --- a/lib/sqlalchemy/ext/associationproxy.py +++ b/lib/sqlalchemy/ext/associationproxy.py @@ -11,6 +11,7 @@ import sqlalchemy.exceptions as exceptions import sqlalchemy.orm as orm import sqlalchemy.util as util + def association_proxy(targetcollection, attr, **kw): """Convenience function for use in mapped classes. Implements a Python property representing a relation as a collection of simpler values. The @@ -60,7 +61,7 @@ class AssociationProxy(object): on an object.""" def __init__(self, targetcollection, attr, creator=None, - proxy_factory=None, proxy_bulk_set=None): + getset_factory=None, proxy_factory=None, proxy_bulk_set=None): """Arguments are: targetcollection @@ -83,6 +84,17 @@ class AssociationProxy(object): If you want to construct instances differently, supply a 'creator' function that takes arguments as above and returns instances. + getset_factory + Optional. Proxied attribute access is automatically handled + by routines that get and set values based on the `attr` argument + for this proxy. + + If you would like to customize this behavior, you may supply a + `getset_factory` callable that produces a tuple of `getter` and + `setter` functions. The factory is called with two arguments, + the abstract type of the underlying collection and this proxy + instance. + proxy_factory Optional. The type of collection to emulate is determined by sniffing the target collection. If your collection type can't be @@ -98,6 +110,7 @@ class AssociationProxy(object): self.target_collection = targetcollection # backwards compat name... self.value_attr = attr self.creator = creator + self.getset_factory = getset_factory self.proxy_factory = proxy_factory self.proxy_bulk_set = proxy_bulk_set @@ -165,22 +178,31 @@ class AssociationProxy(object): def __delete__(self, obj): delattr(obj, self.key) + def _default_getset(self, collection_class): + attr = self.value_attr + getter = util.attrgetter(attr) + if collection_class is dict: + setter = lambda o, k, v: setattr(o, attr, v) + else: + setter = lambda o, v: setattr(o, attr, v) + return getter, setter + def _new(self, lazy_collection): creator = self.creator and self.creator or self.target_class self.collection_class = util.duck_type_collection(lazy_collection()) if self.proxy_factory: return self.proxy_factory(lazy_collection, creator, self.value_attr) - - value_attr = self.value_attr - getter = lambda o: getattr(o, value_attr) - setter = lambda o, v: setattr(o, value_attr, v) + + if self.getset_factory: + getter, setter = self.getset_factory(self.collection_class, self) + else: + getter, setter = self._default_getset(self.collection_class) if self.collection_class is list: return _AssociationList(lazy_collection, creator, getter, setter) elif self.collection_class is dict: - kv_setter = lambda o, k, v: setattr(o, value_attr, v) - return _AssociationDict(lazy_collection, creator, getter, kv_setter) + return _AssociationDict(lazy_collection, creator, getter, setter) elif self.collection_class is util.Set: return _AssociationSet(lazy_collection, creator, getter, setter) else: diff --git a/lib/sqlalchemy/orm/collections.py b/lib/sqlalchemy/orm/collections.py index 714535dccf..2a02eb8ff2 100644 --- a/lib/sqlalchemy/orm/collections.py +++ b/lib/sqlalchemy/orm/collections.py @@ -98,14 +98,9 @@ through the adapter, allowing for some very sophisticated behavior. import copy, inspect, sys, weakref from sqlalchemy import exceptions, schema, util as sautil +from sqlalchemy.util import attrgetter from sqlalchemy.orm import mapper -try: - from operator import attrgetter -except: - def attrgetter(attribute): - return lambda value: getattr(value, attribute) - __all__ = ['collection', 'collection_adapter', 'mapped_collection', 'column_mapped_collection', diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index ae8205f210..44ff3a2c5c 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -53,6 +53,12 @@ except ImportError: Decimal.warn = True decimal_type = float +try: + from operator import attrgetter +except: + def attrgetter(attribute): + return lambda value: getattr(value, attribute) + if sys.version_info >= (2, 5): class PopulateDict(dict): """a dict which populates missing values via a creation function.