From 2be867ffac8881a4a20ca5387063ed207ac876dc Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sun, 17 May 2009 18:17:46 +0000 Subject: [PATCH] - Significant performance enhancements regarding Sessions/flush() in conjunction with large mapper graphs, large numbers of objects: - The Session's "weak referencing" behavior is now *full* - no strong references whatsoever are made to a mapped object or related items/collections in its __dict__. Backrefs and other cycles in objects no longer affect the Session's ability to lose all references to unmodified objects. Objects with pending changes still are maintained strongly until flush. [ticket:1398] The implementation also improves performance by moving the "resurrection" process of garbage collected items to only be relevant for mappings that map "mutable" attributes (i.e. PickleType, composite attrs). This removes overhead from the gc process and simplifies internal behavior. If a "mutable" attribute change is the sole change on an object which is then dereferenced, the mapper will not have access to other attribute state when the UPDATE is issued. This may present itself differently to some MapperExtensions. The change also affects the internal attribute API, but not the AttributeExtension interface nor any of the publically documented attribute functions. - The unit of work no longer genererates a graph of "dependency" processors for the full graph of mappers during flush(), instead creating such processors only for those mappers which represent objects with pending changes. This saves a tremendous number of method calls in the context of a large interconnected graph of mappers. - Cached a wasteful "table sort" operation that previously occured multiple times per flush, also removing significant method call count from flush(). - Other redundant behaviors have been simplified in mapper._save_obj(). --- CHANGES | 42 ++ lib/sqlalchemy/orm/attributes.py | 672 +++++++++--------------------- lib/sqlalchemy/orm/collections.py | 7 +- lib/sqlalchemy/orm/dependency.py | 39 +- lib/sqlalchemy/orm/dynamic.py | 46 +- lib/sqlalchemy/orm/identity.py | 89 ++-- lib/sqlalchemy/orm/interfaces.py | 24 +- lib/sqlalchemy/orm/mapper.py | 144 +++++-- lib/sqlalchemy/orm/properties.py | 18 +- lib/sqlalchemy/orm/query.py | 16 +- lib/sqlalchemy/orm/session.py | 33 +- lib/sqlalchemy/orm/state.py | 429 +++++++++++++++++++ lib/sqlalchemy/orm/strategies.py | 34 +- lib/sqlalchemy/orm/unitofwork.py | 20 +- test/orm/attributes.py | 76 ++-- test/orm/extendedattr.py | 10 +- test/orm/instrumentation.py | 25 +- test/orm/mapper.py | 10 +- test/orm/query.py | 2 +- test/orm/session.py | 61 ++- test/orm/unitofwork.py | 23 + test/profiling/zoomark_orm.py | 6 +- 22 files changed, 1074 insertions(+), 752 deletions(-) create mode 100644 lib/sqlalchemy/orm/state.py diff --git a/CHANGES b/CHANGES index 6161ae6de2..5130e886fe 100644 --- a/CHANGES +++ b/CHANGES @@ -8,6 +8,48 @@ CHANGES ===== - orm + - Significant performance enhancements regarding Sessions/flush() + in conjunction with large mapper graphs, large numbers of + objects: + + - The Session's "weak referencing" behavior is now *full* - + no strong references whatsoever are made to a mapped object + or related items/collections in its __dict__. Backrefs and + other cycles in objects no longer affect the Session's ability + to lose all references to unmodified objects. Objects with + pending changes still are maintained strongly until flush. + [ticket:1398] + + The implementation also improves performance by moving + the "resurrection" process of garbage collected items + to only be relevant for mappings that map "mutable" + attributes (i.e. PickleType, composite attrs). This removes + overhead from the gc process and simplifies internal + behavior. + + If a "mutable" attribute change is the sole change on an object + which is then dereferenced, the mapper will not have access to + other attribute state when the UPDATE is issued. This may present + itself differently to some MapperExtensions. + + The change also affects the internal attribute API, but not + the AttributeExtension interface nor any of the publically + documented attribute functions. + + - The unit of work no longer genererates a graph of "dependency" + processors for the full graph of mappers during flush(), instead + creating such processors only for those mappers which represent + objects with pending changes. This saves a tremendous number + of method calls in the context of a large interconnected + graph of mappers. + + - Cached a wasteful "table sort" operation that previously + occured multiple times per flush, also removing significant + method call count from flush(). + + - Other redundant behaviors have been simplified in + mapper._save_obj(). + - Modified query_cls on DynamicAttributeImpl to accept a full mixin version of the AppenderQuery, which allows subclassing the AppenderMixin. diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 68aa0d93ae..4fa41ff3b5 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -20,14 +20,13 @@ import types import weakref from sqlalchemy import util -from sqlalchemy.util import EMPTY_SET from sqlalchemy.orm import interfaces, collections, exc import sqlalchemy.exceptions as sa_exc # lazy imports _entity_info = None identity_equal = None - +state = None PASSIVE_NORESULT = util.symbol('PASSIVE_NORESULT') ATTR_WAS_SET = util.symbol('ATTR_WAS_SET') @@ -105,7 +104,7 @@ class QueryableAttribute(interfaces.PropComparator): self.parententity = parententity def get_history(self, instance, **kwargs): - return self.impl.get_history(instance_state(instance), **kwargs) + return self.impl.get_history(instance_state(instance), instance_dict(instance), **kwargs) def __selectable__(self): # TODO: conditionally attach this method based on clause_element ? @@ -148,15 +147,15 @@ class InstrumentedAttribute(QueryableAttribute): """Public-facing descriptor, placed in the mapped class dictionary.""" def __set__(self, instance, value): - self.impl.set(instance_state(instance), value, None) + self.impl.set(instance_state(instance), instance_dict(instance), value, None) def __delete__(self, instance): - self.impl.delete(instance_state(instance)) + self.impl.delete(instance_state(instance), instance_dict(instance)) def __get__(self, instance, owner): if instance is None: return self - return self.impl.get(instance_state(instance)) + return self.impl.get(instance_state(instance), instance_dict(instance)) class _ProxyImpl(object): accepts_scalar_loader = False @@ -335,7 +334,7 @@ class AttributeImpl(object): else: state.callables[self.key] = callable_ - def get_history(self, state, passive=PASSIVE_OFF): + def get_history(self, state, dict_, passive=PASSIVE_OFF): raise NotImplementedError() def _get_callable(self, state): @@ -346,13 +345,13 @@ class AttributeImpl(object): else: return None - def initialize(self, state): + def initialize(self, state, dict_): """Initialize this attribute on the given object instance with an empty value.""" - state.dict[self.key] = None + dict_[self.key] = None return None - def get(self, state, passive=PASSIVE_OFF): + def get(self, state, dict_, passive=PASSIVE_OFF): """Retrieve a value from the given object. If a callable is assembled on this object's attribute, and @@ -361,7 +360,7 @@ class AttributeImpl(object): """ try: - return state.dict[self.key] + return dict_[self.key] except KeyError: # if no history, check for lazy callables, etc. if state.committed_state.get(self.key, NEVER_SET) is NEVER_SET: @@ -374,25 +373,25 @@ class AttributeImpl(object): return PASSIVE_NORESULT value = callable_() if value is not ATTR_WAS_SET: - return self.set_committed_value(state, value) + return self.set_committed_value(state, dict_, value) else: - if self.key not in state.dict: + if self.key not in dict_: return self.get(state, passive=passive) - return state.dict[self.key] + return dict_[self.key] # Return a new, empty value - return self.initialize(state) + return self.initialize(state, dict_) - def append(self, state, value, initiator, passive=PASSIVE_OFF): - self.set(state, value, initiator) + def append(self, state, dict_, value, initiator, passive=PASSIVE_OFF): + self.set(state, dict_, value, initiator) - def remove(self, state, value, initiator, passive=PASSIVE_OFF): - self.set(state, None, initiator) + def remove(self, state, dict_, value, initiator, passive=PASSIVE_OFF): + self.set(state, dict_, None, initiator) - def set(self, state, value, initiator): + def set(self, state, dict_, value, initiator): raise NotImplementedError() - def get_committed_value(self, state, passive=PASSIVE_OFF): + def get_committed_value(self, state, dict_, passive=PASSIVE_OFF): """return the unchanged value of this attribute""" if self.key in state.committed_state: @@ -401,12 +400,12 @@ class AttributeImpl(object): else: return state.committed_state.get(self.key) else: - return self.get(state, passive=passive) + return self.get(state, dict_, passive=passive) - def set_committed_value(self, state, value): + def set_committed_value(self, state, dict_, value): """set an attribute value on the given instance and 'commit' it.""" - state.commit([self.key]) + state.commit(dict_, [self.key]) state.callables.pop(self.key, None) state.dict[self.key] = value @@ -419,45 +418,45 @@ class ScalarAttributeImpl(AttributeImpl): accepts_scalar_loader = True uses_objects = False - def delete(self, state): + def delete(self, state, dict_): # TODO: catch key errors, convert to attributeerror? if self.active_history or self.extensions: - old = self.get(state) + old = self.get(state, dict_) else: - old = state.dict.get(self.key, NO_VALUE) + old = dict_.get(self.key, NO_VALUE) - state.modified_event(self, False, old) + state.modified_event(dict_, self, False, old) if self.extensions: - self.fire_remove_event(state, old, None) - del state.dict[self.key] + self.fire_remove_event(state, dict_, old, None) + del dict_[self.key] - def get_history(self, state, passive=PASSIVE_OFF): + def get_history(self, state, dict_, passive=PASSIVE_OFF): return History.from_attribute( - self, state, state.dict.get(self.key, NO_VALUE)) + self, state, dict_.get(self.key, NO_VALUE)) - def set(self, state, value, initiator): + def set(self, state, dict_, value, initiator): if initiator is self: return if self.active_history or self.extensions: - old = self.get(state) + old = self.get(state, dict_) else: - old = state.dict.get(self.key, NO_VALUE) + old = dict_.get(self.key, NO_VALUE) - state.modified_event(self, False, old) + state.modified_event(dict_, self, False, old) if self.extensions: - value = self.fire_replace_event(state, value, old, initiator) - state.dict[self.key] = value + value = self.fire_replace_event(state, dict_, value, old, initiator) + dict_[self.key] = value - def fire_replace_event(self, state, value, previous, initiator): + def fire_replace_event(self, state, dict_, value, previous, initiator): for ext in self.extensions: value = ext.set(state, value, previous, initiator or self) return value - def fire_remove_event(self, state, value, initiator): + def fire_remove_event(self, state, dict_, value, initiator): for ext in self.extensions: ext.remove(state, value, initiator or self) @@ -483,29 +482,48 @@ class MutableScalarAttributeImpl(ScalarAttributeImpl): raise sa_exc.ArgumentError("MutableScalarAttributeImpl requires a copy function") self.copy = copy_function - def get_history(self, state, passive=PASSIVE_OFF): + def get_history(self, state, dict_, passive=PASSIVE_OFF): + if not dict_: + v = state.committed_state.get(self.key, NO_VALUE) + else: + v = dict_.get(self.key, NO_VALUE) + return History.from_attribute( - self, state, state.dict.get(self.key, NO_VALUE)) + self, state, v) - def commit_to_state(self, state, dest): - dest[self.key] = self.copy(state.dict[self.key]) + def commit_to_state(self, state, dict_, dest): + dest[self.key] = self.copy(dict_[self.key]) - def check_mutable_modified(self, state): - (added, unchanged, deleted) = self.get_history(state, passive=PASSIVE_NO_INITIALIZE) + def check_mutable_modified(self, state, dict_): + (added, unchanged, deleted) = self.get_history(state, dict_, passive=PASSIVE_NO_INITIALIZE) return bool(added or deleted) - def set(self, state, value, initiator): + def get(self, state, dict_, passive=PASSIVE_OFF): + if self.key not in state.mutable_dict: + ret = ScalarAttributeImpl.get(self, state, dict_, passive=passive) + if ret is not PASSIVE_NORESULT: + state.mutable_dict[self.key] = ret + return ret + else: + return state.mutable_dict[self.key] + + def delete(self, state, dict_): + ScalarAttributeImpl.delete(self, state, dict_) + state.mutable_dict.pop(self.key) + + def set(self, state, dict_, value, initiator): if initiator is self: return - state.modified_event(self, True, NEVER_SET) - + state.modified_event(dict_, self, True, NEVER_SET) + if self.extensions: - old = self.get(state) - value = self.fire_replace_event(state, value, old, initiator) - state.dict[self.key] = value + old = self.get(state, dict_) + value = self.fire_replace_event(state, dict_, value, old, initiator) + dict_[self.key] = value else: - state.dict[self.key] = value + dict_[self.key] = value + state.mutable_dict[self.key] = value class ScalarObjectAttributeImpl(ScalarAttributeImpl): @@ -526,22 +544,22 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl): if compare_function is None: self.is_equal = identity_equal - def delete(self, state): - old = self.get(state) - self.fire_remove_event(state, old, self) - del state.dict[self.key] + def delete(self, state, dict_): + old = self.get(state, dict_) + self.fire_remove_event(state, dict_, old, self) + del dict_[self.key] - def get_history(self, state, passive=PASSIVE_OFF): - if self.key in state.dict: - return History.from_attribute(self, state, state.dict[self.key]) + def get_history(self, state, dict_, passive=PASSIVE_OFF): + if self.key in dict_: + return History.from_attribute(self, state, dict_[self.key]) else: - current = self.get(state, passive=passive) + current = self.get(state, dict_, passive=passive) if current is PASSIVE_NORESULT: return HISTORY_BLANK else: return History.from_attribute(self, state, current) - def set(self, state, value, initiator): + def set(self, state, dict_, value, initiator): """Set a value on the given InstanceState. `initiator` is the ``InstrumentedAttribute`` that initiated the @@ -553,12 +571,12 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl): return # may want to add options to allow the get() here to be passive - old = self.get(state) - value = self.fire_replace_event(state, value, old, initiator) - state.dict[self.key] = value + old = self.get(state, dict_) + value = self.fire_replace_event(state, dict_, value, old, initiator) + dict_[self.key] = value - def fire_remove_event(self, state, value, initiator): - state.modified_event(self, False, value) + def fire_remove_event(self, state, dict_, value, initiator): + state.modified_event(dict_, self, False, value) if self.trackparent and value is not None: self.sethasparent(instance_state(value), False) @@ -566,8 +584,8 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl): for ext in self.extensions: ext.remove(state, value, initiator or self) - def fire_replace_event(self, state, value, previous, initiator): - state.modified_event(self, False, previous) + def fire_replace_event(self, state, dict_, value, previous, initiator): + state.modified_event(dict_, self, False, previous) if self.trackparent: if previous is not value and previous is not None: @@ -615,15 +633,15 @@ class CollectionAttributeImpl(AttributeImpl): def __copy(self, item): return [y for y in list(collections.collection_adapter(item))] - def get_history(self, state, passive=PASSIVE_OFF): - current = self.get(state, passive=passive) + def get_history(self, state, dict_, passive=PASSIVE_OFF): + current = self.get(state, dict_, passive=passive) if current is PASSIVE_NORESULT: return HISTORY_BLANK else: return History.from_attribute(self, state, current) - def fire_append_event(self, state, value, initiator): - state.modified_event(self, True, NEVER_SET, passive=PASSIVE_NO_INITIALIZE) + def fire_append_event(self, state, dict_, value, initiator): + state.modified_event(dict_, self, True, NEVER_SET, passive=PASSIVE_NO_INITIALIZE) for ext in self.extensions: value = ext.append(state, value, initiator or self) @@ -633,11 +651,11 @@ class CollectionAttributeImpl(AttributeImpl): return value - def fire_pre_remove_event(self, state, initiator): - state.modified_event(self, True, NEVER_SET, passive=PASSIVE_NO_INITIALIZE) + def fire_pre_remove_event(self, state, dict_, initiator): + state.modified_event(dict_, self, True, NEVER_SET, passive=PASSIVE_NO_INITIALIZE) - def fire_remove_event(self, state, value, initiator): - state.modified_event(self, True, NEVER_SET, passive=PASSIVE_NO_INITIALIZE) + def fire_remove_event(self, state, dict_, value, initiator): + state.modified_event(dict_, self, True, NEVER_SET, passive=PASSIVE_NO_INITIALIZE) if self.trackparent and value is not None: self.sethasparent(instance_state(value), False) @@ -645,51 +663,51 @@ class CollectionAttributeImpl(AttributeImpl): for ext in self.extensions: ext.remove(state, value, initiator or self) - def delete(self, state): - if self.key not in state.dict: + def delete(self, state, dict_): + if self.key not in dict_: return - state.modified_event(self, True, NEVER_SET) + state.modified_event(dict_, self, True, NEVER_SET) - collection = self.get_collection(state) + collection = self.get_collection(state, state.dict) collection.clear_with_event() # TODO: catch key errors, convert to attributeerror? - del state.dict[self.key] + del dict_[self.key] - def initialize(self, state): + def initialize(self, state, dict_): """Initialize this attribute with an empty collection.""" _, user_data = self._initialize_collection(state) - state.dict[self.key] = user_data + dict_[self.key] = user_data return user_data def _initialize_collection(self, state): return state.manager.initialize_collection( self.key, state, self.collection_factory) - def append(self, state, value, initiator, passive=PASSIVE_OFF): + def append(self, state, dict_, value, initiator, passive=PASSIVE_OFF): if initiator is self: return - collection = self.get_collection(state, passive=passive) + collection = self.get_collection(state, dict_, passive=passive) if collection is PASSIVE_NORESULT: - value = self.fire_append_event(state, value, initiator) + value = self.fire_append_event(state, dict_, value, initiator) state.get_pending(self.key).append(value) else: collection.append_with_event(value, initiator) - def remove(self, state, value, initiator, passive=PASSIVE_OFF): + def remove(self, state, dict_, value, initiator, passive=PASSIVE_OFF): if initiator is self: return - collection = self.get_collection(state, passive=passive) + collection = self.get_collection(state, state.dict, passive=passive) if collection is PASSIVE_NORESULT: - self.fire_remove_event(state, value, initiator) + self.fire_remove_event(state, dict_, value, initiator) state.get_pending(self.key).remove(value) else: collection.remove_with_event(value, initiator) - def set(self, state, value, initiator): + def set(self, state, dict_, value, initiator): """Set a value on the given object. `initiator` is the ``InstrumentedAttribute`` that initiated the @@ -701,10 +719,10 @@ class CollectionAttributeImpl(AttributeImpl): return self._set_iterable( - state, value, + state, dict_, value, lambda adapter, i: adapter.adapt_like_to_iterable(i)) - def _set_iterable(self, state, iterable, adapter=None): + def _set_iterable(self, state, dict_, iterable, adapter=None): """Set a collection value from an iterable of state-bearers. ``adapter`` is an optional callable invoked with a CollectionAdapter @@ -722,24 +740,24 @@ class CollectionAttributeImpl(AttributeImpl): else: new_values = list(iterable) - old = self.get(state) + old = self.get(state, dict_) # ignore re-assignment of the current collection, as happens # implicitly with in-place operators (foo.collection |= other) if old is iterable: return - state.modified_event(self, True, old) + state.modified_event(dict_, self, True, old) - old_collection = self.get_collection(state, old) + old_collection = self.get_collection(state, dict_, old) - state.dict[self.key] = user_data + dict_[self.key] = user_data collections.bulk_replace(new_values, old_collection, new_collection) old_collection.unlink(old) - def set_committed_value(self, state, value): + def set_committed_value(self, state, dict_, value): """Set an attribute value on the given instance and 'commit' it.""" collection, user_data = self._initialize_collection(state) @@ -751,13 +769,13 @@ class CollectionAttributeImpl(AttributeImpl): state.callables.pop(self.key, None) state.dict[self.key] = user_data - state.commit([self.key]) + state.commit(dict_, [self.key]) if self.key in state.pending: # pending items exist. issue a modified event, # add/remove new items. - state.modified_event(self, True, user_data) + state.modified_event(dict_, self, True, user_data) pending = state.pending.pop(self.key) added = pending.added_items @@ -769,14 +787,14 @@ class CollectionAttributeImpl(AttributeImpl): return user_data - def get_collection(self, state, user_data=None, passive=PASSIVE_OFF): + def get_collection(self, state, dict_, user_data=None, passive=PASSIVE_OFF): """Retrieve the CollectionAdapter associated with the given state. Creates a new CollectionAdapter if one does not exist. """ if user_data is None: - user_data = self.get(state, passive=passive) + user_data = self.get(state, dict_, passive=passive) if user_data is PASSIVE_NORESULT: return user_data @@ -799,320 +817,26 @@ class GenericBackrefExtension(interfaces.AttributeExtension): if oldchild is not None: # With lazy=None, there's no guarantee that the full collection is # present when updating via a backref. - old_state = instance_state(oldchild) + old_state, old_dict = instance_state(oldchild), instance_dict(oldchild) impl = old_state.get_impl(self.key) try: - impl.remove(old_state, state.obj(), initiator, passive=PASSIVE_NO_CALLABLES) + impl.remove(old_state, old_dict, state.obj(), initiator, passive=PASSIVE_NO_CALLABLES) except (ValueError, KeyError, IndexError): pass if child is not None: - new_state = instance_state(child) - new_state.get_impl(self.key).append(new_state, state.obj(), initiator, passive=PASSIVE_NO_CALLABLES) + new_state, new_dict = instance_state(child), instance_dict(child) + new_state.get_impl(self.key).append(new_state, new_dict, state.obj(), initiator, passive=PASSIVE_NO_CALLABLES) return child def append(self, state, child, initiator): - child_state = instance_state(child) - child_state.get_impl(self.key).append(child_state, state.obj(), initiator, passive=PASSIVE_NO_CALLABLES) + child_state, child_dict = instance_state(child), instance_dict(child) + child_state.get_impl(self.key).append(child_state, child_dict, state.obj(), initiator, passive=PASSIVE_NO_CALLABLES) return child def remove(self, state, child, initiator): if child is not None: - child_state = instance_state(child) - child_state.get_impl(self.key).remove(child_state, state.obj(), initiator, passive=PASSIVE_NO_CALLABLES) - - -class InstanceState(object): - """tracks state information at the instance level.""" - - session_id = None - key = None - runid = None - expired_attributes = EMPTY_SET - load_options = EMPTY_SET - load_path = () - insert_order = None - - def __init__(self, obj, manager): - self.class_ = obj.__class__ - self.manager = manager - self.obj = weakref.ref(obj, self._cleanup) - self.dict = obj.__dict__ - self.modified = False - self.callables = {} - self.expired = False - self.committed_state = {} - self.pending = {} - self.parents = {} - - def detach(self): - if self.session_id: - del self.session_id - - def dispose(self): - if self.session_id: - del self.session_id - del self.obj - del self.dict - - def _cleanup(self, ref): - self.dispose() - - def obj(self): - return None - - @util.memoized_property - def dict(self): - # return a blank dict - # if none is available, so that asynchronous gc - # doesn't blow up expiration operations in progress - # (usually expire_attributes) - return {} - - @property - def sort_key(self): - return self.key and self.key[1] or (self.insert_order, ) - - def check_modified(self): - if self.modified: - return True - else: - for key in self.manager.mutable_attributes: - if self.manager[key].impl.check_mutable_modified(self): - return True - else: - return False - - def initialize_instance(*mixed, **kwargs): - self, instance, args = mixed[0], mixed[1], mixed[2:] - manager = self.manager - - for fn in manager.events.on_init: - fn(self, instance, args, kwargs) - try: - return manager.events.original_init(*mixed[1:], **kwargs) - except: - for fn in manager.events.on_init_failure: - fn(self, instance, args, kwargs) - raise - - def get_history(self, key, **kwargs): - return self.manager.get_impl(key).get_history(self, **kwargs) - - def get_impl(self, key): - return self.manager.get_impl(key) - - def get_pending(self, key): - if key not in self.pending: - self.pending[key] = PendingCollection() - return self.pending[key] - - def value_as_iterable(self, key, passive=PASSIVE_OFF): - """return an InstanceState attribute as a list, - regardless of it being a scalar or collection-based - attribute. - - returns None if passive is not PASSIVE_OFF and the getter returns - PASSIVE_NORESULT. - """ - - impl = self.get_impl(key) - x = impl.get(self, passive=passive) - if x is PASSIVE_NORESULT: - - return None - elif hasattr(impl, 'get_collection'): - return impl.get_collection(self, x, passive=passive) - elif isinstance(x, list): - return x - else: - return [x] - - def _run_on_load(self, instance=None): - if instance is None: - instance = self.obj() - self.manager.events.run('on_load', instance) - - def __getstate__(self): - return {'key': self.key, - 'committed_state': self.committed_state, - 'pending': self.pending, - 'parents': self.parents, - 'modified': self.modified, - 'expired':self.expired, - 'load_options':self.load_options, - 'load_path':interfaces.serialize_path(self.load_path), - 'instance': self.obj(), - 'expired_attributes':self.expired_attributes, - 'callables': self.callables} - - def __setstate__(self, state): - self.committed_state = state['committed_state'] - self.parents = state['parents'] - self.key = state['key'] - self.session_id = None - self.pending = state['pending'] - self.modified = state['modified'] - self.obj = weakref.ref(state['instance']) - self.load_options = state['load_options'] or EMPTY_SET - self.load_path = interfaces.deserialize_path(state['load_path']) - self.class_ = self.obj().__class__ - self.manager = manager_of_class(self.class_) - self.dict = self.obj().__dict__ - self.callables = state['callables'] - self.runid = None - self.expired = state['expired'] - self.expired_attributes = state['expired_attributes'] - - def initialize(self, key): - self.manager.get_impl(key).initialize(self) - - def set_callable(self, key, callable_): - self.dict.pop(key, None) - self.callables[key] = callable_ - - def __call__(self): - """__call__ allows the InstanceState to act as a deferred - callable for loading expired attributes, which is also - serializable (picklable). - - """ - unmodified = self.unmodified - class_manager = self.manager - class_manager.deferred_scalar_loader(self, [ - attr.impl.key for attr in class_manager.attributes if - attr.impl.accepts_scalar_loader and - attr.impl.key in self.expired_attributes and - attr.impl.key in unmodified - ]) - for k in self.expired_attributes: - self.callables.pop(k, None) - del self.expired_attributes - return ATTR_WAS_SET - - @property - def unmodified(self): - """a set of keys which have no uncommitted changes""" - - return set( - key for key in self.manager.iterkeys() - if (key not in self.committed_state or - (key in self.manager.mutable_attributes and - not self.manager[key].impl.check_mutable_modified(self)))) - - @property - def unloaded(self): - """a set of keys which do not have a loaded value. - - This includes expired attributes and any other attribute that - was never populated or modified. - - """ - return set( - key for key in self.manager.iterkeys() - if key not in self.committed_state and key not in self.dict) - - def expire_attributes(self, attribute_names): - self.expired_attributes = set(self.expired_attributes) - - if attribute_names is None: - attribute_names = self.manager.keys() - self.expired = True - self.modified = False - filter_deferred = True - else: - filter_deferred = False - for key in attribute_names: - impl = self.manager[key].impl - if not filter_deferred or \ - not impl.dont_expire_missing or \ - key in self.dict: - self.expired_attributes.add(key) - if impl.accepts_scalar_loader: - self.callables[key] = self - self.dict.pop(key, None) - self.pending.pop(key, None) - self.committed_state.pop(key, None) - - def reset(self, key): - """remove the given attribute and any callables associated with it.""" - - self.dict.pop(key, None) - self.callables.pop(key, None) - - def modified_event(self, attr, should_copy, previous, passive=PASSIVE_OFF): - needs_committed = attr.key not in self.committed_state - - if needs_committed: - if previous is NEVER_SET: - if passive: - if attr.key in self.dict: - previous = self.dict[attr.key] - else: - previous = attr.get(self) - - if should_copy and previous not in (None, NO_VALUE, NEVER_SET): - previous = attr.copy(previous) - - if needs_committed: - self.committed_state[attr.key] = previous - - self.modified = True - - def commit(self, keys): - """Commit attributes. - - This is used by a partial-attribute load operation to mark committed - those attributes which were refreshed from the database. - - Attributes marked as "expired" can potentially remain "expired" after - this step if a value was not populated in state.dict. - - """ - class_manager = self.manager - for key in keys: - if key in self.dict and key in class_manager.mutable_attributes: - class_manager[key].impl.commit_to_state(self, self.committed_state) - else: - self.committed_state.pop(key, None) - - self.expired = False - # unexpire attributes which have loaded - for key in self.expired_attributes.intersection(keys): - if key in self.dict: - self.expired_attributes.remove(key) - self.callables.pop(key, None) - - def commit_all(self): - """commit all attributes unconditionally. - - This is used after a flush() or a full load/refresh - to remove all pending state from the instance. - - - all attributes are marked as "committed" - - the "strong dirty reference" is removed - - the "modified" flag is set to False - - any "expired" markers/callables are removed. - - Attributes marked as "expired" can potentially remain "expired" after this step - if a value was not populated in state.dict. - - """ - - self.committed_state = {} - self.pending = {} - - # unexpire attributes which have loaded - if self.expired_attributes: - for key in self.expired_attributes.intersection(self.dict): - self.callables.pop(key, None) - self.expired_attributes.difference_update(self.dict) - - for key in self.manager.mutable_attributes: - if key in self.dict: - self.manager[key].impl.commit_to_state(self, self.committed_state) - - self.modified = self.expired = False - self._strong_obj = None + child_state, child_dict = instance_state(child), instance_dict(child) + child_state.get_impl(self.key).remove(child_state, child_dict, state.obj(), initiator, passive=PASSIVE_NO_CALLABLES) class Events(object): @@ -1121,6 +845,7 @@ class Events(object): self.on_init = () self.on_init_failure = () self.on_load = () + self.on_resurrect = () def run(self, event, *args, **kwargs): for fn in getattr(self, event): @@ -1146,7 +871,6 @@ class ClassManager(dict): STATE_ATTR = '_sa_instance_state' event_registry_factory = Events - instance_state_factory = InstanceState deferred_scalar_loader = None def __init__(self, class_): @@ -1170,7 +894,6 @@ class ClassManager(dict): def _configure_create_arguments(self, _source=None, - instance_state_factory=None, deferred_scalar_loader=None): """Accept extra **kw arguments passed to create_manager_for_cls. @@ -1185,11 +908,8 @@ class ClassManager(dict): """ if _source: - instance_state_factory = _source.instance_state_factory deferred_scalar_loader = _source.deferred_scalar_loader - if instance_state_factory: - self.instance_state_factory = instance_state_factory if deferred_scalar_loader: self.deferred_scalar_loader = deferred_scalar_loader @@ -1222,7 +942,16 @@ class ClassManager(dict): if self.new_init: self.uninstall_member('__init__') self.new_init = None - + + def _create_instance_state(self, instance): + global state + if state is None: + from sqlalchemy.orm import state + if self.mutable_attributes: + return state.MutableAttrInstanceState(instance, self) + else: + return state.InstanceState(instance, self) + def manage(self): """Mark this instance as the manager for its class.""" @@ -1330,11 +1059,11 @@ class ClassManager(dict): def new_instance(self, state=None): instance = self.class_.__new__(self.class_) - setattr(instance, self.STATE_ATTR, state or self.instance_state_factory(instance, self)) + setattr(instance, self.STATE_ATTR, state or self._create_instance_state(instance)) return instance def setup_instance(self, instance, state=None): - setattr(instance, self.STATE_ATTR, state or self.instance_state_factory(instance, self)) + setattr(instance, self.STATE_ATTR, state or self._create_instance_state(instance)) def teardown_instance(self, instance): delattr(instance, self.STATE_ATTR) @@ -1348,13 +1077,10 @@ class ClassManager(dict): if hasattr(instance, self.STATE_ATTR): return False else: - state = self.instance_state_factory(instance, self) + state = self._create_instance_state(instance) setattr(instance, self.STATE_ATTR, state) return state - def state_of(self, instance): - return getattr(instance, self.STATE_ATTR) - def state_getter(self): """Return a (instance) -> InstanceState callable. @@ -1365,6 +1091,9 @@ class ClassManager(dict): return attrgetter(self.STATE_ATTR) + def dict_getter(self): + return attrgetter('__dict__') + def has_state(self, instance): return hasattr(instance, self.STATE_ATTR) @@ -1385,6 +1114,9 @@ class _ClassInstrumentationAdapter(ClassManager): def __init__(self, class_, override, **kw): self._adapted = override + self._get_state = self._adapted.state_getter(class_) + self._get_dict = self._adapted.dict_getter(class_) + ClassManager.__init__(self, class_, **kw) def manage(self): @@ -1446,36 +1178,27 @@ class _ClassInstrumentationAdapter(ClassManager): self._adapted.initialize_instance_dict(self.class_, instance) if state is None: - state = self.instance_state_factory(instance, self) + state = self._create_instance_state(instance) # the given instance is assumed to have no state self._adapted.install_state(self.class_, instance, state) - state.dict = self._adapted.get_instance_dict(self.class_, instance) return state def teardown_instance(self, instance): self._adapted.remove_state(self.class_, instance) - def state_of(self, instance): - if hasattr(self._adapted, 'state_of'): - return self._adapted.state_of(self.class_, instance) - else: - getter = self._adapted.state_getter(self.class_) - return getter(instance) - def has_state(self, instance): - if hasattr(self._adapted, 'has_state'): - return self._adapted.has_state(self.class_, instance) - else: - try: - state = self.state_of(instance) - return True - except exc.NO_STATE: - return False + try: + state = self._get_state(instance) + return True + except exc.NO_STATE: + return False def state_getter(self): - return self._adapted.state_getter(self.class_) + return self._get_state + def dict_getter(self): + return self._get_dict class History(tuple): """A 3-tuple of added, unchanged and deleted values. @@ -1520,7 +1243,7 @@ class History(tuple): original = state.committed_state.get(attribute.key, NEVER_SET) if hasattr(attribute, 'get_collection'): - current = attribute.get_collection(state, current) + current = attribute.get_collection(state, state.dict, current) if original is NO_VALUE: return cls(list(current), (), ()) elif original is NEVER_SET: @@ -1557,30 +1280,8 @@ class History(tuple): HISTORY_BLANK = History(None, None, None) -class PendingCollection(object): - """A writable placeholder for an unloaded collection. - - Stores items appended to and removed from a collection that has not yet - been loaded. When the collection is loaded, the changes stored in - PendingCollection are applied to it to produce the final result. - - """ - def __init__(self): - self.deleted_items = util.IdentitySet() - self.added_items = util.OrderedIdentitySet() - - def append(self, value): - if value in self.deleted_items: - self.deleted_items.remove(value) - self.added_items.add(value) - - def remove(self, value): - if value in self.added_items: - self.added_items.remove(value) - self.deleted_items.add(value) - def _conditional_instance_state(obj): - if not isinstance(obj, InstanceState): + if not isinstance(obj, state.InstanceState): obj = instance_state(obj) return obj @@ -1690,15 +1391,16 @@ def init_collection(obj, key): this usage is deprecated. """ - - return init_state_collection(_conditional_instance_state(obj), key) + state = _conditional_instance_state(obj) + dict_ = state.dict + return init_state_collection(state, dict_, key) -def init_state_collection(state, key): +def init_state_collection(state, dict_, key): """Initialize a collection attribute and return the collection adapter.""" attr = state.get_impl(key) - user_data = attr.initialize(state) - return attr.get_collection(state, user_data) + user_data = attr.initialize(state, dict_) + return attr.get_collection(state, dict_, user_data) def set_committed_value(instance, key, value): """Set the value of an attribute with no history events. @@ -1715,8 +1417,8 @@ def set_committed_value(instance, key, value): as though it were part of its original loaded state. """ - state = instance_state(instance) - state.get_impl(key).set_committed_value(instance, key, value) + state, dict_ = instance_state(instance), instance_dict(instance) + state.get_impl(key).set_committed_value(state, dict_, key, value) def set_attribute(instance, key, value): """Set the value of an attribute, firing history events. @@ -1728,8 +1430,8 @@ def set_attribute(instance, key, value): by SQLAlchemy. """ - state = instance_state(instance) - state.get_impl(key).set(state, value, None) + state, dict_ = instance_state(instance), instance_dict(instance) + state.get_impl(key).set(state, dict_, value, None) def get_attribute(instance, key): """Get the value of an attribute, firing any callables required. @@ -1741,8 +1443,8 @@ def get_attribute(instance, key): by SQLAlchemy. """ - state = instance_state(instance) - return state.get_impl(key).get(state) + state, dict_ = instance_state(instance), instance_dict(instance) + return state.get_impl(key).get(state, dict_) def del_attribute(instance, key): """Delete the value of an attribute, firing history events. @@ -1754,8 +1456,8 @@ def del_attribute(instance, key): by SQLAlchemy. """ - state = instance_state(instance) - state.get_impl(key).delete(state) + state, dict_ = instance_state(instance), instance_dict(instance) + state.get_impl(key).delete(state, dict_) def is_instrumented(instance, key): """Return True if the given attribute on the given instance is instrumented @@ -1772,6 +1474,7 @@ class InstrumentationRegistry(object): _manager_finders = weakref.WeakKeyDictionary() _state_finders = util.WeakIdentityMapping() + _dict_finders = util.WeakIdentityMapping() _extended = False def create_manager_for_cls(self, class_, **kw): @@ -1806,6 +1509,7 @@ class InstrumentationRegistry(object): manager.factory = factory self._manager_finders[class_] = manager.manager_getter() self._state_finders[class_] = manager.state_getter() + self._dict_finders[class_] = manager.dict_getter() return manager def _collect_management_factories_for(self, cls): @@ -1845,6 +1549,7 @@ class InstrumentationRegistry(object): return finder(cls) def state_of(self, instance): + # this is only called when alternate instrumentation has been established if instance is None: raise AttributeError("None has no persistent state.") try: @@ -1852,21 +1557,15 @@ class InstrumentationRegistry(object): except KeyError: raise AttributeError("%r is not instrumented" % instance.__class__) - def state_or_default(self, instance, default=None): + def dict_of(self, instance): + # this is only called when alternate instrumentation has been established if instance is None: - return default + raise AttributeError("None has no persistent state.") try: - finder = self._state_finders[instance.__class__] + return self._dict_finders[instance.__class__](instance) except KeyError: - return default - else: - try: - return finder(instance) - except exc.NO_STATE: - return default - except: - raise - + raise AttributeError("%r is not instrumented" % instance.__class__) + def unregister(self, class_): if class_ in self._manager_finders: manager = self.manager_of_class(class_) @@ -1874,6 +1573,7 @@ class InstrumentationRegistry(object): manager.dispose() del self._manager_finders[class_] del self._state_finders[class_] + del self._dict_finders[class_] instrumentation_registry = InstrumentationRegistry() @@ -1887,12 +1587,14 @@ def _install_lookup_strategy(implementation): and unit tests specific to this behavior. """ - global instance_state + global instance_state, instance_dict if implementation is util.symbol('native'): instance_state = attrgetter(ClassManager.STATE_ATTR) + instance_dict = attrgetter("__dict__") else: instance_state = instrumentation_registry.state_of - + instance_dict = instrumentation_registry.dict_of + manager_of_class = instrumentation_registry.manager_of_class _create_manager_for_cls = instrumentation_registry.create_manager_for_cls _install_lookup_strategy(util.symbol('native')) diff --git a/lib/sqlalchemy/orm/collections.py b/lib/sqlalchemy/orm/collections.py index 5638a7e4a5..4ca4c5719e 100644 --- a/lib/sqlalchemy/orm/collections.py +++ b/lib/sqlalchemy/orm/collections.py @@ -472,6 +472,7 @@ class CollectionAdapter(object): """ def __init__(self, attr, owner_state, data): self.attr = attr + # TODO: figure out what this being a weakref buys us self._data = weakref.ref(data) self.owner_state = owner_state self.link_to_self(data) @@ -578,7 +579,7 @@ class CollectionAdapter(object): """ if initiator is not False and item is not None: - return self.attr.fire_append_event(self.owner_state, item, initiator) + return self.attr.fire_append_event(self.owner_state, self.owner_state.dict, item, initiator) else: return item @@ -591,7 +592,7 @@ class CollectionAdapter(object): """ if initiator is not False and item is not None: - self.attr.fire_remove_event(self.owner_state, item, initiator) + self.attr.fire_remove_event(self.owner_state, self.owner_state.dict, item, initiator) def fire_pre_remove_event(self, initiator=None): """Notify that an entity is about to be removed from the collection. @@ -600,7 +601,7 @@ class CollectionAdapter(object): fire_remove_event(). """ - self.attr.fire_pre_remove_event(self.owner_state, initiator=initiator) + self.attr.fire_pre_remove_event(self.owner_state, self.owner_state.dict, initiator=initiator) def __getstate__(self): return {'key': self.attr.key, diff --git a/lib/sqlalchemy/orm/dependency.py b/lib/sqlalchemy/orm/dependency.py index a80727b7f2..151c557d71 100644 --- a/lib/sqlalchemy/orm/dependency.py +++ b/lib/sqlalchemy/orm/dependency.py @@ -64,17 +64,21 @@ class DependencyProcessor(object): def register_dependencies(self, uowcommit): """Tell a ``UOWTransaction`` what mappers are dependent on which, with regards to the two or three mappers handled by - this ``PropertyLoader``. + this ``DependencyProcessor``. - Also register itself as a *processor* for one of its mappers, - which will be executed after that mapper's objects have been - saved or before they've been deleted. The process operation - manages attributes and dependent operations upon the objects - of one of the involved mappers. """ raise NotImplementedError() + def register_processors(self, uowcommit): + """Tell a ``UOWTransaction`` about this object as a processor, + which will be executed after that mapper's objects have been + saved or before they've been deleted. The process operation + manages attributes and dependent operations between two mappers. + + """ + raise NotImplementedError() + def whose_dependent_on_who(self, state1, state2): """Given an object pair assuming `obj2` is a child of `obj1`, return a tuple with the dependent object second, or None if @@ -181,9 +185,13 @@ class OneToManyDP(DependencyProcessor): if self.post_update: uowcommit.register_dependency(self.mapper, self.dependency_marker) uowcommit.register_dependency(self.parent, self.dependency_marker) - uowcommit.register_processor(self.dependency_marker, self, self.parent) else: uowcommit.register_dependency(self.parent, self.mapper) + + def register_processors(self, uowcommit): + if self.post_update: + uowcommit.register_processor(self.dependency_marker, self, self.parent) + else: uowcommit.register_processor(self.parent, self, self.parent) def process_dependencies(self, task, deplist, uowcommit, delete = False): @@ -285,6 +293,9 @@ class DetectKeySwitch(DependencyProcessor): no_dependencies = True def register_dependencies(self, uowcommit): + pass + + def register_processors(self, uowcommit): uowcommit.register_processor(self.parent, self, self.mapper) def preprocess_dependencies(self, task, deplist, uowcommit, delete=False): @@ -330,12 +341,15 @@ class ManyToOneDP(DependencyProcessor): if self.post_update: uowcommit.register_dependency(self.mapper, self.dependency_marker) uowcommit.register_dependency(self.parent, self.dependency_marker) - uowcommit.register_processor(self.dependency_marker, self, self.parent) else: uowcommit.register_dependency(self.mapper, self.parent) + + def register_processors(self, uowcommit): + if self.post_update: + uowcommit.register_processor(self.dependency_marker, self, self.parent) + else: uowcommit.register_processor(self.mapper, self, self.parent) - def process_dependencies(self, task, deplist, uowcommit, delete=False): if delete: if self.post_update and not self.cascade.delete_orphan and not self.passive_deletes == 'all': @@ -408,8 +422,10 @@ class ManyToManyDP(DependencyProcessor): uowcommit.register_dependency(self.parent, self.dependency_marker) uowcommit.register_dependency(self.mapper, self.dependency_marker) - uowcommit.register_processor(self.dependency_marker, self, self.parent) + def register_processors(self, uowcommit): + uowcommit.register_processor(self.dependency_marker, self, self.parent) + def process_dependencies(self, task, deplist, uowcommit, delete = False): connection = uowcommit.transaction.connection(self.mapper) secondary_delete = [] @@ -527,6 +543,9 @@ class MapperStub(object): def _register_dependencies(self, uowcommit): pass + def _register_procesors(self, uowcommit): + pass + def _save_obj(self, *args, **kwargs): pass diff --git a/lib/sqlalchemy/orm/dynamic.py b/lib/sqlalchemy/orm/dynamic.py index 3d31a686a2..70243291dc 100644 --- a/lib/sqlalchemy/orm/dynamic.py +++ b/lib/sqlalchemy/orm/dynamic.py @@ -55,21 +55,21 @@ class DynamicAttributeImpl(attributes.AttributeImpl): else: self.query_class = mixin_user_query(query_class) - def get(self, state, passive=False): + def get(self, state, dict_, passive=False): if passive: return self._get_collection_history(state, passive=True).added_items else: return self.query_class(self, state) - def get_collection(self, state, user_data=None, passive=True): + def get_collection(self, state, dict_, user_data=None, passive=True): if passive: return self._get_collection_history(state, passive=passive).added_items else: history = self._get_collection_history(state, passive=passive) return history.added_items + history.unchanged_items - def fire_append_event(self, state, value, initiator): - collection_history = self._modified_event(state) + def fire_append_event(self, state, dict_, value, initiator): + collection_history = self._modified_event(state, dict_) collection_history.added_items.append(value) for ext in self.extensions: @@ -78,8 +78,8 @@ class DynamicAttributeImpl(attributes.AttributeImpl): if self.trackparent and value is not None: self.sethasparent(attributes.instance_state(value), True) - def fire_remove_event(self, state, value, initiator): - collection_history = self._modified_event(state) + def fire_remove_event(self, state, dict_, value, initiator): + collection_history = self._modified_event(state, dict_) collection_history.deleted_items.append(value) if self.trackparent and value is not None: @@ -88,31 +88,31 @@ class DynamicAttributeImpl(attributes.AttributeImpl): for ext in self.extensions: ext.remove(state, value, initiator or self) - def _modified_event(self, state): + def _modified_event(self, state, dict_): if self.key not in state.committed_state: state.committed_state[self.key] = CollectionHistory(self, state) - state.modified_event(self, False, attributes.NEVER_SET, passive=attributes.PASSIVE_NO_INITIALIZE) + state.modified_event(dict_, self, False, attributes.NEVER_SET, passive=attributes.PASSIVE_NO_INITIALIZE) # this is a hack to allow the _base.ComparableEntity fixture # to work - state.dict[self.key] = True + dict_[self.key] = True return state.committed_state[self.key] - def set(self, state, value, initiator): + def set(self, state, dict_, value, initiator): if initiator is self: return - self._set_iterable(state, value) + self._set_iterable(state, dict_, value) - def _set_iterable(self, state, iterable, adapter=None): + def _set_iterable(self, state, dict_, iterable, adapter=None): - collection_history = self._modified_event(state) + collection_history = self._modified_event(state, dict_) new_values = list(iterable) if _state_has_identity(state): - old_collection = list(self.get(state)) + old_collection = list(self.get(state, dict_)) else: old_collection = [] @@ -121,7 +121,7 @@ class DynamicAttributeImpl(attributes.AttributeImpl): def delete(self, *args, **kwargs): raise NotImplementedError() - def get_history(self, state, passive=False): + def get_history(self, state, dict_, passive=False): c = self._get_collection_history(state, passive) return attributes.History(c.added_items, c.unchanged_items, c.deleted_items) @@ -136,13 +136,13 @@ class DynamicAttributeImpl(attributes.AttributeImpl): else: return c - def append(self, state, value, initiator, passive=False): + def append(self, state, dict_, value, initiator, passive=False): if initiator is not self: - self.fire_append_event(state, value, initiator) + self.fire_append_event(state, dict_, value, initiator) - def remove(self, state, value, initiator, passive=False): + def remove(self, state, dict_, value, initiator, passive=False): if initiator is not self: - self.fire_remove_event(state, value, initiator) + self.fire_remove_event(state, dict_, value, initiator) class DynCollectionAdapter(object): """the dynamic analogue to orm.collections.CollectionAdapter""" @@ -156,10 +156,10 @@ class DynCollectionAdapter(object): return iter(self.data) def append_with_event(self, item, initiator=None): - self.attr.append(self.state, item, initiator) + self.attr.append(self.state, self.state.dict, item, initiator) def remove_with_event(self, item, initiator=None): - self.attr.remove(self.state, item, initiator) + self.attr.remove(self.state, self.state.dict, item, initiator) def append_without_event(self, item): pass @@ -240,10 +240,10 @@ class AppenderMixin(object): return query def append(self, item): - self.attr.append(attributes.instance_state(self.instance), item, None) + self.attr.append(attributes.instance_state(self.instance), attributes.instance_dict(self.instance), item, None) def remove(self, item): - self.attr.remove(attributes.instance_state(self.instance), item, None) + self.attr.remove(attributes.instance_state(self.instance), attributes.instance_dict(self.instance), item, None) class AppenderQuery(AppenderMixin, Query): diff --git a/lib/sqlalchemy/orm/identity.py b/lib/sqlalchemy/orm/identity.py index 0753ea991f..aa041a5855 100644 --- a/lib/sqlalchemy/orm/identity.py +++ b/lib/sqlalchemy/orm/identity.py @@ -15,6 +15,9 @@ class IdentityMap(dict): self._mutable_attrs = {} self.modified = False self._wr = weakref.ref(self) + + def replace(self, state): + raise NotImplementedError() def add(self, state): raise NotImplementedError() @@ -102,6 +105,17 @@ class WeakInstanceDict(IdentityMap): def contains_state(self, state): return dict.get(self, state.key) is state + def replace(self, state): + if dict.__contains__(self, state.key): + existing = dict.__getitem__(self, state.key) + if existing is not state: + self._manage_removed_state(existing) + else: + return + + dict.__setitem__(self, state.key, state) + self._manage_incoming_state(state) + def add(self, state): if state.key in self: if dict.__getitem__(self, state.key) is not state: @@ -161,12 +175,24 @@ class StrongInstanceDict(IdentityMap): def contains_state(self, state): return state.key in self and attributes.instance_state(self[state.key]) is state + def replace(self, state): + if dict.__contains__(self, state.key): + existing = dict.__getitem__(self, state.key) + existing = attributes.instance_state(existing) + if existing is not state: + self._manage_removed_state(existing) + else: + return + + dict.__setitem__(self, state.key, state.obj()) + self._manage_incoming_state(state) + def add(self, state): dict.__setitem__(self, state.key, state.obj()) self._manage_incoming_state(state) def remove(self, state): - if dict.pop(self, state.key) is not state: + if attributes.instance_state(dict.pop(self, state.key)) is not state: raise AssertionError("State %s is not present in this identity map" % state) self._manage_removed_state(state) @@ -176,7 +202,7 @@ class StrongInstanceDict(IdentityMap): self._manage_removed_state(state) def remove_key(self, key): - state = dict.__getitem__(self, key) + state = attributes.instance_state(dict.__getitem__(self, key)) self.remove(state) def prune(self): @@ -190,62 +216,3 @@ class StrongInstanceDict(IdentityMap): self.modified = bool(dirty) return ref_count - len(self) -class IdentityManagedState(attributes.InstanceState): - def _instance_dict(self): - return None - - def modified_event(self, attr, should_copy, previous, passive=False): - attributes.InstanceState.modified_event(self, attr, should_copy, previous, passive) - - instance_dict = self._instance_dict() - if instance_dict: - instance_dict.modified = True - - def _is_really_none(self): - """do a check modified/resurrect. - - This would be called in the extremely rare - race condition that the weakref returned None but - the cleanup handler had not yet established the - __resurrect callable as its replacement. - - """ - if self.check_modified(): - self.obj = self.__resurrect - return self.obj() - else: - return None - - def _cleanup(self, ref): - """weakref callback. - - This method may be called by an asynchronous - gc. - - If the state shows pending changes, the weakref - is replaced by the __resurrect callable which will - re-establish an object reference on next access, - else removes this InstanceState from the owning - identity map, if any. - - """ - if self.check_modified(): - self.obj = self.__resurrect - else: - instance_dict = self._instance_dict() - if instance_dict: - instance_dict.remove(self) - self.dispose() - - def __resurrect(self): - """A substitute for the obj() weakref function which resurrects.""" - - # store strong ref'ed version of the object; will revert - # to weakref when changes are persisted - obj = self.manager.new_instance(state=self) - self.obj = weakref.ref(obj, self._cleanup) - self._strong_obj = obj - obj.__dict__.update(self.dict) - self.dict = obj.__dict__ - self._run_on_load(obj) - return obj diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index d36f51194e..0ac7713058 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -359,7 +359,7 @@ class MapperProperty(object): Callables are of the following form:: - def new_execute(state, row, **flags): + def new_execute(state, dict_, row, **flags): # process incoming instance state and given row. the instance is # "new" and was just created upon receipt of this row. # flags is a dictionary containing at least the following @@ -368,7 +368,7 @@ class MapperProperty(object): # result of reading this row # instancekey - identity key of the instance - def existing_execute(state, row, **flags): + def existing_execute(state, dict_, row, **flags): # process incoming instance state and given row. the instance is # "existing" and was created based on a previous row. @@ -427,13 +427,23 @@ class MapperProperty(object): def register_dependencies(self, *args, **kwargs): """Called by the ``Mapper`` in response to the UnitOfWork calling the ``Mapper``'s register_dependencies operation. - Should register with the UnitOfWork all inter-mapper - dependencies as well as dependency processors (see UOW docs - for more details). + Establishes a topological dependency between two mappers + which will affect the order in which mappers persist data. + """ pass + def register_processors(self, *args, **kwargs): + """Called by the ``Mapper`` in response to the UnitOfWork + calling the ``Mapper``'s register_processors operation. + Establishes a processor object between two mappers which + will link data and state between parent/child objects. + + """ + + pass + def is_primary(self): """Return True if this ``MapperProperty``'s mapper is the primary mapper for its class. @@ -939,3 +949,7 @@ class InstrumentationManager(object): def state_getter(self, class_): return lambda instance: getattr(instance, '_default_state') + + def dict_getter(self, class_): + return lambda inst: self.get_instance_dict(class_, inst) + \ No newline at end of file diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 8af6153d6b..87c4c8100f 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -23,7 +23,6 @@ deque = __import__('collections').deque from sqlalchemy import sql, util, log, exc as sa_exc from sqlalchemy.sql import expression, visitors, operators, util as sqlutil from sqlalchemy.orm import attributes, exc, sync -from sqlalchemy.orm.identity import IdentityManagedState from sqlalchemy.orm.interfaces import ( MapperProperty, EXT_CONTINUE, PropComparator ) @@ -255,7 +254,8 @@ class Mapper(object): for mapper in self.iterate_to_root(): util.reset_memoized(mapper, '_equivalent_columns') - + util.reset_memoized(mapper, '_sorted_tables') + if self.order_by is False and not self.concrete and self.inherits.order_by is not False: self.order_by = self.inherits.order_by @@ -357,7 +357,6 @@ class Mapper(object): if manager is None: manager = attributes.register_class(self.class_, - instance_state_factory = IdentityManagedState, deferred_scalar_loader = _load_scalar_attributes ) @@ -372,6 +371,8 @@ class Mapper(object): event_registry = manager.events event_registry.add_listener('on_init', _event_on_init) event_registry.add_listener('on_init_failure', _event_on_init_failure) + event_registry.add_listener('on_resurrect', _event_on_resurrect) + for key, method in util.iterate_attributes(self.class_): if isinstance(method, types.FunctionType): if hasattr(method, '__sa_reconstructor__'): @@ -1173,6 +1174,19 @@ class Mapper(object): # persistence + @util.memoized_property + def _sorted_tables(self): + table_to_mapper = {} + for mapper in self.base_mapper.polymorphic_iterator(): + for t in mapper.tables: + table_to_mapper[t] = mapper + + sorted_ = sqlutil.sort_tables(table_to_mapper.iterkeys()) + ret = util.OrderedDict() + for t in sorted_: + ret[t] = table_to_mapper[t] + return ret + def _save_obj(self, states, uowtransaction, postupdate=False, post_update_cols=None, single=False): """Issue ``INSERT`` and/or ``UPDATE`` statements for a list of objects. @@ -1198,16 +1212,37 @@ class Mapper(object): # if session has a connection callable, # organize individual states with the connection to use for insert/update + tups = [] if 'connection_callable' in uowtransaction.mapper_flush_opts: connection_callable = uowtransaction.mapper_flush_opts['connection_callable'] - tups = [(state, _state_mapper(state), connection_callable(self, state.obj()), _state_has_identity(state)) for state in _sort_states(states)] + for state in _sort_states(states): + m = _state_mapper(state) + tups.append( + ( + state, + m, + connection_callable(self, state.obj()), + _state_has_identity(state), + state.key or m._identity_key_from_state(state) + ) + ) else: connection = uowtransaction.transaction.connection(self) - tups = [(state, _state_mapper(state), connection, _state_has_identity(state)) for state in _sort_states(states)] + for state in _sort_states(states): + m = _state_mapper(state) + tups.append( + ( + state, + m, + connection, + _state_has_identity(state), + state.key or m._identity_key_from_state(state) + ) + ) if not postupdate: # call before_XXX extensions - for state, mapper, connection, has_identity in tups: + for state, mapper, connection, has_identity, instance_key in tups: if not has_identity: if 'before_insert' in mapper.extension: mapper.extension.before_insert(mapper, connection, state.obj()) @@ -1215,39 +1250,44 @@ class Mapper(object): if 'before_update' in mapper.extension: mapper.extension.before_update(mapper, connection, state.obj()) - for state, mapper, connection, has_identity in tups: - # detect if we have a "pending" instance (i.e. has no instance_key attached to it), - # and another instance with the same identity key already exists as persistent. convert to an - # UPDATE if so. - instance_key = mapper._identity_key_from_state(state) - if not postupdate and not has_identity and instance_key in uowtransaction.session.identity_map: - instance = uowtransaction.session.identity_map[instance_key] - existing = attributes.instance_state(instance) - if not uowtransaction.is_deleted(existing): - raise exc.FlushError("New instance %s with identity key %s conflicts with persistent instance %s" % (state_str(state), str(instance_key), state_str(existing))) - if self._should_log_debug: - self._log_debug("detected row switch for identity %s. will update %s, remove %s from transaction" % (instance_key, state_str(state), state_str(existing))) - uowtransaction.set_row_switch(existing) - - table_to_mapper = {} - for mapper in self.base_mapper.polymorphic_iterator(): - for t in mapper.tables: - table_to_mapper[t] = mapper + row_switches = set() + if not postupdate: + for state, mapper, connection, has_identity, instance_key in tups: + # detect if we have a "pending" instance (i.e. has no instance_key attached to it), + # and another instance with the same identity key already exists as persistent. convert to an + # UPDATE if so. + if not has_identity and instance_key in uowtransaction.session.identity_map: + instance = uowtransaction.session.identity_map[instance_key] + existing = attributes.instance_state(instance) + if not uowtransaction.is_deleted(existing): + raise exc.FlushError( + "New instance %s with identity key %s conflicts with persistent instance %s" % + (state_str(state), instance_key, state_str(existing))) + if self._should_log_debug: + self._log_debug( + "detected row switch for identity %s. will update %s, remove %s from transaction", + instance_key, state_str(state), state_str(existing)) + + # remove the "delete" flag from the existing element + uowtransaction.set_row_switch(existing) + row_switches.add(state) + + table_to_mapper = self._sorted_tables - for table in sqlutil.sort_tables(table_to_mapper.iterkeys()): + for table in table_to_mapper.iterkeys(): insert = [] update = [] - for state, mapper, connection, has_identity in tups: + for state, mapper, connection, has_identity, instance_key in tups: if table not in mapper._pks_by_table: continue + pks = mapper._pks_by_table[table] - instance_key = mapper._identity_key_from_state(state) - + if self._should_log_debug: self._log_debug("_save_obj() table '%s' instance %s identity %s" % (table.name, state_str(state), str(instance_key))) - isinsert = not instance_key in uowtransaction.session.identity_map and not postupdate and not has_identity + isinsert = not has_identity and not postupdate and state not in row_switches params = {} value_params = {} @@ -1364,7 +1404,7 @@ class Mapper(object): sync.populate(state, m, state, m, m._inherits_equated_pairs) if not postupdate: - for state, mapper, connection, has_identity in tups: + for state, mapper, connection, has_identity, instance_key in tups: # expire readonly attributes readonly = state.unmodified.intersection( @@ -1434,12 +1474,9 @@ class Mapper(object): if 'before_delete' in mapper.extension: mapper.extension.before_delete(mapper, connection, state.obj()) - table_to_mapper = {} - for mapper in self.base_mapper.polymorphic_iterator(): - for t in mapper.tables: - table_to_mapper[t] = mapper + table_to_mapper = self._sorted_tables - for table in reversed(sqlutil.sort_tables(table_to_mapper.iterkeys())): + for table in reversed(table_to_mapper.keys()): delete = {} for state, mapper, connection in tups: if table not in mapper._pks_by_table: @@ -1485,6 +1522,10 @@ class Mapper(object): for dep in self._props.values() + self._dependency_processors: dep.register_dependencies(uowcommit) + def _register_processors(self, uowcommit): + for dep in self._props.values() + self._dependency_processors: + dep.register_processors(uowcommit) + # result set conversion def _instance_processor(self, context, path, adapter, polymorphic_from=None, extension=None, only_load_props=None, refresh_state=None, polymorphic_discriminator=None): @@ -1514,7 +1555,7 @@ class Mapper(object): new_populators = [] existing_populators = [] - def populate_state(state, row, isnew, only_load_props, **flags): + def populate_state(state, dict_, row, isnew, only_load_props, **flags): if isnew: if context.options: state.load_options = context.options @@ -1533,7 +1574,7 @@ class Mapper(object): populators = [p for p in populators if p[0] in only_load_props] for key, populator in populators: - populator(state, row, isnew=isnew, **flags) + populator(state, dict_, row, isnew=isnew, **flags) session_identity_map = context.session.identity_map @@ -1573,9 +1614,11 @@ class Mapper(object): if identitykey in session_identity_map: instance = session_identity_map[identitykey] state = attributes.instance_state(instance) + dict_ = attributes.instance_dict(instance) if self._should_log_debug: - self._log_debug("_instance(): using existing instance %s identity %s" % (instance_str(instance), identitykey)) + self._log_debug("_instance(): using existing instance %s identity %s", + instance_str(instance), identitykey) isnew = state.runid != context.runid currentload = not isnew @@ -1592,12 +1635,13 @@ class Mapper(object): # when eager_defaults is True. state = refresh_state instance = state.obj() + dict_ = attributes.instance_dict(instance) isnew = state.runid != context.runid currentload = True loaded_instance = False else: if self._should_log_debug: - self._log_debug("_instance(): identity key %s not in session" % str(identitykey)) + self._log_debug("_instance(): identity key %s not in session", identitykey) if self.allow_null_pks: for x in identitykey[1]: @@ -1625,8 +1669,10 @@ class Mapper(object): instance = self.class_manager.new_instance() if self._should_log_debug: - self._log_debug("_instance(): created new instance %s identity %s" % (instance_str(instance), str(identitykey))) + self._log_debug("_instance(): created new instance %s identity %s", + instance_str(instance), identitykey) + dict_ = attributes.instance_dict(instance) state = attributes.instance_state(instance) state.key = identitykey @@ -1638,12 +1684,12 @@ class Mapper(object): if currentload or populate_existing: if isnew: state.runid = context.runid - context.progress.add(state) + context.progress[state] = dict_ if not populate_instance or \ populate_instance(self, context, row, instance, only_load_props=only_load_props, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE: - populate_state(state, row, isnew, only_load_props) + populate_state(state, dict_, row, isnew, only_load_props) else: # populate attributes on non-loading instances which have been expired @@ -1652,16 +1698,16 @@ class Mapper(object): if state in context.partials: isnew = False - attrs = context.partials[state] + (d_, attrs) = context.partials[state] else: isnew = True attrs = state.unloaded - context.partials[state] = attrs #<-- allow query.instances to commit the subset of attrs + context.partials[state] = (dict_, attrs) #<-- allow query.instances to commit the subset of attrs if not populate_instance or \ populate_instance(self, context, row, instance, only_load_props=attrs, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE: - populate_state(state, row, isnew, attrs, instancekey=identitykey) + populate_state(state, dict_, row, isnew, attrs, instancekey=identitykey) if loaded_instance: state._run_on_load(instance) @@ -1759,6 +1805,14 @@ def _event_on_init_failure(state, instance, args, kwargs): instrumenting_mapper, instrumenting_mapper.class_, state.manager.events.original_init, instance, args, kwargs) +def _event_on_resurrect(state, instance): + # re-populate the primary key elements + # of the dict based on the mapping. + instrumenting_mapper = state.manager.info[_INSTRUMENTOR] + for col, val in zip(instrumenting_mapper.primary_key, state.key[1]): + instrumenting_mapper._set_state_attr_by_column(state, col, val) + + def _sort_states(states): return sorted(states, key=operator.attrgetter('sort_key')) diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index d0cca2dc11..5605cdcd1e 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -96,13 +96,13 @@ class ColumnProperty(StrategizedProperty): return ColumnProperty(deferred=self.deferred, group=self.group, *self.columns) def getattr(self, state, column): - return state.get_impl(self.key).get(state) + return state.get_impl(self.key).get(state, state.dict) def getcommitted(self, state, column, passive=False): - return state.get_impl(self.key).get_committed_value(state, passive=passive) + return state.get_impl(self.key).get_committed_value(state, state.dict, passive=passive) def setattr(self, state, value, column): - state.get_impl(self.key).set(state, value, None) + state.get_impl(self.key).set(state, state.dict, value, None) def merge(self, session, source, dest, dont_load, _recursive): value = attributes.instance_state(source).value_as_iterable( @@ -159,7 +159,7 @@ class CompositeProperty(ColumnProperty): super(ColumnProperty, self).do_init() def getattr(self, state, column): - obj = state.get_impl(self.key).get(state) + obj = state.get_impl(self.key).get(state, state.dict) return self.get_col_value(column, obj) def getcommitted(self, state, column, passive=False): @@ -168,7 +168,7 @@ class CompositeProperty(ColumnProperty): def setattr(self, state, value, column): - obj = state.get_impl(self.key).get(state) + obj = state.get_impl(self.key).get(state, state.dict) if obj is None: obj = self.composite_class(*[None for c in self.columns]) state.get_impl(self.key).set(state, obj, None) @@ -635,7 +635,7 @@ class RelationProperty(StrategizedProperty): return source_state = attributes.instance_state(source) - dest_state = attributes.instance_state(dest) + dest_state, dest_dict = attributes.instance_state(dest), attributes.instance_dict(dest) if not "merge" in self.cascade: dest_state.expire_attributes([self.key]) @@ -658,7 +658,7 @@ class RelationProperty(StrategizedProperty): for c in dest_list: coll.append_without_event(c) else: - getattr(dest.__class__, self.key).impl._set_iterable(dest_state, dest_list) + getattr(dest.__class__, self.key).impl._set_iterable(dest_state, dest_dict, dest_list) else: current = instances[0] if current is not None: @@ -1119,6 +1119,10 @@ class RelationProperty(StrategizedProperty): if not self.viewonly: self._dependency_processor.register_dependencies(uowcommit) + def register_processors(self, uowcommit): + if not self.viewonly: + self._dependency_processor.register_processors(uowcommit) + PropertyLoader = RelationProperty log.class_logger(RelationProperty) diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 28ddcc5eaa..e3cc3c7569 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -1330,7 +1330,7 @@ class Query(object): rowtuple.keys = labels.keys while True: - context.progress = set() + context.progress = {} context.partials = {} if self._yield_per: @@ -1354,13 +1354,13 @@ class Query(object): rows = filter(rows) if context.refresh_state and self._only_load_props and context.refresh_state in context.progress: - context.refresh_state.commit(self._only_load_props) - context.progress.remove(context.refresh_state) + context.refresh_state.commit(context.refresh_state.dict, self._only_load_props) + context.progress.pop(context.refresh_state) session._finalize_loaded(context.progress) - for ii, attrs in context.partials.items(): - ii.commit(attrs) + for ii, (dict_, attrs) in context.partials.items(): + ii.commit(dict_, attrs) for row in rows: yield row @@ -1683,14 +1683,14 @@ class Query(object): evaluated_keys = value_evaluators.keys() if issubclass(cls, target_cls) and eval_condition(obj): - state = attributes.instance_state(obj) + state, dict_ = attributes.instance_state(obj), attributes.instance_dict(obj) # only evaluate unmodified attributes to_evaluate = state.unmodified.intersection(evaluated_keys) for key in to_evaluate: - state.dict[key] = value_evaluators[key](obj) + dict_[key] = value_evaluators[key](obj) - state.commit(list(to_evaluate)) + state.commit(dict_, list(to_evaluate)) # expire attributes with pending changes (there was no autoflush, so they are overwritten) state.expire_attributes(set(evaluated_keys).difference(to_evaluate)) diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 1e3a750d95..00a7d55e5e 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -12,7 +12,7 @@ import sqlalchemy.exceptions as sa_exc from sqlalchemy import util, sql, engine, log from sqlalchemy.sql import util as sql_util, expression from sqlalchemy.orm import ( - SessionExtension, attributes, exc, query, unitofwork, util as mapperutil, + SessionExtension, attributes, exc, query, unitofwork, util as mapperutil, state ) from sqlalchemy.orm.util import object_mapper as _object_mapper from sqlalchemy.orm.util import class_mapper as _class_mapper @@ -899,8 +899,8 @@ class Session(object): self.flush() def _finalize_loaded(self, states): - for state in states: - state.commit_all() + for state, dict_ in states.items(): + state.commit_all(dict_) def refresh(self, instance, attribute_names=None): """Refresh the attributes on the given instance. @@ -1020,11 +1020,9 @@ class Session(object): # primary key switch self.identity_map.remove(state) state.key = instance_key - - if state.key in self.identity_map and not self.identity_map.contains_state(state): - self.identity_map.remove_key(state.key) - self.identity_map.add(state) - state.commit_all() + + self.identity_map.replace(state) + state.commit_all(state.dict) # remove from new last, might be the last strong ref if state in self._new: @@ -1213,7 +1211,7 @@ class Session(object): prop.merge(self, instance, merged, dont_load, _recursive) if dont_load: - attributes.instance_state(merged).commit_all() # remove any history + attributes.instance_state(merged).commit_all(attributes.instance_dict(merged)) # remove any history if new_instance: merged_state._run_on_load(merged) @@ -1368,7 +1366,7 @@ class Session(object): self.identity_map.modified = False return - flush_context = UOWTransaction(self) + flush_context = UOWTransaction(self) if self.extensions: for ext in self.extensions: @@ -1489,7 +1487,7 @@ class Session(object): return util.IdentitySet( [state for state in self.identity_map.all_states() - if state.check_modified()]) + if state.modified]) @property def dirty(self): @@ -1528,7 +1526,7 @@ class Session(object): return util.IdentitySet(self._new.values()) -_expire_state = attributes.InstanceState.expire_attributes +_expire_state = state.InstanceState.expire_attributes UOWEventHandler = unitofwork.UOWEventHandler @@ -1548,16 +1546,19 @@ def _cascade_unknown_state_iterator(cascade, state, **kwargs): yield _state_for_unknown_persistence_instance(o), m def _state_for_unsaved_instance(instance, create=False): - manager = attributes.manager_of_class(instance.__class__) - if manager is None: + try: + state = attributes.instance_state(instance) + except AttributeError: raise exc.UnmappedInstanceError(instance) - if manager.has_state(instance): - state = manager.state_of(instance) + if state: if state.key is not None: raise sa_exc.InvalidRequestError( "Instance '%s' is already persistent" % mapperutil.state_str(state)) elif create: + manager = attributes.manager_of_class(instance.__class__) + if manager is None: + raise exc.UnmappedInstanceError(instance) state = manager.setup_instance(instance) else: raise exc.UnmappedInstanceError(instance) diff --git a/lib/sqlalchemy/orm/state.py b/lib/sqlalchemy/orm/state.py new file mode 100644 index 0000000000..c99dfe73c7 --- /dev/null +++ b/lib/sqlalchemy/orm/state.py @@ -0,0 +1,429 @@ +from sqlalchemy.util import EMPTY_SET +import weakref +from sqlalchemy import util +from sqlalchemy.orm.attributes import PASSIVE_NORESULT, PASSIVE_OFF, NEVER_SET, NO_VALUE, manager_of_class, ATTR_WAS_SET +from sqlalchemy.orm import attributes +from sqlalchemy.orm import interfaces + +class InstanceState(object): + """tracks state information at the instance level.""" + + session_id = None + key = None + runid = None + expired_attributes = EMPTY_SET + load_options = EMPTY_SET + load_path = () + insert_order = None + mutable_dict = None + + def __init__(self, obj, manager): + self.class_ = obj.__class__ + self.manager = manager + self.obj = weakref.ref(obj, self._cleanup) + self.modified = False + self.callables = {} + self.expired = False + self.committed_state = {} + self.pending = {} + self.parents = {} + + def detach(self): + if self.session_id: + del self.session_id + + def dispose(self): + if self.session_id: + del self.session_id + del self.obj + + def _cleanup(self, ref): + instance_dict = self._instance_dict() + if instance_dict: + instance_dict.remove(self) + self.dispose() + + def obj(self): + return None + + @property + def dict(self): + o = self.obj() + if o is not None: + return attributes.instance_dict(o) + else: + return {} + + @property + def sort_key(self): + return self.key and self.key[1] or (self.insert_order, ) + + def check_modified(self): + # TODO: deprecate + return self.modified + + def initialize_instance(*mixed, **kwargs): + self, instance, args = mixed[0], mixed[1], mixed[2:] + manager = self.manager + + for fn in manager.events.on_init: + fn(self, instance, args, kwargs) + + # LESSTHANIDEAL: + # adjust for the case where the InstanceState was created before + # mapper compilation, and this actually needs to be a MutableAttrInstanceState + if manager.mutable_attributes and self.__class__ is not MutableAttrInstanceState: + self.__class__ = MutableAttrInstanceState + self.obj = weakref.ref(self.obj(), self._cleanup) + self.mutable_dict = {} + + try: + return manager.events.original_init(*mixed[1:], **kwargs) + except: + for fn in manager.events.on_init_failure: + fn(self, instance, args, kwargs) + raise + + def get_history(self, key, **kwargs): + return self.manager.get_impl(key).get_history(self, self.dict, **kwargs) + + def get_impl(self, key): + return self.manager.get_impl(key) + + def get_pending(self, key): + if key not in self.pending: + self.pending[key] = PendingCollection() + return self.pending[key] + + def value_as_iterable(self, key, passive=PASSIVE_OFF): + """return an InstanceState attribute as a list, + regardless of it being a scalar or collection-based + attribute. + + returns None if passive is not PASSIVE_OFF and the getter returns + PASSIVE_NORESULT. + """ + + impl = self.get_impl(key) + dict_ = self.dict + x = impl.get(self, dict_, passive=passive) + if x is PASSIVE_NORESULT: + return None + elif hasattr(impl, 'get_collection'): + return impl.get_collection(self, dict_, x, passive=passive) + elif isinstance(x, list): + return x + else: + return [x] + + def _run_on_load(self, instance): + self.manager.events.run('on_load', instance) + + def __getstate__(self): + return {'key': self.key, + 'committed_state': self.committed_state, + 'pending': self.pending, + 'parents': self.parents, + 'modified': self.modified, + 'expired':self.expired, + 'load_options':self.load_options, + 'load_path':interfaces.serialize_path(self.load_path), + 'instance': self.obj(), + 'expired_attributes':self.expired_attributes, + 'callables': self.callables} + + def __setstate__(self, state): + self.committed_state = state['committed_state'] + self.parents = state['parents'] + self.key = state['key'] + self.session_id = None + self.pending = state['pending'] + self.modified = state['modified'] + self.obj = weakref.ref(state['instance']) + self.load_options = state['load_options'] or EMPTY_SET + self.load_path = interfaces.deserialize_path(state['load_path']) + self.class_ = self.obj().__class__ + self.manager = manager_of_class(self.class_) + self.callables = state['callables'] + self.runid = None + self.expired = state['expired'] + self.expired_attributes = state['expired_attributes'] + + def initialize(self, key): + self.manager.get_impl(key).initialize(self, self.dict) + + def set_callable(self, key, callable_): + self.dict.pop(key, None) + self.callables[key] = callable_ + + def __call__(self): + """__call__ allows the InstanceState to act as a deferred + callable for loading expired attributes, which is also + serializable (picklable). + + """ + unmodified = self.unmodified + class_manager = self.manager + class_manager.deferred_scalar_loader(self, [ + attr.impl.key for attr in class_manager.attributes if + attr.impl.accepts_scalar_loader and + attr.impl.key in self.expired_attributes and + attr.impl.key in unmodified + ]) + for k in self.expired_attributes: + self.callables.pop(k, None) + del self.expired_attributes + return ATTR_WAS_SET + + @property + def unmodified(self): + """a set of keys which have no uncommitted changes""" + + return set(self.manager).difference(self.committed_state) + + @property + def unloaded(self): + """a set of keys which do not have a loaded value. + + This includes expired attributes and any other attribute that + was never populated or modified. + + """ + return set( + key for key in self.manager.iterkeys() + if key not in self.committed_state and key not in self.dict) + + def expire_attributes(self, attribute_names): + self.expired_attributes = set(self.expired_attributes) + + if attribute_names is None: + attribute_names = self.manager.keys() + self.expired = True + self.modified = False + filter_deferred = True + else: + filter_deferred = False + dict_ = self.dict + + for key in attribute_names: + impl = self.manager[key].impl + if not filter_deferred or \ + not impl.dont_expire_missing or \ + key in dict_: + self.expired_attributes.add(key) + if impl.accepts_scalar_loader: + self.callables[key] = self + dict_.pop(key, None) + self.pending.pop(key, None) + self.committed_state.pop(key, None) + if self.mutable_dict: + self.mutable_dict.pop(key, None) + + def reset(self, key, dict_): + """remove the given attribute and any callables associated with it.""" + + dict_.pop(key, None) + self.callables.pop(key, None) + + def _instance_dict(self): + return None + + def _is_really_none(self): + return self.obj() + + def modified_event(self, dict_, attr, should_copy, previous, passive=PASSIVE_OFF): + needs_committed = attr.key not in self.committed_state + + if needs_committed: + if previous is NEVER_SET: + if passive: + if attr.key in dict_: + previous = dict_[attr.key] + else: + previous = attr.get(self, dict_) + + if should_copy and previous not in (None, NO_VALUE, NEVER_SET): + previous = attr.copy(previous) + + if needs_committed: + self.committed_state[attr.key] = previous + + self.modified = True + self._strong_obj = self.obj() + + instance_dict = self._instance_dict() + if instance_dict: + instance_dict.modified = True + + def commit(self, dict_, keys): + """Commit attributes. + + This is used by a partial-attribute load operation to mark committed + those attributes which were refreshed from the database. + + Attributes marked as "expired" can potentially remain "expired" after + this step if a value was not populated in state.dict. + + """ + class_manager = self.manager + for key in keys: + if key in dict_ and key in class_manager.mutable_attributes: + class_manager[key].impl.commit_to_state(self, dict_, self.committed_state) + else: + self.committed_state.pop(key, None) + + self.expired = False + # unexpire attributes which have loaded + for key in self.expired_attributes.intersection(keys): + if key in dict_: + self.expired_attributes.remove(key) + self.callables.pop(key, None) + + def commit_all(self, dict_): + """commit all attributes unconditionally. + + This is used after a flush() or a full load/refresh + to remove all pending state from the instance. + + - all attributes are marked as "committed" + - the "strong dirty reference" is removed + - the "modified" flag is set to False + - any "expired" markers/callables are removed. + + Attributes marked as "expired" can potentially remain "expired" after this step + if a value was not populated in state.dict. + + """ + + self.committed_state = {} + self.pending = {} + + # unexpire attributes which have loaded + if self.expired_attributes: + for key in self.expired_attributes.intersection(dict_): + self.callables.pop(key, None) + self.expired_attributes.difference_update(dict_) + + for key in self.manager.mutable_attributes: + if key in dict_: + self.manager[key].impl.commit_to_state(self, dict_, self.committed_state) + + self.modified = self.expired = False + self._strong_obj = None + +class MutableAttrInstanceState(InstanceState): + def __init__(self, obj, manager): + self.mutable_dict = {} + InstanceState.__init__(self, obj, manager) + + def _get_modified(self, dict_=None): + if self.__dict__.get('modified', False): + return True + else: + if dict_ is None: + dict_ = self.dict + for key in self.manager.mutable_attributes: + if self.manager[key].impl.check_mutable_modified(self, dict_): + return True + else: + return False + + def _set_modified(self, value): + self.__dict__['modified'] = value + + modified = property(_get_modified, _set_modified) + + @property + def unmodified(self): + """a set of keys which have no uncommitted changes""" + + dict_ = self.dict + return set( + key for key in self.manager.iterkeys() + if (key not in self.committed_state or + (key in self.manager.mutable_attributes and + not self.manager[key].impl.check_mutable_modified(self, dict_)))) + + def _is_really_none(self): + """do a check modified/resurrect. + + This would be called in the extremely rare + race condition that the weakref returned None but + the cleanup handler had not yet established the + __resurrect callable as its replacement. + + """ + if self.modified: + self.obj = self.__resurrect + return self.obj() + else: + return None + + def reset(self, key, dict_): + self.mutable_dict.pop(key, None) + InstanceState.reset(self, key, dict_) + + def _cleanup(self, ref): + """weakref callback. + + This method may be called by an asynchronous + gc. + + If the state shows pending changes, the weakref + is replaced by the __resurrect callable which will + re-establish an object reference on next access, + else removes this InstanceState from the owning + identity map, if any. + + """ + if self._get_modified(self.mutable_dict): + self.obj = self.__resurrect + else: + instance_dict = self._instance_dict() + if instance_dict: + instance_dict.remove(self) + self.dispose() + + def __resurrect(self): + """A substitute for the obj() weakref function which resurrects.""" + + # store strong ref'ed version of the object; will revert + # to weakref when changes are persisted + + obj = self.manager.new_instance(state=self) + self.obj = weakref.ref(obj, self._cleanup) + self._strong_obj = obj + obj.__dict__.update(self.mutable_dict) + + # re-establishes identity attributes from the key + self.manager.events.run('on_resurrect', self, obj) + + # TODO: don't really think we should run this here. + # resurrect is only meant to preserve the minimal state needed to + # do an UPDATE, not to produce a fully usable object + self._run_on_load(obj) + + return obj + +class PendingCollection(object): + """A writable placeholder for an unloaded collection. + + Stores items appended to and removed from a collection that has not yet + been loaded. When the collection is loaded, the changes stored in + PendingCollection are applied to it to produce the final result. + + """ + def __init__(self): + self.deleted_items = util.IdentitySet() + self.added_items = util.OrderedIdentitySet() + + def append(self, value): + if value in self.deleted_items: + self.deleted_items.remove(value) + self.added_items.add(value) + + def remove(self, value): + if value in self.added_items: + self.added_items.remove(value) + self.deleted_items.add(value) + diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 1aeb311e1c..20cbb8f4dc 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -115,8 +115,8 @@ class ColumnLoader(LoaderStrategy): if adapter: col = adapter.columns[col] if col in row: - def new_execute(state, row, **flags): - state.dict[key] = row[col] + def new_execute(state, dict_, row, **flags): + dict_[key] = row[col] if self._should_log_debug: new_execute = self.debug_callable(new_execute, self.logger, @@ -125,7 +125,7 @@ class ColumnLoader(LoaderStrategy): ) return (new_execute, None) else: - def new_execute(state, row, isnew, **flags): + def new_execute(state, dict_, row, isnew, **flags): if isnew: state.expire_attributes([key]) if self._should_log_debug: @@ -171,15 +171,15 @@ class CompositeColumnLoader(ColumnLoader): columns = [adapter.columns[c] for c in columns] for c in columns: if c not in row: - def new_execute(state, row, isnew, **flags): + def new_execute(state, dict_, row, isnew, **flags): if isnew: state.expire_attributes([key]) if self._should_log_debug: self.logger.debug("%s deferring load" % self) return (new_execute, None) else: - def new_execute(state, row, **flags): - state.dict[key] = composite_class(*[row[c] for c in columns]) + def new_execute(state, dict_, row, **flags): + dict_[key] = composite_class(*[row[c] for c in columns]) if self._should_log_debug: new_execute = self.debug_callable(new_execute, self.logger, @@ -202,13 +202,13 @@ class DeferredColumnLoader(LoaderStrategy): return self.parent_property._get_strategy(ColumnLoader).create_row_processor(selectcontext, path, mapper, row, adapter) elif not self.is_class_level: - def new_execute(state, row, **flags): + def new_execute(state, dict_, row, **flags): state.set_callable(self.key, LoadDeferredColumns(state, self.key)) else: - def new_execute(state, row, **flags): + def new_execute(state, dict_, row, **flags): # reset state on the key so that deferred callables # fire off on next access. - state.reset(self.key) + state.reset(self.key, dict_) if self._should_log_debug: new_execute = self.debug_callable(new_execute, self.logger, None, @@ -340,7 +340,7 @@ class NoLoader(AbstractRelationLoader): ) def create_row_processor(self, selectcontext, path, mapper, row, adapter): - def new_execute(state, row, **flags): + def new_execute(state, dict_, row, **flags): self._init_instance_attribute(state) if self._should_log_debug: @@ -437,7 +437,7 @@ class LazyLoader(AbstractRelationLoader): def create_row_processor(self, selectcontext, path, mapper, row, adapter): if not self.is_class_level: - def new_execute(state, row, **flags): + def new_execute(state, dict_, row, **flags): # we are not the primary manager for this attribute on this class - set up a per-instance lazyloader, # which will override the class-level behavior. # this currently only happens when using a "lazyload" option on a "no load" attribute - @@ -451,11 +451,11 @@ class LazyLoader(AbstractRelationLoader): return (new_execute, None) else: - def new_execute(state, row, **flags): + def new_execute(state, dict_, row, **flags): # we are the primary manager for this attribute on this class - reset its per-instance attribute state, # so that the class-level lazy loader is executed when next referenced on this instance. # this is needed in populate_existing() types of scenarios to reset any existing state. - state.reset(self.key) + state.reset(self.key, dict_) if self._should_log_debug: new_execute = self.debug_callable(new_execute, self.logger, None, @@ -735,24 +735,24 @@ class EagerLoader(AbstractRelationLoader): _instance = self.mapper._instance_processor(context, path + (self.mapper.base_mapper,), eager_adapter) if not self.uselist: - def execute(state, row, isnew, **flags): + def execute(state, dict_, row, isnew, **flags): if isnew: # set a scalar object instance directly on the # parent object, bypassing InstrumentedAttribute # event handlers. - state.dict[key] = _instance(row, None) + dict_[key] = _instance(row, None) else: # call _instance on the row, even though the object has been created, # so that we further descend into properties _instance(row, None) else: - def execute(state, row, isnew, **flags): + def execute(state, dict_, row, isnew, **flags): if isnew or (state, key) not in context.attributes: # appender_key can be absent from context.attributes with isnew=False # when self-referential eager loading is used; the same instance may be present # in two distinct sets of result columns - collection = attributes.init_state_collection(state, key) + collection = attributes.init_state_collection(state, dict_, key) appender = util.UniqueAppender(collection, 'append_without_event') context.attributes[(state, key)] = appender diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py index 4ac9c765e0..407b702a8b 100644 --- a/lib/sqlalchemy/orm/unitofwork.py +++ b/lib/sqlalchemy/orm/unitofwork.py @@ -96,6 +96,8 @@ class UOWTransaction(object): # information. self.attributes = {} + self.processors = set() + def get_attribute_history(self, state, key, passive=True): hashkey = ("history", state, key) @@ -136,6 +138,16 @@ class UOWTransaction(object): else: task.append(state, listonly=listonly, isdelete=isdelete) + # ensure the mapper for this object has had its + # DependencyProcessors added. + if mapper not in self.processors: + mapper._register_processors(self) + self.processors.add(mapper) + + if mapper.base_mapper not in self.processors: + mapper.base_mapper._register_processors(self) + self.processors.add(mapper.base_mapper) + def set_row_switch(self, state): """mark a deleted object as a 'row switch'. @@ -147,7 +159,7 @@ class UOWTransaction(object): task = self.get_task_by_mapper(mapper) taskelement = task._objects[state] taskelement.isdelete = "rowswitch" - + def is_deleted(self, state): """return true if the given state is marked as deleted within this UOWTransaction.""" @@ -201,9 +213,9 @@ class UOWTransaction(object): self.dependencies.add((mapper, dependency)) def register_processor(self, mapper, processor, mapperfrom): - """register a dependency processor, corresponding to dependencies between - the two given mappers. - + """register a dependency processor, corresponding to + operations which occur between two mappers. + """ # correct for primary mapper mapper = mapper.primary_mapper() diff --git a/test/orm/attributes.py b/test/orm/attributes.py index 772e1bbd04..76108a713e 100644 --- a/test/orm/attributes.py +++ b/test/orm/attributes.py @@ -38,7 +38,7 @@ class AttributesTest(_base.ORMTest): u.email_address = 'lala@123.com' self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com') - attributes.instance_state(u).commit_all() + attributes.instance_state(u).commit_all(attributes.instance_dict(u)) self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com') u.user_name = 'heythere' @@ -158,7 +158,7 @@ class AttributesTest(_base.ORMTest): eq_(f.a, None) eq_(f.b, 12) - attributes.instance_state(f).commit_all() + attributes.instance_state(f).commit_all(attributes.instance_dict(f)) eq_(f.a, None) eq_(f.b, 12) @@ -205,7 +205,7 @@ class AttributesTest(_base.ORMTest): u.addresses.append(a) self.assert_(u.user_id == 7 and u.user_name == 'john' and u.addresses[0].email_address == 'lala@123.com') - u, attributes.instance_state(a).commit_all() + u, attributes.instance_state(a).commit_all(attributes.instance_dict(a)) self.assert_(u.user_id == 7 and u.user_name == 'john' and u.addresses[0].email_address == 'lala@123.com') u.user_name = 'heythere' @@ -272,7 +272,7 @@ class AttributesTest(_base.ORMTest): p1 = Post() attributes.instance_state(b).set_callable('posts', lambda:[p1]) attributes.instance_state(p1).set_callable('blog', lambda:b) - p1, attributes.instance_state(b).commit_all() + p1, attributes.instance_state(b).commit_all(attributes.instance_dict(b)) # no orphans (called before the lazy loaders fire off) assert attributes.has_parent(Blog, p1, 'posts', optimistic=True) @@ -353,7 +353,7 @@ class AttributesTest(_base.ORMTest): x = Bar() x.element = el eq_(attributes.get_history(attributes.instance_state(x), 'element'), ([el], (), ())) - attributes.instance_state(x).commit_all() + attributes.instance_state(x).commit_all(attributes.instance_dict(x)) (added, unchanged, deleted) = attributes.get_history(attributes.instance_state(x), 'element') assert added == () @@ -381,7 +381,7 @@ class AttributesTest(_base.ORMTest): attributes.register_attribute(Bar, 'id', uselist=False, useobject=True) x = Foo() - attributes.instance_state(x).commit_all() + attributes.instance_state(x).commit_all(attributes.instance_dict(x)) x.col2.append(bar4) eq_(attributes.get_history(attributes.instance_state(x), 'col2'), ([bar4], [bar1, bar2, bar3], [])) @@ -427,7 +427,7 @@ class AttributesTest(_base.ORMTest): attributes.register_attribute(Foo, 'element', uselist=False, copy_function=lambda x:[y for y in x], mutable_scalars=True, useobject=False) x = Foo() x.element = ['one', 'two', 'three'] - attributes.instance_state(x).commit_all() + attributes.instance_state(x).commit_all(attributes.instance_dict(x)) x.element[1] = 'five' assert attributes.instance_state(x).check_modified() @@ -437,7 +437,7 @@ class AttributesTest(_base.ORMTest): attributes.register_attribute(Foo, 'element', uselist=False, useobject=False) x = Foo() x.element = ['one', 'two', 'three'] - attributes.instance_state(x).commit_all() + attributes.instance_state(x).commit_all(attributes.instance_dict(x)) x.element[1] = 'five' assert not attributes.instance_state(x).check_modified() @@ -699,8 +699,8 @@ class PendingBackrefTest(_base.ORMTest): b = Blog("blog 1") p1.blog = b - attributes.instance_state(b).commit_all() - attributes.instance_state(p1).commit_all() + attributes.instance_state(b).commit_all(attributes.instance_dict(b)) + attributes.instance_state(p1).commit_all(attributes.instance_dict(p1)) assert b.posts == [Post("post 1")] class HistoryTest(_base.ORMTest): @@ -713,17 +713,17 @@ class HistoryTest(_base.ORMTest): attributes.register_attribute(Foo, 'someattr', uselist=False, useobject=False) f = Foo() - eq_(Foo.someattr.impl.get_committed_value(attributes.instance_state(f)), None) + eq_(Foo.someattr.impl.get_committed_value(attributes.instance_state(f), attributes.instance_dict(f)), None) f.someattr = 3 - eq_(Foo.someattr.impl.get_committed_value(attributes.instance_state(f)), None) + eq_(Foo.someattr.impl.get_committed_value(attributes.instance_state(f), attributes.instance_dict(f)), None) f = Foo() f.someattr = 3 - eq_(Foo.someattr.impl.get_committed_value(attributes.instance_state(f)), None) + eq_(Foo.someattr.impl.get_committed_value(attributes.instance_state(f), attributes.instance_dict(f)), None) - attributes.instance_state(f).commit(['someattr']) - eq_(Foo.someattr.impl.get_committed_value(attributes.instance_state(f)), 3) + attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) + eq_(Foo.someattr.impl.get_committed_value(attributes.instance_state(f), attributes.instance_dict(f)), 3) def test_scalar(self): class Foo(_base.BasicEntity): @@ -739,13 +739,13 @@ class HistoryTest(_base.ORMTest): f.someattr = "hi" eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), (['hi'], (), ())) - attributes.instance_state(f).commit(['someattr']) + attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), ['hi'], ())) f.someattr = 'there' eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), (['there'], (), ['hi'])) - attributes.instance_state(f).commit(['someattr']) + attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), ['there'], ())) @@ -760,7 +760,7 @@ class HistoryTest(_base.ORMTest): f.someattr = 'old' eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), (['old'], (), ['new'])) - attributes.instance_state(f).commit(['someattr']) + attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), ['old'], ())) # setting None on uninitialized is currently a change for a scalar attribute @@ -778,7 +778,7 @@ class HistoryTest(_base.ORMTest): # set same value twice f = Foo() - attributes.instance_state(f).commit(['someattr']) + attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) f.someattr = 'one' eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), (['one'], (), ())) f.someattr = 'two' @@ -799,7 +799,7 @@ class HistoryTest(_base.ORMTest): f.someattr = {'foo':'hi'} eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([{'foo':'hi'}], (), ())) - attributes.instance_state(f).commit(['someattr']) + attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [{'foo':'hi'}], ())) eq_(attributes.instance_state(f).committed_state['someattr'], {'foo':'hi'}) @@ -807,7 +807,7 @@ class HistoryTest(_base.ORMTest): eq_(attributes.instance_state(f).committed_state['someattr'], {'foo':'hi'}) eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([{'foo':'there'}], (), [{'foo':'hi'}])) - attributes.instance_state(f).commit(['someattr']) + attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [{'foo':'there'}], ())) @@ -819,7 +819,7 @@ class HistoryTest(_base.ORMTest): f.someattr = {'foo':'old'} eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([{'foo':'old'}], (), [{'foo':'new'}])) - attributes.instance_state(f).commit(['someattr']) + attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [{'foo':'old'}], ())) @@ -847,13 +847,13 @@ class HistoryTest(_base.ORMTest): f.someattr = hi eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([hi], (), ())) - attributes.instance_state(f).commit(['someattr']) + attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [hi], ())) f.someattr = there eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([there], (), [hi])) - attributes.instance_state(f).commit(['someattr']) + attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [there], ())) @@ -868,7 +868,7 @@ class HistoryTest(_base.ORMTest): f.someattr = old eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([old], (), ['new'])) - attributes.instance_state(f).commit(['someattr']) + attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [old], ())) # setting None on uninitialized is currently not a change for an object attribute @@ -887,7 +887,7 @@ class HistoryTest(_base.ORMTest): # set same value twice f = Foo() - attributes.instance_state(f).commit(['someattr']) + attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) f.someattr = 'one' eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), (['one'], (), ())) f.someattr = 'two' @@ -915,13 +915,13 @@ class HistoryTest(_base.ORMTest): f.someattr = [hi] eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([hi], [], [])) - attributes.instance_state(f).commit(['someattr']) + attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [hi], ())) f.someattr = [there] eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([there], [], [hi])) - attributes.instance_state(f).commit(['someattr']) + attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [there], ())) @@ -935,13 +935,13 @@ class HistoryTest(_base.ORMTest): f = Foo() collection = attributes.init_collection(attributes.instance_state(f), 'someattr') collection.append_without_event(new) - attributes.instance_state(f).commit_all() + attributes.instance_state(f).commit_all(attributes.instance_dict(f)) eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [new], ())) f.someattr = [old] eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([old], [], [new])) - attributes.instance_state(f).commit(['someattr']) + attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [old], ())) def test_dict_collections(self): @@ -969,7 +969,7 @@ class HistoryTest(_base.ORMTest): f.someattr['there'] = there eq_(tuple([set(x) for x in attributes.get_history(attributes.instance_state(f), 'someattr')]), (set([hi, there]), set(), set())) - attributes.instance_state(f).commit(['someattr']) + attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) eq_(tuple([set(x) for x in attributes.get_history(attributes.instance_state(f), 'someattr')]), (set(), set([hi, there]), set())) def test_object_collections_mutate(self): @@ -994,13 +994,13 @@ class HistoryTest(_base.ORMTest): f.someattr.append(hi) eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([hi], [], [])) - attributes.instance_state(f).commit(['someattr']) + attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [hi], ())) f.someattr.append(there) eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([there], [hi], [])) - attributes.instance_state(f).commit(['someattr']) + attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [hi, there], ())) @@ -1010,7 +1010,7 @@ class HistoryTest(_base.ORMTest): f.someattr.append(old) f.someattr.append(new) eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([old, new], [hi], [there])) - attributes.instance_state(f).commit(['someattr']) + attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [hi, old, new], ())) f.someattr.pop(0) @@ -1021,19 +1021,19 @@ class HistoryTest(_base.ORMTest): f.__dict__['id'] = 1 collection = attributes.init_collection(attributes.instance_state(f), 'someattr') collection.append_without_event(new) - attributes.instance_state(f).commit_all() + attributes.instance_state(f).commit_all(attributes.instance_dict(f)) eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [new], ())) f.someattr.append(old) eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([old], [new], [])) - attributes.instance_state(f).commit(['someattr']) + attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr']) eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [new, old], ())) f = Foo() collection = attributes.init_collection(attributes.instance_state(f), 'someattr') collection.append_without_event(new) - attributes.instance_state(f).commit_all() + attributes.instance_state(f).commit_all(attributes.instance_dict(f)) eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [new], ())) f.id = 1 @@ -1056,7 +1056,7 @@ class HistoryTest(_base.ORMTest): f.someattr.append(hi) eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([hi, there, hi], [], [])) - attributes.instance_state(f).commit_all() + attributes.instance_state(f).commit_all(attributes.instance_dict(f)) eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [hi, there, hi], ())) f.someattr = [] diff --git a/test/orm/extendedattr.py b/test/orm/extendedattr.py index 69164ebafb..aec6c181f2 100644 --- a/test/orm/extendedattr.py +++ b/test/orm/extendedattr.py @@ -117,7 +117,7 @@ class UserDefinedExtensionTest(_base.ORMTest): u.user_id = 7 u.user_name = 'john' u.email_address = 'lala@123.com' - self.assert_(u.__dict__ == {'_my_state':u._my_state, '_goofy_dict':{'user_id':7, 'user_name':'john', 'email_address':'lala@123.com'}}) + self.assert_(u.__dict__ == {'_my_state':u._my_state, '_goofy_dict':{'user_id':7, 'user_name':'john', 'email_address':'lala@123.com'}}, u.__dict__) def test_basic(self): for base in (object, MyBaseClass, MyClass): @@ -135,7 +135,7 @@ class UserDefinedExtensionTest(_base.ORMTest): u.email_address = 'lala@123.com' self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com') - attributes.instance_state(u).commit_all() + attributes.instance_state(u).commit_all(attributes.instance_dict(u)) self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com') u.user_name = 'heythere' @@ -182,7 +182,7 @@ class UserDefinedExtensionTest(_base.ORMTest): self.assertEquals(f.a, None) self.assertEquals(f.b, 12) - attributes.instance_state(f).commit_all() + attributes.instance_state(f).commit_all(attributes.instance_dict(f)) self.assertEquals(f.a, None) self.assertEquals(f.b, 12) @@ -272,8 +272,8 @@ class UserDefinedExtensionTest(_base.ORMTest): f1.bars.append(b1) self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'bars'), ([b1], [], [])) - attributes.instance_state(f1).commit_all() - attributes.instance_state(b1).commit_all() + attributes.instance_state(f1).commit_all(attributes.instance_dict(f1)) + attributes.instance_state(b1).commit_all(attributes.instance_dict(b1)) self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'name'), ((), ['f1'], ())) self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'bars'), ((), [b1], ())) diff --git a/test/orm/instrumentation.py b/test/orm/instrumentation.py index 081c46cdd8..fd15420d0a 100644 --- a/test/orm/instrumentation.py +++ b/test/orm/instrumentation.py @@ -1,8 +1,8 @@ import testenv; testenv.configure_for_tests() from testlib import sa -from testlib.sa import MetaData, Table, Column, Integer, ForeignKey -from testlib.sa.orm import mapper, relation, create_session, attributes, class_mapper +from testlib.sa import MetaData, Table, Column, Integer, ForeignKey, util +from testlib.sa.orm import mapper, relation, create_session, attributes, class_mapper, clear_mappers from testlib.testing import eq_, ne_ from testlib.compat import _function_named from orm import _base @@ -458,25 +458,9 @@ class MapperInitTest(_base.ORMTest): m = mapper(A, self.fixture()) - a = attributes.instance_state(A()) - assert isinstance(a, attributes.InstanceState) - assert type(a) is not attributes.InstanceState - - b = attributes.instance_state(B()) - assert isinstance(b, attributes.InstanceState) - assert type(b) is not attributes.InstanceState - # B is not mapped in the current implementation self.assertRaises(sa.orm.exc.UnmappedClassError, class_mapper, B) - # the constructor of C is decorated too. - # we don't support unmapped subclasses in any case, - # users should not be expecting any particular behavior - # from this scenario. - c = attributes.instance_state(C(3)) - assert isinstance(c, attributes.InstanceState) - assert type(c) is not attributes.InstanceState - # C is not mapped in the current implementation self.assertRaises(sa.orm.exc.UnmappedClassError, class_mapper, C) @@ -573,6 +557,10 @@ class OnLoadTest(_base.ORMTest): finally: del A + def tearDownAll(self): + clear_mappers() + attributes._install_lookup_strategy(util.symbol('native')) + class ExtendedEventsTest(_base.ORMTest): """Allow custom Events implementations.""" @@ -593,6 +581,7 @@ class ExtendedEventsTest(_base.ORMTest): assert isinstance(manager.events, MyEvents) + class NativeInstrumentationTest(_base.ORMTest): @with_lookup_strategy(sa.util.symbol('native')) def test_register_reserved_attribute(self): diff --git a/test/orm/mapper.py b/test/orm/mapper.py index d2687fa5a0..3f427654f5 100644 --- a/test/orm/mapper.py +++ b/test/orm/mapper.py @@ -1751,9 +1751,9 @@ class CompositeTypesTest(_base.MappedTest): return [self.x, self.y] __hash__ = None def __eq__(self, other): - return other.x == self.x and other.y == self.y + return isinstance(other, Point) and other.x == self.x and other.y == self.y def __ne__(self, other): - return not self.__eq__(other) + return not isinstance(other, Point) or not self.__eq__(other) class Graph(object): pass @@ -1819,6 +1819,12 @@ class CompositeTypesTest(_base.MappedTest): # query by columns eq_(sess.query(Edge.start, Edge.end).all(), [(3, 4, 5, 6), (14, 5, 19, 5)]) + e = g.edges[1] + e.end.x = e.end.y = None + sess.flush() + eq_(sess.query(Edge.start, Edge.end).all(), [(3, 4, 5, 6), (14, 5, None, None)]) + + @testing.resolve_artifact_names def test_pk(self): """Using a composite type as a primary key""" diff --git a/test/orm/query.py b/test/orm/query.py index e95f10ba29..5abcce689f 100644 --- a/test/orm/query.py +++ b/test/orm/query.py @@ -371,7 +371,7 @@ class OperatorTest(QueryTest, AssertsCompiledSQL): ) u7 = User(id=7) - attributes.instance_state(u7).commit_all() + attributes.instance_state(u7).commit_all(attributes.instance_dict(u7)) self._test(Address.user == u7, ":param_1 = addresses.user_id") diff --git a/test/orm/session.py b/test/orm/session.py index 818ab03daf..6ae05c77b0 100644 --- a/test/orm/session.py +++ b/test/orm/session.py @@ -5,7 +5,7 @@ import pickle from sqlalchemy.orm import create_session, sessionmaker, attributes from testlib import engines, sa, testing, config from testlib.sa import Table, Column, Integer, String, Sequence -from testlib.sa.orm import mapper, relation, backref +from testlib.sa.orm import mapper, relation, backref, eagerload from testlib.testing import eq_ from engine import _base as engine_base from orm import _base, _fixtures @@ -776,7 +776,66 @@ class SessionTest(_fixtures.FixtureTest): user = s.query(User).one() assert user.name == 'fred' assert s.identity_map + + @testing.resolve_artifact_names + def test_weakref_with_cycles_o2m(self): + s = sessionmaker()() + mapper(User, users, properties={ + "addresses":relation(Address, backref="user") + }) + mapper(Address, addresses) + s.add(User(name="ed", addresses=[Address(email_address="ed1")])) + s.commit() + + user = s.query(User).options(eagerload(User.addresses)).one() + user.addresses[0].user # lazyload + eq_(user, User(name="ed", addresses=[Address(email_address="ed1")])) + + del user + gc.collect() + assert len(s.identity_map) == 0 + user = s.query(User).options(eagerload(User.addresses)).one() + user.addresses[0].email_address='ed2' + user.addresses[0].user # lazyload + del user + gc.collect() + assert len(s.identity_map) == 2 + + s.commit() + user = s.query(User).options(eagerload(User.addresses)).one() + eq_(user, User(name="ed", addresses=[Address(email_address="ed2")])) + + @testing.resolve_artifact_names + def test_weakref_with_cycles_o2o(self): + s = sessionmaker()() + mapper(User, users, properties={ + "address":relation(Address, backref="user", uselist=False) + }) + mapper(Address, addresses) + s.add(User(name="ed", address=Address(email_address="ed1"))) + s.commit() + + user = s.query(User).options(eagerload(User.address)).one() + user.address.user + eq_(user, User(name="ed", address=Address(email_address="ed1"))) + + del user + gc.collect() + assert len(s.identity_map) == 0 + + user = s.query(User).options(eagerload(User.address)).one() + user.address.email_address='ed2' + user.address.user # lazyload + + del user + gc.collect() + assert len(s.identity_map) == 2 + + s.commit() + user = s.query(User).options(eagerload(User.address)).one() + eq_(user, User(name="ed", address=Address(email_address="ed2"))) + @testing.resolve_artifact_names def test_strong_ref(self): s = create_session(weak_identity_map=False) diff --git a/test/orm/unitofwork.py b/test/orm/unitofwork.py index 7a8415cdc3..c5e3afd014 100644 --- a/test/orm/unitofwork.py +++ b/test/orm/unitofwork.py @@ -14,6 +14,7 @@ from orm import _base, _fixtures from engine import _base as engine_base import pickleable from testlib.assertsql import AllOf, CompiledSQL +import gc class UnitOfWorkTest(object): pass @@ -366,6 +367,28 @@ class MutableTypesTest(_base.MappedTest): "WHERE mutable_t.id = :mutable_t_id", {'mutable_t_id': f1.id, 'val': u'hi', 'data':f1.data})]) + + @testing.resolve_artifact_names + def test_resurrect(self): + f1 = Foo() + f1.data = pickleable.Bar(4,5) + f1.val = u'hi' + + session = create_session(autocommit=False) + session.add(f1) + session.commit() + + f1.data.y = 19 + del f1 + + gc.collect() + assert len(session.identity_map) == 1 + + session.commit() + + assert session.query(Foo).one().data == pickleable.Bar(4, 19) + + @testing.uses_deprecated() @testing.resolve_artifact_names def test_nocomparison(self): diff --git a/test/profiling/zoomark_orm.py b/test/profiling/zoomark_orm.py index 7a189f8731..5d7192261d 100644 --- a/test/profiling/zoomark_orm.py +++ b/test/profiling/zoomark_orm.py @@ -290,11 +290,11 @@ class ZooMarkTest(TestBase): def test_profile_1_create_tables(self): self.test_baseline_1_create_tables() - @profiling.function_call_count(12925, {'2.4':12478}) + @profiling.function_call_count(12178, {'2.4':12178}) def test_profile_1a_populate(self): self.test_baseline_1a_populate() - @profiling.function_call_count(1185, {'2.4':1184}) + @profiling.function_call_count(903, {'2.4':903}) def test_profile_2_insert(self): self.test_baseline_2_insert() @@ -310,7 +310,7 @@ class ZooMarkTest(TestBase): def test_profile_5_aggregates(self): self.test_baseline_5_aggregates() - @profiling.function_call_count(3545) + @profiling.function_call_count(3343) def test_profile_6_editing(self): self.test_baseline_6_editing() -- 2.47.2