See the example ``examples/association/proxied_association.py``.
"""
-from sqlalchemy.orm import class_mapper
+from sqlalchemy.orm.attributes import InstrumentedList
+import sqlalchemy.exceptions as exceptions
+import sqlalchemy.orm as orm
+import sqlalchemy.util as util
+
+def association_proxy(targetcollection, attr, **kw):
+ """Convenience function for use in mapped classes. Implements a Python
+ property representing a relation as a collection of simpler values. The
+ proxied property will mimic the collection type of the target (list, dict
+ or set), or in the case of a one to one relation, a simple scalar value.
+
+ targetcollection
+ Name of the relation attribute we'll proxy to, usually created with
+ 'relation()' in a mapper setup.
+
+ attr
+ Attribute on the associated instances we'll proxy for. For example,
+ given a target collection of [obj1, obj2], a list created by this proxy
+ property would look like
+ [getattr(obj1, attr), getattr(obj2, attr)]
+
+ If the relation is one-to-one or otherwise uselist=False, then simply:
+ getattr(obj, attr)
+
+ creator (optional)
+ When new items are added to this proxied collection, new instances of
+ the class collected by the target collection will be created. For
+ list and set collections, the target class constructor will be called
+ with the 'value' for the new instance. For dict types, two arguments
+ are passed: key and value.
+
+ If you want to construct instances differently, supply a 'creator'
+ function that takes arguments as above and returns instances.
+
+ For scalar relations, creator() will be called if the target is None.
+ If the target is present, set operations are proxied to setattr() on the
+ associated object.
+
+ If you have an associated object with multiple attributes, you may set up
+ multiple association proxies mapping to different attributes. See the
+ unit tests for examples, and for examples of how creator() functions can
+ be used to construct the scalar relation on-demand in this situation.
+
+ Passes along any other arguments to AssociationProxy
+ """
+
+ return AssociationProxy(targetcollection, attr, **kw)
+
class AssociationProxy(object):
- """A property object that automatically sets up ``AssociationLists`` on a parent object."""
+ """A property object that automatically sets up `AssociationLists`
+ on an object."""
- def __init__(self, targetcollection, attr, creator=None):
- """Create a new association property.
+ def __init__(self, targetcollection, attr, creator=None,
+ proxy_factory=None, proxy_bulk_set=None):
+ """Arguments are:
- targetcollection
- The attribute name which stores the collection of Associations.
+ targetcollection
+ Name of the collection we'll proxy to, usually created with
+ 'relation()' in a mapper setup.
- attr
- Name of the attribute on the Association in which to get/set target values.
+ attr
+ Attribute on the collected instances we'll proxy for. For example,
+ given a target collection of [obj1, obj2],
+ a list created by this proxy property would look like
+ [getattr(obj1, attr), getattr(obj2, attr)]
- creator
- Optional callable which is used to create a new association
- object. This callable is given a single argument which is
- an instance of the *proxied* object. If creator is not
- given, the association object is created using the class
- associated with the targetcollection attribute, using its
- ``__init__()`` constructor and setting the proxied
- attribute.
+ creator
+ Optional. When new items are added to this proxied collection, new
+ instances of the class collected by the target collection will be
+ created. For list and set collections, the target class
+ constructor will be called with the 'value' for the new instance.
+ For dict types, two arguments are passed: key and value.
+
+ If you want to construct instances differently, supply a 'creator'
+ function that takes arguments as above and returns instances.
+
+ proxy_factory
+ Optional. The type of collection to emulate is determined by
+ sniffing the target collection. If your collection type can't be
+ determined by duck typing or you'd like to use a different collection
+ implementation, you may supply a factory function to produce those
+ collections. Only applicable to non-scalar relations.
+
+ proxy_bulk_set
+ Optional, use with proxy_factory. See the _set() method for
+ details.
"""
- self.targetcollection = targetcollection
- self.attr = attr
+ self.target_collection = targetcollection # backwards compat name...
+ self.value_attr = attr
self.creator = creator
-
- def __init_deferred(self):
- prop = class_mapper(self._owner_class).props[self.targetcollection]
- self._cls = prop.mapper.class_
- self._uselist = prop.uselist
+ self.proxy_factory = proxy_factory
+ self.proxy_bulk_set = proxy_bulk_set
- def _get_class(self):
- try:
- return self._cls
- except AttributeError:
- self.__init_deferred()
- return self._cls
+ self.scalar = None
+ self.owning_class = None
+ self.key = '_%s_%s_%s' % (type(self).__name__,
+ targetcollection, id(self))
+ self.collection_class = None
- def _get_uselist(self):
- try:
- return self._uselist
- except AttributeError:
- self.__init_deferred()
- return self._uselist
+ def _get_property(self):
+ return orm.class_mapper(self.owning_class).props[self.target_collection]
- cls = property(_get_class)
- uselist = property(_get_uselist)
+ def _target_class(self):
+ return self._get_property().mapper.class_
+ target_class = property(_target_class)
- def create(self, target, **kw):
- if self.creator is not None:
- return self.creator(target, **kw)
- else:
- assoc = self.cls(**kw)
- setattr(assoc, self.attr, target)
- return assoc
-
- def __get__(self, obj, owner):
- self._owner_class = owner
+
+ def __get__(self, obj, class_):
if obj is None:
- return self
- storage_key = '_AssociationProxy_%s_%s' % (self.targetcollection, self.attr)
- if self.uselist:
+ self.owning_class = class_
+ return
+ elif self.scalar is None:
+ self.scalar = not self._get_property().uselist
+
+ if self.scalar:
+ return getattr(getattr(obj, self.target_collection), self.value_attr)
+ else:
try:
- return getattr(obj, storage_key)
+ return getattr(obj, self.key)
except AttributeError:
- a = _AssociationList(self, obj)
- setattr(obj, storage_key, a)
- return a
+ proxy = self._new(getattr(obj, self.target_collection))
+ setattr(obj, self.key, proxy)
+ return proxy
+
+ def __set__(self, obj, values):
+ if self.scalar:
+ creator = self.creator and self.creator or self.target_class
+ target = getattr(obj, self.target_collection)
+ if target is None:
+ setattr(obj, self.target_collection, creator(values))
+ else:
+ setattr(target, self.value_attr, values)
+ else:
+ proxy = self.__get__(obj, None)
+ proxy.clear()
+ self._set(proxy, values)
+
+ def __delete__(self, obj):
+ delattr(obj, self.key)
+
+ def _new(self, collection):
+ creator = self.creator and self.creator or self.target_class
+
+ # Prefer class typing here to spot dicts with the required append()
+ # method.
+ if isinstance(collection.data, dict):
+ self.collection_class = dict
else:
- return getattr(getattr(obj, self.targetcollection), self.attr)
+ self.collection_class = util.duck_type_collection(collection.data)
+
+ if self.proxy_factory:
+ return self.proxy_factory(collection, creator, self.value_attr)
- def __set__(self, obj, value):
- if self.uselist:
- setattr(obj, self.targetcollection, [self.create(x) for x in value])
+ value_attr = self.value_attr
+ getter = lambda o: getattr(o, value_attr)
+ setter = lambda o, v: setattr(o, value_attr, v)
+
+ if self.collection_class is list:
+ return _AssociationList(collection, creator, getter, setter)
+ elif self.collection_class is dict:
+ kv_setter = lambda o, k, v: setattr(o, value_attr, v)
+ return _AssociationDict(collection, creator, getter, setter)
+ elif self.collection_class is util.Set:
+ return _AssociationSet(collection, creator, getter, setter)
else:
- setattr(obj, self.targetcollection, self.create(value))
+ raise exceptions.ArgumentError(
+ 'could not guess which interface to use for '
+ 'collection_class "%s" backing "%s"; specify a '
+ 'proxy_factory and proxy_bulk_set manually' %
+ (self.collection_class.__name__, self.target_collection))
- def __del__(self, obj):
- delattr(obj, self.targetcollection)
+ def _set(self, proxy, values):
+ if self.proxy_bulk_set:
+ self.proxy_bulk_set(proxy, values)
+ elif self.collection_class is list:
+ proxy.extend(values)
+ elif self.collection_class is dict:
+ proxy.update(values)
+ elif self.collection_class is util.Set:
+ proxy.update(values)
+ else:
+ raise exceptions.ArgumentError(
+ 'no proxy_bulk_set supplied for custom '
+ 'collection_class implementation')
class _AssociationList(object):
- """Generic proxying list which proxies list operations to a
- different list-holding attribute of the parent object, converting
- Association objects to and from a target attribute on each
- Association object.
+ """Generic proxying list which proxies list operations to a another list,
+ converting association objects to and from a simplified value.
"""
- def __init__(self, proxy, parent):
- """Create a new ``AssociationList``."""
- self.proxy = proxy
- self.parent = parent
+ def __init__(self, collection, creator, getter, setter):
+ """
+ collection
+ A list-based collection of entities (usually an object attribute
+ managed by a SQLAlchemy relation())
+
+ creator
+ A function that creates new target entities. Given one parameter:
+ value. The assertion is assumed:
+ obj = creator(somevalue)
+ assert getter(obj) == somevalue
+
+ getter
+ A function. Given an associated object, return the 'value'.
+
+ setter
+ A function. Given an associated object and a value, store
+ that value on the object.
+ """
+
+ self.col = collection
+ self.creator = creator
+ self.getter = getter
+ self.setter = setter
+
+ # For compatibility with 0.3.1 through 0.3.7- pass kw through to creator.
+ # (see append() below)
+ def _create(self, value, **kw):
+ return self.creator(value, **kw)
+
+ def _get(self, object):
+ return self.getter(object)
+
+ def _set(self, object, value):
+ return self.setter(object, value)
- def append(self, item, **kw):
- a = self.proxy.create(item, **kw)
- getattr(self.parent, self.proxy.targetcollection).append(a)
+ def __len__(self):
+ return len(self.col)
+
+ def __nonzero__(self):
+ return True if self.col else False
+
+ def __getitem__(self, index):
+ return self._get(self.col[index])
+
+ def __setitem__(self, index, value):
+ self._set(self.col[index], value)
+
+ def __delitem__(self, index):
+ del self.col[index]
+
+ def __contains__(self, value):
+ for member in self.col:
+ if self._get(member) == value:
+ return True
+ return False
+
+ def __getslice__(self, start, end):
+ return [self._get(member) for member in self.col[start:end]]
+
+ def __setslice__(self, start, end, values):
+ members = [self._create(v) for v in values]
+ self.col[start:end] = members
+
+ def __delslice__(self, start, end):
+ del self.col[start:end]
def __iter__(self):
- return iter([getattr(x, self.proxy.attr) for x in getattr(self.parent, self.proxy.targetcollection)])
+ """Iterate over proxied values. For the actual domain objects,
+ iterate over .col instead or just use the underlying collection
+ directly from its property on the parent."""
+ for member in self.col:
+ yield self._get(member)
+ raise StopIteration
+
+ # For compatibility with 0.3.1 through 0.3.7- pass kw through to creator
+ # on append() only. (Can't on __setitem__, __contains__, etc., obviously.)
+ def append(self, value, **kw):
+ item = self._create(value, **kw)
+ self.col.append(item)
+
+ def extend(self, values):
+ for v in values:
+ self.append(v)
+
+ def insert(self, index, value):
+ self.col[index:index] = [self._create(value)]
+
+ def clear(self):
+ del self.col[0:len(self.col)]
+
+ def __eq__(self, other): return list(self) == other
+ def __ne__(self, other): return list(self) != other
+ def __lt__(self, other): return list(self) < other
+ def __le__(self, other): return list(self) <= other
+ def __gt__(self, other): return list(self) > other
+ def __ge__(self, other): return list(self) >= other
+ def __cmp__(self, other): return cmp(list(self), other)
+
+ def copy(self):
+ return list(self)
def __repr__(self):
- return repr([getattr(x, self.proxy.attr) for x in getattr(self.parent, self.proxy.targetcollection)])
+ return repr(list(self))
+
+ def hash(self):
+ raise TypeError("%s objects are unhashable" % type(self).__name__)
+
+_NotProvided = object()
+class _AssociationDict(object):
+ """Generic proxying list which proxies dict operations to a another dict,
+ converting association objects to and from a simplified value.
+ """
+
+ def __init__(self, collection, creator, getter, setter):
+ """
+ collection
+ A list-based collection of entities (usually an object attribute
+ managed by a SQLAlchemy relation())
+
+ creator
+ A function that creates new target entities. Given two parameters:
+ key and value. The assertion is assumed:
+ obj = creator(somekey, somevalue)
+ assert getter(somekey) == somevalue
+
+ getter
+ A function. Given an associated object and a key, return the 'value'.
+
+ setter
+ A function. Given an associated object, a key and a value, store
+ that value on the object.
+ """
+
+ self.col = collection
+ self.creator = creator
+ self.getter = getter
+ self.setter = setter
+
+ def _create(self, key, value):
+ return self.creator(key, value)
+
+ def _get(self, object):
+ return self.getter(object)
+
+ def _set(self, object, key, value):
+ return self.setter(object, key, value)
def __len__(self):
- return len(getattr(self.parent, self.proxy.targetcollection))
+ return len(self.col)
- def __getitem__(self, index):
- return getattr(getattr(self.parent, self.proxy.targetcollection)[index], self.proxy.attr)
+ def __nonzero__(self):
+ return True if self.col else False
- def __setitem__(self, index, value):
- a = self.proxy.create(item)
- getattr(self.parent, self.proxy.targetcollection)[index] = a
+ def __getitem__(self, key):
+ return self._get(self.col[key])
+
+ def __setitem__(self, key, value):
+ if key in self.col:
+ self._set(self.col[key], key, value)
+ else:
+ self.col[key] = self._create(key, value)
+
+ def __delitem__(self, key):
+ del self.col[key]
+
+ def __contains__(self, key):
+ return key in self.col
+ has_key = __contains__
+
+ def __iter__(self):
+ return iter(self.col)
+
+ def clear(self):
+ self.col.clear()
+
+ def __eq__(self, other): return dict(self) == other
+ def __ne__(self, other): return dict(self) != other
+ def __lt__(self, other): return dict(self) < other
+ def __le__(self, other): return dict(self) <= other
+ def __gt__(self, other): return dict(self) > other
+ def __ge__(self, other): return dict(self) >= other
+ def __cmp__(self, other): return cmp(dict(self), other)
+
+ def __repr__(self):
+ return repr(dict(self.items()))
+
+ def get(self, key, default=None):
+ try:
+ return self[key]
+ except KeyError:
+ return default
+
+ def setdefault(self, key, default=None):
+ if key not in self.col:
+ self.col[key] = self._create(key, default)
+ return default
+ else:
+ return self[key]
+
+ def keys(self):
+ return self.col.keys()
+ def iterkeys(self):
+ return self.col.iterkeys()
+
+ def values(self):
+ return [ self._get(member) for member in self.col.values() ]
+ def itervalues(self):
+ for key in self.col:
+ yield self._get(self.col[key])
+ raise StopIteration
+
+ def items(self):
+ return [(k, self._get(self.col[k])) for k in self]
+ def iteritems(self):
+ for key in self.col:
+ yield (key, self._get(self.col[key]))
+ raise StopIteration
+
+ def pop(self, key, default=_NotProvided):
+ if default is _NotProvided:
+ member = self.col.pop(key)
+ else:
+ member = self.col.pop(key, default)
+ return self._get(member)
+
+ def popitem(self):
+ item = self.col.popitem()
+ return (item[0], self._get(item[1]))
+
+ def update(self, *a, **kw):
+ if len(a) > 1:
+ raise TypeError('update expected at most 1 arguments, got %i' %
+ len(a))
+ elif len(a) == 1:
+ seq_or_map = a[0]
+ for item in seq_or_map:
+ if isinstance(item, tuple):
+ self[item[0]] = item[1]
+ else:
+ self[item] = seq_or_map[item]
+
+ for key, value in kw:
+ self[key] = value
+
+ def copy(self):
+ return dict(self.items())
+
+ def hash(self):
+ raise TypeError("%s objects are unhashable" % type(self).__name__)
+
+class _AssociationSet(object):
+ """Generic proxying list which proxies set operations to a another set,
+ converting association objects to and from a simplified value.
+ """
+
+ def __init__(self, collection, creator, getter, setter):
+ """
+ collection
+ A list-based collection of entities (usually an object attribute
+ managed by a SQLAlchemy relation())
+
+ creator
+ A function that creates new target entities. Given one parameter:
+ value. The assertion is assumed:
+ obj = creator(somevalue)
+ assert getter(obj) == somevalue
+
+ getter
+ A function. Given an associated object, return the 'value'.
+
+ setter
+ A function. Given an associated object and a value, store
+ that value on the object.
+ """
+
+ self.col = collection
+ self.creator = creator
+ self.getter = getter
+ self.setter = setter
+
+ def _create(self, value):
+ return self.creator(value)
+
+ def _get(self, object):
+ return self.getter(object)
+
+ def _set(self, object, value):
+ return self.setter(object, value)
+
+ def __len__(self):
+ return len(self.col)
+
+ def __nonzero__(self):
+ return True if self.col else False
+
+ def __contains__(self, value):
+ for member in self.col:
+ if self._get(member) == value:
+ return True
+ return False
+
+ def __iter__(self):
+ """Iterate over proxied values. For the actual domain objects,
+ iterate over .col instead or just use the underlying collection
+ directly from its property on the parent."""
+ for member in self.col:
+ yield self._get(member)
+ raise StopIteration
+
+ def add(self, value):
+ if value not in self:
+ # must shove this through InstrumentedList.append() which will
+ # eventually call the collection_class .add()
+ self.col.append(self._create(value))
+
+ # for discard and remove, choosing a more expensive check strategy rather
+ # than call self.creator()
+ def discard(self, value):
+ for member in self.col:
+ if self._get(member) == value:
+ self.col.discard(member)
+ break
+
+ def remove(self, value):
+ for member in self.col:
+ if self._get(member) == value:
+ self.col.discard(member)
+ return
+ raise KeyError(value)
+
+ def pop(self):
+ if not self.col:
+ raise KeyError('pop from an empty set')
+ # grumble, pop() is borked on InstrumentedList (#548)
+ if isinstance(self.col, InstrumentedList):
+ member = list(self.col)[0]
+ self.col.remove(member)
+ else:
+ member = self.col.pop()
+ return self._get(member)
+
+ def update(self, other):
+ for value in other:
+ self.add(value)
+
+ __ior__ = update
+
+ def _set(self):
+ return util.Set(iter(self))
+
+ def union(self, other):
+ return util.Set(self).union(other)
+
+ __or__ = union
+
+ def difference(self, other):
+ return util.Set(self).difference(other)
+
+ __sub__ = difference
+
+ def difference_update(self, other):
+ for value in other:
+ self.discard(value)
+
+ __isub__ = difference_update
+
+ def intersection(self, other):
+ return util.Set(self).intersection(other)
+
+ __and__ = intersection
+
+ def intersection_update(self, other):
+ want, have = self.intersection(other), util.Set(self)
+
+ remove, add = have - want, want - have
+
+ for value in remove:
+ self.remove(value)
+ for value in add:
+ self.add(value)
+
+ __iand__ = intersection_update
+
+ def symmetric_difference(self, other):
+ return util.Set(self).symmetric_difference(other)
+
+ __xor__ = symmetric_difference
+
+ def symmetric_difference_update(self, other):
+ want, have = self.symmetric_difference(other), util.Set(self)
+
+ remove, add = have - want, want - have
+
+ for value in remove:
+ self.remove(value)
+ for value in add:
+ self.add(value)
+
+ __ixor__ = symmetric_difference_update
+
+ def issubset(self, other):
+ return util.Set(self).issubset(other)
+
+ def issuperset(self, other):
+ return util.Set(self).issuperset(other)
+
+ def clear(self):
+ self.col.clear()
+
+ def copy(self):
+ return util.Set(self)
+
+ def __eq__(self, other): return util.Set(self) == other
+ def __ne__(self, other): return util.Set(self) != other
+ def __lt__(self, other): return util.Set(self) < other
+ def __le__(self, other): return util.Set(self) <= other
+ def __gt__(self, other): return util.Set(self) > other
+ def __ge__(self, other): return util.Set(self) >= other
+
+ def __repr__(self):
+ return repr(util.Set(self))
+
+ def hash(self):
+ raise TypeError("%s objects are unhashable" % type(self).__name__)
--- /dev/null
+from testbase import PersistTest
+import sqlalchemy.util as util
+import unittest
+import testbase
+from sqlalchemy import *
+from sqlalchemy.ext.associationproxy import *
+
+db = testbase.db
+
+class DictCollection(dict):
+ def append(self, obj):
+ self[obj.foo] = obj
+ def __iter__(self):
+ return self.itervalues()
+
+class SetCollection(set):
+ pass
+
+class ListCollection(list):
+ pass
+
+class ObjectCollection(object):
+ def __init__(self):
+ self.values = list()
+ def append(self, obj):
+ self.values.append(obj)
+ def __iter__(self):
+ return iter(self.values)
+ def clear(self):
+ self.values.clear()
+
+class _CollectionOperations(PersistTest):
+ def setUp(self):
+ collection_class = self.collection_class
+
+ metadata = BoundMetaData(db)
+
+ parents_table = Table('Parent', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('name', String))
+ children_table = Table('Children', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('parent_id', Integer,
+ ForeignKey('Parent.id')),
+ Column('foo', String),
+ Column('name', String))
+
+ class Parent(object):
+ children = association_proxy('_children', 'name')
+
+ def __init__(self, name):
+ self.name = name
+
+ class Child(object):
+ if collection_class and issubclass(collection_class, dict):
+ def __init__(self, foo, name):
+ self.foo = foo
+ self.name = name
+ else:
+ def __init__(self, name):
+ self.name = name
+
+ mapper(Parent, parents_table, properties={
+ '_children': relation(Child, lazy=False,
+ collection_class=collection_class)})
+ mapper(Child, children_table)
+
+ metadata.create_all()
+
+ self.metadata = metadata
+ self.session = create_session()
+ self.Parent, self.Child = Parent, Child
+
+ def tearDown(self):
+ self.metadata.drop_all()
+
+ def roundtrip(self, obj):
+ self.session.save(obj)
+ self.session.flush()
+ id, type_ = obj.id, type(obj)
+ self.session.clear()
+ return self.session.query(type_).get(id)
+
+ def _test_sequence_ops(self):
+ Parent, Child = self.Parent, self.Child
+
+ p1 = Parent('P1')
+
+ self.assert_(not p1._children)
+ self.assert_(not p1.children)
+
+ ch = Child('regular')
+ p1._children.append(ch)
+
+ self.assert_(ch in p1._children)
+ self.assert_(len(p1._children) == 1)
+
+ self.assert_(p1.children)
+ self.assert_(len(p1.children) == 1)
+ self.assert_(ch not in p1.children)
+ self.assert_('regular' in p1.children)
+
+ p1.children.append('proxied')
+
+ self.assert_('proxied' in p1.children)
+ self.assert_('proxied' not in p1._children)
+ self.assert_(len(p1.children) == 2)
+ self.assert_(len(p1._children) == 2)
+
+ self.assert_(p1._children[0].name == 'regular')
+ self.assert_(p1._children[1].name == 'proxied')
+
+ del p1._children[1]
+
+ self.assert_(len(p1._children) == 1)
+ self.assert_(len(p1.children) == 1)
+ self.assert_(p1._children[0] == ch)
+
+ del p1.children[0]
+
+ self.assert_(len(p1._children) == 0)
+ self.assert_(len(p1.children) == 0)
+
+ p1.children = ['a','b','c']
+ self.assert_(len(p1._children) == 3)
+ self.assert_(len(p1.children) == 3)
+
+ del ch
+ p1 = self.roundtrip(p1)
+
+ self.assert_(len(p1._children) == 3)
+ self.assert_(len(p1.children) == 3)
+
+class DefaultTest(_CollectionOperations):
+ def __init__(self, *args, **kw):
+ super(DefaultTest, self).__init__(*args, **kw)
+ self.collection_class = None
+
+ def test_sequence_ops(self):
+ self._test_sequence_ops()
+
+class ListTest(_CollectionOperations):
+ def __init__(self, *args, **kw):
+ super(ListTest, self).__init__(*args, **kw)
+ self.collection_class = list
+
+ def test_sequence_ops(self):
+ self._test_sequence_ops()
+
+class CustomListTest(ListTest):
+ def __init__(self, *args, **kw):
+ super(CustomListTest, self).__init__(*args, **kw)
+ self.collection_class = list
+
+# No-can-do until ticket #213
+class DictTest(_CollectionOperations):
+ pass
+
+class CustomDictTest(DictTest):
+ def __init__(self, *args, **kw):
+ super(DictTest, self).__init__(*args, **kw)
+ self.collection_class = DictCollection
+
+ def test_mapping_ops(self):
+ Parent, Child = self.Parent, self.Child
+
+ p1 = Parent('P1')
+
+ self.assert_(not p1._children)
+ self.assert_(not p1.children)
+
+ ch = Child('a', 'regular')
+ p1._children.append(ch)
+
+ print repr(p1._children)
+ self.assert_(ch in p1._children.values())
+ self.assert_(len(p1._children) == 1)
+
+ self.assert_(p1.children)
+ self.assert_(len(p1.children) == 1)
+ self.assert_(ch not in p1.children)
+ self.assert_('a' in p1.children)
+ self.assert_(p1.children['a'] == 'regular')
+ self.assert_(p1._children['a'] == ch)
+
+ p1.children['b'] = 'proxied'
+
+ self.assert_('proxied' in p1.children.values())
+ self.assert_('b' in p1.children)
+ self.assert_('proxied' not in p1._children)
+ self.assert_(len(p1.children) == 2)
+ self.assert_(len(p1._children) == 2)
+
+ self.assert_(p1._children['a'].name == 'regular')
+ self.assert_(p1._children['b'].name == 'proxied')
+
+ del p1._children['b']
+
+ self.assert_(len(p1._children) == 1)
+ self.assert_(len(p1.children) == 1)
+ self.assert_(p1._children['a'] == ch)
+
+ del p1.children['a']
+
+ self.assert_(len(p1._children) == 0)
+ self.assert_(len(p1.children) == 0)
+
+ p1.children = {'d': 'v d', 'e': 'v e', 'f': 'v f'}
+ self.assert_(len(p1._children) == 3)
+ self.assert_(len(p1.children) == 3)
+
+ del ch
+ p1 = self.roundtrip(p1)
+ self.assert_(len(p1._children) == 3)
+ self.assert_(len(p1.children) == 3)
+
+
+class SetTest(_CollectionOperations):
+ def __init__(self, *args, **kw):
+ super(SetTest, self).__init__(*args, **kw)
+ self.collection_class = set
+
+ def test_set_operations(self):
+ Parent, Child = self.Parent, self.Child
+
+ p1 = Parent('P1')
+
+ self.assert_(not p1._children)
+ self.assert_(not p1.children)
+
+ ch1 = Child('regular')
+ p1._children.append(ch1)
+
+ self.assert_(ch1 in p1._children)
+ self.assert_(len(p1._children) == 1)
+
+ self.assert_(p1.children)
+ self.assert_(len(p1.children) == 1)
+ self.assert_(ch1 not in p1.children)
+ self.assert_('regular' in p1.children)
+
+ p1.children.add('proxied')
+
+ self.assert_('proxied' in p1.children)
+ self.assert_('proxied' not in p1._children)
+ self.assert_(len(p1.children) == 2)
+ self.assert_(len(p1._children) == 2)
+
+ self.assert_(set([o.name for o in p1._children]) == set(['regular', 'proxied']))
+
+ ch2 = None
+ for o in p1._children:
+ if o.name == 'proxied':
+ ch2 = o
+ break
+
+ p1._children.remove(ch2)
+
+ self.assert_(len(p1._children) == 1)
+ self.assert_(len(p1.children) == 1)
+ self.assert_(p1._children == set([ch1]))
+
+ p1.children.remove('regular')
+
+ self.assert_(len(p1._children) == 0)
+ self.assert_(len(p1.children) == 0)
+
+ p1.children = ['a','b','c']
+ self.assert_(len(p1._children) == 3)
+ self.assert_(len(p1.children) == 3)
+
+ del ch1
+ p1 = self.roundtrip(p1)
+
+ self.assert_(len(p1._children) == 3)
+ self.assert_(len(p1.children) == 3)
+
+ self.assert_('a' in p1.children)
+ self.assert_('b' in p1.children)
+ self.assert_('d' not in p1.children)
+
+ self.assert_(p1.children == set(['a','b','c']))
+
+ try:
+ p1.children.remove('d')
+ self.fail()
+ except KeyError:
+ pass
+
+ self.assert_(len(p1.children) == 3)
+ p1.children.discard('d')
+ self.assert_(len(p1.children) == 3)
+ p1 = self.roundtrip(p1)
+ self.assert_(len(p1.children) == 3)
+
+ popped = p1.children.pop()
+ self.assert_(len(p1.children) == 2)
+ self.assert_(popped not in p1.children)
+ p1 = self.roundtrip(p1)
+ self.assert_(len(p1.children) == 2)
+ self.assert_(popped not in p1.children)
+
+ p1.children = ['a','b','c']
+ p1 = self.roundtrip(p1)
+ self.assert_(p1.children == set(['a','b','c']))
+
+ p1.children.discard('b')
+ p1 = self.roundtrip(p1)
+ self.assert_(p1.children == set(['a', 'c']))
+
+ p1.children.remove('a')
+ p1 = self.roundtrip(p1)
+ self.assert_(p1.children == set(['c']))
+
+ def test_set_comparisons(self):
+ Parent, Child = self.Parent, self.Child
+
+ p1 = Parent('P1')
+ p1.children = ['a','b','c']
+ control = set(['a','b','c'])
+
+ for other in (set(['a','b','c']), set(['a','b','c','d']),
+ set(['a']), set(['a','b']),
+ set(['c','d']), set(['e', 'f', 'g']),
+ set()):
+
+ self.assertEqual(p1.children.union(other),
+ control.union(other))
+ self.assertEqual(p1.children.difference(other),
+ control.difference(other))
+ self.assertEqual((p1.children - other),
+ (control - other))
+ self.assertEqual(p1.children.intersection(other),
+ control.intersection(other))
+ self.assertEqual(p1.children.symmetric_difference(other),
+ control.symmetric_difference(other))
+ self.assertEqual(p1.children.issubset(other),
+ control.issubset(other))
+ self.assertEqual(p1.children.issuperset(other),
+ control.issuperset(other))
+
+ self.assert_((p1.children == other) == (control == other))
+ self.assert_((p1.children != other) == (control != other))
+ self.assert_((p1.children < other) == (control < other))
+ self.assert_((p1.children <= other) == (control <= other))
+ self.assert_((p1.children > other) == (control > other))
+ self.assert_((p1.children >= other) == (control >= other))
+
+ def test_set_mutation(self):
+ Parent, Child = self.Parent, self.Child
+
+ # mutations
+ for op in ('update', 'intersection_update',
+ 'difference_update', 'symmetric_difference_update'):
+ for base in (['a', 'b', 'c'], []):
+ for other in (set(['a','b','c']), set(['a','b','c','d']),
+ set(['a']), set(['a','b']),
+ set(['c','d']), set(['e', 'f', 'g']),
+ set()):
+ p = Parent('p')
+ p.children = base[:]
+ control = set(base[:])
+
+ getattr(p.children, op)(other)
+ getattr(control, op)(other)
+ try:
+ self.assert_(p.children == control)
+ except:
+ print 'Test %s.%s(%s):' % (set(base), op, other)
+ print 'want', repr(control)
+ print 'got', repr(p.children)
+ raise
+
+ p = self.roundtrip(p)
+
+ try:
+ self.assert_(p.children == control)
+ except:
+ print 'Test %s.%s(%s):' % (base, op, other)
+ print 'want', repr(control)
+ print 'got', repr(p.children)
+ raise
+
+ # workaround for bug #548
+ def test_set_pop(self):
+ Parent, Child = self.Parent, self.Child
+ p = Parent('p1')
+ p.children.add('a')
+ p.children.pop()
+ self.assert_(True)
+
+class CustomSetTest(SetTest):
+ def __init__(self, *args, **kw):
+ super(CustomSetTest, self).__init__(*args, **kw)
+ self.collection_class = SetCollection
+
+class CustomObjectTest(_CollectionOperations):
+ def __init__(self, *args, **kw):
+ super(CustomObjectTest, self).__init__(*args, **kw)
+ self.collection_class = ObjectCollection
+
+ def test_basic(self):
+ Parent, Child = self.Parent, self.Child
+
+ p = Parent('p1')
+ self.assert_(len(list(p.children)) == 0)
+
+ p.children.append('child')
+ self.assert_(len(list(p.children)) == 1)
+
+ p = self.roundtrip(p)
+ self.assert_(len(list(p.children)) == 1)
+
+ # We didn't provide an alternate _AssociationList implementation for
+ # our ObjectCollection, so indexing will fail.
+ try:
+ v = p.children[1]
+ self.fail()
+ except TypeError:
+ pass
+
+class ScalarTest(PersistTest):
+ def test_scalar_proxy(self):
+ metadata = BoundMetaData(db)
+
+ parents_table = Table('Parent', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('name', String))
+ children_table = Table('Children', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('parent_id', Integer,
+ ForeignKey('Parent.id')),
+ Column('foo', String),
+ Column('bar', String),
+ Column('baz', String))
+
+ class Parent(object):
+ foo = association_proxy('child', 'foo')
+ bar = association_proxy('child', 'bar',
+ creator=lambda v: Child(bar=v))
+ baz = association_proxy('child', 'baz',
+ creator=lambda v: Child(baz=v))
+
+ def __init__(self, name):
+ self.name = name
+
+ class Child(object):
+ def __init__(self, **kw):
+ for attr in kw:
+ setattr(self, attr, kw[attr])
+
+ mapper(Parent, parents_table, properties={
+ 'child': relation(Child, lazy=False,
+ backref='parent', uselist=False)})
+ mapper(Child, children_table)
+
+ metadata.create_all()
+ session = create_session()
+
+ def roundtrip(obj):
+ session.save(obj)
+ session.flush()
+ id, type_ = obj.id, type(obj)
+ session.clear()
+ return session.query(type_).get(id)
+
+ p = Parent('p')
+
+ # No child
+ try:
+ v = p.foo
+ self.fail()
+ except:
+ pass
+
+ p.child = Child(foo='a', bar='b', baz='c')
+
+ self.assert_(p.foo == 'a')
+ self.assert_(p.bar == 'b')
+ self.assert_(p.baz == 'c')
+
+ p.bar = 'x'
+ self.assert_(p.foo == 'a')
+ self.assert_(p.bar == 'x')
+ self.assert_(p.baz == 'c')
+
+ p = roundtrip(p)
+
+ self.assert_(p.foo == 'a')
+ self.assert_(p.bar == 'x')
+ self.assert_(p.baz == 'c')
+
+ p.child = None
+
+ # No child again
+ try:
+ v = p.foo
+ self.fail()
+ except:
+ pass
+
+ # Bogus creator for this scalar type
+ try:
+ p.foo = 'zzz'
+ self.fail()
+ except TypeError:
+ pass
+
+ p.bar = 'yyy'
+
+ self.assert_(p.foo is None)
+ self.assert_(p.bar == 'yyy')
+ self.assert_(p.baz is None)
+
+ del p.child
+
+ p = roundtrip(p)
+
+ self.assert_(p.child is None)
+
+ p.baz = 'xxx'
+
+ self.assert_(p.foo is None)
+ self.assert_(p.bar is None)
+ self.assert_(p.baz == 'xxx')
+
+ p = roundtrip(p)
+
+ self.assert_(p.foo is None)
+ self.assert_(p.bar is None)
+ self.assert_(p.baz == 'xxx')
+
+if __name__ == "__main__":
+ testbase.main()