]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- New association proxy implementation, implementing complete proxies to list, dict...
authorJason Kirtland <jek@discorporate.us>
Thu, 3 May 2007 00:57:59 +0000 (00:57 +0000)
committerJason Kirtland <jek@discorporate.us>
Thu, 3 May 2007 00:57:59 +0000 (00:57 +0000)
- Added util.duck_type_collection

CHANGES
doc/build/content/plugins.txt
doc/build/gen_docstrings.py
lib/sqlalchemy/ext/associationproxy.py
lib/sqlalchemy/util.py
test/ext/alltests.py
test/ext/associationproxy.py [new file with mode: 0644]

diff --git a/CHANGES b/CHANGES
index 7f2c0cb6f895fa077944b5de6a48b34d5c33ff48..c507c202ecdb2c3a88593e60918b33b311739035 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -2,6 +2,8 @@
     - support for column-level CHARACTER SET and COLLATE declarations,
       as well as ASCII, UNICODE, NATIONAL and BINARY shorthand.
 -extensions
+    - new association proxy implementation, implementing complete
+      proxies to list, dict and set-based relation collections
     - added orderinglist, a custom list class that synchronizes an object
       attribute with that object's position in the list
     - small fix to SelectResultsExt to not bypass itself during
index 4ee02fede3b6e66d62508d5193e48466fb11db64..040c703fd968d21b960bd81ce6b7f800e47390e0 100644 (file)
@@ -278,7 +278,7 @@ To continue the `MyClass` example:
 
 ### associationproxy
 
-**Author:** Mike Bayer<br/>
+**Author:** Mike Bayer and Jason Kirtland<br/>
 **Version:** 0.3.1 or greater
 
 `associationproxy` is used to create a transparent proxy to the associated object in an association relationship, thereby decreasing the verbosity of the pattern in cases where explicit access to the association object is not required.  The association relationship pattern is a richer form of a many-to-many relationship, which is described in [datamapping_association](rel:datamapping_association).  It is strongly recommended to fully understand the association object pattern in its explicit form before using this extension; see the examples in the SQLAlchemy distribution under the directory `examples/association/`.
@@ -286,7 +286,7 @@ To continue the `MyClass` example:
 When dealing with association relationships, the **association object** refers to the object that maps to a row in the association table (i.e. the many-to-many table), while the **associated object** refers to the "endpoint" of the association, i.e. the ultimate object referenced by the parent.  The proxy can return collections of objects attached to association objects, and can also create new association objects given only the associated object.  An example using the Keyword mapping described in the data mapping documentation is as follows:
 
     {python}
-    from sqlalchemy.ext.associationproxy import AssociationProxy
+    from sqlalchemy.ext.associationproxy import association_proxy
     
     class User(object):
         pass
@@ -300,7 +300,7 @@ When dealing with association relationships, the **association object** refers t
         # the collection is called 'keyword_associations', the endpoint
         # attribute of each association object is called 'keyword'.  the 
         # class itself of the association object will be figured out automatically  .
-        keywords = AssociationProxy('keyword_associations', 'keyword')
+        keywords = association_proxy('keyword_associations', 'keyword')
 
     class KeywordAssociation(object):
         pass
@@ -323,7 +323,7 @@ When dealing with association relationships, the **association object** refers t
     mapper(Keyword, keywords_table)
 
     # now, Keywords can be attached to an Article directly;
-    # KeywordAssociation will be created by the AssociationProxy, and have the 
+    # KeywordAssociation will be created by the association_proxy, and have the 
     # 'keyword' attribute set to the new Keyword.
     # note that these KeywordAssociation objects will not have a User attached to them.
     article = Article()
@@ -341,9 +341,9 @@ When dealing with association relationships, the **association object** refers t
     article.keyword_associations.append(KeywordAssociation())
     print [ka for ka in article.keyword_associations]
     
-Note that the above operations on the `keywords` collection are proxying operations to and from the `keyword_associations` collection, which exists normally and can be accessed directly.  `AssociationProxy` will also detect if the collection is list or scalar based and will configure the proxied property to act the same way.
+Note that the above operations on the `keywords` collection are proxying operations to and from the `keyword_associations` collection, which exists normally and can be accessed directly.  `association_proxy` will also detect if the collection is list or scalar based and will configure the proxied property to act the same way.
 
-For the common case where the association object's creation needs to be specified by the application, `AssociationProxy` takes an optional callable `creator()` which takes a single associated object as an argument, and returns a new association object.
+For the common case where the association object's creation needs to be specified by the application, `association_proxy` takes an optional callable `creator()` which takes a single associated object as an argument, and returns a new association object.
 
     {python}
     def create_keyword_association(keyword):
@@ -353,7 +353,21 @@ For the common case where the association object's creation needs to be specifie
         
     class Article(object):
         # create "keywords" proxied association
-        keywords = AssociationProxy('keyword_associations', 'keyword', creator=create_keyword_association)
+        keywords = association_proxy('keyword_associations', 'keyword', creator=create_keyword_association)
+
+Proxy properties are implemented by the `AssociationProxy` class, which is
+also available in the module.  The `association_proxy` function is not present
+in SQLAlchemy versions 0.3.1 through 0.3.7, instead instantiate the class
+directly:
+
+    {python}
+    from sqlalchemy.ext.associationproxy import AssociationProxy
+
+    class Article(object):
+       keywords = AssociationProxy('keyword_associations', 'keyword')
+
+
+The `association_proxy` function is
 
 ### orderinglist
 
index 330b9dcbaef2383fa681f80cdf0d58c30e5efc4f..042cabbf5f6a561442acc1ef489999c04a12e653 100644 (file)
@@ -8,6 +8,7 @@ import sqlalchemy.ext.sessioncontext as sessioncontext
 import sqlalchemy.mods.threadlocal as threadlocal
 import sqlalchemy.ext.selectresults as selectresults
 import sqlalchemy.ext.orderinglist as orderinglist
+import sqlalchemy.ext.associationproxy as associationproxy
 
 def make_doc(obj, classes=None, functions=None, **kwargs):
     """generate a docstring.ObjectDoc structure for an individual module, list of classes, and list of functions."""
@@ -38,6 +39,7 @@ def make_all_docs():
         make_doc(obj=selectresults),
         make_doc(obj=proxy),
         make_doc(obj=orderinglist, classes=[orderinglist.OrderingList]),
+        make_doc(obj=associationproxy, classes=[associationproxy.AssociationProxy]),
     ] + [make_doc(getattr(__import__('sqlalchemy.databases.%s' % m).databases, m)) for m in databases.__all__]
     return objects
     
index 65b95ccbadbfd5dd4f46c411b2b723231804e5e1..0913d6c488903d48b9328de6790a99500b26ec92 100644 (file)
@@ -6,116 +6,620 @@ transparent proxied access to the endpoint of an association object.
 See the example ``examples/association/proxied_association.py``.
 """
 
-from sqlalchemy.orm import class_mapper
+from sqlalchemy.orm.attributes import InstrumentedList
+import sqlalchemy.exceptions as exceptions
+import sqlalchemy.orm as orm
+import sqlalchemy.util as util
+
+def association_proxy(targetcollection, attr, **kw):
+    """Convenience function for use in mapped classes.  Implements a Python
+    property representing a relation as a collection of simpler values.  The
+    proxied property will mimic the collection type of the target (list, dict
+    or set), or in the case of a one to one relation, a simple scalar value.
+
+    targetcollection
+      Name of the relation attribute we'll proxy to, usually created with
+      'relation()' in a mapper setup.
+
+    attr
+      Attribute on the associated instances we'll proxy for.  For example,
+      given a target collection of [obj1, obj2], a list created by this proxy
+      property would look like
+      [getattr(obj1, attr), getattr(obj2, attr)]
+
+      If the relation is one-to-one or otherwise uselist=False, then simply:
+      getattr(obj, attr)
+
+    creator (optional)
+      When new items are added to this proxied collection, new instances of
+      the class collected by the target collection will be created.  For
+      list and set collections, the target class constructor will be called
+      with the 'value' for the new instance.  For dict types, two arguments
+      are passed: key and value.
+
+      If you want to construct instances differently, supply a 'creator'
+      function that takes arguments as above and returns instances.
+
+      For scalar relations, creator() will be called if the target is None.
+      If the target is present, set operations are proxied to setattr() on the
+      associated object.
+
+      If you have an associated object with multiple attributes, you may set up
+      multiple association proxies mapping to different attributes.  See the
+      unit tests for examples, and for examples of how creator() functions can
+      be used to construct the scalar relation on-demand in this situation.
+
+    Passes along any other arguments to AssociationProxy
+    """
+
+    return AssociationProxy(targetcollection, attr, **kw)
+
 
 class AssociationProxy(object):
-    """A property object that automatically sets up ``AssociationLists`` on a parent object."""
+    """A property object that automatically sets up `AssociationLists`
+    on an object."""
 
-    def __init__(self, targetcollection, attr, creator=None):
-        """Create a new association property.
+    def __init__(self, targetcollection, attr, creator=None,
+                 proxy_factory=None, proxy_bulk_set=None):
+        """Arguments are:
 
-        targetcollection
-          The attribute name which stores the collection of Associations.
+          targetcollection
+            Name of the collection we'll proxy to, usually created with
+            'relation()' in a mapper setup.
 
-        attr
-          Name of the attribute on the Association in which to get/set target values.
+          attr
+            Attribute on the collected instances we'll proxy for.  For example,
+            given a target collection of [obj1, obj2],
+            a list created by this proxy property would look like
+            [getattr(obj1, attr), getattr(obj2, attr)]
 
-        creator
-          Optional callable which is used to create a new association
-          object.  This callable is given a single argument which is
-          an instance of the *proxied* object.  If creator is not
-          given, the association object is created using the class
-          associated with the targetcollection attribute, using its
-          ``__init__()`` constructor and setting the proxied
-          attribute.
+          creator
+            Optional. When new items are added to this proxied collection, new
+            instances of the class collected by the target collection will be
+            created.  For list and set collections, the target class
+            constructor will be called with the 'value' for the new instance.
+            For dict types, two arguments are passed: key and value.
+
+            If you want to construct instances differently, supply a 'creator'
+            function that takes arguments as above and returns instances.
+
+          proxy_factory
+            Optional.  The type of collection to emulate is determined by
+            sniffing the target collection.  If your collection type can't be
+            determined by duck typing or you'd like to use a different collection
+            implementation, you may supply a factory function to produce those
+            collections.  Only applicable to non-scalar relations.
+
+          proxy_bulk_set
+            Optional, use with proxy_factory.  See the _set() method for
+            details.
         """
         
-        self.targetcollection = targetcollection
-        self.attr = attr
+        self.target_collection = targetcollection # backwards compat name...
+        self.value_attr = attr
         self.creator = creator
-        
-    def __init_deferred(self):
-        prop = class_mapper(self._owner_class).props[self.targetcollection]
-        self._cls = prop.mapper.class_
-        self._uselist = prop.uselist
+        self.proxy_factory = proxy_factory
+        self.proxy_bulk_set = proxy_bulk_set
 
-    def _get_class(self):
-        try:
-            return self._cls
-        except AttributeError:
-            self.__init_deferred()
-            return self._cls
+        self.scalar = None
+        self.owning_class = None
+        self.key = '_%s_%s_%s' % (type(self).__name__,
+                                  targetcollection, id(self))
+        self.collection_class = None
 
-    def _get_uselist(self):
-        try:
-            return self._uselist
-        except AttributeError:
-            self.__init_deferred()
-            return self._uselist
+    def _get_property(self):
+        return orm.class_mapper(self.owning_class).props[self.target_collection]
 
-    cls = property(_get_class)
-    uselist = property(_get_uselist)
+    def _target_class(self):
+        return self._get_property().mapper.class_
+    target_class = property(_target_class)
 
-    def create(self, target, **kw):
-        if self.creator is not None:
-            return self.creator(target, **kw)
-        else:
-            assoc = self.cls(**kw)
-            setattr(assoc, self.attr, target)
-            return assoc
-
-    def __get__(self, obj, owner):
-        self._owner_class = owner
+        
+    def __get__(self, obj, class_):
         if obj is None:
-            return self
-        storage_key = '_AssociationProxy_%s_%s' % (self.targetcollection, self.attr)
-        if self.uselist:
+            self.owning_class = class_
+            return
+        elif self.scalar is None:
+            self.scalar = not self._get_property().uselist
+
+        if self.scalar:
+            return getattr(getattr(obj, self.target_collection), self.value_attr)
+        else:
             try:
-                return getattr(obj, storage_key)
+                return getattr(obj, self.key)
             except AttributeError:
-                a = _AssociationList(self, obj)
-                setattr(obj, storage_key, a)
-                return a
+                proxy = self._new(getattr(obj, self.target_collection))
+                setattr(obj, self.key, proxy)
+                return proxy
+
+    def __set__(self, obj, values):
+        if self.scalar:
+            creator = self.creator and self.creator or self.target_class
+            target = getattr(obj, self.target_collection)
+            if target is None:
+                setattr(obj, self.target_collection, creator(values))
+            else:
+                setattr(target, self.value_attr, values)
+        else:
+            proxy = self.__get__(obj, None)
+            proxy.clear()
+            self._set(proxy, values)
+
+    def __delete__(self, obj):
+        delattr(obj, self.key)
+
+    def _new(self, collection):
+        creator = self.creator and self.creator or self.target_class
+
+        # Prefer class typing here to spot dicts with the required append()
+        # method.
+        if isinstance(collection.data, dict):
+            self.collection_class = dict
         else:
-            return getattr(getattr(obj, self.targetcollection), self.attr)
+            self.collection_class = util.duck_type_collection(collection.data)
+
+        if self.proxy_factory:
+            return self.proxy_factory(collection, creator, self.value_attr)
 
-    def __set__(self, obj, value):
-        if self.uselist:
-            setattr(obj, self.targetcollection, [self.create(x) for x in value])
+        value_attr = self.value_attr
+        getter = lambda o: getattr(o, value_attr)
+        setter = lambda o, v: setattr(o, value_attr, v)
+        
+        if self.collection_class is list:
+            return _AssociationList(collection, creator, getter, setter)
+        elif self.collection_class is dict:
+            kv_setter = lambda o, k, v: setattr(o, value_attr, v)
+            return _AssociationDict(collection, creator, getter, setter)
+        elif self.collection_class is util.Set:
+            return _AssociationSet(collection, creator, getter, setter)
         else:
-            setattr(obj, self.targetcollection, self.create(value))
+            raise exceptions.ArgumentError(
+                'could not guess which interface to use for '
+                'collection_class "%s" backing "%s"; specify a '
+                'proxy_factory and proxy_bulk_set manually' %
+                (self.collection_class.__name__, self.target_collection))
 
-    def __del__(self, obj):
-        delattr(obj, self.targetcollection)
+    def _set(self, proxy, values):
+        if self.proxy_bulk_set:
+            self.proxy_bulk_set(proxy, values)
+        elif self.collection_class is list:
+            proxy.extend(values)
+        elif self.collection_class is dict:
+            proxy.update(values)
+        elif self.collection_class is util.Set:
+            proxy.update(values)
+        else:
+            raise exceptions.ArgumentError(
+               'no proxy_bulk_set supplied for custom '
+               'collection_class implementation')
 
 class _AssociationList(object):
-    """Generic proxying list which proxies list operations to a
-    different list-holding attribute of the parent object, converting
-    Association objects to and from a target attribute on each
-    Association object.
+    """Generic proxying list which proxies list operations to a another list,
+    converting association objects to and from a simplified value.
     """
 
-    def __init__(self, proxy, parent):
-        """Create a new ``AssociationList``."""
-        self.proxy = proxy
-        self.parent = parent
+    def __init__(self, collection, creator, getter, setter):
+        """
+        collection
+          A list-based collection of entities (usually an object attribute
+          managed by a SQLAlchemy relation())
+          
+        creator
+          A function that creates new target entities.  Given one parameter:
+          value.  The assertion is assumed:
+            obj = creator(somevalue)
+            assert getter(obj) == somevalue
+
+        getter
+          A function.  Given an associated object, return the 'value'.
+
+        setter
+          A function.  Given an associated object and a value, store
+          that value on the object.
+        """
+
+        self.col = collection
+        self.creator = creator
+        self.getter = getter
+        self.setter = setter
+
+    # For compatibility with 0.3.1 through 0.3.7- pass kw through to creator.
+    # (see append() below)
+    def _create(self, value, **kw):
+        return self.creator(value, **kw)
+
+    def _get(self, object):
+        return self.getter(object)
+
+    def _set(self, object, value):
+        return self.setter(object, value)
 
-    def append(self, item, **kw):
-        a = self.proxy.create(item, **kw)
-        getattr(self.parent, self.proxy.targetcollection).append(a)
+    def __len__(self):
+        return len(self.col)
+
+    def __nonzero__(self):
+        return True if self.col else False
+
+    def __getitem__(self, index):
+        return self._get(self.col[index])
+    
+    def __setitem__(self, index, value):
+        self._set(self.col[index], value)
+
+    def __delitem__(self, index):
+        del self.col[index]
+
+    def __contains__(self, value):
+        for member in self.col:
+            if self._get(member) == value:
+                return True
+        return False
+
+    def __getslice__(self, start, end):
+        return [self._get(member) for member in self.col[start:end]]
+
+    def __setslice__(self, start, end, values):
+        members = [self._create(v) for v in values]
+        self.col[start:end] = members
+
+    def __delslice__(self, start, end):
+        del self.col[start:end]
 
     def __iter__(self):
-        return iter([getattr(x, self.proxy.attr) for x in getattr(self.parent, self.proxy.targetcollection)])
+        """Iterate over proxied values.  For the actual domain objects,
+        iterate over .col instead or just use the underlying collection
+        directly from its property on the parent."""
+        for member in self.col:
+            yield self._get(member)
+        raise StopIteration
+
+    # For compatibility with 0.3.1 through 0.3.7- pass kw through to creator
+    # on append() only.  (Can't on __setitem__, __contains__, etc., obviously.)
+    def append(self, value, **kw):
+        item = self._create(value, **kw)
+        self.col.append(item)
+
+    def extend(self, values):
+        for v in values:
+            self.append(v)
+
+    def insert(self, index, value):
+        self.col[index:index] = [self._create(value)]
+
+    def clear(self):
+        del self.col[0:len(self.col)]
+
+    def __eq__(self, other): return list(self) == other
+    def __ne__(self, other): return list(self) != other
+    def __lt__(self, other): return list(self) < other
+    def __le__(self, other): return list(self) <= other
+    def __gt__(self, other): return list(self) > other
+    def __ge__(self, other): return list(self) >= other
+    def __cmp__(self, other): return cmp(list(self), other)
+
+    def copy(self):
+        return list(self)
 
     def __repr__(self):
-        return repr([getattr(x, self.proxy.attr) for x in getattr(self.parent, self.proxy.targetcollection)])
+        return repr(list(self))
+
+    def hash(self):
+        raise TypeError("%s objects are unhashable" % type(self).__name__)
+
+_NotProvided = object()
+class _AssociationDict(object):
+    """Generic proxying list which proxies dict operations to a another dict,
+    converting association objects to and from a simplified value.
+    """
+
+    def __init__(self, collection, creator, getter, setter):
+        """
+        collection
+          A list-based collection of entities (usually an object attribute
+          managed by a SQLAlchemy relation())
+          
+        creator
+          A function that creates new target entities.  Given two parameters:
+          key and value.  The assertion is assumed:
+            obj = creator(somekey, somevalue)
+            assert getter(somekey) == somevalue
+
+        getter
+          A function.  Given an associated object and a key, return the 'value'.
+
+        setter
+          A function.  Given an associated object, a key and a value, store
+          that value on the object.
+        """
+
+        self.col = collection
+        self.creator = creator
+        self.getter = getter
+        self.setter = setter
+
+    def _create(self, key, value):
+        return self.creator(key, value)
+
+    def _get(self, object):
+        return self.getter(object)
+
+    def _set(self, object, key, value):
+        return self.setter(object, key, value)
 
     def __len__(self):
-        return len(getattr(self.parent, self.proxy.targetcollection))
+        return len(self.col)
 
-    def __getitem__(self, index):
-        return getattr(getattr(self.parent, self.proxy.targetcollection)[index], self.proxy.attr)
+    def __nonzero__(self):
+        return True if self.col else False
 
-    def __setitem__(self, index, value):
-        a = self.proxy.create(item)
-        getattr(self.parent, self.proxy.targetcollection)[index] = a
+    def __getitem__(self, key):
+        return self._get(self.col[key])
+    
+    def __setitem__(self, key, value):
+        if key in self.col:
+            self._set(self.col[key], key, value)
+        else:
+            self.col[key] = self._create(key, value)
+
+    def __delitem__(self, key):
+        del self.col[key]
+
+    def __contains__(self, key):
+        return key in self.col
+    has_key = __contains__
+
+    def __iter__(self):
+        return iter(self.col)
+
+    def clear(self):
+        self.col.clear()
+
+    def __eq__(self, other): return dict(self) == other
+    def __ne__(self, other): return dict(self) != other
+    def __lt__(self, other): return dict(self) < other
+    def __le__(self, other): return dict(self) <= other
+    def __gt__(self, other): return dict(self) > other
+    def __ge__(self, other): return dict(self) >= other
+    def __cmp__(self, other): return cmp(dict(self), other)
+
+    def __repr__(self):
+        return repr(dict(self.items()))
+
+    def get(self, key, default=None):
+        try:
+            return self[key]
+        except KeyError:
+            return default
+
+    def setdefault(self, key, default=None):
+        if key not in self.col:
+            self.col[key] = self._create(key, default)
+            return default
+        else:
+            return self[key]
+
+    def keys(self):
+        return self.col.keys()
+    def iterkeys(self):
+        return self.col.iterkeys()
+
+    def values(self):
+        return [ self._get(member) for member in self.col.values() ]
+    def itervalues(self):
+        for key in self.col:
+            yield self._get(self.col[key])
+        raise StopIteration
+
+    def items(self):
+        return [(k, self._get(self.col[k])) for k in self]
+    def iteritems(self):
+        for key in self.col:
+            yield (key, self._get(self.col[key]))
+        raise StopIteration
+
+    def pop(self, key, default=_NotProvided):
+        if default is _NotProvided:
+            member = self.col.pop(key)
+        else:
+            member = self.col.pop(key, default)
+        return self._get(member)
+
+    def popitem(self):
+        item = self.col.popitem()
+        return (item[0], self._get(item[1]))
+    
+    def update(self, *a, **kw):
+        if len(a) > 1:
+            raise TypeError('update expected at most 1 arguments, got %i' %
+                            len(a))
+        elif len(a) == 1:
+            seq_or_map = a[0]
+            for item in seq_or_map:
+                if isinstance(item, tuple):
+                    self[item[0]] = item[1]
+                else:
+                    self[item] = seq_or_map[item]
+
+        for key, value in kw:
+            self[key] = value
+
+    def copy(self):
+        return dict(self.items())
+
+    def hash(self):
+        raise TypeError("%s objects are unhashable" % type(self).__name__)
+
+class _AssociationSet(object):
+    """Generic proxying list which proxies set operations to a another set,
+    converting association objects to and from a simplified value.
+    """
+
+    def __init__(self, collection, creator, getter, setter):
+        """
+        collection
+          A list-based collection of entities (usually an object attribute
+          managed by a SQLAlchemy relation())
+          
+        creator
+          A function that creates new target entities.  Given one parameter:
+          value.  The assertion is assumed:
+            obj = creator(somevalue)
+            assert getter(obj) == somevalue
+
+        getter
+          A function.  Given an associated object, return the 'value'.
+
+        setter
+          A function.  Given an associated object and a value, store
+          that value on the object.
+        """
+
+        self.col = collection
+        self.creator = creator
+        self.getter = getter
+        self.setter = setter
+
+    def _create(self, value):
+        return self.creator(value)
+
+    def _get(self, object):
+        return self.getter(object)
+
+    def _set(self, object, value):
+        return self.setter(object, value)
+
+    def __len__(self):
+        return len(self.col)
+
+    def __nonzero__(self):
+        return True if self.col else False
+
+    def __contains__(self, value):
+        for member in self.col:
+            if self._get(member) == value:
+                return True
+        return False
+
+    def __iter__(self):
+        """Iterate over proxied values.  For the actual domain objects,
+        iterate over .col instead or just use the underlying collection
+        directly from its property on the parent."""
+        for member in self.col:
+            yield self._get(member)
+        raise StopIteration
+
+    def add(self, value):
+        if value not in self:
+            # must shove this through InstrumentedList.append() which will
+            # eventually call the collection_class .add()
+            self.col.append(self._create(value))
+
+    # for discard and remove, choosing a more expensive check strategy rather
+    # than call self.creator()
+    def discard(self, value):
+        for member in self.col:
+            if self._get(member) == value:
+                self.col.discard(member)
+                break
+
+    def remove(self, value):
+        for member in self.col:
+            if self._get(member) == value:
+                self.col.discard(member)
+                return
+        raise KeyError(value)
+
+    def pop(self):
+        if not self.col:
+            raise KeyError('pop from an empty set')
+        # grumble, pop() is borked on InstrumentedList (#548)
+        if isinstance(self.col, InstrumentedList):
+            member = list(self.col)[0]
+            self.col.remove(member)
+        else:
+            member = self.col.pop()
+        return self._get(member)
+
+    def update(self, other):
+        for value in other:
+            self.add(value)
+
+    __ior__ = update
+
+    def _set(self):
+        return util.Set(iter(self))
+
+    def union(self, other):
+        return util.Set(self).union(other)
+
+    __or__ = union
+
+    def difference(self, other):
+        return util.Set(self).difference(other)
+
+    __sub__ = difference
+
+    def difference_update(self, other):
+        for value in other:
+            self.discard(value)
+
+    __isub__ = difference_update
+
+    def intersection(self, other):
+        return util.Set(self).intersection(other)
+
+    __and__ = intersection
+
+    def intersection_update(self, other):
+        want, have = self.intersection(other), util.Set(self)
+
+        remove, add = have - want, want - have
+
+        for value in remove:
+            self.remove(value)
+        for value in add:
+            self.add(value)
+
+    __iand__ = intersection_update
+
+    def symmetric_difference(self, other):
+        return util.Set(self).symmetric_difference(other)
+
+    __xor__ = symmetric_difference
+
+    def symmetric_difference_update(self, other):
+        want, have = self.symmetric_difference(other), util.Set(self)
+
+        remove, add = have - want, want - have
+
+        for value in remove:
+            self.remove(value)
+        for value in add:
+            self.add(value)
+
+    __ixor__ = symmetric_difference_update
+
+    def issubset(self, other):
+        return util.Set(self).issubset(other)
+    
+    def issuperset(self, other):
+        return util.Set(self).issuperset(other)
+                
+    def clear(self):
+        self.col.clear()
+
+    def copy(self):
+        return util.Set(self)
+
+    def __eq__(self, other): return util.Set(self) == other
+    def __ne__(self, other): return util.Set(self) != other
+    def __lt__(self, other): return util.Set(self) < other
+    def __le__(self, other): return util.Set(self) <= other
+    def __gt__(self, other): return util.Set(self) > other
+    def __ge__(self, other): return util.Set(self) >= other
+
+    def __repr__(self):
+        return repr(util.Set(self))
+
+    def hash(self):
+        raise TypeError("%s objects are unhashable" % type(self).__name__)
index e4b0efad4b83838c491fe66e64043170a4e2dc5a..ea5a468d2afdd3a0101a615622af5e48742cd707 100644 (file)
@@ -103,12 +103,30 @@ def coerce_kw_type(kw, key, type_, flexi_bool=True):
     necessary.  If 'flexi_bool' is True, the string '0' is considered false
     when coercing to boolean.
     """
+
     if key in kw and type(kw[key]) is not type_ and kw[key] is not None:
         if type_ is bool and flexi_bool and kw[key] == '0':
             kw[key] = False
         else:
             kw[key] = type_(kw[key])
 
+def duck_type_collection(col, default=None):
+    """Given an instance or class, guess if it is or is acting as one of
+    the basic collection types: list, set and dict.  If the __emulates__
+    property is present, return that preferentially.
+    """
+    
+    if hasattr(col, '__emulates__'):
+        return getattr(col, '__emulates__')
+    elif hasattr(col, 'append'):
+        return list
+    elif hasattr(col, 'add'):
+        return Set
+    elif hasattr(col, 'set'):
+        return dict
+    else:
+        return default
+    
 class SimpleProperty(object):
     """A *default* property accessor."""
 
index 7bfdc93ffb86f0e324e5a8ca5e3940d01926be36..713601c3bef69ac55f348ee4ab21d164d1c212d1 100644 (file)
@@ -5,7 +5,8 @@ def suite():
     unittest_modules = ['ext.activemapper',
                         'ext.selectresults',
                         'ext.assignmapper',
-                        'ext.orderinglist']
+                        'ext.orderinglist',
+                        'ext.associationproxy']
     doctest_modules = ['sqlalchemy.ext.sqlsoup']
 
     alltests = unittest.TestSuite()
diff --git a/test/ext/associationproxy.py b/test/ext/associationproxy.py
new file mode 100644 (file)
index 0000000..9374247
--- /dev/null
@@ -0,0 +1,534 @@
+from testbase import PersistTest
+import sqlalchemy.util as util
+import unittest
+import testbase
+from sqlalchemy import *
+from sqlalchemy.ext.associationproxy import *
+
+db = testbase.db
+
+class DictCollection(dict):
+    def append(self, obj):
+        self[obj.foo] = obj
+    def __iter__(self):
+        return self.itervalues()
+
+class SetCollection(set):
+    pass
+
+class ListCollection(list):
+    pass
+
+class ObjectCollection(object):
+    def __init__(self):
+        self.values = list()
+    def append(self, obj):
+        self.values.append(obj)
+    def __iter__(self):
+        return iter(self.values)
+    def clear(self):
+        self.values.clear()
+
+class _CollectionOperations(PersistTest):
+    def setUp(self):
+        collection_class = self.collection_class
+
+        metadata = BoundMetaData(db)
+    
+        parents_table = Table('Parent', metadata,
+                              Column('id', Integer, primary_key=True),
+                              Column('name', String))
+        children_table = Table('Children', metadata,
+                               Column('id', Integer, primary_key=True),
+                               Column('parent_id', Integer,
+                                      ForeignKey('Parent.id')),
+                               Column('foo', String),
+                               Column('name', String))
+
+        class Parent(object):
+            children = association_proxy('_children', 'name')
+        
+            def __init__(self, name):
+                self.name = name
+
+        class Child(object):
+            if collection_class and issubclass(collection_class, dict):
+                def __init__(self, foo, name):
+                    self.foo = foo
+                    self.name = name
+            else:
+                def __init__(self, name):
+                    self.name = name
+
+        mapper(Parent, parents_table, properties={
+            '_children': relation(Child, lazy=False,
+                                  collection_class=collection_class)})
+        mapper(Child, children_table)
+
+        metadata.create_all()
+
+        self.metadata = metadata
+        self.session = create_session()
+        self.Parent, self.Child = Parent, Child
+
+    def tearDown(self):
+        self.metadata.drop_all()
+
+    def roundtrip(self, obj):
+        self.session.save(obj)
+        self.session.flush()
+        id, type_ = obj.id, type(obj)
+        self.session.clear()
+        return self.session.query(type_).get(id)
+
+    def _test_sequence_ops(self):
+        Parent, Child = self.Parent, self.Child
+        
+        p1 = Parent('P1')
+
+        self.assert_(not p1._children)
+        self.assert_(not p1.children)
+
+        ch = Child('regular')
+        p1._children.append(ch)
+
+        self.assert_(ch in p1._children)
+        self.assert_(len(p1._children) == 1)
+
+        self.assert_(p1.children)
+        self.assert_(len(p1.children) == 1)
+        self.assert_(ch not in p1.children)
+        self.assert_('regular' in p1.children)
+
+        p1.children.append('proxied')
+
+        self.assert_('proxied' in p1.children)
+        self.assert_('proxied' not in p1._children)
+        self.assert_(len(p1.children) == 2)
+        self.assert_(len(p1._children) == 2)
+
+        self.assert_(p1._children[0].name == 'regular')
+        self.assert_(p1._children[1].name == 'proxied')
+    
+        del p1._children[1]
+
+        self.assert_(len(p1._children) == 1)
+        self.assert_(len(p1.children) == 1)
+        self.assert_(p1._children[0] == ch)
+
+        del p1.children[0]
+
+        self.assert_(len(p1._children) == 0)
+        self.assert_(len(p1.children) == 0)
+    
+        p1.children = ['a','b','c']
+        self.assert_(len(p1._children) == 3)
+        self.assert_(len(p1.children) == 3)
+
+        del ch
+        p1 = self.roundtrip(p1)
+
+        self.assert_(len(p1._children) == 3)
+        self.assert_(len(p1.children) == 3)
+        
+class DefaultTest(_CollectionOperations):
+    def __init__(self, *args, **kw):
+        super(DefaultTest, self).__init__(*args, **kw)
+        self.collection_class = None
+
+    def test_sequence_ops(self):
+        self._test_sequence_ops()
+
+class ListTest(_CollectionOperations):
+    def __init__(self, *args, **kw):
+        super(ListTest, self).__init__(*args, **kw)
+        self.collection_class = list
+
+    def test_sequence_ops(self):
+        self._test_sequence_ops()
+
+class CustomListTest(ListTest):
+    def __init__(self, *args, **kw):
+        super(CustomListTest, self).__init__(*args, **kw)
+        self.collection_class = list
+
+# No-can-do until ticket #213
+class DictTest(_CollectionOperations):
+    pass
+
+class CustomDictTest(DictTest):
+    def __init__(self, *args, **kw):
+        super(DictTest, self).__init__(*args, **kw)
+        self.collection_class = DictCollection
+
+    def test_mapping_ops(self):
+        Parent, Child = self.Parent, self.Child
+
+        p1 = Parent('P1')
+
+        self.assert_(not p1._children)
+        self.assert_(not p1.children)
+
+        ch = Child('a', 'regular')
+        p1._children.append(ch)
+
+        print repr(p1._children)
+        self.assert_(ch in p1._children.values())
+        self.assert_(len(p1._children) == 1)
+
+        self.assert_(p1.children)
+        self.assert_(len(p1.children) == 1)
+        self.assert_(ch not in p1.children)
+        self.assert_('a' in p1.children)
+        self.assert_(p1.children['a'] == 'regular')
+        self.assert_(p1._children['a'] == ch)
+
+        p1.children['b'] = 'proxied'
+
+        self.assert_('proxied' in p1.children.values())
+        self.assert_('b' in p1.children)
+        self.assert_('proxied' not in p1._children)
+        self.assert_(len(p1.children) == 2)
+        self.assert_(len(p1._children) == 2)
+
+        self.assert_(p1._children['a'].name == 'regular')
+        self.assert_(p1._children['b'].name == 'proxied')
+    
+        del p1._children['b']
+
+        self.assert_(len(p1._children) == 1)
+        self.assert_(len(p1.children) == 1)
+        self.assert_(p1._children['a'] == ch)
+
+        del p1.children['a']
+
+        self.assert_(len(p1._children) == 0)
+        self.assert_(len(p1.children) == 0)
+
+        p1.children = {'d': 'v d', 'e': 'v e', 'f': 'v f'}
+        self.assert_(len(p1._children) == 3)
+        self.assert_(len(p1.children) == 3)
+
+        del ch
+        p1 = self.roundtrip(p1)
+        self.assert_(len(p1._children) == 3)
+        self.assert_(len(p1.children) == 3)
+    
+
+class SetTest(_CollectionOperations):
+    def __init__(self, *args, **kw):
+        super(SetTest, self).__init__(*args, **kw)
+        self.collection_class = set
+
+    def test_set_operations(self):
+        Parent, Child = self.Parent, self.Child
+
+        p1 = Parent('P1')
+
+        self.assert_(not p1._children)
+        self.assert_(not p1.children)
+
+        ch1 = Child('regular')
+        p1._children.append(ch1)
+
+        self.assert_(ch1 in p1._children)
+        self.assert_(len(p1._children) == 1)
+
+        self.assert_(p1.children)
+        self.assert_(len(p1.children) == 1)
+        self.assert_(ch1 not in p1.children)
+        self.assert_('regular' in p1.children)
+
+        p1.children.add('proxied')
+
+        self.assert_('proxied' in p1.children)
+        self.assert_('proxied' not in p1._children)
+        self.assert_(len(p1.children) == 2)
+        self.assert_(len(p1._children) == 2)
+
+        self.assert_(set([o.name for o in p1._children]) == set(['regular', 'proxied']))
+
+        ch2 = None
+        for o in p1._children:
+            if o.name == 'proxied':
+                ch2 = o
+                break
+
+        p1._children.remove(ch2)
+
+        self.assert_(len(p1._children) == 1)
+        self.assert_(len(p1.children) == 1)
+        self.assert_(p1._children == set([ch1]))
+
+        p1.children.remove('regular')
+
+        self.assert_(len(p1._children) == 0)
+        self.assert_(len(p1.children) == 0)
+    
+        p1.children = ['a','b','c']
+        self.assert_(len(p1._children) == 3)
+        self.assert_(len(p1.children) == 3)
+
+        del ch1
+        p1 = self.roundtrip(p1)
+
+        self.assert_(len(p1._children) == 3)
+        self.assert_(len(p1.children) == 3)
+
+        self.assert_('a' in p1.children)
+        self.assert_('b' in p1.children)
+        self.assert_('d' not in p1.children)
+
+        self.assert_(p1.children == set(['a','b','c']))
+
+        try:
+            p1.children.remove('d')
+            self.fail()
+        except KeyError:
+            pass
+
+        self.assert_(len(p1.children) == 3)
+        p1.children.discard('d')
+        self.assert_(len(p1.children) == 3)
+        p1 = self.roundtrip(p1)
+        self.assert_(len(p1.children) == 3)
+
+        popped = p1.children.pop()
+        self.assert_(len(p1.children) == 2)
+        self.assert_(popped not in p1.children)
+        p1 = self.roundtrip(p1)
+        self.assert_(len(p1.children) == 2)
+        self.assert_(popped not in p1.children)
+    
+        p1.children = ['a','b','c']
+        p1 = self.roundtrip(p1)
+        self.assert_(p1.children == set(['a','b','c']))
+    
+        p1.children.discard('b')
+        p1 = self.roundtrip(p1)
+        self.assert_(p1.children == set(['a', 'c']))
+
+        p1.children.remove('a')
+        p1 = self.roundtrip(p1)
+        self.assert_(p1.children == set(['c']))
+
+    def test_set_comparisons(self):
+        Parent, Child = self.Parent, self.Child
+
+        p1 = Parent('P1')
+        p1.children = ['a','b','c']
+        control = set(['a','b','c'])
+
+        for other in (set(['a','b','c']), set(['a','b','c','d']),
+                      set(['a']), set(['a','b']),
+                      set(['c','d']), set(['e', 'f', 'g']),
+                      set()):
+
+            self.assertEqual(p1.children.union(other),
+                             control.union(other))
+            self.assertEqual(p1.children.difference(other),
+                             control.difference(other))
+            self.assertEqual((p1.children - other),
+                             (control - other))
+            self.assertEqual(p1.children.intersection(other),
+                             control.intersection(other))
+            self.assertEqual(p1.children.symmetric_difference(other),
+                             control.symmetric_difference(other))
+            self.assertEqual(p1.children.issubset(other),
+                             control.issubset(other))
+            self.assertEqual(p1.children.issuperset(other),
+                             control.issuperset(other))
+            
+            self.assert_((p1.children == other)  ==  (control == other))
+            self.assert_((p1.children != other)  ==  (control != other))
+            self.assert_((p1.children < other)   ==  (control < other))
+            self.assert_((p1.children <= other)  ==  (control <= other))
+            self.assert_((p1.children > other)   ==  (control > other))
+            self.assert_((p1.children >= other)  ==  (control >= other))
+
+    def test_set_mutation(self):
+        Parent, Child = self.Parent, self.Child
+
+        # mutations
+        for op in ('update', 'intersection_update',
+                   'difference_update', 'symmetric_difference_update'):
+            for base in (['a', 'b', 'c'], []):
+                for other in (set(['a','b','c']), set(['a','b','c','d']),
+                              set(['a']), set(['a','b']),
+                              set(['c','d']), set(['e', 'f', 'g']),
+                              set()):
+                    p = Parent('p')
+                    p.children = base[:]
+                    control = set(base[:])
+
+                    getattr(p.children, op)(other)
+                    getattr(control, op)(other)
+                    try:
+                        self.assert_(p.children == control)
+                    except:
+                        print 'Test %s.%s(%s):' % (set(base), op, other)
+                        print 'want', repr(control)
+                        print 'got', repr(p.children)
+                        raise
+
+                    p = self.roundtrip(p)
+
+                    try:
+                        self.assert_(p.children == control)
+                    except:
+                        print 'Test %s.%s(%s):' % (base, op, other)
+                        print 'want', repr(control)
+                        print 'got', repr(p.children)
+                        raise
+    
+    # workaround for bug #548
+    def test_set_pop(self):
+        Parent, Child = self.Parent, self.Child
+        p = Parent('p1')
+        p.children.add('a')
+        p.children.pop()
+        self.assert_(True)
+
+class CustomSetTest(SetTest):
+    def __init__(self, *args, **kw):
+        super(CustomSetTest, self).__init__(*args, **kw)
+        self.collection_class = SetCollection
+
+class CustomObjectTest(_CollectionOperations):
+    def __init__(self, *args, **kw):
+        super(CustomObjectTest, self).__init__(*args, **kw)
+        self.collection_class = ObjectCollection
+
+    def test_basic(self):
+        Parent, Child = self.Parent, self.Child
+
+        p = Parent('p1')
+        self.assert_(len(list(p.children)) == 0)
+
+        p.children.append('child')
+        self.assert_(len(list(p.children)) == 1)
+
+        p = self.roundtrip(p)
+        self.assert_(len(list(p.children)) == 1)
+
+        # We didn't provide an alternate _AssociationList implementation for
+        # our ObjectCollection, so indexing will fail.
+        try:
+            v = p.children[1]
+            self.fail()
+        except TypeError:
+            pass
+
+class ScalarTest(PersistTest):
+    def test_scalar_proxy(self):
+        metadata = BoundMetaData(db)
+    
+        parents_table = Table('Parent', metadata,
+                              Column('id', Integer, primary_key=True),
+                              Column('name', String))
+        children_table = Table('Children', metadata,
+                               Column('id', Integer, primary_key=True),
+                               Column('parent_id', Integer,
+                                      ForeignKey('Parent.id')),
+                               Column('foo', String),
+                               Column('bar', String),
+                               Column('baz', String))
+
+        class Parent(object):
+            foo = association_proxy('child', 'foo')
+            bar = association_proxy('child', 'bar',
+                                    creator=lambda v: Child(bar=v))
+            baz = association_proxy('child', 'baz',
+                                    creator=lambda v: Child(baz=v))
+        
+            def __init__(self, name):
+                self.name = name
+
+        class Child(object):
+            def __init__(self, **kw):
+                for attr in kw:
+                    setattr(self, attr, kw[attr])
+
+        mapper(Parent, parents_table, properties={
+            'child': relation(Child, lazy=False,
+                                 backref='parent', uselist=False)})
+        mapper(Child, children_table)
+
+        metadata.create_all()
+        session = create_session()
+
+        def roundtrip(obj):
+            session.save(obj)
+            session.flush()
+            id, type_ = obj.id, type(obj)
+            session.clear()
+            return session.query(type_).get(id)
+            
+        p = Parent('p')
+
+        # No child
+        try:
+            v = p.foo
+            self.fail()
+        except:
+            pass
+
+        p.child = Child(foo='a', bar='b', baz='c')
+
+        self.assert_(p.foo == 'a')
+        self.assert_(p.bar == 'b')
+        self.assert_(p.baz == 'c')
+
+        p.bar = 'x'
+        self.assert_(p.foo == 'a')
+        self.assert_(p.bar == 'x')
+        self.assert_(p.baz == 'c')
+        
+        p = roundtrip(p)
+
+        self.assert_(p.foo == 'a')
+        self.assert_(p.bar == 'x')
+        self.assert_(p.baz == 'c')
+
+        p.child = None
+
+        # No child again
+        try:
+            v = p.foo
+            self.fail()
+        except:
+            pass
+
+        # Bogus creator for this scalar type
+        try:
+            p.foo = 'zzz'
+            self.fail()
+        except TypeError:
+            pass
+
+        p.bar = 'yyy'
+
+        self.assert_(p.foo is None)
+        self.assert_(p.bar == 'yyy')
+        self.assert_(p.baz is None)
+
+        del p.child
+
+        p = roundtrip(p)
+
+        self.assert_(p.child is None)
+
+        p.baz = 'xxx'
+
+        self.assert_(p.foo is None)
+        self.assert_(p.bar is None)
+        self.assert_(p.baz == 'xxx')
+
+        p = roundtrip(p)
+
+        self.assert_(p.foo is None)
+        self.assert_(p.bar is None)
+        self.assert_(p.baz == 'xxx')
+
+if __name__ == "__main__":
+    testbase.main()