]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Significant performance enhancements regarding Sessions/flush()
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 17 May 2009 18:17:46 +0000 (18:17 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 17 May 2009 18:17:46 +0000 (18:17 +0000)
      in conjunction with large mapper graphs, large numbers of
      objects:

      - The Session's "weak referencing" behavior is now *full* -
        no strong references whatsoever are made to a mapped object
        or related items/collections in its __dict__.  Backrefs and
        other cycles in objects no longer affect the Session's ability
        to lose all references to unmodified objects.  Objects with
        pending changes still are maintained strongly until flush.
        [ticket:1398]

        The implementation also improves performance by moving
        the "resurrection" process of garbage collected items
        to only be relevant for mappings that map "mutable"
        attributes (i.e. PickleType, composite attrs).  This removes
        overhead from the gc process and simplifies internal
        behavior.

        If a "mutable" attribute change is the sole change on an object
        which is then dereferenced, the mapper will not have access to
        other attribute state when the UPDATE is issued.  This may present
        itself differently to some MapperExtensions.

        The change also affects the internal attribute API, but not
        the AttributeExtension interface nor any of the publically
        documented attribute functions.

      - The unit of work no longer genererates a graph of "dependency"
        processors for the full graph of mappers during flush(), instead
        creating such processors only for those mappers which represent
        objects with pending changes.  This saves a tremendous number
        of method calls in the context of a large interconnected
        graph of mappers.

      - Cached a wasteful "table sort" operation that previously
        occured multiple times per flush, also removing significant
        method call count from flush().

      - Other redundant behaviors have been simplified in
        mapper._save_obj().

22 files changed:
CHANGES
lib/sqlalchemy/orm/attributes.py
lib/sqlalchemy/orm/collections.py
lib/sqlalchemy/orm/dependency.py
lib/sqlalchemy/orm/dynamic.py
lib/sqlalchemy/orm/identity.py
lib/sqlalchemy/orm/interfaces.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/orm/state.py [new file with mode: 0644]
lib/sqlalchemy/orm/strategies.py
lib/sqlalchemy/orm/unitofwork.py
test/orm/attributes.py
test/orm/extendedattr.py
test/orm/instrumentation.py
test/orm/mapper.py
test/orm/query.py
test/orm/session.py
test/orm/unitofwork.py
test/profiling/zoomark_orm.py

diff --git a/CHANGES b/CHANGES
index 6161ae6de222ba48cf063b4f5abbf8abbefac6b2..5130e886fee54b0a7e406393af02159e6d6e3935 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -8,6 +8,48 @@ CHANGES
 =====
 
 - orm
+    - Significant performance enhancements regarding Sessions/flush()
+      in conjunction with large mapper graphs, large numbers of 
+      objects:
+      
+      - The Session's "weak referencing" behavior is now *full* -
+        no strong references whatsoever are made to a mapped object
+        or related items/collections in its __dict__.  Backrefs and 
+        other cycles in objects no longer affect the Session's ability 
+        to lose all references to unmodified objects.  Objects with 
+        pending changes still are maintained strongly until flush.  
+        [ticket:1398]
+        
+        The implementation also improves performance by moving
+        the "resurrection" process of garbage collected items
+        to only be relevant for mappings that map "mutable" 
+        attributes (i.e. PickleType, composite attrs).  This removes
+        overhead from the gc process and simplifies internal 
+        behavior.
+        
+        If a "mutable" attribute change is the sole change on an object 
+        which is then dereferenced, the mapper will not have access to 
+        other attribute state when the UPDATE is issued.  This may present 
+        itself differently to some MapperExtensions.
+        
+        The change also affects the internal attribute API, but not
+        the AttributeExtension interface nor any of the publically
+        documented attribute functions.
+        
+      - The unit of work no longer genererates a graph of "dependency"
+        processors for the full graph of mappers during flush(), instead
+        creating such processors only for those mappers which represent
+        objects with pending changes.  This saves a tremendous number
+        of method calls in the context of a large interconnected 
+        graph of mappers.
+        
+      - Cached a wasteful "table sort" operation that previously
+        occured multiple times per flush, also removing significant
+        method call count from flush().
+        
+      - Other redundant behaviors have been simplified in 
+        mapper._save_obj().
+      
     - Modified query_cls on DynamicAttributeImpl to accept a full
       mixin version of the AppenderQuery, which allows subclassing
       the AppenderMixin.
index 68aa0d93ae4f31ffcfb4499f7eec8f7ea1218831..4fa41ff3b5451539ac2af32e0bdf35aec461f412 100644 (file)
@@ -20,14 +20,13 @@ import types
 import weakref
 
 from sqlalchemy import util
-from sqlalchemy.util import EMPTY_SET
 from sqlalchemy.orm import interfaces, collections, exc
 import sqlalchemy.exceptions as sa_exc
 
 # lazy imports
 _entity_info = None
 identity_equal = None
-
+state = None
 
 PASSIVE_NORESULT = util.symbol('PASSIVE_NORESULT')
 ATTR_WAS_SET = util.symbol('ATTR_WAS_SET')
@@ -105,7 +104,7 @@ class QueryableAttribute(interfaces.PropComparator):
         self.parententity = parententity
 
     def get_history(self, instance, **kwargs):
-        return self.impl.get_history(instance_state(instance), **kwargs)
+        return self.impl.get_history(instance_state(instance), instance_dict(instance), **kwargs)
 
     def __selectable__(self):
         # TODO: conditionally attach this method based on clause_element ?
@@ -148,15 +147,15 @@ class InstrumentedAttribute(QueryableAttribute):
     """Public-facing descriptor, placed in the mapped class dictionary."""
 
     def __set__(self, instance, value):
-        self.impl.set(instance_state(instance), value, None)
+        self.impl.set(instance_state(instance), instance_dict(instance), value, None)
 
     def __delete__(self, instance):
-        self.impl.delete(instance_state(instance))
+        self.impl.delete(instance_state(instance), instance_dict(instance))
 
     def __get__(self, instance, owner):
         if instance is None:
             return self
-        return self.impl.get(instance_state(instance))
+        return self.impl.get(instance_state(instance), instance_dict(instance))
 
 class _ProxyImpl(object):
     accepts_scalar_loader = False
@@ -335,7 +334,7 @@ class AttributeImpl(object):
         else:
             state.callables[self.key] = callable_
 
-    def get_history(self, state, passive=PASSIVE_OFF):
+    def get_history(self, state, dict_, passive=PASSIVE_OFF):
         raise NotImplementedError()
 
     def _get_callable(self, state):
@@ -346,13 +345,13 @@ class AttributeImpl(object):
         else:
             return None
 
-    def initialize(self, state):
+    def initialize(self, state, dict_):
         """Initialize this attribute on the given object instance with an empty value."""
 
-        state.dict[self.key] = None
+        dict_[self.key] = None
         return None
 
-    def get(self, state, passive=PASSIVE_OFF):
+    def get(self, state, dict_, passive=PASSIVE_OFF):
         """Retrieve a value from the given object.
 
         If a callable is assembled on this object's attribute, and
@@ -361,7 +360,7 @@ class AttributeImpl(object):
         """
 
         try:
-            return state.dict[self.key]
+            return dict_[self.key]
         except KeyError:
             # if no history, check for lazy callables, etc.
             if state.committed_state.get(self.key, NEVER_SET) is NEVER_SET:
@@ -374,25 +373,25 @@ class AttributeImpl(object):
                         return PASSIVE_NORESULT
                     value = callable_()
                     if value is not ATTR_WAS_SET:
-                        return self.set_committed_value(state, value)
+                        return self.set_committed_value(state, dict_, value)
                     else:
-                        if self.key not in state.dict:
+                        if self.key not in dict_:
                             return self.get(state, passive=passive)
-                        return state.dict[self.key]
+                        return dict_[self.key]
 
             # Return a new, empty value
-            return self.initialize(state)
+            return self.initialize(state, dict_)
 
-    def append(self, state, value, initiator, passive=PASSIVE_OFF):
-        self.set(state, value, initiator)
+    def append(self, state, dict_, value, initiator, passive=PASSIVE_OFF):
+        self.set(state, dict_, value, initiator)
 
-    def remove(self, state, value, initiator, passive=PASSIVE_OFF):
-        self.set(state, None, initiator)
+    def remove(self, state, dict_, value, initiator, passive=PASSIVE_OFF):
+        self.set(state, dict_, None, initiator)
 
-    def set(self, state, value, initiator):
+    def set(self, state, dict_, value, initiator):
         raise NotImplementedError()
 
-    def get_committed_value(self, state, passive=PASSIVE_OFF):
+    def get_committed_value(self, state, dict_, passive=PASSIVE_OFF):
         """return the unchanged value of this attribute"""
 
         if self.key in state.committed_state:
@@ -401,12 +400,12 @@ class AttributeImpl(object):
             else:
                 return state.committed_state.get(self.key)
         else:
-            return self.get(state, passive=passive)
+            return self.get(state, dict_, passive=passive)
 
-    def set_committed_value(self, state, value):
+    def set_committed_value(self, state, dict_, value):
         """set an attribute value on the given instance and 'commit' it."""
 
-        state.commit([self.key])
+        state.commit(dict_, [self.key])
 
         state.callables.pop(self.key, None)
         state.dict[self.key] = value
@@ -419,45 +418,45 @@ class ScalarAttributeImpl(AttributeImpl):
     accepts_scalar_loader = True
     uses_objects = False
 
-    def delete(self, state):
+    def delete(self, state, dict_):
 
         # TODO: catch key errors, convert to attributeerror?
         if self.active_history or self.extensions:
-            old = self.get(state)
+            old = self.get(state, dict_)
         else:
-            old = state.dict.get(self.key, NO_VALUE)
+            old = dict_.get(self.key, NO_VALUE)
 
-        state.modified_event(self, False, old)
+        state.modified_event(dict_, self, False, old)
 
         if self.extensions:
-            self.fire_remove_event(state, old, None)
-        del state.dict[self.key]
+            self.fire_remove_event(state, dict_, old, None)
+        del dict_[self.key]
 
-    def get_history(self, state, passive=PASSIVE_OFF):
+    def get_history(self, state, dict_, passive=PASSIVE_OFF):
         return History.from_attribute(
-            self, state, state.dict.get(self.key, NO_VALUE))
+            self, state, dict_.get(self.key, NO_VALUE))
 
-    def set(self, state, value, initiator):
+    def set(self, state, dict_, value, initiator):
         if initiator is self:
             return
 
         if self.active_history or self.extensions:
-            old = self.get(state)
+            old = self.get(state, dict_)
         else:
-            old = state.dict.get(self.key, NO_VALUE)
+            old = dict_.get(self.key, NO_VALUE)
 
-        state.modified_event(self, False, old)
+        state.modified_event(dict_, self, False, old)
 
         if self.extensions:
-            value = self.fire_replace_event(state, value, old, initiator)
-        state.dict[self.key] = value
+            value = self.fire_replace_event(state, dict_, value, old, initiator)
+        dict_[self.key] = value
 
-    def fire_replace_event(self, state, value, previous, initiator):
+    def fire_replace_event(self, state, dict_, value, previous, initiator):
         for ext in self.extensions:
             value = ext.set(state, value, previous, initiator or self)
         return value
 
-    def fire_remove_event(self, state, value, initiator):
+    def fire_remove_event(self, state, dict_, value, initiator):
         for ext in self.extensions:
             ext.remove(state, value, initiator or self)
 
@@ -483,29 +482,48 @@ class MutableScalarAttributeImpl(ScalarAttributeImpl):
             raise sa_exc.ArgumentError("MutableScalarAttributeImpl requires a copy function")
         self.copy = copy_function
 
-    def get_history(self, state, passive=PASSIVE_OFF):
+    def get_history(self, state, dict_, passive=PASSIVE_OFF):
+        if not dict_:
+            v = state.committed_state.get(self.key, NO_VALUE)
+        else:
+            v = dict_.get(self.key, NO_VALUE)
+            
         return History.from_attribute(
-            self, state, state.dict.get(self.key, NO_VALUE))
+            self, state, v)
 
-    def commit_to_state(self, state, dest):
-        dest[self.key] = self.copy(state.dict[self.key])
+    def commit_to_state(self, state, dict_, dest):
+        dest[self.key] = self.copy(dict_[self.key])
 
-    def check_mutable_modified(self, state):
-        (added, unchanged, deleted) = self.get_history(state, passive=PASSIVE_NO_INITIALIZE)
+    def check_mutable_modified(self, state, dict_):
+        (added, unchanged, deleted) = self.get_history(state, dict_, passive=PASSIVE_NO_INITIALIZE)
         return bool(added or deleted)
 
-    def set(self, state, value, initiator):
+    def get(self, state, dict_, passive=PASSIVE_OFF):
+        if self.key not in state.mutable_dict:
+            ret = ScalarAttributeImpl.get(self, state, dict_, passive=passive)
+            if ret is not PASSIVE_NORESULT:
+                state.mutable_dict[self.key] = ret
+            return ret
+        else:
+            return state.mutable_dict[self.key]
+
+    def delete(self, state, dict_):
+        ScalarAttributeImpl.delete(self, state, dict_)
+        state.mutable_dict.pop(self.key)
+
+    def set(self, state, dict_, value, initiator):
         if initiator is self:
             return
 
-        state.modified_event(self, True, NEVER_SET)
-
+        state.modified_event(dict_, self, True, NEVER_SET)
+        
         if self.extensions:
-            old = self.get(state)
-            value = self.fire_replace_event(state, value, old, initiator)
-            state.dict[self.key] = value
+            old = self.get(state, dict_)
+            value = self.fire_replace_event(state, dict_, value, old, initiator)
+            dict_[self.key] = value
         else:
-            state.dict[self.key] = value
+            dict_[self.key] = value
+        state.mutable_dict[self.key] = value
 
 
 class ScalarObjectAttributeImpl(ScalarAttributeImpl):
@@ -526,22 +544,22 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
         if compare_function is None:
             self.is_equal = identity_equal
 
-    def delete(self, state):
-        old = self.get(state)
-        self.fire_remove_event(state, old, self)
-        del state.dict[self.key]
+    def delete(self, state, dict_):
+        old = self.get(state, dict_)
+        self.fire_remove_event(state, dict_, old, self)
+        del dict_[self.key]
 
-    def get_history(self, state, passive=PASSIVE_OFF):
-        if self.key in state.dict:
-            return History.from_attribute(self, state, state.dict[self.key])
+    def get_history(self, state, dict_, passive=PASSIVE_OFF):
+        if self.key in dict_:
+            return History.from_attribute(self, state, dict_[self.key])
         else:
-            current = self.get(state, passive=passive)
+            current = self.get(state, dict_, passive=passive)
             if current is PASSIVE_NORESULT:
                 return HISTORY_BLANK
             else:
                 return History.from_attribute(self, state, current)
 
-    def set(self, state, value, initiator):
+    def set(self, state, dict_, value, initiator):
         """Set a value on the given InstanceState.
 
         `initiator` is the ``InstrumentedAttribute`` that initiated the
@@ -553,12 +571,12 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
             return
 
         # may want to add options to allow the get() here to be passive
-        old = self.get(state)
-        value = self.fire_replace_event(state, value, old, initiator)
-        state.dict[self.key] = value
+        old = self.get(state, dict_)
+        value = self.fire_replace_event(state, dict_, value, old, initiator)
+        dict_[self.key] = value
 
-    def fire_remove_event(self, state, value, initiator):
-        state.modified_event(self, False, value)
+    def fire_remove_event(self, state, dict_, value, initiator):
+        state.modified_event(dict_, self, False, value)
 
         if self.trackparent and value is not None:
             self.sethasparent(instance_state(value), False)
@@ -566,8 +584,8 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
         for ext in self.extensions:
             ext.remove(state, value, initiator or self)
 
-    def fire_replace_event(self, state, value, previous, initiator):
-        state.modified_event(self, False, previous)
+    def fire_replace_event(self, state, dict_, value, previous, initiator):
+        state.modified_event(dict_, self, False, previous)
 
         if self.trackparent:
             if previous is not value and previous is not None:
@@ -615,15 +633,15 @@ class CollectionAttributeImpl(AttributeImpl):
     def __copy(self, item):
         return [y for y in list(collections.collection_adapter(item))]
 
-    def get_history(self, state, passive=PASSIVE_OFF):
-        current = self.get(state, passive=passive)
+    def get_history(self, state, dict_, passive=PASSIVE_OFF):
+        current = self.get(state, dict_, passive=passive)
         if current is PASSIVE_NORESULT:
             return HISTORY_BLANK
         else:
             return History.from_attribute(self, state, current)
 
-    def fire_append_event(self, state, value, initiator):
-        state.modified_event(self, True, NEVER_SET, passive=PASSIVE_NO_INITIALIZE)
+    def fire_append_event(self, state, dict_, value, initiator):
+        state.modified_event(dict_, self, True, NEVER_SET, passive=PASSIVE_NO_INITIALIZE)
 
         for ext in self.extensions:
             value = ext.append(state, value, initiator or self)
@@ -633,11 +651,11 @@ class CollectionAttributeImpl(AttributeImpl):
 
         return value
 
-    def fire_pre_remove_event(self, state, initiator):
-        state.modified_event(self, True, NEVER_SET, passive=PASSIVE_NO_INITIALIZE)
+    def fire_pre_remove_event(self, state, dict_, initiator):
+        state.modified_event(dict_, self, True, NEVER_SET, passive=PASSIVE_NO_INITIALIZE)
 
-    def fire_remove_event(self, state, value, initiator):
-        state.modified_event(self, True, NEVER_SET, passive=PASSIVE_NO_INITIALIZE)
+    def fire_remove_event(self, state, dict_, value, initiator):
+        state.modified_event(dict_, self, True, NEVER_SET, passive=PASSIVE_NO_INITIALIZE)
 
         if self.trackparent and value is not None:
             self.sethasparent(instance_state(value), False)
@@ -645,51 +663,51 @@ class CollectionAttributeImpl(AttributeImpl):
         for ext in self.extensions:
             ext.remove(state, value, initiator or self)
 
-    def delete(self, state):
-        if self.key not in state.dict:
+    def delete(self, state, dict_):
+        if self.key not in dict_:
             return
 
-        state.modified_event(self, True, NEVER_SET)
+        state.modified_event(dict_, self, True, NEVER_SET)
 
-        collection = self.get_collection(state)
+        collection = self.get_collection(state, state.dict)
         collection.clear_with_event()
         # TODO: catch key errors, convert to attributeerror?
-        del state.dict[self.key]
+        del dict_[self.key]
 
-    def initialize(self, state):
+    def initialize(self, state, dict_):
         """Initialize this attribute with an empty collection."""
 
         _, user_data = self._initialize_collection(state)
-        state.dict[self.key] = user_data
+        dict_[self.key] = user_data
         return user_data
 
     def _initialize_collection(self, state):
         return state.manager.initialize_collection(
             self.key, state, self.collection_factory)
 
-    def append(self, state, value, initiator, passive=PASSIVE_OFF):
+    def append(self, state, dict_, value, initiator, passive=PASSIVE_OFF):
         if initiator is self:
             return
 
-        collection = self.get_collection(state, passive=passive)
+        collection = self.get_collection(state, dict_, passive=passive)
         if collection is PASSIVE_NORESULT:
-            value = self.fire_append_event(state, value, initiator)
+            value = self.fire_append_event(state, dict_, value, initiator)
             state.get_pending(self.key).append(value)
         else:
             collection.append_with_event(value, initiator)
 
-    def remove(self, state, value, initiator, passive=PASSIVE_OFF):
+    def remove(self, state, dict_, value, initiator, passive=PASSIVE_OFF):
         if initiator is self:
             return
 
-        collection = self.get_collection(state, passive=passive)
+        collection = self.get_collection(state, state.dict, passive=passive)
         if collection is PASSIVE_NORESULT:
-            self.fire_remove_event(state, value, initiator)
+            self.fire_remove_event(state, dict_, value, initiator)
             state.get_pending(self.key).remove(value)
         else:
             collection.remove_with_event(value, initiator)
 
-    def set(self, state, value, initiator):
+    def set(self, state, dict_, value, initiator):
         """Set a value on the given object.
 
         `initiator` is the ``InstrumentedAttribute`` that initiated the
@@ -701,10 +719,10 @@ class CollectionAttributeImpl(AttributeImpl):
             return
 
         self._set_iterable(
-            state, value,
+            state, dict_, value,
             lambda adapter, i: adapter.adapt_like_to_iterable(i))
 
-    def _set_iterable(self, state, iterable, adapter=None):
+    def _set_iterable(self, state, dict_, iterable, adapter=None):
         """Set a collection value from an iterable of state-bearers.
 
         ``adapter`` is an optional callable invoked with a CollectionAdapter
@@ -722,24 +740,24 @@ class CollectionAttributeImpl(AttributeImpl):
         else:
             new_values = list(iterable)
 
-        old = self.get(state)
+        old = self.get(state, dict_)
 
         # ignore re-assignment of the current collection, as happens
         # implicitly with in-place operators (foo.collection |= other)
         if old is iterable:
             return
 
-        state.modified_event(self, True, old)
+        state.modified_event(dict_, self, True, old)
 
-        old_collection = self.get_collection(state, old)
+        old_collection = self.get_collection(state, dict_, old)
 
-        state.dict[self.key] = user_data
+        dict_[self.key] = user_data
 
         collections.bulk_replace(new_values, old_collection, new_collection)
         old_collection.unlink(old)
 
 
-    def set_committed_value(self, state, value):
+    def set_committed_value(self, state, dict_, value):
         """Set an attribute value on the given instance and 'commit' it."""
 
         collection, user_data = self._initialize_collection(state)
@@ -751,13 +769,13 @@ class CollectionAttributeImpl(AttributeImpl):
         state.callables.pop(self.key, None)
         state.dict[self.key] = user_data
 
-        state.commit([self.key])
+        state.commit(dict_, [self.key])
 
         if self.key in state.pending:
             
             # pending items exist.  issue a modified event,
             # add/remove new items.
-            state.modified_event(self, True, user_data)
+            state.modified_event(dict_, self, True, user_data)
 
             pending = state.pending.pop(self.key)
             added = pending.added_items
@@ -769,14 +787,14 @@ class CollectionAttributeImpl(AttributeImpl):
 
         return user_data
 
-    def get_collection(self, state, user_data=None, passive=PASSIVE_OFF):
+    def get_collection(self, state, dict_, user_data=None, passive=PASSIVE_OFF):
         """Retrieve the CollectionAdapter associated with the given state.
 
         Creates a new CollectionAdapter if one does not exist.
 
         """
         if user_data is None:
-            user_data = self.get(state, passive=passive)
+            user_data = self.get(state, dict_, passive=passive)
             if user_data is PASSIVE_NORESULT:
                 return user_data
 
@@ -799,320 +817,26 @@ class GenericBackrefExtension(interfaces.AttributeExtension):
         if oldchild is not None:
             # With lazy=None, there's no guarantee that the full collection is
             # present when updating via a backref.
-            old_state = instance_state(oldchild)
+            old_state, old_dict = instance_state(oldchild), instance_dict(oldchild)
             impl = old_state.get_impl(self.key)
             try:
-                impl.remove(old_state, state.obj(), initiator, passive=PASSIVE_NO_CALLABLES)
+                impl.remove(old_state, old_dict, state.obj(), initiator, passive=PASSIVE_NO_CALLABLES)
             except (ValueError, KeyError, IndexError):
                 pass
         if child is not None:
-            new_state = instance_state(child)
-            new_state.get_impl(self.key).append(new_state, state.obj(), initiator, passive=PASSIVE_NO_CALLABLES)
+            new_state,  new_dict = instance_state(child), instance_dict(child)
+            new_state.get_impl(self.key).append(new_state, new_dict, state.obj(), initiator, passive=PASSIVE_NO_CALLABLES)
         return child
 
     def append(self, state, child, initiator):
-        child_state = instance_state(child)
-        child_state.get_impl(self.key).append(child_state, state.obj(), initiator, passive=PASSIVE_NO_CALLABLES)
+        child_state, child_dict = instance_state(child), instance_dict(child)
+        child_state.get_impl(self.key).append(child_state, child_dict, state.obj(), initiator, passive=PASSIVE_NO_CALLABLES)
         return child
 
     def remove(self, state, child, initiator):
         if child is not None:
-            child_state = instance_state(child)
-            child_state.get_impl(self.key).remove(child_state, state.obj(), initiator, passive=PASSIVE_NO_CALLABLES)
-
-
-class InstanceState(object):
-    """tracks state information at the instance level."""
-
-    session_id = None
-    key = None
-    runid = None
-    expired_attributes = EMPTY_SET
-    load_options = EMPTY_SET
-    load_path = ()
-    insert_order = None
-    
-    def __init__(self, obj, manager):
-        self.class_ = obj.__class__
-        self.manager = manager
-        self.obj = weakref.ref(obj, self._cleanup)
-        self.dict = obj.__dict__
-        self.modified = False
-        self.callables = {}
-        self.expired = False
-        self.committed_state = {}
-        self.pending = {}
-        self.parents = {}
-        
-    def detach(self):
-        if self.session_id:
-            del self.session_id
-
-    def dispose(self):
-        if self.session_id:
-            del self.session_id
-        del self.obj
-        del self.dict
-    
-    def _cleanup(self, ref):
-        self.dispose()
-    
-    def obj(self):
-        return None
-    
-    @util.memoized_property
-    def dict(self):
-        # return a blank dict
-        # if none is available, so that asynchronous gc
-        # doesn't blow up expiration operations in progress
-        # (usually expire_attributes)
-        return {}
-    
-    @property
-    def sort_key(self):
-        return self.key and self.key[1] or (self.insert_order, )
-
-    def check_modified(self):
-        if self.modified:
-            return True
-        else:
-            for key in self.manager.mutable_attributes:
-                if self.manager[key].impl.check_mutable_modified(self):
-                    return True
-            else:
-                return False
-
-    def initialize_instance(*mixed, **kwargs):
-        self, instance, args = mixed[0], mixed[1], mixed[2:]
-        manager = self.manager
-
-        for fn in manager.events.on_init:
-            fn(self, instance, args, kwargs)
-        try:
-            return manager.events.original_init(*mixed[1:], **kwargs)
-        except:
-            for fn in manager.events.on_init_failure:
-                fn(self, instance, args, kwargs)
-            raise
-
-    def get_history(self, key, **kwargs):
-        return self.manager.get_impl(key).get_history(self, **kwargs)
-
-    def get_impl(self, key):
-        return self.manager.get_impl(key)
-
-    def get_pending(self, key):
-        if key not in self.pending:
-            self.pending[key] = PendingCollection()
-        return self.pending[key]
-
-    def value_as_iterable(self, key, passive=PASSIVE_OFF):
-        """return an InstanceState attribute as a list,
-        regardless of it being a scalar or collection-based
-        attribute.
-
-        returns None if passive is not PASSIVE_OFF and the getter returns
-        PASSIVE_NORESULT.
-        """
-
-        impl = self.get_impl(key)
-        x = impl.get(self, passive=passive)
-        if x is PASSIVE_NORESULT:
-
-            return None
-        elif hasattr(impl, 'get_collection'):
-            return impl.get_collection(self, x, passive=passive)
-        elif isinstance(x, list):
-            return x
-        else:
-            return [x]
-
-    def _run_on_load(self, instance=None):
-        if instance is None:
-            instance = self.obj()
-        self.manager.events.run('on_load', instance)
-
-    def __getstate__(self):
-        return {'key': self.key,
-                'committed_state': self.committed_state,
-                'pending': self.pending,
-                'parents': self.parents,
-                'modified': self.modified,
-                'expired':self.expired,
-                'load_options':self.load_options,
-                'load_path':interfaces.serialize_path(self.load_path),
-                'instance': self.obj(),
-                'expired_attributes':self.expired_attributes,
-                'callables': self.callables}
-
-    def __setstate__(self, state):
-        self.committed_state = state['committed_state']
-        self.parents = state['parents']
-        self.key = state['key']
-        self.session_id = None
-        self.pending = state['pending']
-        self.modified = state['modified']
-        self.obj = weakref.ref(state['instance'])
-        self.load_options = state['load_options'] or EMPTY_SET
-        self.load_path = interfaces.deserialize_path(state['load_path'])
-        self.class_ = self.obj().__class__
-        self.manager = manager_of_class(self.class_)
-        self.dict = self.obj().__dict__
-        self.callables = state['callables']
-        self.runid = None
-        self.expired = state['expired']
-        self.expired_attributes = state['expired_attributes']
-
-    def initialize(self, key):
-        self.manager.get_impl(key).initialize(self)
-
-    def set_callable(self, key, callable_):
-        self.dict.pop(key, None)
-        self.callables[key] = callable_
-
-    def __call__(self):
-        """__call__ allows the InstanceState to act as a deferred
-        callable for loading expired attributes, which is also
-        serializable (picklable).
-
-        """
-        unmodified = self.unmodified
-        class_manager = self.manager
-        class_manager.deferred_scalar_loader(self, [
-            attr.impl.key for attr in class_manager.attributes if
-                attr.impl.accepts_scalar_loader and
-                attr.impl.key in self.expired_attributes and
-                attr.impl.key in unmodified
-            ])
-        for k in self.expired_attributes:
-            self.callables.pop(k, None)
-        del self.expired_attributes
-        return ATTR_WAS_SET
-
-    @property
-    def unmodified(self):
-        """a set of keys which have no uncommitted changes"""
-
-        return set(
-            key for key in self.manager.iterkeys()
-            if (key not in self.committed_state or
-                (key in self.manager.mutable_attributes and
-                 not self.manager[key].impl.check_mutable_modified(self))))
-
-    @property
-    def unloaded(self):
-        """a set of keys which do not have a loaded value.
-
-        This includes expired attributes and any other attribute that
-        was never populated or modified.
-
-        """
-        return set(
-            key for key in self.manager.iterkeys()
-            if key not in self.committed_state and key not in self.dict)
-
-    def expire_attributes(self, attribute_names):
-        self.expired_attributes = set(self.expired_attributes)
-
-        if attribute_names is None:
-            attribute_names = self.manager.keys()
-            self.expired = True
-            self.modified = False
-            filter_deferred = True
-        else:
-            filter_deferred = False
-        for key in attribute_names:
-            impl = self.manager[key].impl
-            if not filter_deferred or \
-                not impl.dont_expire_missing or \
-                key in self.dict:
-                self.expired_attributes.add(key)
-                if impl.accepts_scalar_loader:
-                    self.callables[key] = self
-            self.dict.pop(key, None)
-            self.pending.pop(key, None)
-            self.committed_state.pop(key, None)
-
-    def reset(self, key):
-        """remove the given attribute and any callables associated with it."""
-
-        self.dict.pop(key, None)
-        self.callables.pop(key, None)
-
-    def modified_event(self, attr, should_copy, previous, passive=PASSIVE_OFF):
-        needs_committed = attr.key not in self.committed_state
-
-        if needs_committed:
-            if previous is NEVER_SET:
-                if passive:
-                    if attr.key in self.dict:
-                        previous = self.dict[attr.key]
-                else:
-                    previous = attr.get(self)
-
-            if should_copy and previous not in (None, NO_VALUE, NEVER_SET):
-                previous = attr.copy(previous)
-
-            if needs_committed:
-                self.committed_state[attr.key] = previous
-
-        self.modified = True
-
-    def commit(self, keys):
-        """Commit attributes.
-
-        This is used by a partial-attribute load operation to mark committed
-        those attributes which were refreshed from the database.
-
-        Attributes marked as "expired" can potentially remain "expired" after
-        this step if a value was not populated in state.dict.
-
-        """
-        class_manager = self.manager
-        for key in keys:
-            if key in self.dict and key in class_manager.mutable_attributes:
-                class_manager[key].impl.commit_to_state(self, self.committed_state)
-            else:
-                self.committed_state.pop(key, None)
-
-        self.expired = False
-        # unexpire attributes which have loaded
-        for key in self.expired_attributes.intersection(keys):
-            if key in self.dict:
-                self.expired_attributes.remove(key)
-                self.callables.pop(key, None)
-
-    def commit_all(self):
-        """commit all attributes unconditionally.
-
-        This is used after a flush() or a full load/refresh
-        to remove all pending state from the instance.
-
-         - all attributes are marked as "committed"
-         - the "strong dirty reference" is removed
-         - the "modified" flag is set to False
-         - any "expired" markers/callables are removed.
-
-        Attributes marked as "expired" can potentially remain "expired" after this step
-        if a value was not populated in state.dict.
-
-        """
-        
-        self.committed_state = {}
-        self.pending = {}
-        
-        # unexpire attributes which have loaded
-        if self.expired_attributes:
-            for key in self.expired_attributes.intersection(self.dict):
-                self.callables.pop(key, None)
-            self.expired_attributes.difference_update(self.dict)
-
-        for key in self.manager.mutable_attributes:
-            if key in self.dict:
-                self.manager[key].impl.commit_to_state(self, self.committed_state)
-
-        self.modified = self.expired = False
-        self._strong_obj = None
+            child_state, child_dict = instance_state(child), instance_dict(child)
+            child_state.get_impl(self.key).remove(child_state, child_dict, state.obj(), initiator, passive=PASSIVE_NO_CALLABLES)
 
 
 class Events(object):
@@ -1121,6 +845,7 @@ class Events(object):
         self.on_init = ()
         self.on_init_failure = ()
         self.on_load = ()
+        self.on_resurrect = ()
 
     def run(self, event, *args, **kwargs):
         for fn in getattr(self, event):
@@ -1146,7 +871,6 @@ class ClassManager(dict):
     STATE_ATTR = '_sa_instance_state'
 
     event_registry_factory = Events
-    instance_state_factory = InstanceState
     deferred_scalar_loader = None
     
     def __init__(self, class_):
@@ -1170,7 +894,6 @@ class ClassManager(dict):
     
     def _configure_create_arguments(self, 
                             _source=None, 
-                            instance_state_factory=None, 
                             deferred_scalar_loader=None):
         """Accept extra **kw arguments passed to create_manager_for_cls.
         
@@ -1185,11 +908,8 @@ class ClassManager(dict):
         
         """
         if _source:
-            instance_state_factory = _source.instance_state_factory
             deferred_scalar_loader = _source.deferred_scalar_loader
 
-        if instance_state_factory:
-            self.instance_state_factory = instance_state_factory
         if deferred_scalar_loader:
             self.deferred_scalar_loader = deferred_scalar_loader
     
@@ -1222,7 +942,16 @@ class ClassManager(dict):
         if self.new_init:
             self.uninstall_member('__init__')
             self.new_init = None
-
+    
+    def _create_instance_state(self, instance):
+        global state
+        if state is None:
+            from sqlalchemy.orm import state
+        if self.mutable_attributes:
+            return state.MutableAttrInstanceState(instance, self)
+        else:
+            return state.InstanceState(instance, self)
+        
     def manage(self):
         """Mark this instance as the manager for its class."""
         
@@ -1330,11 +1059,11 @@ class ClassManager(dict):
 
     def new_instance(self, state=None):
         instance = self.class_.__new__(self.class_)
-        setattr(instance, self.STATE_ATTR, state or self.instance_state_factory(instance, self))
+        setattr(instance, self.STATE_ATTR, state or self._create_instance_state(instance))
         return instance
 
     def setup_instance(self, instance, state=None):
-        setattr(instance, self.STATE_ATTR, state or self.instance_state_factory(instance, self))
+        setattr(instance, self.STATE_ATTR, state or self._create_instance_state(instance))
     
     def teardown_instance(self, instance):
         delattr(instance, self.STATE_ATTR)
@@ -1348,13 +1077,10 @@ class ClassManager(dict):
         if hasattr(instance, self.STATE_ATTR):
             return False
         else:
-            state = self.instance_state_factory(instance, self)
+            state = self._create_instance_state(instance)
             setattr(instance, self.STATE_ATTR, state)
             return state
     
-    def state_of(self, instance):
-        return getattr(instance, self.STATE_ATTR)
-        
     def state_getter(self):
         """Return a (instance) -> InstanceState callable.
 
@@ -1365,6 +1091,9 @@ class ClassManager(dict):
 
         return attrgetter(self.STATE_ATTR)
     
+    def dict_getter(self):
+        return attrgetter('__dict__')
+        
     def has_state(self, instance):
         return hasattr(instance, self.STATE_ATTR)
         
@@ -1385,6 +1114,9 @@ class _ClassInstrumentationAdapter(ClassManager):
 
     def __init__(self, class_, override, **kw):
         self._adapted = override
+        self._get_state = self._adapted.state_getter(class_)
+        self._get_dict = self._adapted.dict_getter(class_)
+        
         ClassManager.__init__(self, class_, **kw)
 
     def manage(self):
@@ -1446,36 +1178,27 @@ class _ClassInstrumentationAdapter(ClassManager):
         self._adapted.initialize_instance_dict(self.class_, instance)
         
         if state is None:
-            state = self.instance_state_factory(instance, self)
+            state = self._create_instance_state(instance)
             
         # the given instance is assumed to have no state
         self._adapted.install_state(self.class_, instance, state)
-        state.dict = self._adapted.get_instance_dict(self.class_, instance)
         return state
 
     def teardown_instance(self, instance):
         self._adapted.remove_state(self.class_, instance)
 
-    def state_of(self, instance):
-        if hasattr(self._adapted, 'state_of'):
-            return self._adapted.state_of(self.class_, instance)
-        else:
-            getter = self._adapted.state_getter(self.class_)
-            return getter(instance)
-
     def has_state(self, instance):
-        if hasattr(self._adapted, 'has_state'):
-            return self._adapted.has_state(self.class_, instance)
-        else:
-            try:
-                state = self.state_of(instance)
-                return True
-            except exc.NO_STATE:
-                return False
+        try:
+            state = self._get_state(instance)
+            return True
+        except exc.NO_STATE:
+            return False
 
     def state_getter(self):
-        return self._adapted.state_getter(self.class_)
+        return self._get_state
 
+    def dict_getter(self):
+        return self._get_dict
 
 class History(tuple):
     """A 3-tuple of added, unchanged and deleted values.
@@ -1520,7 +1243,7 @@ class History(tuple):
         original = state.committed_state.get(attribute.key, NEVER_SET)
 
         if hasattr(attribute, 'get_collection'):
-            current = attribute.get_collection(state, current)
+            current = attribute.get_collection(state, state.dict, current)
             if original is NO_VALUE:
                 return cls(list(current), (), ())
             elif original is NEVER_SET:
@@ -1557,30 +1280,8 @@ class History(tuple):
 
 HISTORY_BLANK = History(None, None, None)
 
-class PendingCollection(object):
-    """A writable placeholder for an unloaded collection.
-
-    Stores items appended to and removed from a collection that has not yet
-    been loaded. When the collection is loaded, the changes stored in
-    PendingCollection are applied to it to produce the final result.
-
-    """
-    def __init__(self):
-        self.deleted_items = util.IdentitySet()
-        self.added_items = util.OrderedIdentitySet()
-
-    def append(self, value):
-        if value in self.deleted_items:
-            self.deleted_items.remove(value)
-        self.added_items.add(value)
-
-    def remove(self, value):
-        if value in self.added_items:
-            self.added_items.remove(value)
-        self.deleted_items.add(value)
-
 def _conditional_instance_state(obj):
-    if not isinstance(obj, InstanceState):
+    if not isinstance(obj, state.InstanceState):
         obj = instance_state(obj)
     return obj
         
@@ -1690,15 +1391,16 @@ def init_collection(obj, key):
     this usage is deprecated.
     
     """
-
-    return init_state_collection(_conditional_instance_state(obj), key)
+    state = _conditional_instance_state(obj)
+    dict_ = state.dict
+    return init_state_collection(state, dict_, key)
     
-def init_state_collection(state, key):
+def init_state_collection(state, dict_, key):
     """Initialize a collection attribute and return the collection adapter."""
     
     attr = state.get_impl(key)
-    user_data = attr.initialize(state)
-    return attr.get_collection(state, user_data)
+    user_data = attr.initialize(state, dict_)
+    return attr.get_collection(state, dict_, user_data)
 
 def set_committed_value(instance, key, value):
     """Set the value of an attribute with no history events.
@@ -1715,8 +1417,8 @@ def set_committed_value(instance, key, value):
     as though it were part of its original loaded state.
     
     """
-    state = instance_state(instance)
-    state.get_impl(key).set_committed_value(instance, key, value)
+    state, dict_ = instance_state(instance), instance_dict(instance)
+    state.get_impl(key).set_committed_value(state, dict_, key, value)
     
 def set_attribute(instance, key, value):
     """Set the value of an attribute, firing history events.
@@ -1728,8 +1430,8 @@ def set_attribute(instance, key, value):
     by SQLAlchemy.
     
     """
-    state = instance_state(instance)
-    state.get_impl(key).set(state, value, None)
+    state, dict_ = instance_state(instance), instance_dict(instance)
+    state.get_impl(key).set(state, dict_, value, None)
 
 def get_attribute(instance, key):
     """Get the value of an attribute, firing any callables required.
@@ -1741,8 +1443,8 @@ def get_attribute(instance, key):
     by SQLAlchemy.
     
     """
-    state = instance_state(instance)
-    return state.get_impl(key).get(state)
+    state, dict_ = instance_state(instance), instance_dict(instance)
+    return state.get_impl(key).get(state, dict_)
 
 def del_attribute(instance, key):
     """Delete the value of an attribute, firing history events.
@@ -1754,8 +1456,8 @@ def del_attribute(instance, key):
     by SQLAlchemy.
     
     """
-    state = instance_state(instance)
-    state.get_impl(key).delete(state)
+    state, dict_ = instance_state(instance), instance_dict(instance)
+    state.get_impl(key).delete(state, dict_)
 
 def is_instrumented(instance, key):
     """Return True if the given attribute on the given instance is instrumented
@@ -1772,6 +1474,7 @@ class InstrumentationRegistry(object):
 
     _manager_finders = weakref.WeakKeyDictionary()
     _state_finders = util.WeakIdentityMapping()
+    _dict_finders = util.WeakIdentityMapping()
     _extended = False
 
     def create_manager_for_cls(self, class_, **kw):
@@ -1806,6 +1509,7 @@ class InstrumentationRegistry(object):
         manager.factory = factory
         self._manager_finders[class_] = manager.manager_getter()
         self._state_finders[class_] = manager.state_getter()
+        self._dict_finders[class_] = manager.dict_getter()
         return manager
 
     def _collect_management_factories_for(self, cls):
@@ -1845,6 +1549,7 @@ class InstrumentationRegistry(object):
             return finder(cls)
 
     def state_of(self, instance):
+        # this is only called when alternate instrumentation has been established
         if instance is None:
             raise AttributeError("None has no persistent state.")
         try:
@@ -1852,21 +1557,15 @@ class InstrumentationRegistry(object):
         except KeyError:
             raise AttributeError("%r is not instrumented" % instance.__class__)
 
-    def state_or_default(self, instance, default=None):
+    def dict_of(self, instance):
+        # this is only called when alternate instrumentation has been established
         if instance is None:
-            return default
+            raise AttributeError("None has no persistent state.")
         try:
-            finder = self._state_finders[instance.__class__]
+            return self._dict_finders[instance.__class__](instance)
         except KeyError:
-            return default
-        else:
-            try:
-                return finder(instance)
-            except exc.NO_STATE:
-                return default
-            except:
-                raise
-
+            raise AttributeError("%r is not instrumented" % instance.__class__)
+        
     def unregister(self, class_):
         if class_ in self._manager_finders:
             manager = self.manager_of_class(class_)
@@ -1874,6 +1573,7 @@ class InstrumentationRegistry(object):
             manager.dispose()
             del self._manager_finders[class_]
             del self._state_finders[class_]
+            del self._dict_finders[class_]
 
 instrumentation_registry = InstrumentationRegistry()
 
@@ -1887,12 +1587,14 @@ def _install_lookup_strategy(implementation):
     and unit tests specific to this behavior.
     
     """
-    global instance_state
+    global instance_state, instance_dict
     if implementation is util.symbol('native'):
         instance_state = attrgetter(ClassManager.STATE_ATTR)
+        instance_dict = attrgetter("__dict__")
     else:
         instance_state = instrumentation_registry.state_of
-    
+        instance_dict = instrumentation_registry.dict_of
+        
 manager_of_class = instrumentation_registry.manager_of_class
 _create_manager_for_cls = instrumentation_registry.create_manager_for_cls
 _install_lookup_strategy(util.symbol('native'))
index 5638a7e4a517ea3214856055f1eeed715d27a34c..4ca4c5719eb514c1e357b89effd2ca86b9208a82 100644 (file)
@@ -472,6 +472,7 @@ class CollectionAdapter(object):
     """
     def __init__(self, attr, owner_state, data):
         self.attr = attr
+        # TODO: figure out what this being a weakref buys us
         self._data = weakref.ref(data)
         self.owner_state = owner_state
         self.link_to_self(data)
@@ -578,7 +579,7 @@ class CollectionAdapter(object):
 
         """
         if initiator is not False and item is not None:
-            return self.attr.fire_append_event(self.owner_state, item, initiator)
+            return self.attr.fire_append_event(self.owner_state, self.owner_state.dict, item, initiator)
         else:
             return item
 
@@ -591,7 +592,7 @@ class CollectionAdapter(object):
 
         """
         if initiator is not False and item is not None:
-            self.attr.fire_remove_event(self.owner_state, item, initiator)
+            self.attr.fire_remove_event(self.owner_state, self.owner_state.dict, item, initiator)
 
     def fire_pre_remove_event(self, initiator=None):
         """Notify that an entity is about to be removed from the collection.
@@ -600,7 +601,7 @@ class CollectionAdapter(object):
         fire_remove_event().
 
         """
-        self.attr.fire_pre_remove_event(self.owner_state, initiator=initiator)
+        self.attr.fire_pre_remove_event(self.owner_state, self.owner_state.dict, initiator=initiator)
 
     def __getstate__(self):
         return {'key': self.attr.key,
index a80727b7f2245c80f9c3fb9d37b5fd058a76e5ed..151c557d712420bef0b3e66e837e76bf43147a5d 100644 (file)
@@ -64,17 +64,21 @@ class DependencyProcessor(object):
     def register_dependencies(self, uowcommit):
         """Tell a ``UOWTransaction`` what mappers are dependent on
         which, with regards to the two or three mappers handled by
-        this ``PropertyLoader``.
+        this ``DependencyProcessor``.
 
-        Also register itself as a *processor* for one of its mappers,
-        which will be executed after that mapper's objects have been
-        saved or before they've been deleted.  The process operation
-        manages attributes and dependent operations upon the objects
-        of one of the involved mappers.
         """
 
         raise NotImplementedError()
 
+    def register_processors(self, uowcommit):
+        """Tell a ``UOWTransaction`` about this object as a processor,
+        which will be executed after that mapper's objects have been
+        saved or before they've been deleted.  The process operation
+        manages attributes and dependent operations between two mappers.
+        
+        """
+        raise NotImplementedError()
+        
     def whose_dependent_on_who(self, state1, state2):
         """Given an object pair assuming `obj2` is a child of `obj1`,
         return a tuple with the dependent object second, or None if
@@ -181,9 +185,13 @@ class OneToManyDP(DependencyProcessor):
         if self.post_update:
             uowcommit.register_dependency(self.mapper, self.dependency_marker)
             uowcommit.register_dependency(self.parent, self.dependency_marker)
-            uowcommit.register_processor(self.dependency_marker, self, self.parent)
         else:
             uowcommit.register_dependency(self.parent, self.mapper)
+
+    def register_processors(self, uowcommit):
+        if self.post_update:
+            uowcommit.register_processor(self.dependency_marker, self, self.parent)
+        else:
             uowcommit.register_processor(self.parent, self, self.parent)
 
     def process_dependencies(self, task, deplist, uowcommit, delete = False):
@@ -285,6 +293,9 @@ class DetectKeySwitch(DependencyProcessor):
     no_dependencies = True
 
     def register_dependencies(self, uowcommit):
+        pass
+
+    def register_processors(self, uowcommit):
         uowcommit.register_processor(self.parent, self, self.mapper)
 
     def preprocess_dependencies(self, task, deplist, uowcommit, delete=False):
@@ -330,12 +341,15 @@ class ManyToOneDP(DependencyProcessor):
         if self.post_update:
             uowcommit.register_dependency(self.mapper, self.dependency_marker)
             uowcommit.register_dependency(self.parent, self.dependency_marker)
-            uowcommit.register_processor(self.dependency_marker, self, self.parent)
         else:
             uowcommit.register_dependency(self.mapper, self.parent)
+    
+    def register_processors(self, uowcommit):
+        if self.post_update:
+            uowcommit.register_processor(self.dependency_marker, self, self.parent)
+        else:
             uowcommit.register_processor(self.mapper, self, self.parent)
 
-
     def process_dependencies(self, task, deplist, uowcommit, delete=False):
         if delete:
             if self.post_update and not self.cascade.delete_orphan and not self.passive_deletes == 'all':
@@ -408,8 +422,10 @@ class ManyToManyDP(DependencyProcessor):
 
         uowcommit.register_dependency(self.parent, self.dependency_marker)
         uowcommit.register_dependency(self.mapper, self.dependency_marker)
-        uowcommit.register_processor(self.dependency_marker, self, self.parent)
 
+    def register_processors(self, uowcommit):
+        uowcommit.register_processor(self.dependency_marker, self, self.parent)
+        
     def process_dependencies(self, task, deplist, uowcommit, delete = False):
         connection = uowcommit.transaction.connection(self.mapper)
         secondary_delete = []
@@ -527,6 +543,9 @@ class MapperStub(object):
     def _register_dependencies(self, uowcommit):
         pass
 
+    def _register_procesors(self, uowcommit):
+        pass
+
     def _save_obj(self, *args, **kwargs):
         pass
 
index 3d31a686a2e5f75f3cb039682804b414e92d6007..70243291dc3e279bdae383bbe02b7ae1b157f213 100644 (file)
@@ -55,21 +55,21 @@ class DynamicAttributeImpl(attributes.AttributeImpl):
         else:
             self.query_class = mixin_user_query(query_class)
 
-    def get(self, state, passive=False):
+    def get(self, state, dict_, passive=False):
         if passive:
             return self._get_collection_history(state, passive=True).added_items
         else:
             return self.query_class(self, state)
 
-    def get_collection(self, state, user_data=None, passive=True):
+    def get_collection(self, state, dict_, user_data=None, passive=True):
         if passive:
             return self._get_collection_history(state, passive=passive).added_items
         else:
             history = self._get_collection_history(state, passive=passive)
             return history.added_items + history.unchanged_items
 
-    def fire_append_event(self, state, value, initiator):
-        collection_history = self._modified_event(state)
+    def fire_append_event(self, state, dict_, value, initiator):
+        collection_history = self._modified_event(state, dict_)
         collection_history.added_items.append(value)
 
         for ext in self.extensions:
@@ -78,8 +78,8 @@ class DynamicAttributeImpl(attributes.AttributeImpl):
         if self.trackparent and value is not None:
             self.sethasparent(attributes.instance_state(value), True)
 
-    def fire_remove_event(self, state, value, initiator):
-        collection_history = self._modified_event(state)
+    def fire_remove_event(self, state, dict_, value, initiator):
+        collection_history = self._modified_event(state, dict_)
         collection_history.deleted_items.append(value)
 
         if self.trackparent and value is not None:
@@ -88,31 +88,31 @@ class DynamicAttributeImpl(attributes.AttributeImpl):
         for ext in self.extensions:
             ext.remove(state, value, initiator or self)
 
-    def _modified_event(self, state):
+    def _modified_event(self, state, dict_):
 
         if self.key not in state.committed_state:
             state.committed_state[self.key] = CollectionHistory(self, state)
 
-        state.modified_event(self, False, attributes.NEVER_SET, passive=attributes.PASSIVE_NO_INITIALIZE)
+        state.modified_event(dict_, self, False, attributes.NEVER_SET, passive=attributes.PASSIVE_NO_INITIALIZE)
 
         # this is a hack to allow the _base.ComparableEntity fixture
         # to work
-        state.dict[self.key] = True
+        dict_[self.key] = True
         return state.committed_state[self.key]
 
-    def set(self, state, value, initiator):
+    def set(self, state, dict_, value, initiator):
         if initiator is self:
             return
 
-        self._set_iterable(state, value)
+        self._set_iterable(state, dict_, value)
 
-    def _set_iterable(self, state, iterable, adapter=None):
+    def _set_iterable(self, state, dict_, iterable, adapter=None):
 
-        collection_history = self._modified_event(state)
+        collection_history = self._modified_event(state, dict_)
         new_values = list(iterable)
 
         if _state_has_identity(state):
-            old_collection = list(self.get(state))
+            old_collection = list(self.get(state, dict_))
         else:
             old_collection = []
 
@@ -121,7 +121,7 @@ class DynamicAttributeImpl(attributes.AttributeImpl):
     def delete(self, *args, **kwargs):
         raise NotImplementedError()
 
-    def get_history(self, state, passive=False):
+    def get_history(self, state, dict_, passive=False):
         c = self._get_collection_history(state, passive)
         return attributes.History(c.added_items, c.unchanged_items, c.deleted_items)
 
@@ -136,13 +136,13 @@ class DynamicAttributeImpl(attributes.AttributeImpl):
         else:
             return c
 
-    def append(self, state, value, initiator, passive=False):
+    def append(self, state, dict_, value, initiator, passive=False):
         if initiator is not self:
-            self.fire_append_event(state, value, initiator)
+            self.fire_append_event(state, dict_, value, initiator)
 
-    def remove(self, state, value, initiator, passive=False):
+    def remove(self, state, dict_, value, initiator, passive=False):
         if initiator is not self:
-            self.fire_remove_event(state, value, initiator)
+            self.fire_remove_event(state, dict_, value, initiator)
 
 class DynCollectionAdapter(object):
     """the dynamic analogue to orm.collections.CollectionAdapter"""
@@ -156,10 +156,10 @@ class DynCollectionAdapter(object):
         return iter(self.data)
 
     def append_with_event(self, item, initiator=None):
-        self.attr.append(self.state, item, initiator)
+        self.attr.append(self.state, self.state.dict, item, initiator)
 
     def remove_with_event(self, item, initiator=None):
-        self.attr.remove(self.state, item, initiator)
+        self.attr.remove(self.state, self.state.dict, item, initiator)
 
     def append_without_event(self, item):
         pass
@@ -240,10 +240,10 @@ class AppenderMixin(object):
         return query
 
     def append(self, item):
-        self.attr.append(attributes.instance_state(self.instance), item, None)
+        self.attr.append(attributes.instance_state(self.instance), attributes.instance_dict(self.instance), item, None)
 
     def remove(self, item):
-        self.attr.remove(attributes.instance_state(self.instance), item, None)
+        self.attr.remove(attributes.instance_state(self.instance), attributes.instance_dict(self.instance), item, None)
 
 
 class AppenderQuery(AppenderMixin, Query):
index 0753ea991f40bf8b9611b7108eeca9fea53bd55f..aa041a5855547bdae65c5d82769557cc684eaf42 100644 (file)
@@ -15,6 +15,9 @@ class IdentityMap(dict):
         self._mutable_attrs = {}
         self.modified = False
         self._wr = weakref.ref(self)
+
+    def replace(self, state):
+        raise NotImplementedError()
         
     def add(self, state):
         raise NotImplementedError()
@@ -102,6 +105,17 @@ class WeakInstanceDict(IdentityMap):
     def contains_state(self, state):
         return dict.get(self, state.key) is state
         
+    def replace(self, state):
+        if dict.__contains__(self, state.key):
+            existing = dict.__getitem__(self, state.key)
+            if existing is not state:
+                self._manage_removed_state(existing)
+            else:
+                return
+                
+        dict.__setitem__(self, state.key, state)
+        self._manage_incoming_state(state)
+                 
     def add(self, state):
         if state.key in self:
             if dict.__getitem__(self, state.key) is not state:
@@ -161,12 +175,24 @@ class StrongInstanceDict(IdentityMap):
     def contains_state(self, state):
         return state.key in self and attributes.instance_state(self[state.key]) is state
     
+    def replace(self, state):
+        if dict.__contains__(self, state.key):
+            existing = dict.__getitem__(self, state.key)
+            existing = attributes.instance_state(existing)
+            if existing is not state:
+                self._manage_removed_state(existing)
+            else:
+                return
+
+        dict.__setitem__(self, state.key, state.obj())
+        self._manage_incoming_state(state)
+        
     def add(self, state):
         dict.__setitem__(self, state.key, state.obj())
         self._manage_incoming_state(state)
     
     def remove(self, state):
-        if dict.pop(self, state.key) is not state:
+        if attributes.instance_state(dict.pop(self, state.key)) is not state:
             raise AssertionError("State %s is not present in this identity map" % state)
         self._manage_removed_state(state)
     
@@ -176,7 +202,7 @@ class StrongInstanceDict(IdentityMap):
             self._manage_removed_state(state)
             
     def remove_key(self, key):
-        state = dict.__getitem__(self, key)
+        state = attributes.instance_state(dict.__getitem__(self, key))
         self.remove(state)
 
     def prune(self):
@@ -190,62 +216,3 @@ class StrongInstanceDict(IdentityMap):
         self.modified = bool(dirty)
         return ref_count - len(self)
         
-class IdentityManagedState(attributes.InstanceState):
-    def _instance_dict(self):
-        return None
-    
-    def modified_event(self, attr, should_copy, previous, passive=False):
-        attributes.InstanceState.modified_event(self, attr, should_copy, previous, passive)
-        
-        instance_dict = self._instance_dict()
-        if instance_dict:
-            instance_dict.modified = True
-    
-    def _is_really_none(self):
-        """do a check modified/resurrect.
-        
-        This would be called in the extremely rare
-        race condition that the weakref returned None but
-        the cleanup handler had not yet established the 
-        __resurrect callable as its replacement.
-        
-        """
-        if self.check_modified():
-            self.obj = self.__resurrect
-            return self.obj()
-        else:
-            return None
-            
-    def _cleanup(self, ref):
-        """weakref callback.
-        
-        This method may be called by an asynchronous
-        gc.
-        
-        If the state shows pending changes, the weakref
-        is replaced by the __resurrect callable which will
-        re-establish an object reference on next access,
-        else removes this InstanceState from the owning
-        identity map, if any.
-        
-        """
-        if self.check_modified():
-            self.obj = self.__resurrect
-        else:
-            instance_dict = self._instance_dict()
-            if instance_dict:
-                instance_dict.remove(self)
-            self.dispose()
-            
-    def __resurrect(self):
-        """A substitute for the obj() weakref function which resurrects."""
-        
-        # store strong ref'ed version of the object; will revert
-        # to weakref when changes are persisted
-        obj = self.manager.new_instance(state=self)
-        self.obj = weakref.ref(obj, self._cleanup)
-        self._strong_obj = obj
-        obj.__dict__.update(self.dict)
-        self.dict = obj.__dict__
-        self._run_on_load(obj)
-        return obj
index d36f51194e09aec35cd6bd08490998fca4e06f66..0ac771305833137686b1ac126a43cc4023c4a445 100644 (file)
@@ -359,7 +359,7 @@ class MapperProperty(object):
 
         Callables are of the following form::
 
-            def new_execute(state, row, **flags):
+            def new_execute(state, dict_, row, **flags):
                 # process incoming instance state and given row.  the instance is
                 # "new" and was just created upon receipt of this row.
                 # flags is a dictionary containing at least the following
@@ -368,7 +368,7 @@ class MapperProperty(object):
                 #           result of reading this row
                 #   instancekey - identity key of the instance
 
-            def existing_execute(state, row, **flags):
+            def existing_execute(state, dict_, row, **flags):
                 # process incoming instance state and given row.  the instance is
                 # "existing" and was created based on a previous row.
 
@@ -427,13 +427,23 @@ class MapperProperty(object):
     def register_dependencies(self, *args, **kwargs):
         """Called by the ``Mapper`` in response to the UnitOfWork
         calling the ``Mapper``'s register_dependencies operation.
-        Should register with the UnitOfWork all inter-mapper
-        dependencies as well as dependency processors (see UOW docs
-        for more details).
+        Establishes a topological dependency between two mappers
+        which will affect the order in which mappers persist data.
+        
         """
 
         pass
 
+    def register_processors(self, *args, **kwargs):
+        """Called by the ``Mapper`` in response to the UnitOfWork
+        calling the ``Mapper``'s register_processors operation.
+        Establishes a processor object between two mappers which
+        will link data and state between parent/child objects.
+        
+        """
+
+        pass
+        
     def is_primary(self):
         """Return True if this ``MapperProperty``'s mapper is the
         primary mapper for its class.
@@ -939,3 +949,7 @@ class InstrumentationManager(object):
 
     def state_getter(self, class_):
         return lambda instance: getattr(instance, '_default_state')
+
+    def dict_getter(self, class_):
+        return lambda inst: self.get_instance_dict(class_, inst)
+        
\ No newline at end of file
index 8af6153d6bf463b65b4a9c2bbf9472bf27027449..87c4c8100fa6176ba591fd2f6a3f9103e8d3e299 100644 (file)
@@ -23,7 +23,6 @@ deque = __import__('collections').deque
 from sqlalchemy import sql, util, log, exc as sa_exc
 from sqlalchemy.sql import expression, visitors, operators, util as sqlutil
 from sqlalchemy.orm import attributes, exc, sync
-from sqlalchemy.orm.identity import IdentityManagedState
 from sqlalchemy.orm.interfaces import (
     MapperProperty, EXT_CONTINUE, PropComparator
     )
@@ -255,7 +254,8 @@ class Mapper(object):
 
             for mapper in self.iterate_to_root():
                 util.reset_memoized(mapper, '_equivalent_columns')
-
+                util.reset_memoized(mapper, '_sorted_tables')
+                
             if self.order_by is False and not self.concrete and self.inherits.order_by is not False:
                 self.order_by = self.inherits.order_by
 
@@ -357,7 +357,6 @@ class Mapper(object):
 
         if manager is None:
             manager = attributes.register_class(self.class_, 
-                instance_state_factory = IdentityManagedState,
                 deferred_scalar_loader = _load_scalar_attributes
             )
 
@@ -372,6 +371,8 @@ class Mapper(object):
         event_registry = manager.events
         event_registry.add_listener('on_init', _event_on_init)
         event_registry.add_listener('on_init_failure', _event_on_init_failure)
+        event_registry.add_listener('on_resurrect', _event_on_resurrect)
+        
         for key, method in util.iterate_attributes(self.class_):
             if isinstance(method, types.FunctionType):
                 if hasattr(method, '__sa_reconstructor__'):
@@ -1173,6 +1174,19 @@ class Mapper(object):
 
     # persistence
 
+    @util.memoized_property
+    def _sorted_tables(self):
+        table_to_mapper = {}
+        for mapper in self.base_mapper.polymorphic_iterator():
+            for t in mapper.tables:
+                table_to_mapper[t] = mapper
+        
+        sorted_ = sqlutil.sort_tables(table_to_mapper.iterkeys())
+        ret = util.OrderedDict()
+        for t in sorted_:
+            ret[t] = table_to_mapper[t]
+        return ret
+
     def _save_obj(self, states, uowtransaction, postupdate=False, post_update_cols=None, single=False):
         """Issue ``INSERT`` and/or ``UPDATE`` statements for a list of objects.
 
@@ -1198,16 +1212,37 @@ class Mapper(object):
 
         # if session has a connection callable,
         # organize individual states with the connection to use for insert/update
+        tups = []
         if 'connection_callable' in uowtransaction.mapper_flush_opts:
             connection_callable = uowtransaction.mapper_flush_opts['connection_callable']
-            tups = [(state, _state_mapper(state), connection_callable(self, state.obj()), _state_has_identity(state)) for state in _sort_states(states)]
+            for state in _sort_states(states):
+                m = _state_mapper(state)
+                tups.append(
+                    (
+                        state, 
+                        m, 
+                        connection_callable(self, state.obj()), 
+                        _state_has_identity(state), 
+                        state.key or m._identity_key_from_state(state)
+                    )
+                )
         else:
             connection = uowtransaction.transaction.connection(self)
-            tups = [(state, _state_mapper(state), connection, _state_has_identity(state)) for state in _sort_states(states)]
+            for state in _sort_states(states):
+                m = _state_mapper(state)
+                tups.append(
+                    (
+                        state, 
+                        m, 
+                        connection,
+                        _state_has_identity(state), 
+                        state.key or m._identity_key_from_state(state)
+                    )
+                )
 
         if not postupdate:
             # call before_XXX extensions
-            for state, mapper, connection, has_identity in tups:
+            for state, mapper, connection, has_identity, instance_key in tups:
                 if not has_identity:
                     if 'before_insert' in mapper.extension:
                         mapper.extension.before_insert(mapper, connection, state.obj())
@@ -1215,39 +1250,44 @@ class Mapper(object):
                     if 'before_update' in mapper.extension:
                         mapper.extension.before_update(mapper, connection, state.obj())
 
-        for state, mapper, connection, has_identity in tups:
-            # detect if we have a "pending" instance (i.e. has no instance_key attached to it),
-            # and another instance with the same identity key already exists as persistent.  convert to an
-            # UPDATE if so.
-            instance_key = mapper._identity_key_from_state(state)
-            if not postupdate and not has_identity and instance_key in uowtransaction.session.identity_map:
-                instance = uowtransaction.session.identity_map[instance_key]
-                existing = attributes.instance_state(instance)
-                if not uowtransaction.is_deleted(existing):
-                    raise exc.FlushError("New instance %s with identity key %s conflicts with persistent instance %s" % (state_str(state), str(instance_key), state_str(existing)))
-                if self._should_log_debug:
-                    self._log_debug("detected row switch for identity %s.  will update %s, remove %s from transaction" % (instance_key, state_str(state), state_str(existing)))
-                uowtransaction.set_row_switch(existing)
-
-        table_to_mapper = {}
-        for mapper in self.base_mapper.polymorphic_iterator():
-            for t in mapper.tables:
-                table_to_mapper[t] = mapper
+        row_switches = set()
+        if not postupdate:
+            for state, mapper, connection, has_identity, instance_key in tups:
+                # detect if we have a "pending" instance (i.e. has no instance_key attached to it),
+                # and another instance with the same identity key already exists as persistent.  convert to an
+                # UPDATE if so.
+                if not has_identity and instance_key in uowtransaction.session.identity_map:
+                    instance = uowtransaction.session.identity_map[instance_key]
+                    existing = attributes.instance_state(instance)
+                    if not uowtransaction.is_deleted(existing):
+                        raise exc.FlushError(
+                            "New instance %s with identity key %s conflicts with persistent instance %s" % 
+                            (state_str(state), instance_key, state_str(existing)))
+                    if self._should_log_debug:
+                        self._log_debug(
+                            "detected row switch for identity %s.  will update %s, remove %s from transaction", 
+                            instance_key, state_str(state), state_str(existing))
+                            
+                    # remove the "delete" flag from the existing element
+                    uowtransaction.set_row_switch(existing)
+                    row_switches.add(state)
+        
+        table_to_mapper = self._sorted_tables
 
-        for table in sqlutil.sort_tables(table_to_mapper.iterkeys()):
+        for table in table_to_mapper.iterkeys():
             insert = []
             update = []
 
-            for state, mapper, connection, has_identity in tups:
+            for state, mapper, connection, has_identity, instance_key in tups:
                 if table not in mapper._pks_by_table:
                     continue
+                    
                 pks = mapper._pks_by_table[table]
-                instance_key = mapper._identity_key_from_state(state)
-
+                
                 if self._should_log_debug:
                     self._log_debug("_save_obj() table '%s' instance %s identity %s" % (table.name, state_str(state), str(instance_key)))
 
-                isinsert = not instance_key in uowtransaction.session.identity_map and not postupdate and not has_identity
+                isinsert = not has_identity and not postupdate and state not in row_switches
                 
                 params = {}
                 value_params = {}
@@ -1364,7 +1404,7 @@ class Mapper(object):
                             sync.populate(state, m, state, m, m._inherits_equated_pairs)
 
         if not postupdate:
-            for state, mapper, connection, has_identity in tups:
+            for state, mapper, connection, has_identity, instance_key in tups:
 
                 # expire readonly attributes
                 readonly = state.unmodified.intersection(
@@ -1434,12 +1474,9 @@ class Mapper(object):
             if 'before_delete' in mapper.extension:
                 mapper.extension.before_delete(mapper, connection, state.obj())
 
-        table_to_mapper = {}
-        for mapper in self.base_mapper.polymorphic_iterator():
-            for t in mapper.tables:
-                table_to_mapper[t] = mapper
+        table_to_mapper = self._sorted_tables
 
-        for table in reversed(sqlutil.sort_tables(table_to_mapper.iterkeys())):
+        for table in reversed(table_to_mapper.keys()):
             delete = {}
             for state, mapper, connection in tups:
                 if table not in mapper._pks_by_table:
@@ -1485,6 +1522,10 @@ class Mapper(object):
         for dep in self._props.values() + self._dependency_processors:
             dep.register_dependencies(uowcommit)
 
+    def _register_processors(self, uowcommit):
+        for dep in self._props.values() + self._dependency_processors:
+            dep.register_processors(uowcommit)
+
     # result set conversion
 
     def _instance_processor(self, context, path, adapter, polymorphic_from=None, extension=None, only_load_props=None, refresh_state=None, polymorphic_discriminator=None):
@@ -1514,7 +1555,7 @@ class Mapper(object):
         new_populators = []
         existing_populators = []
 
-        def populate_state(state, row, isnew, only_load_props, **flags):
+        def populate_state(state, dict_, row, isnew, only_load_props, **flags):
             if isnew:
                 if context.options:
                     state.load_options = context.options
@@ -1533,7 +1574,7 @@ class Mapper(object):
                 populators = [p for p in populators if p[0] in only_load_props]
 
             for key, populator in populators:
-                populator(state, row, isnew=isnew, **flags)
+                populator(state, dict_, row, isnew=isnew, **flags)
 
         session_identity_map = context.session.identity_map
 
@@ -1573,9 +1614,11 @@ class Mapper(object):
             if identitykey in session_identity_map:
                 instance = session_identity_map[identitykey]
                 state = attributes.instance_state(instance)
+                dict_ = attributes.instance_dict(instance)
 
                 if self._should_log_debug:
-                    self._log_debug("_instance(): using existing instance %s identity %s" % (instance_str(instance), identitykey))
+                    self._log_debug("_instance(): using existing instance %s identity %s",
+                                        instance_str(instance), identitykey)
 
                 isnew = state.runid != context.runid
                 currentload = not isnew
@@ -1592,12 +1635,13 @@ class Mapper(object):
                 # when eager_defaults is True.
                 state = refresh_state
                 instance = state.obj()
+                dict_ = attributes.instance_dict(instance)
                 isnew = state.runid != context.runid
                 currentload = True
                 loaded_instance = False
             else:
                 if self._should_log_debug:
-                    self._log_debug("_instance(): identity key %s not in session" % str(identitykey))
+                    self._log_debug("_instance(): identity key %s not in session", identitykey)
 
                 if self.allow_null_pks:
                     for x in identitykey[1]:
@@ -1625,8 +1669,10 @@ class Mapper(object):
                     instance = self.class_manager.new_instance()
 
                 if self._should_log_debug:
-                    self._log_debug("_instance(): created new instance %s identity %s" % (instance_str(instance), str(identitykey)))
+                    self._log_debug("_instance(): created new instance %s identity %s",
+                                instance_str(instance), identitykey)
 
+                dict_ = attributes.instance_dict(instance)
                 state = attributes.instance_state(instance)
                 state.key = identitykey
 
@@ -1638,12 +1684,12 @@ class Mapper(object):
             if currentload or populate_existing:
                 if isnew:
                     state.runid = context.runid
-                    context.progress.add(state)
+                    context.progress[state] = dict_
 
                 if not populate_instance or \
                         populate_instance(self, context, row, instance, 
                             only_load_props=only_load_props, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE:
-                    populate_state(state, row, isnew, only_load_props)
+                    populate_state(state, dict_, row, isnew, only_load_props)
 
             else:
                 # populate attributes on non-loading instances which have been expired
@@ -1652,16 +1698,16 @@ class Mapper(object):
 
                     if state in context.partials:
                         isnew = False
-                        attrs = context.partials[state]
+                        (d_, attrs) = context.partials[state]
                     else:
                         isnew = True
                         attrs = state.unloaded
-                        context.partials[state] = attrs  #<-- allow query.instances to commit the subset of attrs
+                        context.partials[state] = (dict_, attrs)  #<-- allow query.instances to commit the subset of attrs
 
                     if not populate_instance or \
                             populate_instance(self, context, row, instance, 
                                 only_load_props=attrs, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE:
-                        populate_state(state, row, isnew, attrs, instancekey=identitykey)
+                        populate_state(state, dict_, row, isnew, attrs, instancekey=identitykey)
 
             if loaded_instance:
                 state._run_on_load(instance)
@@ -1759,6 +1805,14 @@ def _event_on_init_failure(state, instance, args, kwargs):
             instrumenting_mapper, instrumenting_mapper.class_,
             state.manager.events.original_init, instance, args, kwargs)
 
+def _event_on_resurrect(state, instance):
+    # re-populate the primary key elements
+    # of the dict based on the mapping.
+    instrumenting_mapper = state.manager.info[_INSTRUMENTOR]
+    for col, val in zip(instrumenting_mapper.primary_key, state.key[1]):
+        instrumenting_mapper._set_state_attr_by_column(state, col, val)
+    
+    
 def _sort_states(states):
     return sorted(states, key=operator.attrgetter('sort_key'))
 
index d0cca2dc1145a0c6d1f60af45787de550beef63f..5605cdcd1e83fbbf1dc31a4f00a11d32603b068a 100644 (file)
@@ -96,13 +96,13 @@ class ColumnProperty(StrategizedProperty):
         return ColumnProperty(deferred=self.deferred, group=self.group, *self.columns)
 
     def getattr(self, state, column):
-        return state.get_impl(self.key).get(state)
+        return state.get_impl(self.key).get(state, state.dict)
 
     def getcommitted(self, state, column, passive=False):
-        return state.get_impl(self.key).get_committed_value(state, passive=passive)
+        return state.get_impl(self.key).get_committed_value(state, state.dict, passive=passive)
 
     def setattr(self, state, value, column):
-        state.get_impl(self.key).set(state, value, None)
+        state.get_impl(self.key).set(state, state.dict, value, None)
 
     def merge(self, session, source, dest, dont_load, _recursive):
         value = attributes.instance_state(source).value_as_iterable(
@@ -159,7 +159,7 @@ class CompositeProperty(ColumnProperty):
         super(ColumnProperty, self).do_init()
 
     def getattr(self, state, column):
-        obj = state.get_impl(self.key).get(state)
+        obj = state.get_impl(self.key).get(state, state.dict)
         return self.get_col_value(column, obj)
 
     def getcommitted(self, state, column, passive=False):
@@ -168,7 +168,7 @@ class CompositeProperty(ColumnProperty):
 
     def setattr(self, state, value, column):
 
-        obj = state.get_impl(self.key).get(state)
+        obj = state.get_impl(self.key).get(state, state.dict)
         if obj is None:
             obj = self.composite_class(*[None for c in self.columns])
             state.get_impl(self.key).set(state, obj, None)
@@ -635,7 +635,7 @@ class RelationProperty(StrategizedProperty):
                     return
 
         source_state = attributes.instance_state(source)
-        dest_state = attributes.instance_state(dest)
+        dest_state, dest_dict = attributes.instance_state(dest), attributes.instance_dict(dest)
 
         if not "merge" in self.cascade:
             dest_state.expire_attributes([self.key])
@@ -658,7 +658,7 @@ class RelationProperty(StrategizedProperty):
                 for c in dest_list:
                     coll.append_without_event(c)
             else:
-                getattr(dest.__class__, self.key).impl._set_iterable(dest_state, dest_list)
+                getattr(dest.__class__, self.key).impl._set_iterable(dest_state, dest_dict, dest_list)
         else:
             current = instances[0]
             if current is not None:
@@ -1119,6 +1119,10 @@ class RelationProperty(StrategizedProperty):
         if not self.viewonly:
             self._dependency_processor.register_dependencies(uowcommit)
 
+    def register_processors(self, uowcommit):
+        if not self.viewonly:
+            self._dependency_processor.register_processors(uowcommit)
+
 PropertyLoader = RelationProperty
 log.class_logger(RelationProperty)
 
index 28ddcc5eaa06a388fab948591ae882783bf71bb1..e3cc3c75697ef84a140ccc326d2c7a47913d1ee5 100644 (file)
@@ -1330,7 +1330,7 @@ class Query(object):
             rowtuple.keys = labels.keys
 
         while True:
-            context.progress = set()
+            context.progress = {}
             context.partials = {}
 
             if self._yield_per:
@@ -1354,13 +1354,13 @@ class Query(object):
                 rows = filter(rows)
 
             if context.refresh_state and self._only_load_props and context.refresh_state in context.progress:
-                context.refresh_state.commit(self._only_load_props)
-                context.progress.remove(context.refresh_state)
+                context.refresh_state.commit(context.refresh_state.dict, self._only_load_props)
+                context.progress.pop(context.refresh_state)
 
             session._finalize_loaded(context.progress)
 
-            for ii, attrs in context.partials.items():
-                ii.commit(attrs)
+            for ii, (dict_, attrs) in context.partials.items():
+                ii.commit(dict_, attrs)
 
             for row in rows:
                 yield row
@@ -1683,14 +1683,14 @@ class Query(object):
                 evaluated_keys = value_evaluators.keys()
 
                 if issubclass(cls, target_cls) and eval_condition(obj):
-                    state = attributes.instance_state(obj)
+                    state, dict_ = attributes.instance_state(obj), attributes.instance_dict(obj)
 
                     # only evaluate unmodified attributes
                     to_evaluate = state.unmodified.intersection(evaluated_keys)
                     for key in to_evaluate:
-                        state.dict[key] = value_evaluators[key](obj)
+                        dict_[key] = value_evaluators[key](obj)
 
-                    state.commit(list(to_evaluate))
+                    state.commit(dict_, list(to_evaluate))
 
                     # expire attributes with pending changes (there was no autoflush, so they are overwritten)
                     state.expire_attributes(set(evaluated_keys).difference(to_evaluate))
index 1e3a750d950fb51864d34f6b4c770fd8178986b4..00a7d55e5ecdc32c72744d7f2bf97cafd1eb6ecb 100644 (file)
@@ -12,7 +12,7 @@ import sqlalchemy.exceptions as sa_exc
 from sqlalchemy import util, sql, engine, log
 from sqlalchemy.sql import util as sql_util, expression
 from sqlalchemy.orm import (
-    SessionExtension, attributes, exc, query, unitofwork, util as mapperutil,
+    SessionExtension, attributes, exc, query, unitofwork, util as mapperutil, state
     )
 from sqlalchemy.orm.util import object_mapper as _object_mapper
 from sqlalchemy.orm.util import class_mapper as _class_mapper
@@ -899,8 +899,8 @@ class Session(object):
             self.flush()
 
     def _finalize_loaded(self, states):
-        for state in states:
-            state.commit_all()
+        for state, dict_ in states.items():
+            state.commit_all(dict_)
 
     def refresh(self, instance, attribute_names=None):
         """Refresh the attributes on the given instance.
@@ -1020,11 +1020,9 @@ class Session(object):
                 # primary key switch
                 self.identity_map.remove(state)
                 state.key = instance_key
-
-            if state.key in self.identity_map and not self.identity_map.contains_state(state):
-                self.identity_map.remove_key(state.key)
-            self.identity_map.add(state)
-            state.commit_all()
+            
+            self.identity_map.replace(state)
+            state.commit_all(state.dict)
 
         # remove from new last, might be the last strong ref
         if state in self._new:
@@ -1213,7 +1211,7 @@ class Session(object):
             prop.merge(self, instance, merged, dont_load, _recursive)
 
         if dont_load:
-            attributes.instance_state(merged).commit_all()  # remove any history
+            attributes.instance_state(merged).commit_all(attributes.instance_dict(merged))  # remove any history
 
         if new_instance:
             merged_state._run_on_load(merged)
@@ -1368,7 +1366,7 @@ class Session(object):
             self.identity_map.modified = False
             return
 
-        flush_context   = UOWTransaction(self)
+        flush_context = UOWTransaction(self)
 
         if self.extensions:
             for ext in self.extensions:
@@ -1489,7 +1487,7 @@ class Session(object):
         return util.IdentitySet(
             [state
              for state in self.identity_map.all_states()
-             if state.check_modified()])
+             if state.modified])
 
     @property
     def dirty(self):
@@ -1528,7 +1526,7 @@ class Session(object):
 
         return util.IdentitySet(self._new.values())
 
-_expire_state = attributes.InstanceState.expire_attributes
+_expire_state = state.InstanceState.expire_attributes
     
 UOWEventHandler = unitofwork.UOWEventHandler
 
@@ -1548,16 +1546,19 @@ def _cascade_unknown_state_iterator(cascade, state, **kwargs):
         yield _state_for_unknown_persistence_instance(o), m
 
 def _state_for_unsaved_instance(instance, create=False):
-    manager = attributes.manager_of_class(instance.__class__)
-    if manager is None:
+    try:
+        state = attributes.instance_state(instance)
+    except AttributeError:
         raise exc.UnmappedInstanceError(instance)
-    if manager.has_state(instance):
-        state = manager.state_of(instance)
+    if state:
         if state.key is not None:
             raise sa_exc.InvalidRequestError(
                 "Instance '%s' is already persistent" %
                 mapperutil.state_str(state))
     elif create:
+        manager = attributes.manager_of_class(instance.__class__)
+        if manager is None:
+            raise exc.UnmappedInstanceError(instance)
         state = manager.setup_instance(instance)
     else:
         raise exc.UnmappedInstanceError(instance)
diff --git a/lib/sqlalchemy/orm/state.py b/lib/sqlalchemy/orm/state.py
new file mode 100644 (file)
index 0000000..c99dfe7
--- /dev/null
@@ -0,0 +1,429 @@
+from sqlalchemy.util import EMPTY_SET
+import weakref
+from sqlalchemy import util
+from sqlalchemy.orm.attributes import PASSIVE_NORESULT, PASSIVE_OFF, NEVER_SET, NO_VALUE, manager_of_class, ATTR_WAS_SET
+from sqlalchemy.orm import attributes
+from sqlalchemy.orm import interfaces
+
+class InstanceState(object):
+    """tracks state information at the instance level."""
+
+    session_id = None
+    key = None
+    runid = None
+    expired_attributes = EMPTY_SET
+    load_options = EMPTY_SET
+    load_path = ()
+    insert_order = None
+    mutable_dict = None
+    
+    def __init__(self, obj, manager):
+        self.class_ = obj.__class__
+        self.manager = manager
+        self.obj = weakref.ref(obj, self._cleanup)
+        self.modified = False
+        self.callables = {}
+        self.expired = False
+        self.committed_state = {}
+        self.pending = {}
+        self.parents = {}
+        
+    def detach(self):
+        if self.session_id:
+            del self.session_id
+
+    def dispose(self):
+        if self.session_id:
+            del self.session_id
+        del self.obj
+    
+    def _cleanup(self, ref):
+        instance_dict = self._instance_dict()
+        if instance_dict:
+            instance_dict.remove(self)
+        self.dispose()
+    
+    def obj(self):
+        return None
+    
+    @property
+    def dict(self):
+        o = self.obj()
+        if o is not None:
+            return attributes.instance_dict(o)
+        else:
+            return {}
+        
+    @property
+    def sort_key(self):
+        return self.key and self.key[1] or (self.insert_order, )
+
+    def check_modified(self):
+        # TODO: deprecate
+        return self.modified
+
+    def initialize_instance(*mixed, **kwargs):
+        self, instance, args = mixed[0], mixed[1], mixed[2:]
+        manager = self.manager
+
+        for fn in manager.events.on_init:
+            fn(self, instance, args, kwargs)
+            
+        # LESSTHANIDEAL:
+        # adjust for the case where the InstanceState was created before
+        # mapper compilation, and this actually needs to be a MutableAttrInstanceState
+        if manager.mutable_attributes and self.__class__ is not MutableAttrInstanceState:
+            self.__class__ = MutableAttrInstanceState
+            self.obj = weakref.ref(self.obj(), self._cleanup)
+            self.mutable_dict = {}
+            
+        try:
+            return manager.events.original_init(*mixed[1:], **kwargs)
+        except:
+            for fn in manager.events.on_init_failure:
+                fn(self, instance, args, kwargs)
+            raise
+
+    def get_history(self, key, **kwargs):
+        return self.manager.get_impl(key).get_history(self, self.dict, **kwargs)
+
+    def get_impl(self, key):
+        return self.manager.get_impl(key)
+
+    def get_pending(self, key):
+        if key not in self.pending:
+            self.pending[key] = PendingCollection()
+        return self.pending[key]
+
+    def value_as_iterable(self, key, passive=PASSIVE_OFF):
+        """return an InstanceState attribute as a list,
+        regardless of it being a scalar or collection-based
+        attribute.
+
+        returns None if passive is not PASSIVE_OFF and the getter returns
+        PASSIVE_NORESULT.
+        """
+
+        impl = self.get_impl(key)
+        dict_ = self.dict
+        x = impl.get(self, dict_, passive=passive)
+        if x is PASSIVE_NORESULT:
+            return None
+        elif hasattr(impl, 'get_collection'):
+            return impl.get_collection(self, dict_, x, passive=passive)
+        elif isinstance(x, list):
+            return x
+        else:
+            return [x]
+
+    def _run_on_load(self, instance):
+        self.manager.events.run('on_load', instance)
+
+    def __getstate__(self):
+        return {'key': self.key,
+                'committed_state': self.committed_state,
+                'pending': self.pending,
+                'parents': self.parents,
+                'modified': self.modified,
+                'expired':self.expired,
+                'load_options':self.load_options,
+                'load_path':interfaces.serialize_path(self.load_path),
+                'instance': self.obj(),
+                'expired_attributes':self.expired_attributes,
+                'callables': self.callables}
+
+    def __setstate__(self, state):
+        self.committed_state = state['committed_state']
+        self.parents = state['parents']
+        self.key = state['key']
+        self.session_id = None
+        self.pending = state['pending']
+        self.modified = state['modified']
+        self.obj = weakref.ref(state['instance'])
+        self.load_options = state['load_options'] or EMPTY_SET
+        self.load_path = interfaces.deserialize_path(state['load_path'])
+        self.class_ = self.obj().__class__
+        self.manager = manager_of_class(self.class_)
+        self.callables = state['callables']
+        self.runid = None
+        self.expired = state['expired']
+        self.expired_attributes = state['expired_attributes']
+
+    def initialize(self, key):
+        self.manager.get_impl(key).initialize(self, self.dict)
+
+    def set_callable(self, key, callable_):
+        self.dict.pop(key, None)
+        self.callables[key] = callable_
+
+    def __call__(self):
+        """__call__ allows the InstanceState to act as a deferred
+        callable for loading expired attributes, which is also
+        serializable (picklable).
+
+        """
+        unmodified = self.unmodified
+        class_manager = self.manager
+        class_manager.deferred_scalar_loader(self, [
+            attr.impl.key for attr in class_manager.attributes if
+                attr.impl.accepts_scalar_loader and
+                attr.impl.key in self.expired_attributes and
+                attr.impl.key in unmodified
+            ])
+        for k in self.expired_attributes:
+            self.callables.pop(k, None)
+        del self.expired_attributes
+        return ATTR_WAS_SET
+
+    @property
+    def unmodified(self):
+        """a set of keys which have no uncommitted changes"""
+        
+        return set(self.manager).difference(self.committed_state)
+
+    @property
+    def unloaded(self):
+        """a set of keys which do not have a loaded value.
+
+        This includes expired attributes and any other attribute that
+        was never populated or modified.
+
+        """
+        return set(
+            key for key in self.manager.iterkeys()
+            if key not in self.committed_state and key not in self.dict)
+
+    def expire_attributes(self, attribute_names):
+        self.expired_attributes = set(self.expired_attributes)
+
+        if attribute_names is None:
+            attribute_names = self.manager.keys()
+            self.expired = True
+            self.modified = False
+            filter_deferred = True
+        else:
+            filter_deferred = False
+        dict_ = self.dict
+        
+        for key in attribute_names:
+            impl = self.manager[key].impl
+            if not filter_deferred or \
+                not impl.dont_expire_missing or \
+                key in dict_:
+                self.expired_attributes.add(key)
+                if impl.accepts_scalar_loader:
+                    self.callables[key] = self
+            dict_.pop(key, None)
+            self.pending.pop(key, None)
+            self.committed_state.pop(key, None)
+            if self.mutable_dict:
+                self.mutable_dict.pop(key, None)
+                
+    def reset(self, key, dict_):
+        """remove the given attribute and any callables associated with it."""
+
+        dict_.pop(key, None)
+        self.callables.pop(key, None)
+
+    def _instance_dict(self):
+        return None
+
+    def _is_really_none(self):
+        return self.obj()
+        
+    def modified_event(self, dict_, attr, should_copy, previous, passive=PASSIVE_OFF):
+        needs_committed = attr.key not in self.committed_state
+
+        if needs_committed:
+            if previous is NEVER_SET:
+                if passive:
+                    if attr.key in dict_:
+                        previous = dict_[attr.key]
+                else:
+                    previous = attr.get(self, dict_)
+
+            if should_copy and previous not in (None, NO_VALUE, NEVER_SET):
+                previous = attr.copy(previous)
+
+            if needs_committed:
+                self.committed_state[attr.key] = previous
+
+        self.modified = True
+        self._strong_obj = self.obj()
+
+        instance_dict = self._instance_dict()
+        if instance_dict:
+            instance_dict.modified = True
+        
+    def commit(self, dict_, keys):
+        """Commit attributes.
+
+        This is used by a partial-attribute load operation to mark committed
+        those attributes which were refreshed from the database.
+
+        Attributes marked as "expired" can potentially remain "expired" after
+        this step if a value was not populated in state.dict.
+
+        """
+        class_manager = self.manager
+        for key in keys:
+            if key in dict_ and key in class_manager.mutable_attributes:
+                class_manager[key].impl.commit_to_state(self, dict_, self.committed_state)
+            else:
+                self.committed_state.pop(key, None)
+
+        self.expired = False
+        # unexpire attributes which have loaded
+        for key in self.expired_attributes.intersection(keys):
+            if key in dict_:
+                self.expired_attributes.remove(key)
+                self.callables.pop(key, None)
+
+    def commit_all(self, dict_):
+        """commit all attributes unconditionally.
+
+        This is used after a flush() or a full load/refresh
+        to remove all pending state from the instance.
+
+         - all attributes are marked as "committed"
+         - the "strong dirty reference" is removed
+         - the "modified" flag is set to False
+         - any "expired" markers/callables are removed.
+
+        Attributes marked as "expired" can potentially remain "expired" after this step
+        if a value was not populated in state.dict.
+
+        """
+        
+        self.committed_state = {}
+        self.pending = {}
+        
+        # unexpire attributes which have loaded
+        if self.expired_attributes:
+            for key in self.expired_attributes.intersection(dict_):
+                self.callables.pop(key, None)
+            self.expired_attributes.difference_update(dict_)
+
+        for key in self.manager.mutable_attributes:
+            if key in dict_:
+                self.manager[key].impl.commit_to_state(self, dict_, self.committed_state)
+
+        self.modified = self.expired = False
+        self._strong_obj = None
+
+class MutableAttrInstanceState(InstanceState):
+    def __init__(self, obj, manager):
+        self.mutable_dict = {}
+        InstanceState.__init__(self, obj, manager)
+        
+    def _get_modified(self, dict_=None):
+        if self.__dict__.get('modified', False):
+            return True
+        else:
+            if dict_ is None:
+                dict_ = self.dict
+            for key in self.manager.mutable_attributes:
+                if self.manager[key].impl.check_mutable_modified(self, dict_):
+                    return True
+            else:
+                return False
+    
+    def _set_modified(self, value):
+        self.__dict__['modified'] = value
+        
+    modified = property(_get_modified, _set_modified)
+    
+    @property
+    def unmodified(self):
+        """a set of keys which have no uncommitted changes"""
+
+        dict_ = self.dict
+        return set(
+            key for key in self.manager.iterkeys()
+            if (key not in self.committed_state or
+                (key in self.manager.mutable_attributes and
+                 not self.manager[key].impl.check_mutable_modified(self, dict_))))
+
+    def _is_really_none(self):
+        """do a check modified/resurrect.
+        
+        This would be called in the extremely rare
+        race condition that the weakref returned None but
+        the cleanup handler had not yet established the 
+        __resurrect callable as its replacement.
+        
+        """
+        if self.modified:
+            self.obj = self.__resurrect
+            return self.obj()
+        else:
+            return None
+
+    def reset(self, key, dict_):
+        self.mutable_dict.pop(key, None)
+        InstanceState.reset(self, key, dict_)
+    
+    def _cleanup(self, ref):
+        """weakref callback.
+        
+        This method may be called by an asynchronous
+        gc.
+        
+        If the state shows pending changes, the weakref
+        is replaced by the __resurrect callable which will
+        re-establish an object reference on next access,
+        else removes this InstanceState from the owning
+        identity map, if any.
+        
+        """
+        if self._get_modified(self.mutable_dict):
+            self.obj = self.__resurrect
+        else:
+            instance_dict = self._instance_dict()
+            if instance_dict:
+                instance_dict.remove(self)
+            self.dispose()
+            
+    def __resurrect(self):
+        """A substitute for the obj() weakref function which resurrects."""
+        
+        # store strong ref'ed version of the object; will revert
+        # to weakref when changes are persisted
+        
+        obj = self.manager.new_instance(state=self)
+        self.obj = weakref.ref(obj, self._cleanup)
+        self._strong_obj = obj
+        obj.__dict__.update(self.mutable_dict)
+
+        # re-establishes identity attributes from the key
+        self.manager.events.run('on_resurrect', self, obj)
+        
+        # TODO: don't really think we should run this here.
+        # resurrect is only meant to preserve the minimal state needed to
+        # do an UPDATE, not to produce a fully usable object
+        self._run_on_load(obj)
+        
+        return obj
+
+class PendingCollection(object):
+    """A writable placeholder for an unloaded collection.
+
+    Stores items appended to and removed from a collection that has not yet
+    been loaded. When the collection is loaded, the changes stored in
+    PendingCollection are applied to it to produce the final result.
+
+    """
+    def __init__(self):
+        self.deleted_items = util.IdentitySet()
+        self.added_items = util.OrderedIdentitySet()
+
+    def append(self, value):
+        if value in self.deleted_items:
+            self.deleted_items.remove(value)
+        self.added_items.add(value)
+
+    def remove(self, value):
+        if value in self.added_items:
+            self.added_items.remove(value)
+        self.deleted_items.add(value)
+
index 1aeb311e1c99e573fcd2711918b0fc9377ba5859..20cbb8f4dcdeb09300203d25cbd6779626a297f2 100644 (file)
@@ -115,8 +115,8 @@ class ColumnLoader(LoaderStrategy):
         if adapter:
             col = adapter.columns[col]
         if col in row:
-            def new_execute(state, row, **flags):
-                state.dict[key] = row[col]
+            def new_execute(state, dict_, row, **flags):
+                dict_[key] = row[col]
                 
             if self._should_log_debug:
                 new_execute = self.debug_callable(new_execute, self.logger,
@@ -125,7 +125,7 @@ class ColumnLoader(LoaderStrategy):
                 )
             return (new_execute, None)
         else:
-            def new_execute(state, row, isnew, **flags):
+            def new_execute(state, dict_, row, isnew, **flags):
                 if isnew:
                     state.expire_attributes([key])
             if self._should_log_debug:
@@ -171,15 +171,15 @@ class CompositeColumnLoader(ColumnLoader):
             columns = [adapter.columns[c] for c in columns]
         for c in columns:
             if c not in row:
-                def new_execute(state, row, isnew, **flags):
+                def new_execute(state, dict_, row, isnew, **flags):
                     if isnew:
                         state.expire_attributes([key])
                 if self._should_log_debug:
                     self.logger.debug("%s deferring load" % self)
                 return (new_execute, None)
         else:
-            def new_execute(state, row, **flags):
-                state.dict[key] = composite_class(*[row[c] for c in columns])
+            def new_execute(state, dict_, row, **flags):
+                dict_[key] = composite_class(*[row[c] for c in columns])
 
             if self._should_log_debug:
                 new_execute = self.debug_callable(new_execute, self.logger,
@@ -202,13 +202,13 @@ class DeferredColumnLoader(LoaderStrategy):
             return self.parent_property._get_strategy(ColumnLoader).create_row_processor(selectcontext, path, mapper, row, adapter)
 
         elif not self.is_class_level:
-            def new_execute(state, row, **flags):
+            def new_execute(state, dict_, row, **flags):
                 state.set_callable(self.key, LoadDeferredColumns(state, self.key))
         else:
-            def new_execute(state, row, **flags):
+            def new_execute(state, dict_, row, **flags):
                 # reset state on the key so that deferred callables
                 # fire off on next access.
-                state.reset(self.key)
+                state.reset(self.key, dict_)
 
         if self._should_log_debug:
             new_execute = self.debug_callable(new_execute, self.logger, None,
@@ -340,7 +340,7 @@ class NoLoader(AbstractRelationLoader):
         )
 
     def create_row_processor(self, selectcontext, path, mapper, row, adapter):
-        def new_execute(state, row, **flags):
+        def new_execute(state, dict_, row, **flags):
             self._init_instance_attribute(state)
 
         if self._should_log_debug:
@@ -437,7 +437,7 @@ class LazyLoader(AbstractRelationLoader):
 
     def create_row_processor(self, selectcontext, path, mapper, row, adapter):
         if not self.is_class_level:
-            def new_execute(state, row, **flags):
+            def new_execute(state, dict_, row, **flags):
                 # we are not the primary manager for this attribute on this class - set up a per-instance lazyloader,
                 # which will override the class-level behavior.
                 # this currently only happens when using a "lazyload" option on a "no load" attribute -
@@ -451,11 +451,11 @@ class LazyLoader(AbstractRelationLoader):
 
             return (new_execute, None)
         else:
-            def new_execute(state, row, **flags):
+            def new_execute(state, dict_, row, **flags):
                 # we are the primary manager for this attribute on this class - reset its per-instance attribute state, 
                 # so that the class-level lazy loader is executed when next referenced on this instance.
                 # this is needed in populate_existing() types of scenarios to reset any existing state.
-                state.reset(self.key)
+                state.reset(self.key, dict_)
 
             if self._should_log_debug:
                 new_execute = self.debug_callable(new_execute, self.logger, None,
@@ -735,24 +735,24 @@ class EagerLoader(AbstractRelationLoader):
             _instance = self.mapper._instance_processor(context, path + (self.mapper.base_mapper,), eager_adapter)
             
             if not self.uselist:
-                def execute(state, row, isnew, **flags):
+                def execute(state, dict_, row, isnew, **flags):
                     if isnew:
                         # set a scalar object instance directly on the
                         # parent object, bypassing InstrumentedAttribute
                         # event handlers.
-                        state.dict[key] = _instance(row, None)
+                        dict_[key] = _instance(row, None)
                     else:
                         # call _instance on the row, even though the object has been created,
                         # so that we further descend into properties
                         _instance(row, None)
             else:
-                def execute(state, row, isnew, **flags):
+                def execute(state, dict_, row, isnew, **flags):
                     if isnew or (state, key) not in context.attributes:
                         # appender_key can be absent from context.attributes with isnew=False
                         # when self-referential eager loading is used; the same instance may be present
                         # in two distinct sets of result columns
 
-                        collection = attributes.init_state_collection(state, key)
+                        collection = attributes.init_state_collection(state, dict_, key)
                         appender = util.UniqueAppender(collection, 'append_without_event')
 
                         context.attributes[(state, key)] = appender
index 4ac9c765e03bb9643f2e1c09f9e7b26f7c8d02c3..407b702a8bc09f468407f41aa07b11de466c9bea 100644 (file)
@@ -96,6 +96,8 @@ class UOWTransaction(object):
         # information.
         self.attributes = {}
         
+        self.processors = set()
+        
     def get_attribute_history(self, state, key, passive=True):
         hashkey = ("history", state, key)
 
@@ -136,6 +138,16 @@ class UOWTransaction(object):
         else:
             task.append(state, listonly=listonly, isdelete=isdelete)
 
+        # ensure the mapper for this object has had its 
+        # DependencyProcessors added.
+        if mapper not in self.processors:
+            mapper._register_processors(self)
+            self.processors.add(mapper)
+
+            if mapper.base_mapper not in self.processors:
+                mapper.base_mapper._register_processors(self)
+                self.processors.add(mapper.base_mapper)
+            
     def set_row_switch(self, state):
         """mark a deleted object as a 'row switch'.
 
@@ -147,7 +159,7 @@ class UOWTransaction(object):
         task = self.get_task_by_mapper(mapper)
         taskelement = task._objects[state]
         taskelement.isdelete = "rowswitch"
-
+    
     def is_deleted(self, state):
         """return true if the given state is marked as deleted within this UOWTransaction."""
 
@@ -201,9 +213,9 @@ class UOWTransaction(object):
         self.dependencies.add((mapper, dependency))
 
     def register_processor(self, mapper, processor, mapperfrom):
-        """register a dependency processor, corresponding to dependencies between
-        the two given mappers.
-
+        """register a dependency processor, corresponding to 
+        operations which occur between two mappers.
+        
         """
         # correct for primary mapper
         mapper = mapper.primary_mapper()
index 772e1bbd0417d22a3ad81ef7d6e9feaf015f206b..76108a713e401ff7d0af295fae394c32b803f448 100644 (file)
@@ -38,7 +38,7 @@ class AttributesTest(_base.ORMTest):
         u.email_address = 'lala@123.com'
 
         self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com')
-        attributes.instance_state(u).commit_all()
+        attributes.instance_state(u).commit_all(attributes.instance_dict(u))
         self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com')
 
         u.user_name = 'heythere'
@@ -158,7 +158,7 @@ class AttributesTest(_base.ORMTest):
         eq_(f.a, None)
         eq_(f.b, 12)
 
-        attributes.instance_state(f).commit_all()
+        attributes.instance_state(f).commit_all(attributes.instance_dict(f))
         eq_(f.a, None)
         eq_(f.b, 12)
 
@@ -205,7 +205,7 @@ class AttributesTest(_base.ORMTest):
         u.addresses.append(a)
 
         self.assert_(u.user_id == 7 and u.user_name == 'john' and u.addresses[0].email_address == 'lala@123.com')
-        u, attributes.instance_state(a).commit_all()
+        u, attributes.instance_state(a).commit_all(attributes.instance_dict(a))
         self.assert_(u.user_id == 7 and u.user_name == 'john' and u.addresses[0].email_address == 'lala@123.com')
 
         u.user_name = 'heythere'
@@ -272,7 +272,7 @@ class AttributesTest(_base.ORMTest):
         p1 = Post()
         attributes.instance_state(b).set_callable('posts', lambda:[p1])
         attributes.instance_state(p1).set_callable('blog', lambda:b)
-        p1, attributes.instance_state(b).commit_all()
+        p1, attributes.instance_state(b).commit_all(attributes.instance_dict(b))
 
         # no orphans (called before the lazy loaders fire off)
         assert attributes.has_parent(Blog, p1, 'posts', optimistic=True)
@@ -353,7 +353,7 @@ class AttributesTest(_base.ORMTest):
         x = Bar()
         x.element = el
         eq_(attributes.get_history(attributes.instance_state(x), 'element'), ([el], (), ()))
-        attributes.instance_state(x).commit_all()
+        attributes.instance_state(x).commit_all(attributes.instance_dict(x))
 
         (added, unchanged, deleted) = attributes.get_history(attributes.instance_state(x), 'element')
         assert added == ()
@@ -381,7 +381,7 @@ class AttributesTest(_base.ORMTest):
         attributes.register_attribute(Bar, 'id', uselist=False, useobject=True)
 
         x = Foo()
-        attributes.instance_state(x).commit_all()
+        attributes.instance_state(x).commit_all(attributes.instance_dict(x))
         x.col2.append(bar4)
         eq_(attributes.get_history(attributes.instance_state(x), 'col2'), ([bar4], [bar1, bar2, bar3], []))
 
@@ -427,7 +427,7 @@ class AttributesTest(_base.ORMTest):
         attributes.register_attribute(Foo, 'element', uselist=False, copy_function=lambda x:[y for y in x], mutable_scalars=True, useobject=False)
         x = Foo()
         x.element = ['one', 'two', 'three']
-        attributes.instance_state(x).commit_all()
+        attributes.instance_state(x).commit_all(attributes.instance_dict(x))
         x.element[1] = 'five'
         assert attributes.instance_state(x).check_modified()
 
@@ -437,7 +437,7 @@ class AttributesTest(_base.ORMTest):
         attributes.register_attribute(Foo, 'element', uselist=False, useobject=False)
         x = Foo()
         x.element = ['one', 'two', 'three']
-        attributes.instance_state(x).commit_all()
+        attributes.instance_state(x).commit_all(attributes.instance_dict(x))
         x.element[1] = 'five'
         assert not attributes.instance_state(x).check_modified()
 
@@ -699,8 +699,8 @@ class PendingBackrefTest(_base.ORMTest):
 
         b = Blog("blog 1")
         p1.blog = b
-        attributes.instance_state(b).commit_all()
-        attributes.instance_state(p1).commit_all()
+        attributes.instance_state(b).commit_all(attributes.instance_dict(b))
+        attributes.instance_state(p1).commit_all(attributes.instance_dict(p1))
         assert b.posts == [Post("post 1")]
 
 class HistoryTest(_base.ORMTest):
@@ -713,17 +713,17 @@ class HistoryTest(_base.ORMTest):
         attributes.register_attribute(Foo, 'someattr', uselist=False, useobject=False)
 
         f = Foo()
-        eq_(Foo.someattr.impl.get_committed_value(attributes.instance_state(f)), None)
+        eq_(Foo.someattr.impl.get_committed_value(attributes.instance_state(f), attributes.instance_dict(f)), None)
 
         f.someattr = 3
-        eq_(Foo.someattr.impl.get_committed_value(attributes.instance_state(f)), None)
+        eq_(Foo.someattr.impl.get_committed_value(attributes.instance_state(f), attributes.instance_dict(f)), None)
 
         f = Foo()
         f.someattr = 3
-        eq_(Foo.someattr.impl.get_committed_value(attributes.instance_state(f)), None)
+        eq_(Foo.someattr.impl.get_committed_value(attributes.instance_state(f), attributes.instance_dict(f)), None)
         
-        attributes.instance_state(f).commit(['someattr'])
-        eq_(Foo.someattr.impl.get_committed_value(attributes.instance_state(f)), 3)
+        attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr'])
+        eq_(Foo.someattr.impl.get_committed_value(attributes.instance_state(f), attributes.instance_dict(f)), 3)
 
     def test_scalar(self):
         class Foo(_base.BasicEntity):
@@ -739,13 +739,13 @@ class HistoryTest(_base.ORMTest):
         f.someattr = "hi"
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), (['hi'], (), ()))
 
-        attributes.instance_state(f).commit(['someattr'])
+        attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr'])
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), ['hi'], ()))
 
         f.someattr = 'there'
 
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), (['there'], (), ['hi']))
-        attributes.instance_state(f).commit(['someattr'])
+        attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr'])
 
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), ['there'], ()))
 
@@ -760,7 +760,7 @@ class HistoryTest(_base.ORMTest):
         f.someattr = 'old'
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), (['old'], (), ['new']))
 
-        attributes.instance_state(f).commit(['someattr'])
+        attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr'])
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), ['old'], ()))
 
         # setting None on uninitialized is currently a change for a scalar attribute
@@ -778,7 +778,7 @@ class HistoryTest(_base.ORMTest):
 
         # set same value twice
         f = Foo()
-        attributes.instance_state(f).commit(['someattr'])
+        attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr'])
         f.someattr = 'one'
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), (['one'], (), ()))
         f.someattr = 'two'
@@ -799,7 +799,7 @@ class HistoryTest(_base.ORMTest):
         f.someattr = {'foo':'hi'}
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([{'foo':'hi'}], (), ()))
 
-        attributes.instance_state(f).commit(['someattr'])
+        attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr'])
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [{'foo':'hi'}], ()))
         eq_(attributes.instance_state(f).committed_state['someattr'], {'foo':'hi'})
 
@@ -807,7 +807,7 @@ class HistoryTest(_base.ORMTest):
         eq_(attributes.instance_state(f).committed_state['someattr'], {'foo':'hi'})
 
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([{'foo':'there'}], (), [{'foo':'hi'}]))
-        attributes.instance_state(f).commit(['someattr'])
+        attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr'])
 
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [{'foo':'there'}], ()))
 
@@ -819,7 +819,7 @@ class HistoryTest(_base.ORMTest):
         f.someattr = {'foo':'old'}
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([{'foo':'old'}], (), [{'foo':'new'}]))
 
-        attributes.instance_state(f).commit(['someattr'])
+        attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr'])
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [{'foo':'old'}], ()))
 
 
@@ -847,13 +847,13 @@ class HistoryTest(_base.ORMTest):
         f.someattr = hi
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([hi], (), ()))
 
-        attributes.instance_state(f).commit(['someattr'])
+        attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr'])
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [hi], ()))
 
         f.someattr = there
 
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([there], (), [hi]))
-        attributes.instance_state(f).commit(['someattr'])
+        attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr'])
 
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [there], ()))
 
@@ -868,7 +868,7 @@ class HistoryTest(_base.ORMTest):
         f.someattr = old
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([old], (), ['new']))
 
-        attributes.instance_state(f).commit(['someattr'])
+        attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr'])
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [old], ()))
 
         # setting None on uninitialized is currently not a change for an object attribute
@@ -887,7 +887,7 @@ class HistoryTest(_base.ORMTest):
 
         # set same value twice
         f = Foo()
-        attributes.instance_state(f).commit(['someattr'])
+        attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr'])
         f.someattr = 'one'
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), (['one'], (), ()))
         f.someattr = 'two'
@@ -915,13 +915,13 @@ class HistoryTest(_base.ORMTest):
         f.someattr = [hi]
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([hi], [], []))
 
-        attributes.instance_state(f).commit(['someattr'])
+        attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr'])
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [hi], ()))
 
         f.someattr = [there]
 
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([there], [], [hi]))
-        attributes.instance_state(f).commit(['someattr'])
+        attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr'])
 
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [there], ()))
 
@@ -935,13 +935,13 @@ class HistoryTest(_base.ORMTest):
         f = Foo()
         collection = attributes.init_collection(attributes.instance_state(f), 'someattr')
         collection.append_without_event(new)
-        attributes.instance_state(f).commit_all()
+        attributes.instance_state(f).commit_all(attributes.instance_dict(f))
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [new], ()))
 
         f.someattr = [old]
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([old], [], [new]))
 
-        attributes.instance_state(f).commit(['someattr'])
+        attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr'])
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [old], ()))
 
     def test_dict_collections(self):
@@ -969,7 +969,7 @@ class HistoryTest(_base.ORMTest):
         f.someattr['there'] = there
         eq_(tuple([set(x) for x in attributes.get_history(attributes.instance_state(f), 'someattr')]), (set([hi, there]), set(), set()))
 
-        attributes.instance_state(f).commit(['someattr'])
+        attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr'])
         eq_(tuple([set(x) for x in attributes.get_history(attributes.instance_state(f), 'someattr')]), (set(), set([hi, there]), set()))
 
     def test_object_collections_mutate(self):
@@ -994,13 +994,13 @@ class HistoryTest(_base.ORMTest):
         f.someattr.append(hi)
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([hi], [], []))
 
-        attributes.instance_state(f).commit(['someattr'])
+        attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr'])
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [hi], ()))
 
         f.someattr.append(there)
 
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([there], [hi], []))
-        attributes.instance_state(f).commit(['someattr'])
+        attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr'])
 
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [hi, there], ()))
 
@@ -1010,7 +1010,7 @@ class HistoryTest(_base.ORMTest):
         f.someattr.append(old)
         f.someattr.append(new)
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([old, new], [hi], [there]))
-        attributes.instance_state(f).commit(['someattr'])
+        attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr'])
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [hi, old, new], ()))
 
         f.someattr.pop(0)
@@ -1021,19 +1021,19 @@ class HistoryTest(_base.ORMTest):
         f.__dict__['id'] = 1
         collection = attributes.init_collection(attributes.instance_state(f), 'someattr')
         collection.append_without_event(new)
-        attributes.instance_state(f).commit_all()
+        attributes.instance_state(f).commit_all(attributes.instance_dict(f))
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [new], ()))
 
         f.someattr.append(old)
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([old], [new], []))
 
-        attributes.instance_state(f).commit(['someattr'])
+        attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr'])
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [new, old], ()))
 
         f = Foo()
         collection = attributes.init_collection(attributes.instance_state(f), 'someattr')
         collection.append_without_event(new)
-        attributes.instance_state(f).commit_all()
+        attributes.instance_state(f).commit_all(attributes.instance_dict(f))
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [new], ()))
 
         f.id = 1
@@ -1056,7 +1056,7 @@ class HistoryTest(_base.ORMTest):
         f.someattr.append(hi)
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([hi, there, hi], [], []))
 
-        attributes.instance_state(f).commit_all()
+        attributes.instance_state(f).commit_all(attributes.instance_dict(f))
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [hi, there, hi], ()))
         
         f.someattr = []
index 69164ebafb41f638ca6b60fa4884c918b31e840e..aec6c181f26c6af1c72828b566f6268b33478892 100644 (file)
@@ -117,7 +117,7 @@ class UserDefinedExtensionTest(_base.ORMTest):
         u.user_id = 7
         u.user_name = 'john'
         u.email_address = 'lala@123.com'
-        self.assert_(u.__dict__ == {'_my_state':u._my_state, '_goofy_dict':{'user_id':7, 'user_name':'john', 'email_address':'lala@123.com'}})
+        self.assert_(u.__dict__ == {'_my_state':u._my_state, '_goofy_dict':{'user_id':7, 'user_name':'john', 'email_address':'lala@123.com'}}, u.__dict__)
         
     def test_basic(self):
         for base in (object, MyBaseClass, MyClass):
@@ -135,7 +135,7 @@ class UserDefinedExtensionTest(_base.ORMTest):
             u.email_address = 'lala@123.com'
 
             self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com')
-            attributes.instance_state(u).commit_all()
+            attributes.instance_state(u).commit_all(attributes.instance_dict(u))
             self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com')
 
             u.user_name = 'heythere'
@@ -182,7 +182,7 @@ class UserDefinedExtensionTest(_base.ORMTest):
             self.assertEquals(f.a, None)
             self.assertEquals(f.b, 12)
 
-            attributes.instance_state(f).commit_all()
+            attributes.instance_state(f).commit_all(attributes.instance_dict(f))
             self.assertEquals(f.a, None)
             self.assertEquals(f.b, 12)
 
@@ -272,8 +272,8 @@ class UserDefinedExtensionTest(_base.ORMTest):
             f1.bars.append(b1)
             self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'bars'), ([b1], [], []))
 
-            attributes.instance_state(f1).commit_all()
-            attributes.instance_state(b1).commit_all()
+            attributes.instance_state(f1).commit_all(attributes.instance_dict(f1))
+            attributes.instance_state(b1).commit_all(attributes.instance_dict(b1))
 
             self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'name'), ((), ['f1'], ()))
             self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'bars'), ((), [b1], ()))
index 081c46cdd886540738d916e335e71fc7462f6653..fd15420d0ad8b20dd929dde1072bcba7aba9de1b 100644 (file)
@@ -1,8 +1,8 @@
 import testenv; testenv.configure_for_tests()
 
 from testlib import sa
-from testlib.sa import MetaData, Table, Column, Integer, ForeignKey
-from testlib.sa.orm import mapper, relation, create_session, attributes, class_mapper
+from testlib.sa import MetaData, Table, Column, Integer, ForeignKey, util
+from testlib.sa.orm import mapper, relation, create_session, attributes, class_mapper, clear_mappers
 from testlib.testing import eq_, ne_
 from testlib.compat import _function_named
 from orm import _base
@@ -458,25 +458,9 @@ class MapperInitTest(_base.ORMTest):
 
         m = mapper(A, self.fixture())
 
-        a = attributes.instance_state(A())
-        assert isinstance(a, attributes.InstanceState)
-        assert type(a) is not attributes.InstanceState
-
-        b = attributes.instance_state(B())
-        assert isinstance(b, attributes.InstanceState)
-        assert type(b) is not attributes.InstanceState
-
         # B is not mapped in the current implementation
         self.assertRaises(sa.orm.exc.UnmappedClassError, class_mapper, B)
 
-        # the constructor of C is decorated too.  
-        # we don't support unmapped subclasses in any case,
-        # users should not be expecting any particular behavior
-        # from this scenario.
-        c = attributes.instance_state(C(3))
-        assert isinstance(c, attributes.InstanceState)
-        assert type(c) is not attributes.InstanceState
-
         # C is not mapped in the current implementation
         self.assertRaises(sa.orm.exc.UnmappedClassError, class_mapper, C)
 
@@ -573,6 +557,10 @@ class OnLoadTest(_base.ORMTest):
         finally:
             del A
 
+    def tearDownAll(self):
+        clear_mappers()
+        attributes._install_lookup_strategy(util.symbol('native'))
+
 
 class ExtendedEventsTest(_base.ORMTest):
     """Allow custom Events implementations."""
@@ -593,6 +581,7 @@ class ExtendedEventsTest(_base.ORMTest):
         assert isinstance(manager.events, MyEvents)
 
 
+
 class NativeInstrumentationTest(_base.ORMTest):
     @with_lookup_strategy(sa.util.symbol('native'))
     def test_register_reserved_attribute(self):
index d2687fa5a02691ba2c0595f5ad0747f3d373b343..3f427654f5c321162537c4be57f8bd4d6a302681 100644 (file)
@@ -1751,9 +1751,9 @@ class CompositeTypesTest(_base.MappedTest):
                 return [self.x, self.y]
             __hash__ = None
             def __eq__(self, other):
-                return other.x == self.x and other.y == self.y
+                return isinstance(other, Point) and other.x == self.x and other.y == self.y
             def __ne__(self, other):
-                return not self.__eq__(other)
+                return not isinstance(other, Point) or not self.__eq__(other)
 
         class Graph(object):
             pass
@@ -1819,6 +1819,12 @@ class CompositeTypesTest(_base.MappedTest):
         # query by columns
         eq_(sess.query(Edge.start, Edge.end).all(), [(3, 4, 5, 6), (14, 5, 19, 5)])
 
+        e = g.edges[1]
+        e.end.x = e.end.y = None
+        sess.flush()
+        eq_(sess.query(Edge.start, Edge.end).all(), [(3, 4, 5, 6), (14, 5, None, None)])
+
+
     @testing.resolve_artifact_names
     def test_pk(self):
         """Using a composite type as a primary key"""
index e95f10ba29691d1836dfc2a7c5f305c35d1720ea..5abcce689f246629687546aabf27a9957696ed54 100644 (file)
@@ -371,7 +371,7 @@ class OperatorTest(QueryTest, AssertsCompiledSQL):
                     )
 
         u7 = User(id=7)
-        attributes.instance_state(u7).commit_all()
+        attributes.instance_state(u7).commit_all(attributes.instance_dict(u7))
         
         self._test(Address.user == u7, ":param_1 = addresses.user_id")
 
index 818ab03daf50091624c09e2f95602e6a568f823d..6ae05c77b05e0b057d1f6be212e346b116f59fd7 100644 (file)
@@ -5,7 +5,7 @@ import pickle
 from sqlalchemy.orm import create_session, sessionmaker, attributes
 from testlib import engines, sa, testing, config
 from testlib.sa import Table, Column, Integer, String, Sequence
-from testlib.sa.orm import mapper, relation, backref
+from testlib.sa.orm import mapper, relation, backref, eagerload
 from testlib.testing import eq_
 from engine import _base as engine_base
 from orm import _base, _fixtures
@@ -776,7 +776,66 @@ class SessionTest(_fixtures.FixtureTest):
         user = s.query(User).one()
         assert user.name == 'fred'
         assert s.identity_map
+    
+    @testing.resolve_artifact_names
+    def test_weakref_with_cycles_o2m(self):
+        s = sessionmaker()()
+        mapper(User, users, properties={
+            "addresses":relation(Address, backref="user")
+        })
+        mapper(Address, addresses)
+        s.add(User(name="ed", addresses=[Address(email_address="ed1")]))
+        s.commit()
+        
+        user = s.query(User).options(eagerload(User.addresses)).one()
+        user.addresses[0].user # lazyload
+        eq_(user, User(name="ed", addresses=[Address(email_address="ed1")]))
+        
+        del user
+        gc.collect()
+        assert len(s.identity_map) == 0
 
+        user = s.query(User).options(eagerload(User.addresses)).one()
+        user.addresses[0].email_address='ed2'
+        user.addresses[0].user # lazyload
+        del user
+        gc.collect()
+        assert len(s.identity_map) == 2
+        
+        s.commit()
+        user = s.query(User).options(eagerload(User.addresses)).one()
+        eq_(user, User(name="ed", addresses=[Address(email_address="ed2")]))
+        
+    @testing.resolve_artifact_names
+    def test_weakref_with_cycles_o2o(self):
+        s = sessionmaker()()
+        mapper(User, users, properties={
+            "address":relation(Address, backref="user", uselist=False)
+        })
+        mapper(Address, addresses)
+        s.add(User(name="ed", address=Address(email_address="ed1")))
+        s.commit()
+
+        user = s.query(User).options(eagerload(User.address)).one()
+        user.address.user
+        eq_(user, User(name="ed", address=Address(email_address="ed1")))
+
+        del user
+        gc.collect()
+        assert len(s.identity_map) == 0
+
+        user = s.query(User).options(eagerload(User.address)).one()
+        user.address.email_address='ed2'
+        user.address.user # lazyload
+
+        del user
+        gc.collect()
+        assert len(s.identity_map) == 2
+        
+        s.commit()
+        user = s.query(User).options(eagerload(User.address)).one()
+        eq_(user, User(name="ed", address=Address(email_address="ed2")))
+    
     @testing.resolve_artifact_names
     def test_strong_ref(self):
         s = create_session(weak_identity_map=False)
index 7a8415cdc31ee33b2f4bc0378b92ff8ddf7ae013..c5e3afd01484f091888ff9c50387c63371992e7e 100644 (file)
@@ -14,6 +14,7 @@ from orm import _base, _fixtures
 from engine import _base as engine_base
 import pickleable
 from testlib.assertsql import AllOf, CompiledSQL
+import gc
 
 class UnitOfWorkTest(object):
     pass
@@ -366,6 +367,28 @@ class MutableTypesTest(_base.MappedTest):
              "WHERE mutable_t.id = :mutable_t_id",
              {'mutable_t_id': f1.id, 'val': u'hi', 'data':f1.data})])
 
+
+    @testing.resolve_artifact_names
+    def test_resurrect(self):
+        f1 = Foo()
+        f1.data = pickleable.Bar(4,5)
+        f1.val = u'hi'
+
+        session = create_session(autocommit=False)
+        session.add(f1)
+        session.commit()
+
+        f1.data.y = 19
+        del f1
+        
+        gc.collect()
+        assert len(session.identity_map) == 1
+        
+        session.commit()
+        
+        assert session.query(Foo).one().data == pickleable.Bar(4, 19)
+        
+        
     @testing.uses_deprecated()
     @testing.resolve_artifact_names
     def test_nocomparison(self):
index 7a189f87313ed027268cbd93f28dbe25d4c11b8b..5d7192261d61867845f362eca3786ac0f85eb9ec 100644 (file)
@@ -290,11 +290,11 @@ class ZooMarkTest(TestBase):
     def test_profile_1_create_tables(self):
         self.test_baseline_1_create_tables()
 
-    @profiling.function_call_count(12925, {'2.4':12478})
+    @profiling.function_call_count(12178, {'2.4':12178})
     def test_profile_1a_populate(self):
         self.test_baseline_1a_populate()
 
-    @profiling.function_call_count(1185, {'2.4':1184})
+    @profiling.function_call_count(903, {'2.4':903})
     def test_profile_2_insert(self):
         self.test_baseline_2_insert()
 
@@ -310,7 +310,7 @@ class ZooMarkTest(TestBase):
     def test_profile_5_aggregates(self):
         self.test_baseline_5_aggregates()
 
-    @profiling.function_call_count(3545)
+    @profiling.function_call_count(3343)
     def test_profile_6_editing(self):
         self.test_baseline_6_editing()