From: Jason Kirtland Date: Wed, 27 Jun 2007 21:08:14 +0000 (+0000) Subject: - Replaced collection api: The "InstrumentedList" proxy is replaced with X-Git-Tag: rel_0_4_6~171 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=7b87fcecd6652187fa789066c20b67a5154f44e1;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - Replaced collection api: The "InstrumentedList" proxy is replaced with a proxy-free, decorator-based approach for user-space instrumentation and a "view" adapter for interaction with the user's collection within the orm. Fixes [ticket:213], [ticket:548], [ticket:563]. - This needs many more unit tests. There is significant indirect coverage through association proxy, but direct tests are needed, specifically in the decorators and add/remove event firing. - Collections are now instrumented via decorations rather than proxying. You can now have collections that manage their own membership, and your class instance will be directly exposed on the relation property. The changes are transparent for most users. - InstrumentedList (as it was) is removed, and relation properties no longer have 'clear()', '.data', or any other added methods beyond those provided by the collection type. You are free, of course, to add them to a custom class. - __setitem__-like assignments now fire remove events for the existing value, if any. - dict-likes used as collection classes no longer need to change __iter__ semantics- itervalues() is used by default instead. This is a backwards incompatible change. - subclassing dict for a mapped collection is no longer needed in most cases. orm.collections provides canned implementations that key objects by a specified column or a custom function of your choice. - collection assignment now requires a compatible type- assigning None to clear a collection or assinging a list to a dict collection will now raise an argument error. - AttributeExtension moved to interfaces, and .delete is now .remove The event method signature has also been swapped around. --- diff --git a/CHANGES b/CHANGES index 1e9386551d..68e7b2d5ef 100644 --- a/CHANGES +++ b/CHANGES @@ -1,5 +1,27 @@ 0.4.0 - orm + - new collection_class api and implementation [ticket:213] + collections are now instrumented via decorations rather than + proxying. you can now have collections that manage their own + membership, and your class instance will be directly exposed on the + relation property. the changes are transparent for most users. + - InstrumentedList (as it was) is removed, and relation properties no + longer have 'clear()', '.data', or any other added methods beyond those + provided by the collection type. you are free, of course, to add them + to a custom class. + - __setitem__-like assignments now fire remove events for the existing + value, if any. + - dict-likes used as collection classes no longer need to change __iter__ + semantics- itervalues() is used by default instead. this is a backwards + incompatible change. + - subclassing dict for a mapped collection is no longer needed in most cases. + orm.collections provides canned implementations that key objects by a + specified column or a custom function of your choice. + - collection assignment now requires a compatible type- assigning None + to clear a collection or assinging a list to a dict collection will now + raise an argument error. + - AttributeExtension moved to interfaces, and .delete is now .remove + The event method signature has also been swapped around. - major interface pare-down for Query: all selectXXX methods are deprecated. generative methods are now the standard way to do things, i.e. filter(), filter_by(), all(), one(), diff --git a/lib/sqlalchemy/ext/associationproxy.py b/lib/sqlalchemy/ext/associationproxy.py index cdb8147027..c55b547611 100644 --- a/lib/sqlalchemy/ext/associationproxy.py +++ b/lib/sqlalchemy/ext/associationproxy.py @@ -6,10 +6,10 @@ transparent proxied access to the endpoint of an association object. See the example ``examples/association/proxied_association.py``. """ -from sqlalchemy.orm.attributes import InstrumentedList import sqlalchemy.exceptions as exceptions import sqlalchemy.orm as orm import sqlalchemy.util as util +from sqlalchemy.orm import collections import weakref def association_proxy(targetcollection, attr, **kw): @@ -168,15 +168,7 @@ class AssociationProxy(object): def _new(self, lazy_collection): creator = self.creator and self.creator or self.target_class - - # Prefer class typing here to spot dicts with the required append() - # method. - collection = lazy_collection() - if isinstance(collection.data, dict): - self.collection_class = dict - else: - self.collection_class = util.duck_type_collection(collection.data) - del collection + self.collection_class = util.duck_type_collection(lazy_collection()) if self.proxy_factory: return self.proxy_factory(lazy_collection, creator, self.value_attr) @@ -545,9 +537,7 @@ class _AssociationSet(object): def add(self, value): if value not in self: - # must shove this through InstrumentedList.append() which will - # eventually call the collection_class .add() - self.col.append(self._create(value)) + self.col.add(self._create(value)) # for discard and remove, choosing a more expensive check strategy rather # than call self.creator() @@ -567,12 +557,7 @@ class _AssociationSet(object): def pop(self): if not self.col: raise KeyError('pop from an empty set') - # grumble, pop() is borked on InstrumentedList (#548) - if isinstance(self.col, InstrumentedList): - member = list(self.col)[0] - self.col.remove(member) - else: - member = self.col.pop() + member = self.col.pop() return self._get(member) def update(self, other): diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index f64360766f..64e2499393 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -5,37 +5,25 @@ # the MIT License: http://www.opensource.org/licenses/mit-license.php from sqlalchemy import util -from sqlalchemy.orm import util as orm_util +from sqlalchemy.orm import util as orm_util, interfaces, collections +from sqlalchemy.orm.mapper import class_mapper from sqlalchemy import logging, exceptions import weakref -class InstrumentedAttribute(object): - """A property object that instruments attribute access on object instances. - All methods correspond to a single attribute on a particular - class. - """ +PASSIVE_NORESULT = object() +ATTR_WAS_SET = object() - PASSIVE_NORESULT = object() - ATTR_WAS_SET = object() - - def __init__(self, manager, key, uselist, callable_, typecallable, trackparent=False, extension=None, copy_function=None, compare_function=None, mutable_scalars=False, **kwargs): +class InstrumentedAttribute(object): + def __init__(self, class_, manager, key, callable_, trackparent=False, extension=None, compare_function=None, mutable_scalars=False, **kwargs): + self.class_ = class_ self.manager = manager self.key = key - self.uselist = uselist self.callable_ = callable_ - self.typecallable= typecallable self.trackparent = trackparent self.mutable_scalars = mutable_scalars - if copy_function is None: - if uselist: - self.copy = lambda x:[y for y in x] - else: - # scalar values are assumed to be immutable unless a copy function - # is passed - self.copy = lambda x:x - else: - self.copy = lambda x:copy_function(x) + + self.copy = None if compare_function is None: self.is_equal = lambda x,y: x == y else: @@ -43,7 +31,7 @@ class InstrumentedAttribute(object): self.extensions = util.to_list(extension or []) def __set__(self, obj, value): - self.set(None, obj, value) + self.set(obj, value, None) def __delete__(self, obj): self.delete(None, obj) @@ -53,21 +41,6 @@ class InstrumentedAttribute(object): return self return self.get(obj) - def get_instrument(cls, obj, key): - return getattr(obj.__class__, key) - get_instrument = classmethod(get_instrument) - - def check_mutable_modified(self, obj): - if self.mutable_scalars: - h = self.get_history(obj, passive=True) - if h is not None and h.is_modified(): - obj._state['modified'] = True - return True - else: - return False - else: - return False - def hasparent(self, item, optimistic=False): """Return the boolean value of a `hasparent` flag attached to the given item. @@ -103,8 +76,8 @@ class InstrumentedAttribute(object): # get the current state. this may trigger a lazy load if # passive is False. - current = self.get(obj, passive=passive, raiseerr=False) - if current is InstrumentedAttribute.PASSIVE_NORESULT: + current = self.get(obj, passive=passive) + if current is PASSIVE_NORESULT: return None return AttributeHistory(self, obj, current, passive=passive) @@ -128,6 +101,14 @@ class InstrumentedAttribute(object): else: obj._state[('callable', self)] = callable_ + def _get_callable(self, obj): + if ('callable', self) in obj._state: + return obj._state[('callable', self)] + elif self.callable_ is not None: + return self.callable_(obj) + else: + return None + def reset(self, obj): """Remove any per-instance callable functions corresponding to this ``InstrumentedAttribute``'s attribute from the given @@ -153,65 +134,21 @@ class InstrumentedAttribute(object): except KeyError: pass - def _get_callable(self, obj): - if ('callable', self) in obj._state: - return obj._state[('callable', self)] - elif self.callable_ is not None: - return self.callable_(obj) - else: - return None - - def _blank_list(self): - if self.typecallable is not None: - return self.typecallable() - else: - return [] + def check_mutable_modified(self, obj): + return False def initialize(self, obj): - """Initialize this attribute on the given object instance. + """Initialize this attribute on the given object instance with an empty value.""" - If this is a list-based attribute, a new, blank list will be - created. if a scalar attribute, the value will be initialized - to None. - """ + obj.__dict__[self.key] = None + return None - if self.uselist: - l = InstrumentedList(self, obj, self._blank_list()) - obj.__dict__[self.key] = l - return l - else: - obj.__dict__[self.key] = None - return None - - def set_committed_value(self, obj, value): - """set an attribute value on the given instance and 'commit' it. - - this indicates that the given value is the "persisted" value, - and history will be logged only if a newly set value is not - equal to this value. - - this is typically used by deferred/lazy attribute loaders - to set object attributes after the initial load. - """ - - state = obj._state - orig = state.get('original', None) - if self.uselist: - value = InstrumentedList(self, obj, value, init=False) - if orig is not None: - orig.commit_attribute(self, obj, value) - # remove per-instance callable, if any - state.pop(('callable', self), None) - obj.__dict__[self.key] = value - return value - - def get(self, obj, passive=False, raiseerr=True): + def get(self, obj, passive=False): """Retrieve a value from the given object. If a callable is assembled on this object's attribute, and passive is False, the callable will be executed and the - resulting value will be set as the new value for this - attribute. + resulting value will be set as the new value for this attribute. """ try: @@ -224,417 +161,289 @@ class InstrumentedAttribute(object): trig = state['trigger'] del state['trigger'] trig() - return self.get(obj, passive=passive, raiseerr=raiseerr) + return self.get(obj, passive=passive) callable_ = self._get_callable(obj) if callable_ is not None: if passive: - return InstrumentedAttribute.PASSIVE_NORESULT - self.logger.debug("Executing lazy callable on %s.%s" % (orm_util.instance_str(obj), self.key)) + return PASSIVE_NORESULT + self.logger.debug("Executing lazy callable on %s.%s" % + (orm_util.instance_str(obj), self.key)) value = callable_() - if value is not InstrumentedAttribute.ATTR_WAS_SET: + if value is not ATTR_WAS_SET: return self.set_committed_value(obj, value) else: return obj.__dict__[self.key] else: - if self.uselist: - # note that we arent raising AttributeErrors, just creating a new - # blank list and setting it. - # this might be a good thing to be changeable by options. - return self.set_committed_value(obj, self._blank_list()) - else: - # note that we arent raising AttributeErrors, just returning None. - # this might be a good thing to be changeable by options. - value = None - return value + # Return a new, empty value + return self.initialize(obj) - def set(self, event, obj, value): - """Set a value on the given object. + def append(self, obj, value, initiator): + self.set(obj, value, initiator) - `event` is the ``InstrumentedAttribute`` that initiated the - ``set()` operation and is used to control the depth of a - circular setter operation. - """ + def remove(self, obj, value, initiator): + self.set(obj, None, initiator) - if event is not self: - state = obj._state - # if an instance-wide "trigger" was set, call that - if 'trigger' in state: - trig = state['trigger'] - del state['trigger'] - trig() - if self.uselist: - value = InstrumentedList(self, obj, value) - old = self.get(obj) - obj.__dict__[self.key] = value - state['modified'] = True - if not self.uselist: - if self.trackparent: - if value is not None: - self.sethasparent(value, True) - if old is not None: - self.sethasparent(old, False) - for ext in self.extensions: - ext.set(event or self, obj, value, old) - else: - # mark all the old elements as detached from the parent - old.list_replaced() - - def delete(self, event, obj): - """Delete a value from the given object. - - `event` is the ``InstrumentedAttribute`` that initiated the - ``delete()`` operation and is used to control the depth of a - circular delete operation. - """ - - if event is not self: - try: - if not self.uselist and (self.trackparent or len(self.extensions)): - old = self.get(obj) - del obj.__dict__[self.key] - except KeyError: - # TODO: raise this? not consistent with get() ? - raise AttributeError(self.key) - obj._state['modified'] = True - if not self.uselist: - if self.trackparent: - if old is not None: - self.sethasparent(old, False) - for ext in self.extensions: - ext.delete(event or self, obj, old) - - def append(self, event, obj, value): - """Append an element to a list based element or sets a scalar - based element to the given value. - - Used by ``GenericBackrefExtension`` to *append* an item - independent of list/scalar semantics. - - `event` is the ``InstrumentedAttribute`` that initiated the - ``append()`` operation and is used to control the depth of a - circular append operation. - """ - - if self.uselist: - if event is not self: - self.get(obj).append_with_event(value, event) - else: - self.set(event, obj, value) - - def remove(self, event, obj, value): - """Remove an element from a list based element or sets a - scalar based element to None. - - Used by ``GenericBackrefExtension`` to *remove* an item - independent of list/scalar semantics. + def set(self, obj, value, initiator): + raise NotImplementedError() - `event` is the ``InstrumentedAttribute`` that initiated the - ``remove()`` operation and is used to control the depth of a - circular remove operation. + def set_committed_value(self, obj, value): + """set an attribute value on the given instance and 'commit' it. + + this indicates that the given value is the "persisted" value, + and history will be logged only if a newly set value is not + equal to this value. + + this is typically used by deferred/lazy attribute loaders + to set object attributes after the initial load. """ - if self.uselist: - if event is not self: - self.get(obj).remove_with_event(value, event) - else: - self.set(event, obj, None) + state = obj._state + orig = state.get('original', None) + if orig is not None: + orig.commit_attribute(self, obj, value) + # remove per-instance callable, if any + state.pop(('callable', self), None) + obj.__dict__[self.key] = value + return value - def append_event(self, event, obj, value): - """Called by ``InstrumentedList`` when an item is appended.""" + def set_raw_value(self, obj, value): + obj.__dict__[self.key] = value + return value + def fire_append_event(self, obj, value, initiator): obj._state['modified'] = True if self.trackparent and value is not None: self.sethasparent(value, True) for ext in self.extensions: - ext.append(event or self, obj, value) - - def remove_event(self, event, obj, value): - """Called by ``InstrumentedList`` when an item is removed.""" + ext.append(obj, value, initiator or self) + def fire_remove_event(self, obj, value, initiator): obj._state['modified'] = True if self.trackparent and value is not None: self.sethasparent(value, False) for ext in self.extensions: - ext.delete(event or self, obj, value) + ext.remove(obj, value, initiator or self) -InstrumentedAttribute.logger = logging.class_logger(InstrumentedAttribute) - - -class InstrumentedList(object): - """Instrument a list-based attribute. - - All mutator operations (i.e. append, remove, etc.) will fire off - events to the ``InstrumentedAttribute`` that manages the object's - attribute. Those events in turn trigger things like backref - operations and whatever is implemented by - ``do_list_value_changed`` on ``InstrumentedAttribute``. - - Note that this list does a lot less than earlier versions of SA - list-based attributes, which used ``HistoryArraySet``. This list - wrapper does **not** maintain setlike semantics, meaning you can add - as many duplicates as you want (which can break a lot of SQL), and - also does not do anything related to history tracking. - - Please see ticket #213 for information on the future of this - class, where it will be broken out into more collection-specific - subtypes. - """ + def fire_replace_event(self, obj, value, initiator, previous): + obj._state['modified'] = True + if self.trackparent: + if value is not None: + self.sethasparent(value, True) + if previous is not None: + self.sethasparent(previous, False) + for ext in self.extensions: + ext.set(obj, value, previous, initiator or self) - def __init__(self, attr, obj, data, init=True): - self.attr = attr - # this weakref is to prevent circular references between the parent object - # and the list attribute, which interferes with immediate garbage collection. - self.__obj = weakref.ref(obj) - self.key = attr.key - - # adapt to lists or sets - # TODO: make three subclasses of InstrumentedList that come off from a - # metaclass, based on the type of data sent in - if attr.typecallable is not None: - self.data = attr.typecallable() - else: - self.data = data or attr._blank_list() - - if isinstance(self.data, list): - self._data_appender = self.data.append - self._clear_data = self._clear_list - elif isinstance(self.data, util.Set): - self._data_appender = self.data.add - self._clear_data = self._clear_set - elif isinstance(self.data, dict): - if hasattr(self.data, 'append'): - self._data_appender = self.data.append - else: - raise exceptions.ArgumentError("Dictionary collection class '%s' must implement an append() method" % type(self.data).__name__) - self._clear_data = self._clear_dict - else: - if hasattr(self.data, 'append'): - self._data_appender = self.data.append - elif hasattr(self.data, 'add'): - self._data_appender = self.data.add - else: - raise exceptions.ArgumentError("Collection class '%s' is not of type 'list', 'set', or 'dict' and has no append() or add() method" % type(self.data).__name__) + property = property(lambda s: class_mapper(s.class_).props[s.key], + doc="the MapperProperty object associated with this attribute") - if hasattr(self.data, 'clear'): - self._clear_data = self._clear_set - else: - raise exceptions.ArgumentError("Collection class '%s' is not of type 'list', 'set', or 'dict' and has no clear() method" % type(self.data).__name__) - - if data is not None and data is not self.data: - for elem in data: - self._data_appender(elem) - - - if init: - for x in self.data: - self.__setrecord(x) - - def list_replaced(self): - """Fire off delete event handlers for each item in the list - but doesnt affect the original data list. - """ - - [self.__delrecord(x) for x in self.data] +InstrumentedAttribute.logger = logging.class_logger(InstrumentedAttribute) - def clear(self): - """Clear all items in this InstrumentedList and fires off - delete event handlers for each item. - """ + +class InstrumentedScalarAttribute(InstrumentedAttribute): + def __init__(self, class_, manager, key, callable_, trackparent=False, extension=None, copy_function=None, compare_function=None, mutable_scalars=False, **kwargs): + super(InstrumentedScalarAttribute, self).__init__(class_, manager, key, + callable_, trackparent=trackparent, extension=extension, + compare_function=compare_function, **kwargs) + self.mutable_scalars = mutable_scalars - self._clear_data() + if copy_function is None: + # scalar values are assumed to be immutable unless a copy function + # is passed + self.copy = lambda x:x + else: + self.copy = lambda x:copy_function(x) - def _clear_dict(self): - [self.__delrecord(x) for x in self.data.values()] - self.data.clear() + def __delete__(self, obj): + old = self.get(obj) + del obj.__dict__[self.key] + self.fire_remove_event(obj, old, self) - def _clear_set(self): - [self.__delrecord(x) for x in self.data] - self.data.clear() + def check_mutable_modified(self, obj): + if self.mutable_scalars: + h = self.get_history(obj, passive=True) + if h is not None and h.is_modified(): + obj._state['modified'] = True + return True + else: + return False + else: + return False - def _clear_list(self): - self[:] = [] + def set(self, obj, value, initiator): + """Set a value on the given object. - def __getstate__(self): - """Implemented to allow pickling, since `__obj` is a weakref, - also the ``InstrumentedAttribute`` has callables attached to - it. + `initiator` is the ``InstrumentedAttribute`` that initiated the + ``set()` operation and is used to control the depth of a circular + setter operation. """ - return {'key':self.key, 'obj':self.obj, 'data':self.data} - - def __setstate__(self, d): - """Implemented to allow pickling, since `__obj` is a weakref, - also the ``InstrumentedAttribute`` has callables attached to it. - """ + if initiator is self: + return - self.key = d['key'] - self.__obj = weakref.ref(d['obj']) - self.data = d['data'] - self.attr = getattr(d['obj'].__class__, self.key) + state = obj._state + # if an instance-wide "trigger" was set, call that + if 'trigger' in state: + trig = state['trigger'] + del state['trigger'] + trig() - obj = property(lambda s:s.__obj()) + old = self.get(obj) + obj.__dict__[self.key] = value + self.fire_replace_event(obj, value, initiator, old) - def unchanged_items(self): - """Deprecated.""" +class InstrumentedCollectionAttribute(InstrumentedAttribute): + """A collection-holding attribute that instruments changes in membership. - return self.attr.get_history(self.obj).unchanged_items + InstrumentedCollectionAttribute holds an arbitrary, user-specified + container object (defaulting to a list) and brokers access to the + CollectionAdapter, a "view" onto that object that presents consistent + bag semantics to the orm layer independent of the user data implementation. + """ + + def __init__(self, class_, manager, key, callable_, typecallable=None, trackparent=False, extension=None, copy_function=None, compare_function=None, **kwargs): + super(InstrumentedCollectionAttribute, self).__init__(class_, manager, + key, callable_, trackparent=trackparent, extension=extension, + compare_function=compare_function, **kwargs) - def added_items(self): - """Deprecated.""" + if copy_function is None: + self.copy = lambda x:[y for y in + list(collections.collection_adapter(x))] + else: + self.copy = lambda x:copy_function(x) - return self.attr.get_history(self.obj).added_items + if typecallable is None: + typecallable = list + self.collection_factory = \ + collections._prepare_instrumentation(typecallable) + self.collection_interface = \ + util.duck_type_collection(self.collection_factory()) - def deleted_items(self): - """Deprecated.""" + def __set__(self, obj, value): + """Replace the current collection with a new one.""" - return self.attr.get_history(self.obj).deleted_items + setting_type = util.duck_type_collection(value) - def __iter__(self): - return iter(self.data) + if value is None or setting_type != self.collection_interface: + raise exceptions.ArgumentError( + "Incompatible collection type on assignment: %s is not %s-like" % + (type(value).__name__, self.collection_interface.__name__)) - def __repr__(self): - return repr(self.data) + if hasattr(value, '_sa_adapter'): + self.set(obj, list(getattr(value, '_sa_adapter')), None) + elif setting_type == dict: + self.set(obj, value.values(), None) + else: + self.set(obj, value, None) - def __getattr__(self, attr): - """Proxy unknown methods and attributes to the underlying - data array. This allows custom list classes to be used. - """ + def __delete__(self, obj): + if self.key not in obj.__dict__: + return - return getattr(self.data, attr) + obj._state['modified'] = True - def __setrecord(self, item, event=None): - self.attr.append_event(event, self.obj, item) - return True + collection = self._get_collection(obj) + collection.clear_with_event() + del obj.__dict__[self.key] - def __delrecord(self, item, event=None): - self.attr.remove_event(event, self.obj, item) - return True + def initialize(self, obj): + """Initialize this attribute on the given object instance with an empty collection.""" - def append_with_event(self, item, event): - self.__setrecord(item, event) - self._data_appender(item) + _, user_data = self._build_collection(obj) + obj.__dict__[self.key] = user_data + return user_data - def append_without_event(self, item): - self._data_appender(item) + def append(self, obj, value, initiator): + if initiator is self: + return + collection = self._get_collection(obj) + collection.append_with_event(value, initiator) - def remove_with_event(self, item, event): - self.__delrecord(item, event) - self.data.remove(item) + def remove(self, obj, value, initiator): + if initiator is self: + return + collection = self._get_collection(obj) + collection.remove_with_event(value, initiator) - def append(self, item, _mapper_nohistory=False): - """Fire off dependent events, and appends the given item to the underlying list. + def set(self, obj, value, initiator): + """Set a value on the given object. - `_mapper_nohistory` is a backwards compatibility hack; call - ``append_without_event`` instead. + `initiator` is the ``InstrumentedAttribute`` that initiated the + ``set()` operation and is used to control the depth of a circular + setter operation. """ - if _mapper_nohistory: - self.append_without_event(item) - else: - self.__setrecord(item) - self._data_appender(item) - - def __getitem__(self, i): - return self.data[i] - - def __setitem__(self, i, item): - if isinstance(i, slice): - self.__setslice__(i.start, i.stop, item) - else: - self.__setrecord(item) - self.data[i] = item - - def __delitem__(self, i): - if isinstance(i, slice): - self.__delslice__(i.start, i.stop) - else: - self.__delrecord(self.data[i], None) - del self.data[i] + if initiator is self: + return - def __lt__(self, other): return self.data < self.__cast(other) + state = obj._state + # if an instance-wide "trigger" was set, call that + if 'trigger' in state: + trig = state['trigger'] + del state['trigger'] + trig() - def __le__(self, other): return self.data <= self.__cast(other) + old = self.get(obj) + old_collection = self._get_collection(obj, old) - def __eq__(self, other): return self.data == self.__cast(other) + new_collection, user_data = self._build_collection(obj) + self._load_collection(obj, value or [], emit_events=True, + collection=new_collection) - def __ne__(self, other): return self.data != self.__cast(other) + obj.__dict__[self.key] = user_data + state['modified'] = True - def __gt__(self, other): return self.data > self.__cast(other) + # mark all the old elements as detached from the parent + if old_collection: + old_collection.clear_with_event() + old_collection.unlink(old) - def __ge__(self, other): return self.data >= self.__cast(other) + def set_committed_value(self, obj, value): + """Set an attribute value on the given instance and 'commit' it.""" + + state = obj._state + orig = state.get('original', None) - def __cast(self, other): - if isinstance(other, InstrumentedList): return other.data - else: return other + collection, user_data = self._build_collection(obj) + self._load_collection(obj, value or [], emit_events=False, + collection=collection) + value = user_data - def __cmp__(self, other): - return cmp(self.data, self.__cast(other)) + if orig is not None: + orig.commit_attribute(self, obj, value) + # remove per-instance callable, if any + state.pop(('callable', self), None) + obj.__dict__[self.key] = value + return value - def __contains__(self, item): return item in self.data + def _build_collection(self, obj): + user_data = self.collection_factory() + collection = collections.CollectionAdaptor(self, obj, user_data) + return collection, user_data - def __len__(self): + def _load_collection(self, obj, values, emit_events=True, collection=None): + collection = collection or self._get_collection(obj) + if values is None: + return + elif emit_events: + for item in values: + collection.append_with_event(item) + else: + for item in values: + collection.append_without_event(item) + + def _get_collection(self, obj, user_data=None): + if user_data is None: + user_data = self.get(obj) try: - return len(self.data) - except TypeError: - return len(list(self.data)) - - def __setslice__(self, i, j, other): - [self.__delrecord(x) for x in self.data[i:j]] - g = [a for a in list(other) if self.__setrecord(a)] - self.data[i:j] = g - - def __delslice__(self, i, j): - for a in self.data[i:j]: - self.__delrecord(a) - del self.data[i:j] - - def insert(self, i, item): - if self.__setrecord(item): - self.data.insert(i, item) - - def pop(self, i=-1): - item = self.data[i] - self.__delrecord(item) - return self.data.pop(i) - - def remove(self, item): - self.__delrecord(item) - self.data.remove(item) - - def discard(self, item): - if item in self.data: - self.__delrecord(item) - self.data.remove(item) - - def extend(self, item_list): - for item in item_list: - self.append(item) - - def __add__(self, other): - raise NotImplementedError() - - def __radd__(self, other): - raise NotImplementedError() - - def __iadd__(self, other): - raise NotImplementedError() - -class AttributeExtension(object): - """An abstract class which specifies `append`, `delete`, and `set` - event handlers to be attached to an object property. - """ + return getattr(user_data, '_sa_adapter') + except AttributeError: + collections.CollectionAdaptor(self, obj, user_data) + return getattr(user_data, '_sa_adapter') - def append(self, event, obj, child): - pass - def delete(self, event, obj, child): - pass - - def set(self, event, obj, child, oldchild): - pass - -class GenericBackrefExtension(AttributeExtension): +class GenericBackrefExtension(interfaces.AttributeExtension): """An extension which synchronizes a two-way relationship. A typical two-way relationship is a parent object containing a @@ -646,19 +455,19 @@ class GenericBackrefExtension(AttributeExtension): def __init__(self, key): self.key = key - def set(self, event, obj, child, oldchild): + def set(self, obj, child, oldchild, initiator): if oldchild is child: return if oldchild is not None: - getattr(oldchild.__class__, self.key).remove(event, oldchild, obj) + getattr(oldchild.__class__, self.key).remove(oldchild, obj, initiator) if child is not None: - getattr(child.__class__, self.key).append(event, child, obj) + getattr(child.__class__, self.key).append(child, obj, initiator) - def append(self, event, obj, child): - getattr(child.__class__, self.key).append(event, child, obj) + def append(self, obj, child, initiator): + getattr(child.__class__, self.key).append(child, obj, initiator) - def delete(self, event, obj, child): - getattr(child.__class__, self.key).remove(event, child, obj) + def remove(self, obj, child, initiator): + getattr(child.__class__, self.key).remove(child, obj, initiator) class CommittedState(object): """Store the original state of an object when the ``commit()` @@ -697,10 +506,13 @@ class CommittedState(object): def rollback(self, manager, obj): for attr in manager.managed_attributes(obj.__class__): if self.data.has_key(attr.key): - if attr.uselist: - obj.__dict__[attr.key][:] = self.data[attr.key] - else: + if not isinstance(attr, InstrumentedCollectionAttribute): obj.__dict__[attr.key] = self.data[attr.key] + else: + collection = attr._get_collection(obj) + collection.clear_without_event() + for item in self.data[attr.key]: + collection.append_without_event(item) else: del obj.__dict__[attr.key] @@ -725,17 +537,15 @@ class AttributeHistory(object): else: original = None - if attr.uselist: + if isinstance(attr, InstrumentedCollectionAttribute): self._current = current - else: - self._current = [current] - if attr.uselist: s = util.Set(original or []) self._added_items = [] self._unchanged_items = [] self._deleted_items = [] if current: - for a in current: + collection = attr._get_collection(obj, current) + for a in collection: if a in s: self._unchanged_items.append(a) else: @@ -744,6 +554,7 @@ class AttributeHistory(object): if a not in self._unchanged_items: self._deleted_items.append(a) else: + self._current = [current] if attr.is_equal(current, original): self._unchanged_items = [current] self._added_items = [] @@ -755,7 +566,6 @@ class AttributeHistory(object): else: self._deleted_items = [] self._unchanged_items = [] - #print "key", attr.key, "orig", original, "current", current, "added", self._added_items, "unchanged", self._unchanged_items, "deleted", self._deleted_items def __iter__(self): return iter(self._current) @@ -773,8 +583,7 @@ class AttributeHistory(object): return self._deleted_items def hasparent(self, obj): - """Deprecated. This should be called directly from the - appropriate ``InstrumentedAttribute`` object. + """Deprecated. This should be called directly from the appropriate ``InstrumentedAttribute`` object. """ return self.attr.hasparent(obj) @@ -834,7 +643,7 @@ class AttributeManager(object): o._state['modified'] = False def managed_attributes(self, class_): - """Return an iterator of all ``InstrumentedAttribute`` objects + """Return a list of all ``InstrumentedAttribute`` objects associated with the given class. """ @@ -885,7 +694,7 @@ class AttributeManager(object): """Return an attribute of the given name from the given object. If the attribute is a scalar, return it as a single-item list, - otherwise return the list based attribute. + otherwise return a collection based attribute. If the attribute's value is to be produced by an unexecuted callable, the callable will only be executed if the given @@ -894,10 +703,10 @@ class AttributeManager(object): attr = getattr(obj.__class__, key) x = attr.get(obj, passive=passive) - if x is InstrumentedAttribute.PASSIVE_NORESULT: + if x is PASSIVE_NORESULT: return [] - elif attr.uselist: - return x + elif isinstance(attr, InstrumentedCollectionAttribute): + return list(attr._get_collection(obj, x)) else: return [x] @@ -953,10 +762,9 @@ class AttributeManager(object): """Return True if the given `key` correponds to an instrumented property on the given class. """ - return hasattr(class_, key) and isinstance(getattr(class_, key), InstrumentedAttribute) - def init_instance_attribute(self, obj, key, uselist, callable_=None, **kwargs): + def init_instance_attribute(self, obj, key, callable_=None): """Initialize an attribute on an instance to either a blank value, cancelling out any class- or instance-level callables that were present, or if a `callable` is supplied set the @@ -971,7 +779,24 @@ class AttributeManager(object): events back to this ``AttributeManager``. """ - return InstrumentedAttribute(self, key, uselist, callable_, typecallable, **kwargs) + if uselist: + return InstrumentedCollectionAttribute(class_, self, key, + callable_, + typecallable, + **kwargs) + else: + return InstrumentedScalarAttribute(class_, self, key, callable_, + **kwargs) + + def get_attribute(self, obj_or_cls, key): + """Register an attribute at the class level to be instrumented + for all instances of the class. + """ + + if isinstance(obj_or_cls, type): + return getattr(obj_or_cls, key) + else: + return getattr(obj_or_cls.__class__, key) def register_attribute(self, class_, key, uselist, callable_=None, **kwargs): """Register an attribute at the class level to be instrumented @@ -980,10 +805,9 @@ class AttributeManager(object): # firt invalidate the cache for the given class # (will be reconstituted as needed, while getting managed attributes) - self._inherited_attribute_cache.pop(class_,None) - self._noninherited_attribute_cache.pop(class_,None) + self._inherited_attribute_cache.pop(class_, None) + self._noninherited_attribute_cache.pop(class_, None) - #print self, "register attribute", key, "for class", class_ if not hasattr(class_, '_state'): def _get_state(self): if not hasattr(self, '_sa_attr_state'): @@ -994,4 +818,12 @@ class AttributeManager(object): typecallable = kwargs.pop('typecallable', None) if isinstance(typecallable, InstrumentedAttribute): typecallable = None - setattr(class_, key, self.create_prop(class_, key, uselist, callable_, typecallable=typecallable, **kwargs)) + setattr(class_, key, self.create_prop(class_, key, uselist, callable_, + typecallable=typecallable, **kwargs)) + + def init_collection(self, instance, key): + """Initialize a collection attribute and return the collection adapter.""" + + attr = self.get_attribute(instance, key) + user_data = attr.initialize(instance) + return attr._get_collection(instance, user_data) diff --git a/lib/sqlalchemy/orm/collections.py b/lib/sqlalchemy/orm/collections.py new file mode 100644 index 0000000000..68a16fa0f2 --- /dev/null +++ b/lib/sqlalchemy/orm/collections.py @@ -0,0 +1,1146 @@ +"""Support for attributes that hold collections of objects. + +Mapping a one-to-many or many-to-many relationship results in a collection +of values accessible through an attribute on the class, by default presented +as a list: + + mapper(Parent, properties={ + children = relation(Child) + }) + + parent.children.append(Child()) + print parent.children[0] + +These attributes are not limited to lists- any Python object that can implement +a bag-like interface can be used in the place of a basic list. Custom +collection classes are specified with the collection_class option to relation(), +and instances (at a minimum) must only be able to append, remove and iterate +over objects in the collection. + + mapper(Parent, properties={ + children = relation(Child, collection_class=set) + }) + + child = Child() + parent.children.add(child) + assert child in parent.children + +The collection is watched by the orm, which notes all objects entering and +leaving the collection much in the same way that it watches regular scalar +attributes for changes, setting up backrefs, parents, etc. + +Both 'list' and 'set' can be used directly as a collection_class. Dictionaries +can be used for mapping semantics too, but a little more work is needed to +support the required "value-only" interface the orm needs to add instances +to the collection. The 'column_mapped_collection' is a dict subclass that +uses a column from the member object as the key: + + mapper(Item, properties={ + notes = relation(Note, + collection_class=column_mapped_collection(kw_table.c.keyword)) + }) + + item.notes['color'] = Note('color', 'blue') + print item.notes['color'] + +You can create your own collection classes too. In the simple case, +simply inherit from 'list' or 'set' and add the custom behavior. All of the +basic collection operations are instrumented for you via transparent function +decoration, so a call to, say, 'MyList.pop()' will notify the orm that the +returned object should be deleted. + +Automatic instrumentation isn't restricted to subclasses of built-in types. +The collection package understands the abstract base types of the three +primary collection types and can apply the appropriate instrumentation based +on the duck-typing of your class: + + class ListLike(object): + def __init__(self): + self.data = [] + def append(self, item): + self.data.append(item) + def remove(self, item): + self.data.remove(item) + def extend(self, items): + self.data.extend(items) + def __iter__(self): + return iter(self.data) + def foo(self): + return 'foo' + +'append', 'remove', and 'extend' are known list-like methods, and will be +instrumented automatically. '__iter__' is not a mutator method and won't +be touched, and 'foo' won't be either. + +Duck-typing (aka guesswork) of object-derived classes isn't rock-solid, of +course, so you can be explicit about the interface you are implementing with +the '__emulates__' class attribute: + + class DictLike(object): + __emulates__ = dict + + def __init__(self): + self.data = {} + def append(self, item): + self.data[item.keyword] = item + def remove(self, item): + del self[item.keyword] + def __setitem__(self, key, value): + self.data[key] = value + def __delitem__(self, key): + del self.data[key] + def values(self): + return self.data.itervalues() + +The class looks list-like because of 'append', but __emulates__ forces it to +dict-like. '__setitem__' and '__delitem__' are known to be dict-like and are +instrumented. This class won't quite work as-is yet- a little glue is needed +to adapt it for use by SQLAlchemy- the basic interface of 'append', 'remove' +and 'iterate' needs to be mapped onto the class. A set of decorators is +provided for this. + + from collections import collection + + class DictLike(object): + __emulates__ = dict + + def __init__(self): + self.data = {} + + @collection.appender + def append(self, item): + self.data[item.keyword] = item + + @collection.remover + def remove(self, item): + del self[item.keyword] + + def __setitem__(self, key, value): + self.data[key] = value + def __delitem__(self, key): + del self.data[key] + + @collection.iterator + def values(self): + return self.data.itervalues() + +And that's all that's needed. The SQLAlchemy orm will interact with your +dict-like class through the methods you've tagged. Both 'list' and 'set' +have SQLAlchemy-compatible methods in their base interface and don't need +to be annotated if you have the basic methods in your implementation. If +you don't, or you want to direct through a different method, you can +decorate: + + from collections import collection + + class MyList(list): + @collection.appender + def hey_use_this_instead_for_append(self, item): + # do something special ... + +There is no requirement to be list-, set- or dict-like at all. Collection +classes can be any shape, so long as they have the append, remove and iterate +interface marked for SQLAlchemy's use. + +You can add instrumentation to methods outside of the basic collection +interface as well. Decorators are supplied that can wrap your methods +and fire off SQLAlchemy events based on the arguments passed to the method +and/or the method's return value. + + class MyCollection(object): + ... # need append, remove, and iterate, exercise to the reader + + @collection.adds(2) + def insert(self, where, item): + ... + + @collection.removes_return() + def prune(self, where): + ... + +Tight control over events is also possible by implementing the instrumentation +internally in your methods. The basic instrumentation package works under the +general assumption that collection mutation events will not raise exceptions. +If you want tight control over add and remove events with exception management, +internal instrumentation may be the answer. Within your method, +'collection_adapter(self)' will retrieve an object that you can use for +explicit control over triggering append and remove events. + +There are some caveats: + +A collection class will be modified behind the scenes- decorators will be +applied around methods. Built-ins can't (and shouldn't) be modified, so +a request for, say, a 'list' will actually net an 'InstrumentedList' instance +on the property- a trivial subclass that holds and isolates decorations rather +than interfere with all 'list' instances in the process. + +The decorations are light-weight and no-op outside of their intended context, +but they are unavoidable and will always be applied. When using a library +class as a collection, it can be good practice to use the "trivial subclass" +trick to restrict the decorations to just your usage in mapping. For example: + + class MyAwesomeList(some.great.library.AwesomeList): + pass + + # ... relation(..., collection_class=MyAwesomeList) + +In custom classes, keep in mind that you can fire duplicate events if you +delegate one instrumented method to another. When subclassing a built-in +type, the instrumentation is implicit on mutator methods so you'll need to +be mindful. +""" + +from sqlalchemy import exceptions, schema, util +from sqlalchemy.orm import mapper +import copy, sys, warnings, weakref +import new + +try: + from threading import Lock +except: + from dummy_threading import Lock + + +__all__ = ['collection', 'mapped_collection', 'column_mapped_collection', + 'collection_adapter'] + +def column_mapped_collection(mapping_spec): + """A dictionary-based collection type with column-based keying. + + Returns a MappedCollection factory with a keying function generated + from mapping_spec, which may be a Column or a sequence of Columns. + + The key value must be immutable for the lifetime of the object. You + can not, for example, map on foreign key values if those key values will + change during the session, say from None to a database-assigned integer + after a session flush. + """ + + if isinstance(mapping_spec, schema.Column): + mapping_spec = mapping_spec, + else: + cols = [] + for c in mapping_spec: + if not isinstance(c, schema.Column): + raise exceptions.ArgumentError( + "mapping_spec tuple may only contain columns") + cols.append(c) + mapping_spec = tuple(cols) + + def keyfunc(value): + m = mapper.object_mapper(value) + return tuple([m.get_attr_by_column(value, c) for c in mapping_spec]) + return lambda: MappedCollection(keyfunc) + + +def mapped_collection(keyfunc): + """A dictionary-based collection type with arbitrary keying. + + Returns a MappedCollection factory with a keying function generated + from keyfunc, a callable that takes an object and returns a key value. + + The key value must be immutable for the lifetime of the object. You + can not, for example, map on foreign key values if those key values will + change during the session, say from None to a database-assigned integer + after a session flush. + """ + + return lambda: MappedCollection(keyfunc) + +class collection(object): + """Decorators for custom collection classes. + + The decorators fall into two groups: annotations and interception recipes. + + The annotating decorators (appender, remover, iterator, + internally_instrumented) indicate the method's purpose and take no + arguments. They are not written with parens: + + @collection.appender + def append(self, append): ... + + The recipe decorators all require parens, even those that take no + arguments: + + @collection.adds('entity'): + def insert(self, position, entity): ... + + @collection.removes_return() + def popitem(self): ... + + Decorators can be specified in long-hand for Python 2.3, or with + the class-level dict attribute '__instrumentation__'- see the source + for details. + """ + + # Bundled as a class solely for ease of use: packaging, doc strings, + # importability. + + def appender(cls, fn): + """Tag the method as the collection appender. + + The appender method is called with one positional argument: the value + to append. The method will be automatically decorated with 'adds(1)' + if not already decorated. + + @collection.appender + def add(self, append): ... + + # or, equivalently + @collection.appender + @collection.adds(1) + def add(self, append): ... + + # for mapping type, an 'append' may kick out a previous value + # that occupies that slot. consider d['a'] = 'foo'- any previous + # value in d['a'] is discarded. + @collection.appender + @collection.replaces(1) + def add(self, entity): + key = some_key_func(entity) + previous = None + if key in self: + previous = self[key] + self[key] = entity + return previous + + If the value to append is not allowed in the collection, you may + raise an exception. Something to remember is that the appender + will be called for each object mapped by a database query. If the + database contains rows that violate your collection semantics, you + will need to get creative to fix the problem, as access via the + collection will not work. + + If the appender method is internally instrumented, you must also + receive the keyword argument '_sa_initiator' and ensure its + promulgation to collection events. + """ + + setattr(fn, '_sa_instrument_role', 'appender') + return fn + appender = classmethod(appender) + + def remover(cls, fn): + """Tag the method as the collection remover. + + The remover method is called with one positional argument: the value + to remove. The method will be automatically decorated with + 'removes_return()' if not already decorated. + + @collection.remover + def zap(self, entity): ... + + # or, equivalently + @collection.remover + @collection.removes_return() + def zap(self, ): ... + + If the value to remove is not present in the collection, you may + raise an exception or return None to ignore the error. + + If the remove method is internally instrumented, you must also + receive the keyword argument '_sa_initiator' and ensure its + promulgation to collection events. + """ + + setattr(fn, '_sa_instrument_role', 'remover') + return fn + remover = classmethod(remover) + + def iterator(cls, fn): + """Tag the method as the collection remover. + + The iterator method is called with no arguments. It is expected to + return an iterator over all collection members. + + @collection.iterator + def __iter__(self): ... + """ + + setattr(fn, '_sa_instrument_role', 'iterator') + return fn + iterator = classmethod(iterator) + + def internally_instrumented(cls, fn): + """Tag the method as instrumented. + + This tag will prevent any decoration from being applied to the method. + Use this if you are orchestrating your own calls to collection_adapter + in one of the basic SQLAlchemy interface methods, or to prevent + an automatic ABC method decoration from wrapping your implementation. + + # normally an 'extend' method on a list-like class would be + # automatically intercepted and re-implemented in terms of + # SQLAlchemy events and append(). your implementation will + # never be called, unless: + @collection.internally_instrumented + def extend(self, items): ... + """ + + setattr(fn, '_sa_instrumented', True) + return fn + internally_instrumented = classmethod(internally_instrumented) + + def on_link(cls, fn): + """Tag the method as a the "linked to attribute" event handler. + + This optional event handler will be called when the collection class + is linked to or unlinked from the InstrumentedAttribute. It is + invoked immediately after the '_sa_adapter' property is set on + the instance. A single argument is passed: the collection adapter + that has been linked, or None if unlinking. + """ + + setattr(fn, '_sa_instrument_role', 'on_link') + return fn + on_link = classmethod(on_link) + + def adds(cls, arg): + """Mark the method as adding an entity to the collection. + + Adds "add to collection" handling to the method. The decorator argument + indicates which method argument holds the SQLAlchemy-relevant value. + Arguments can be specified positionally (i.e. integer) or by name. + + @collection.adds(1) + def push(self, item): ... + + @collection.adds('entity') + def do_stuff(self, thing, entity=None): ... + """ + + def decorator(fn): + setattr(fn, '_sa_instrument_before', ('fire_append_event', arg)) + return fn + return decorator + adds = classmethod(adds) + + def replaces(cls, arg): + """Mark the method as replacing an entity in the collection. + + Adds "add to collection" and "remove from collection" handling to + the method. The decorator argument indicates which method argument + holds the SQLAlchemy-relevant value to be added, and return value, if + any will be considered the value to remove. + + Arguments can be specified positionally (i.e. integer) or by name. + + @collection.replaces(2) + def __setitem__(self, index, item): ... + """ + + def decorator(fn): + setattr(fn, '_sa_instrument_before', ('fire_append_event', arg)) + setattr(fn, '_sa_instrument_after', 'fire_remove_event') + return fn + return decorator + replaces = classmethod(replaces) + + def removes(cls, arg): + """Mark the method as removing an entity in the collection. + + Adds "remove from collection" handling to the method. The decorator + argument indicates which method argument holds the SQLAlchemy-relevant + value to be removed. Arguments can be specified positionally (i.e. + integer) or by name. + + @collection.removes(1) + def zap(self, item): ... + + For methods where the value to remove is not known at call-time, use + collection.removes_return. + """ + + def decorator(fn): + setattr(fn, '_sa_instrument_before', ('fire_remove_event', arg)) + return fn + return decorator + removes = classmethod(removes) + + def removes_return(cls): + """Mark the method as removing an entity in the collection. + + Adds "remove from collection" handling to the method. The return value + of the method, if any, is considered the value to remove. The method + arguments are not inspected. + + @collection.removes_return() + def pop(self): ... + + For methods where the value to remove is known at call-time, use + collection.remove. + """ + + def decorator(fn): + setattr(fn, '_sa_instrument_after', 'fire_remove_event') + return fn + return decorator + removes_return = classmethod(removes_return) + + +# public instrumentation interface for 'internally instrumented' +# implementations +def collection_adapter(collection): + return getattr(collection, '_sa_adapter', None) + +class CollectionAdaptor(object): + """Bridges between the orm and arbitrary Python collections. + + Proxies base-level collection operations (append, remove, iterate) + to the underlying Python collection, and emits add/remove events for + entities entering or leaving the collection. + """ + + def __init__(self, attr, owner, data): + self.attr = attr + self._owner = weakref.ref(owner) + self._data = weakref.ref(data) + self.link_to_self(data) + + owner = property(lambda s: s._owner()) + data = property(lambda s: s._data()) + + def link_to_self(self, data): + setattr(data, '_sa_adapter', self) + if hasattr(data, '_sa_on_link'): + getattr(data, '_sa_on_link')(self) + + def unlink(self, data): + setattr(data, '_sa_adapter', None) + if hasattr(data, '_sa_on_link'): + getattr(data, '_sa_on_link')(None) + + def append_with_event(self, item, initiator=None): + getattr(self.data, '_sa_appender')(item, _sa_initiator=initiator) + + def append_without_event(self, item): + getattr(self.data, '_sa_appender')(item, _sa_initiator=False) + + def remove_with_event(self, item, initiator=None): + getattr(self.data, '_sa_remover')(item, _sa_initiator=initiator) + + def remove_without_event(self, item): + getattr(self.data, '_sa_remover')(item, _sa_initiator=False) + + def clear_with_event(self, initiator=None): + for item in list(self): + self.remove_with_event(item, initiator) + + def clear_without_event(self): + for item in list(self): + self.remove_without_event(item) + + def __iter__(self): + return getattr(self.data, '_sa_iterator')() + + def __len__(self): + return len(list(getattr(self.data, '_sa_iterator')())) + + def __nonzero__(self): + return True + + def fire_append_event(self, item, event=None): + if event is not False: + self.attr.fire_append_event(self.owner, item, event) + + def fire_remove_event(self, item, event=None): + if event is not False: + self.attr.fire_remove_event(self.owner, item, event) + + def __getstate__(self): + return { 'key':self.attr.key, + 'owner': self.owner, + 'data': self.data } + + def __setstate__(self, d): + self.attr = getattr(d['owner'].__class__, d['key']) + self._owner = weakref.ref(d['owner']) + self._data = weakref.ref(d['data']) + + +__instrumentation_mutex = Lock() +def _prepare_instrumentation(factory): + """Prepare a callable for future use as a collection class factory. + + Given a collection class factory (either a type or no-arg callable), + return another factory that will produce compatible instances when + called. + + This function is responsible for converting collection_class=list + into the run-time behavior of collection_class=InstrumentedList. + """ + + # Convert a builtin to 'Instrumented*' + if factory in __canned_instrumentation: + factory = __canned_instrumentation[factory] + + # Create a specimen + cls = type(factory()) + + # Did factory callable return a builtin? + if cls in __canned_instrumentation: + # Wrap it so that it returns our 'Instrumented*' + factory = __converting_factory(factory) + cls = factory() + + # Instrument the class if needed. + if __instrumentation_mutex.acquire(): + try: + if not hasattr(cls, '_sa_appender'): + _instrument_class(cls) + finally: + __instrumentation_mutex.release() + + return factory + +def __converting_factory(original_factory): + """Convert the type returned by collection factories on the fly. + + Given a collection factory that returns a builtin type (e.g. a list), + return a wrapped function that converts that type to one of our + instrumented types. + """ + + def wrapper(): + collection = original_factory() + type_ = type(collection) + if type_ in __canned_instrumentation: + # return an instrumented type initialized from the factory's + # collection + return __canned_instrumentation[type_](collection) + else: + raise exceptions.InvalidRequestError( + "Collection class factories must produce instances of a " + "single class.") + try: + # often flawed but better than nothing + wrapper.__name__ = "%sWrapper" % original_factory.__name__ + wrapper.__doc__ = original_factory.__doc__ + except: + pass + return wrapper + +def _instrument_class(cls): + # FIXME: more formally document this as a decoratorless/Python 2.3 + # option for specifying instrumentation. (likely doc'd here in code only, + # not in online docs.) + # + # __instrumentation__ = { + # 'rolename': 'methodname', # ... + # 'methods': { + # 'methodname': ('fire_{append,remove}_event', argspec, + # 'fire_{append,remove}_event'), + # 'append': ('fire_append_event', 1, None), + # '__setitem__': ('fire_append_event', 1, 'fire_remove_event'), + # 'pop': (None, None, 'fire_remove_event'), + # } + # } + + # In the normal call flow, a request for any of the 3 basic collection + # types is transformed into one of our trivial subclasses + # (e.g. InstrumentedList). Catch anything else that sneaks in here... + if cls.__module__ == '__builtin__': + raise exceptions.ArgumentError( + "Can not instrument a built-in type. Use a " + "subclass, even a trivial one.") + + collection_type = util.duck_type_collection(cls) + if collection_type in __interfaces: + roles = __interfaces[collection_type].copy() + decorators = roles.pop('_decorators', {}) + else: + roles, decorators = {}, {} + + if hasattr(cls, '__instrumentation__'): + roles.update(copy.deepcopy(getattr(cls, '__instrumentation__'))) + + methods = roles.pop('methods', {}) + + for name in dir(cls): + method = getattr(cls, name) + if not callable(method): + continue + + # note role declarations + if hasattr(method, '_sa_instrument_role'): + role = method._sa_instrument_role + assert role in ('appender', 'remover', 'iterator', 'on_link') + roles[role] = name + + # transfer instrumentation requests from decorated function + # to the combined queue + before, after = None, None + if hasattr(method, '_sa_instrument_before'): + op, argument = method._sa_instrument_before + assert op in ('fire_append_event', 'fire_remove_event') + before = op, argument + if hasattr(method, '_sa_instrument_after'): + op = method._sa_instrument_after + assert op in ('fire_append_event', 'fire_remove_event') + after = op + if before or after: + methods[name] = before[0], before[1], after + + # apply ABC auto-decoration to methods that need it + for method, decorator in decorators.items(): + fn = getattr(cls, method, None) + if fn and method not in methods and not hasattr(fn, '_sa_instrumented'): + setattr(cls, method, decorator(fn)) + + # ensure all roles are present, and apply implicit instrumentation if + # needed + if 'appender' not in roles or not hasattr(cls, roles['appender']): + raise exceptions.ArgumentError( + "Type %s must elect an appender method to be " + "a collection class" % cls.__name__) + elif (roles['appender'] not in methods and + not hasattr(getattr(cls, roles['appender']), '_sa_instrumented')): + methods[roles['appender']] = ('fire_append_event', 1, None) + + if 'remover' not in roles or not hasattr(cls, roles['remover']): + raise exceptions.ArgumentError( + "Type %s must elect a remover method to be " + "a collection class" % cls.__name__) + elif (roles['remover'] not in methods and + not hasattr(getattr(cls, roles['remover']), '_sa_instrumented')): + methods[roles['remover']] = ('fire_remove_event', 1, None) + + if 'iterator' not in roles or not hasattr(cls, roles['iterator']): + raise exceptions.ArgumentError( + "Type %s must elect an iterator method to be " + "a collection class" % cls.__name__) + + # apply ad-hoc instrumentation from decorators, class-level defaults + # and implicit role declarations + for method, (before, argument, after) in methods.items(): + setattr(cls, method, + _instrument_membership_mutator(getattr(cls, method), + before, argument, after)) + # intern the role map + for role, method in roles.items(): + setattr(cls, '_sa_%s' % role, getattr(cls, method)) + + +def _instrument_membership_mutator(method, before, argument, after): + """Route method args and/or return value through the collection adapter.""" + + if type(argument) is int: + def wrapper(*args, **kw): + if before and len(args) < argument: + raise exceptions.ArgumentError( + 'Missing argument %i' % argument) + initiator = kw.pop('_sa_initiator', None) + if initiator is False: + executor = None + else: + executor = getattr(args[0], '_sa_adapter', None) + + if before and executor: + getattr(executor, before)(args[argument], initiator) + + if not after or not executor: + return method(*args, **kw) + else: + res = method(*args, **kw) + if res is not None: + getattr(executor, after)(res, initiator) + return res + else: + def wrapper(*args, **kw): + if before: + vals = inspect.getargvalues(inspect.currentframe()) + if argument in kw: + value = kw[argument] + else: + positional = inspect.getargspec(method)[0] + pos = positional.index(argument) + if pos == -1: + raise exceptions.ArgumentError('Missing argument %s' % + argument) + else: + value = args[pos] + + initiator = kw.pop('_sa_initiator', None) + if initiator is False: + executor = None + else: + executor = getattr(args[0], '_sa_adapter', None) + + if before and executor: + getattr(executor, op)(value, initiator) + + if not after or not executor: + return method(*args, **kw) + else: + res = method(*args, **kw) + if res is not None: + getattr(executor, after)(res, initiator) + return res + try: + wrapper._sa_instrumented = True + wrapper.__name__ = method.__name__ + wrapper.__doc__ = method.__doc__ + except: + pass + return wrapper + +def __set(collection, item, _sa_initiator=None): + """Run set events, may eventually be inlined into decorators.""" + + if _sa_initiator is not False: + executor = getattr(collection, '_sa_adapter', None) + if executor: + getattr(executor, 'fire_append_event')(item, _sa_initiator) + +def __del(collection, item, _sa_initiator=None): + """Run del events, may eventually be inlined into decorators.""" + + if _sa_initiator is not False: + executor = getattr(collection, '_sa_adapter', None) + if executor: + getattr(executor, 'fire_remove_event')(item, _sa_initiator) + +def _list_decorators(): + """Hand-turned instrumentation wrappers that can decorate any list-like + class.""" + + def _tidy(fn): + try: + setattr(fn, '_sa_instrumented', True) + fn.__doc__ = getattr(getattr(list, fn.__name__), '__doc__') + except: + raise + + def append(fn): + def append(self, item, _sa_initiator=None): + __set(self, item, _sa_initiator) + fn(self, item) + _tidy(append) + return append + + def remove(fn): + def remove(self, value, _sa_initiator=None): + fn(self, value) + __del(self, value, _sa_initiator) + _tidy(remove) + return remove + + def __setitem__(fn): + def __setitem__(self, index, value): + if not isinstance(index, slice): + __set(self, value) + fn(self, index, value) + else: + rng = range(slice.start or 0, slice.stop or 0, slice.step or 1) + if len(value) != len(rng): + raise ValueError + for i in rng: + __set(self, value[i]) + fn(self, i, value[i]) + _tidy(__setitem__) + return __setitem__ + + def __delitem__(fn): + def __delitem__(self, index): + item = self[index] + __del(self, item) + fn(self, index) + _tidy(__delitem__) + return __delitem__ + + def __setslice__(fn): + def __setslice__(self, start, end, values): + for value in self[start:end]: + __del(self, value) + for value in values: + __set(self, value) + fn(self, start, end, values) + _tidy(__setslice__) + return __setslice__ + + def __delslice__(fn): + def __delslice__(self, start, end): + for value in self[start:end]: + __del(self, value) + fn(self, start, end) + _tidy(__delslice__) + return __delslice__ + + def extend(fn): + def extend(self, iterable): + for value in iterable: + self.append(value) + _tidy(extend) + return extend + + def pop(fn): + def pop(self, index=-1): + item = fn(self, index) + __del(self, item) + return item + _tidy(pop) + return pop + + l = locals().copy() + l.pop('_tidy') + return l + +def _dict_decorators(): + """Hand-turned instrumentation wrappers that can decorate any dict-like + mapping class.""" + + def _tidy(fn): + try: + setattr(fn, '_sa_instrumented', True) + fn.__doc__ = getattr(getattr(dict, fn.__name__), '__doc__') + except: + raise + + Unspecified=object() + + def __setitem__(fn): + def __setitem__(self, key, value, _sa_initiator=None): + if key in self: + __del(self, self[key], _sa_initiator) + __set(self, value) + fn(self, key, value) + _tidy(__setitem__) + return __setitem__ + + def __delitem__(fn): + def __delitem__(self, key, _sa_initiator=None): + if key in self: + __del(self, self[key], _sa_initiator) + fn(self, key) + _tidy(__delitem__) + return __delitem__ + + def clear(fn): + def clear(self): + for key in self: + __del(self, self[key]) + fn(self) + _tidy(clear) + return clear + + def pop(fn): + def pop(self, key, default=Unspecified): + if key in self: + __del(self, self[key]) + if default is Unspecified: + return fn(self, key) + else: + return fn(self, key, default) + _tidy(pop) + return pop + + def popitem(fn): + def popitem(self): + item = fn(self) + __del(self, item[1]) + return item + _tidy(popitem) + return popitem + + def setdefault(fn): + def setdefault(self, key, default=None): + if key not in self and default is not None: + __set(self, default) + return fn(self, key, default) + _tidy(setdefault) + return setdefault + + if sys.version_info < (2, 4): + def update(fn): + def update(self, other): + for key in other.keys(): + self[key] = other[key] + _tidy(update) + return update + else: + def update(fn): + def update(self, other=Unspecified, **kw): + if other is not Unspecified: + if hasattr(other, 'keys'): + for key in other.keys(): + self[key] = other[key] + else: + for key, value in other: + self[key] = value + for key in kw: + self[key] = kw[key] + _tidy(update) + return update + + l = locals().copy() + l.pop('_tidy') + l.pop('Unspecified') + return l + +def _set_decorators(): + """Hand-turned instrumentation wrappers that can decorate any set-like + sequence class.""" + + def _tidy(fn): + try: + setattr(fn, '_sa_instrumented', True) + fn.__doc__ = getattr(getattr(set, fn.__name__), '__doc__') + except: + raise + + Unspecified=object() + + def add(fn): + def add(self, value, _sa_initiator=None): + __set(self, value, _sa_initiator) + fn(self, value) + _tidy(add) + return add + + def discard(fn): + def discard(self, value, _sa_initiator=None): + if value in self: + __del(self, value, _sa_initiator) + fn(self, value) + _tidy(discard) + return discard + + def remove(fn): + def remove(self, value, _sa_initiator=None): + if value in self: + __del(self, value, _sa_initiator) + fn(self, value) + _tidy(remove) + return remove + + def pop(fn): + def pop(self): + item = fn(self) + __del(self, item) + return item + _tidy(pop) + return pop + + def update(fn): + def update(self, value): + for item in value: + if item not in self: + self.add(item) + _tidy(update) + return update + __ior__ = update + + def difference_update(fn): + def difference_update(self, value): + for item in value: + self.discard(item) + _tidy(difference_update) + return difference_update + __isub__ = difference_update + + def intersection_update(fn): + def intersection_update(self, other): + want, have = self.intersection(other), util.Set(self) + remove, add = have - want, want - have + + for item in remove: + self.remove(item) + for item in add: + self.add(item) + _tidy(intersection_update) + return intersection_update + __iand__ = intersection_update + + def symmetric_difference_update(fn): + def symmetric_difference_update(self, other): + want, have = self.symmetric_difference(other), util.Set(self) + remove, add = have - want, want - have + + for item in remove: + self.remove(item) + for item in add: + self.add(item) + _tidy(symmetric_difference_update) + return symmetric_difference_update + __ixor__ = symmetric_difference_update + + l = locals().copy() + l.pop('_tidy') + l.pop('Unspecified') + return l + + +class InstrumentedList(list): + __instrumentation__ = { + 'appender': 'append', + 'remover': 'remove', + 'iterator': '__iter__', } + +class InstrumentedSet(util.Set): + __instrumentation__ = { + 'appender': 'add', + 'remover': 'remove', + 'iterator': '__iter__', } + +class InstrumentedDict(dict): + __instrumentation__ = { + 'iterator': 'itervalues', } + +__canned_instrumentation = { + list: InstrumentedList, + util.Set: InstrumentedSet, + dict: InstrumentedDict, + } + +__interfaces = { + list: { 'appender': 'append', + 'remover': 'remove', + 'iterator': '__iter__', + '_decorators': _list_decorators(), }, + util.Set: { 'appender': 'add', + 'remover': 'remove', + 'iterator': '__iter__', + '_decorators': _set_decorators(), }, + # < 0.4 compatible naming (almost), deprecated- use decorators instead. + dict: { 'appender': 'append', + 'remover': 'remove', + 'iterator': 'itervalues', + '_decorators': _dict_decorators(), }, + # < 0.4 compatible naming, deprecated- use decorators instead. + None: { 'appender': 'append', + 'remover': 'remove', + 'iterator': 'values', } + } + + +class MappedCollection(dict): + """A basic dictionary-based collection class. + + Extends dict with the minimal bag semantics that collection classes require. + "append" and "remove" are implemented in terms of a keying function: any + callable that takes an object and returns an object for use as a dictionary + key. + """ + + def __init__(self, keyfunc): + self.keyfunc = keyfunc + + def append(self, value, _sa_initiator=None): + key = self.keyfunc(value) + self.__setitem__(key, value, _sa_initiator) + append = collection.internally_instrumented(append) + append = collection.appender(append) + + def remove(self, value, _sa_initiator=None): + key = self.keyfunc(value) + # Let self[key] raise if key is not in this collection + if self[key] != value: + raise exceptions.InvalidRequestError( + "Can not remove '%s': collection holds '%s' for key '%s'. " + "Possible cause: is the MappedCollection key function " + "based on mutable properties or properties that only obtain " + "values after flush?" % + (value, self[key], key)) + self.__delitem__(self, key, _sa_initiator) + remove = collection.internally_instrumented(remove) + remove = collection.remover(remove) diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index 29695fb1f4..105a070e6d 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -475,6 +475,22 @@ class PropertyOption(MapperOption): PropertyOption.logger = logging.class_logger(PropertyOption) + +class AttributeExtension(object): + """An abstract class which specifies `append`, `delete`, and `set` + event handlers to be attached to an object property. + """ + + def append(self, obj, child, initiator): + pass + + def remove(self, obj, child, initiator): + pass + + def set(self, obj, child, oldchild, initiator): + pass + + class StrategizedOption(PropertyOption): """A MapperOption that affects which LoaderStrategy will be used for an operation by a StrategizedProperty. diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 78830b7409..c183556766 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -129,12 +129,13 @@ class PropertyLoader(StrategizedProperty): if childlist is None: return if self.uselist: - # sets a blank list according to the correct list class - dest_list = getattr(self.parent.class_, self.key).initialize(dest) + # sets a blank collection according to the correct list class + dest_list = sessionlib.attribute_manager.init_collection(dest, self.key) for current in list(childlist): obj = session.merge(current, entity_name=self.mapper.entity_name, _recursive=_recursive) if obj is not None: - dest_list.append(obj) + #dest_list.append_without_event(obj) + dest_list.append_with_event(obj) else: current = list(childlist)[0] if current is not None: diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 9d78c0e7cc..bff67efbcc 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -7,9 +7,8 @@ """sqlalchemy.orm.interfaces.LoaderStrategy implementations, and related MapperOptions.""" from sqlalchemy import sql, schema, util, exceptions, sql_util, logging -from sqlalchemy.orm import mapper +from sqlalchemy.orm import mapper, attributes from sqlalchemy.orm.interfaces import * -from sqlalchemy.orm.attributes import InstrumentedAttribute from sqlalchemy.orm import session as sessionlib from sqlalchemy.orm import util as mapperutil import random @@ -50,7 +49,7 @@ class ColumnLoader(LoaderStrategy): if hosted_mapper.polymorphic_fetch == 'deferred': def execute(instance, row, isnew, **flags): if isnew: - sessionlib.attribute_manager.init_instance_attribute(instance, self.key, False, callable_=self._get_deferred_loader(instance, mapper, needs_tables)) + sessionlib.attribute_manager.init_instance_attribute(instance, self.key, callable_=self._get_deferred_loader(instance, mapper, needs_tables)) self.logger.debug("Returning deferred column fetcher for %s %s" % (mapper, self.key)) return (execute, None) else: @@ -78,8 +77,8 @@ class ColumnLoader(LoaderStrategy): try: row = result.fetchone() for prop in group: - InstrumentedAttribute.get_instrument(instance, prop.key).set_committed_value(instance, row[prop.columns[0]]) - return InstrumentedAttribute.ATTR_WAS_SET + sessionlib.attribute_manager.get_attribute(instance, prop.key).set_committed_value(instance, row[prop.columns[0]]) + return attributes.ATTR_WAS_SET finally: result.close() @@ -102,7 +101,7 @@ class DeferredColumnLoader(LoaderStrategy): if isnew: if self._should_log_debug: self.logger.debug("set deferred callable on %s" % mapperutil.attribute_str(instance, self.key)) - sessionlib.attribute_manager.init_instance_attribute(instance, self.key, False, callable_=self.setup_loader(instance)) + sessionlib.attribute_manager.init_instance_attribute(instance, self.key, callable_=self.setup_loader(instance)) return (execute, None) else: def execute(instance, row, isnew, **flags): @@ -166,8 +165,8 @@ class DeferredColumnLoader(LoaderStrategy): try: row = result.fetchone() for prop in group: - InstrumentedAttribute.get_instrument(instance, prop.key).set_committed_value(instance, row[prop.columns[0]]) - return InstrumentedAttribute.ATTR_WAS_SET + sessionlib.attribute_manager.get_attribute(instance, prop.key).set_committed_value(instance, row[prop.columns[0]]) + return attributes.ATTR_WAS_SET finally: result.close() else: @@ -205,7 +204,7 @@ class AbstractRelationLoader(LoaderStrategy): self._should_log_debug = logging.is_debug_enabled(self.logger) def _init_instance_attribute(self, instance, callable_=None): - return sessionlib.attribute_manager.init_instance_attribute(instance, self.key, self.uselist, cascade=self.cascade, trackparent=True, callable_=callable_) + return sessionlib.attribute_manager.init_instance_attribute(instance, self.key, callable_=callable_) def _register_attribute(self, class_, callable_=None): self.logger.info("register managed %s attribute %s on class %s" % ((self.uselist and "list-holding" or "scalar"), self.key, self.parent.class_.__name__)) @@ -658,9 +657,10 @@ class EagerLoader(AbstractRelationLoader): if self._should_log_debug: self.logger.debug("eagerload scalar instance on %s" % mapperutil.attribute_str(instance, self.key)) if isnew: - # set a scalar object instance directly on the parent object, - # bypassing InstrumentedAttribute event handlers. - instance.__dict__[self.key] = self.mapper._instance(selectcontext, decorated_row, None) + # set a scalar object instance directly on the + # parent object, bypassing InstrumentedAttribute + # event handlers. + sessionlib.attribute_manager.get_attribute(instance, self.key).set_raw_value(instance, self.mapper._instance(selectcontext, decorated_row, None)) else: # call _instance on the row, even though the object has been created, # so that we further descend into properties @@ -670,11 +670,8 @@ class EagerLoader(AbstractRelationLoader): if self._should_log_debug: self.logger.debug("initialize UniqueAppender on %s" % mapperutil.attribute_str(instance, self.key)) - # call the InstrumentedAttribute's initialize() method to create a new, blank list - l = InstrumentedAttribute.get_instrument(instance, self.key).initialize(instance) - - # create an appender object which will add set-like semantics to the list - appender = util.UniqueAppender(l.data) + collection = sessionlib.attribute_manager.init_collection(instance, self.key) + appender = util.UniqueAppender(collection, 'append_without_event') # store it in the "scratch" area, which is local to this load operation. selectcontext.attributes[(instance, self.key)] = appender diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py index c6b0b2689c..c9b8f01519 100644 --- a/lib/sqlalchemy/orm/unitofwork.py +++ b/lib/sqlalchemy/orm/unitofwork.py @@ -20,14 +20,14 @@ changes at once. """ from sqlalchemy import util, logging, topological -from sqlalchemy.orm import attributes +from sqlalchemy.orm import attributes, interfaces from sqlalchemy.orm import util as mapperutil from sqlalchemy.orm.mapper import object_mapper, class_mapper from sqlalchemy.exceptions import * import StringIO import weakref -class UOWEventHandler(attributes.AttributeExtension): +class UOWEventHandler(interfaces.AttributeExtension): """An event handler added to all class attributes which handles session operations. """ @@ -37,7 +37,7 @@ class UOWEventHandler(attributes.AttributeExtension): self.class_ = class_ self.cascade = cascade - def append(self, event, obj, item): + def append(self, obj, item, initiator): # process "save_update" cascade rules for when an instance is appended to the list of another instance sess = object_session(obj) if sess is not None: @@ -47,12 +47,12 @@ class UOWEventHandler(attributes.AttributeExtension): ename = prop.mapper.entity_name sess.save_or_update(item, entity_name=ename) - def delete(self, event, obj, item): + def remove(self, obj, item, initiator): # currently no cascade rules for removing an item from a list # (i.e. it stays in the Session) pass - def set(self, event, obj, newvalue, oldvalue): + def set(self, obj, newvalue, oldvalue, initiator): # process "save_update" cascade rules for when an instance is attached to another instance sess = object_session(obj) if sess is not None: @@ -62,27 +62,21 @@ class UOWEventHandler(attributes.AttributeExtension): ename = prop.mapper.entity_name sess.save_or_update(newvalue, entity_name=ename) -class UOWProperty(attributes.InstrumentedAttribute): - """Override ``InstrumentedAttribute`` to provide an extra - ``AttributeExtension`` to all managed attributes as well as the - `property` property. - """ - - def __init__(self, manager, class_, key, uselist, callable_, typecallable, cascade=None, extension=None, **kwargs): - extension = util.to_list(extension or []) - extension.insert(0, UOWEventHandler(key, class_, cascade=cascade)) - super(UOWProperty, self).__init__(manager, key, uselist, callable_, typecallable, extension=extension,**kwargs) - self.class_ = class_ - - property = property(lambda s:class_mapper(s.class_).props[s.key], doc="returns the MapperProperty object associated with this property") class UOWAttributeManager(attributes.AttributeManager): """Override ``AttributeManager`` to provide the ``UOWProperty`` instance for all ``InstrumentedAttributes``. """ - def create_prop(self, class_, key, uselist, callable_, typecallable, **kwargs): - return UOWProperty(self, class_, key, uselist, callable_, typecallable, **kwargs) + def create_prop(self, class_, key, uselist, callable_, typecallable, + cascade=None, extension=None, **kwargs): + extension = util.to_list(extension or []) + extension.insert(0, UOWEventHandler(key, class_, cascade=cascade)) + + return super(UOWAttributeManager, self).create_prop( + class_, key, uselist, callable_, typecallable, + extension=extension, **kwargs) + class UnitOfWork(object): """Main UOW object which stores lists of dirty/new/deleted objects. diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index a0088f1366..ee51c076e5 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -135,19 +135,25 @@ def coerce_kw_type(kw, key, type_, flexi_bool=True): else: kw[key] = type_(kw[key]) -def duck_type_collection(col, default=None): +def duck_type_collection(specimen, default=None): """Given an instance or class, guess if it is or is acting as one of the basic collection types: list, set and dict. If the __emulates__ property is present, return that preferentially. """ - if hasattr(col, '__emulates__'): - return getattr(col, '__emulates__') - elif hasattr(col, 'append'): + if hasattr(specimen, '__emulates__'): + return specimen.__emulates__ + + isa = isinstance(specimen, type) and issubclass or isinstance + if isa(specimen, list): return list + if isa(specimen, Set): return Set + if isa(specimen, dict): return dict + + if hasattr(specimen, 'append'): return list - elif hasattr(col, 'add'): + elif hasattr(specimen, 'add'): return Set - elif hasattr(col, 'set'): + elif hasattr(specimen, 'set'): return dict else: return default @@ -445,10 +451,12 @@ class UniqueAppender(object): """appends items to a collection such that only unique items are added.""" - def __init__(self, data): + def __init__(self, data, via=None): self.data = data self._unique = Set() - if hasattr(data, 'append'): + if via: + self._data_appender = getattr(data, via) + elif hasattr(data, 'append'): self._data_appender = data.append elif hasattr(data, 'add'): # TODO: we think its a set here. bypass unneeded uniquing logic ? diff --git a/test/ext/associationproxy.py b/test/ext/associationproxy.py index 0d6329fd60..68ad5da6ed 100644 --- a/test/ext/associationproxy.py +++ b/test/ext/associationproxy.py @@ -4,16 +4,19 @@ import unittest import testbase from sqlalchemy import * from sqlalchemy.orm import * +from sqlalchemy.orm.collections import collection from sqlalchemy.ext.associationproxy import * from testbase import Table, Column db = testbase.db class DictCollection(dict): + @collection.appender def append(self, obj): self[obj.foo] = obj - def __iter__(self): - return self.itervalues() + @collection.remover + def remove(self, obj): + del self[obj.foo] class SetCollection(set): pass @@ -24,12 +27,14 @@ class ListCollection(list): class ObjectCollection(object): def __init__(self): self.values = list() + @collection.appender def append(self, obj): self.values.append(obj) + @collection.remover + def remove(self, obj): + self.values.remove(obj) def __iter__(self): return iter(self.values) - def clear(self): - self.values.clear() class _CollectionOperations(PersistTest): def setUp(self): @@ -237,6 +242,17 @@ class CustomDictTest(DictTest): p1._children = {} self.assert_(len(p1.children) == 0) + try: + p1._children = [] + self.assert_(False) + except exceptions.ArgumentError: + self.assert_(True) + + try: + p1._children = None + self.assert_(False) + except exceptions.ArgumentError: + self.assert_(True) class SetTest(_CollectionOperations): def __init__(self, *args, **kw): @@ -252,7 +268,7 @@ class SetTest(_CollectionOperations): self.assert_(not p1.children) ch1 = Child('regular') - p1._children.append(ch1) + p1._children.add(ch1) self.assert_(ch1 in p1._children) self.assert_(len(p1._children) == 1) @@ -335,9 +351,22 @@ class SetTest(_CollectionOperations): p1 = self.roundtrip(p1) self.assert_(p1.children == set(['c'])) - p1._children = [] + p1._children = set() self.assert_(len(p1.children) == 0) + try: + p1._children = [] + self.assert_(False) + except exceptions.ArgumentError: + self.assert_(True) + + try: + p1._children = None + self.assert_(False) + except exceptions.ArgumentError: + self.assert_(True) + + def test_set_comparisons(self): Parent, Child = self.Parent, self.Child @@ -619,7 +648,7 @@ class LazyLoadTest(PersistTest): # Is there a better way to ensure that the association_proxy # didn't convert a lazy load to an eager load? This does work though. self.assert_('_children' not in p.__dict__) - self.assert_(len(p._children.data) == 3) + self.assert_(len(p._children) == 3) self.assert_('_children' in p.__dict__) def test_eager_list(self): @@ -635,7 +664,7 @@ class LazyLoadTest(PersistTest): p = self.roundtrip(p) self.assert_('_children' in p.__dict__) - self.assert_(len(p._children.data) == 3) + self.assert_(len(p._children) == 3) def test_lazy_scalar(self): Parent, Child = self.Parent, self.Child diff --git a/test/orm/attributes.py b/test/orm/attributes.py index 7e0a22aff2..7aeddeff8f 100644 --- a/test/orm/attributes.py +++ b/test/orm/attributes.py @@ -1,6 +1,7 @@ from testbase import PersistTest import sqlalchemy.util as util import sqlalchemy.orm.attributes as attributes +from sqlalchemy.orm.collections import collection from sqlalchemy import exceptions import unittest, sys, os import pickle @@ -110,13 +111,12 @@ class AttributesTest(PersistTest): s = Student() c = Course() s.courses.append(c) - print c.students - print [s] self.assert_(c.students == [s]) s.courses.remove(c) self.assert_(c.students == []) (s1, s2, s3) = (Student(), Student(), Student()) + c.students = [s1, s2, s3] self.assert_(s2.courses == [c]) self.assert_(s1.courses == [c]) @@ -126,9 +126,7 @@ class AttributesTest(PersistTest): print c print c.students s1.courses.remove(c) - self.assert_(c.students == [s2,s3]) - - + self.assert_(c.students == [s2,s3]) class Post(object):pass class Blog(object):pass @@ -334,44 +332,47 @@ class AttributesTest(PersistTest): manager = attributes.AttributeManager() class Foo(object):pass manager.register_attribute(Foo, "collection", uselist=True, typecallable=set) - assert isinstance(Foo().collection.data, set) + assert isinstance(Foo().collection, set) - manager.register_attribute(Foo, "collection", uselist=True, typecallable=dict) try: - Foo().collection + manager.register_attribute(Foo, "collection", uselist=True, typecallable=dict) assert False except exceptions.ArgumentError, e: - assert str(e) == "Dictionary collection class 'dict' must implement an append() method" - + assert str(e) == "Type InstrumentedDict must elect an appender method to be a collection class" + class MyDict(dict): + @collection.appender def append(self, item): self[item.foo] = item + @collection.remover + def remove(self, item): + del self[item.foo] manager.register_attribute(Foo, "collection", uselist=True, typecallable=MyDict) - assert isinstance(Foo().collection.data, MyDict) + assert isinstance(Foo().collection, MyDict) class MyColl(object):pass - manager.register_attribute(Foo, "collection", uselist=True, typecallable=MyColl) try: - Foo().collection + manager.register_attribute(Foo, "collection", uselist=True, typecallable=MyColl) assert False except exceptions.ArgumentError, e: - assert str(e) == "Collection class 'MyColl' is not of type 'list', 'set', or 'dict' and has no append() or add() method" + assert str(e) == "Type MyColl must elect an appender method to be a collection class" class MyColl(object): + @collection.iterator def __iter__(self): return iter([]) + @collection.appender def append(self, item): pass + @collection.remover + def remove(self, item): + pass manager.register_attribute(Foo, "collection", uselist=True, typecallable=MyColl) try: Foo().collection - assert False + assert True except exceptions.ArgumentError, e: - assert str(e) == "Collection class 'MyColl' is not of type 'list', 'set', or 'dict' and has no clear() method" - - def foo(self):pass - MyColl.clear = foo - assert isinstance(Foo().collection.data, MyColl) + assert False if __name__ == "__main__": testbase.main() diff --git a/test/orm/relationships.py b/test/orm/relationships.py index e1197e8f74..3da1097946 100644 --- a/test/orm/relationships.py +++ b/test/orm/relationships.py @@ -2,6 +2,8 @@ import testbase import unittest, sys, datetime from sqlalchemy import * from sqlalchemy.orm import * +from sqlalchemy.orm import collections +from sqlalchemy.orm.collections import collection from testbase import Table, Column db = testbase.db @@ -745,7 +747,7 @@ class CustomCollectionsTest(testbase.ORMTest): }) mapper(Bar, someothertable) f = Foo() - assert isinstance(f.bars.data, MyList) + assert isinstance(f.bars, MyList) def testlazyload(self): """test that a 'set' can be used as a collection and can lazyload.""" class Foo(object): @@ -769,6 +771,7 @@ class CustomCollectionsTest(testbase.ORMTest): def testdict(self): """test that a 'dict' can be used as a collection and can lazyload.""" + class Foo(object): pass class Bar(object): @@ -776,8 +779,11 @@ class CustomCollectionsTest(testbase.ORMTest): class AppenderDict(dict): def append(self, item): self[id(item)] = item + def remove(self, item): + if id(item) in self: + del self[id(item)] def __iter__(self): - return iter(self.values()) + return dict.__iter__(self) mapper(Foo, sometable, properties={ 'bars':relation(Bar, collection_class=AppenderDict) @@ -794,6 +800,44 @@ class CustomCollectionsTest(testbase.ORMTest): assert len(list(f.bars)) == 2 f.bars.clear() + def testdictwrapper(self): + """test that the supplied 'dict' wrapper can be used as a collection and can lazyload.""" + + class Foo(object): + pass + class Bar(object): + def __init__(self, data): self.data = data + + mapper(Foo, sometable, properties={ + 'bars':relation(Bar, + collection_class=collections.column_mapped_collection(someothertable.c.data)) + }) + mapper(Bar, someothertable) + + f = Foo() + col = collections.collection_adapter(f.bars) + col.append_with_event(Bar('a')) + col.append_with_event(Bar('b')) + sess = create_session() + sess.save(f) + sess.flush() + sess.clear() + f = sess.query(Foo).get(f.col1) + assert len(list(f.bars)) == 2 + + existing = set([id(b) for b in f.bars.values()]) + + col = collections.collection_adapter(f.bars) + col.append_with_event(Bar('b')) + f.bars['a'] = Bar('a') + sess.flush() + sess.clear() + f = sess.query(Foo).get(f.col1) + assert len(list(f.bars)) == 2 + + replaced = set([id(b) for b in f.bars.values()]) + self.assert_(existing != replaced) + def testlist(self): class Parent(object): pass @@ -811,13 +855,13 @@ class CustomCollectionsTest(testbase.ORMTest): o = Child() control.append(o) p.children.append(o) - assert control == p.children.data + assert control == p.children assert control == list(p.children) o = [Child(), Child(), Child(), Child()] control.extend(o) p.children.extend(o) - assert control == p.children.data + assert control == p.children assert control == list(p.children) assert control[0] == p.children[0] @@ -826,92 +870,92 @@ class CustomCollectionsTest(testbase.ORMTest): del control[1] del p.children[1] - assert control == p.children.data + assert control == p.children assert control == list(p.children) o = [Child()] control[1:3] = o p.children[1:3] = o - assert control == p.children.data + assert control == p.children assert control == list(p.children) o = [Child(), Child(), Child(), Child()] control[1:3] = o p.children[1:3] = o - assert control == p.children.data + assert control == p.children assert control == list(p.children) o = [Child(), Child(), Child(), Child()] control[-1:-2] = o p.children[-1:-2] = o - assert control == p.children.data + assert control == p.children assert control == list(p.children) o = [Child(), Child(), Child(), Child()] control[4:] = o p.children[4:] = o - assert control == p.children.data + assert control == p.children assert control == list(p.children) o = Child() control.insert(0, o) p.children.insert(0, o) - assert control == p.children.data + assert control == p.children assert control == list(p.children) o = Child() control.insert(3, o) p.children.insert(3, o) - assert control == p.children.data + assert control == p.children assert control == list(p.children) o = Child() control.insert(999, o) p.children.insert(999, o) - assert control == p.children.data + assert control == p.children assert control == list(p.children) del control[0:1] del p.children[0:1] - assert control == p.children.data + assert control == p.children assert control == list(p.children) del control[1:1] del p.children[1:1] - assert control == p.children.data + assert control == p.children assert control == list(p.children) del control[1:3] del p.children[1:3] - assert control == p.children.data + assert control == p.children assert control == list(p.children) del control[7:] del p.children[7:] - assert control == p.children.data + assert control == p.children assert control == list(p.children) assert control.pop() == p.children.pop() - assert control == p.children.data + assert control == p.children assert control == list(p.children) assert control.pop(0) == p.children.pop(0) - assert control == p.children.data + assert control == p.children assert control == list(p.children) assert control.pop(2) == p.children.pop(2) - assert control == p.children.data + assert control == p.children assert control == list(p.children) o = Child() control.insert(2, o) p.children.insert(2, o) - assert control == p.children.data + assert control == p.children assert control == list(p.children) control.remove(o) p.children.remove(o) - assert control == p.children.data + assert control == p.children assert control == list(p.children) def testobj(self): @@ -922,9 +966,12 @@ class CustomCollectionsTest(testbase.ORMTest): class MyCollection(object): def __init__(self): self.data = [] + @collection.appender def append(self, value): self.data.append(value) + @collection.remover + def remove(self, value): self.data.remove(value) + @collection.iterator def __iter__(self): return iter(self.data) - def clear(self): self.data.clear() mapper(Parent, sometable, properties={ 'children':relation(Child, collection_class=MyCollection) diff --git a/test/orm/unitofwork.py b/test/orm/unitofwork.py index 11a540184d..00b609206d 100644 --- a/test/orm/unitofwork.py +++ b/test/orm/unitofwork.py @@ -611,7 +611,6 @@ class OneToManyTest(UnitOfWorkTest): a2.email_address = 'lala@test.org' u.addresses.append(a2) self.echo( repr(u.addresses)) - self.echo( repr(u.addresses.added_items())) ctx.current.flush() usertable = users.select(users.c.user_id.in_(u.user_id)).execute().fetchall()