]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
merge -r5936:5974 of trunk
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 25 May 2009 15:20:44 +0000 (15:20 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 25 May 2009 15:20:44 +0000 (15:20 +0000)
36 files changed:
CHANGES
doc/build/sqlexpression.rst
lib/sqlalchemy/__init__.py
lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/dialects/sqlite/base.py
lib/sqlalchemy/exc.py
lib/sqlalchemy/orm/attributes.py
lib/sqlalchemy/orm/collections.py
lib/sqlalchemy/orm/dependency.py
lib/sqlalchemy/orm/dynamic.py
lib/sqlalchemy/orm/identity.py
lib/sqlalchemy/orm/interfaces.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/orm/state.py [new file with mode: 0644]
lib/sqlalchemy/orm/strategies.py
lib/sqlalchemy/orm/sync.py
lib/sqlalchemy/orm/unitofwork.py
lib/sqlalchemy/sql/expression.py
lib/sqlalchemy/sql/util.py
test/dialect/sqlite.py
test/orm/attributes.py
test/orm/extendedattr.py
test/orm/inheritance/basic.py
test/orm/instrumentation.py
test/orm/mapper.py
test/orm/merge.py
test/orm/naturalpks.py
test/orm/onetoone.py
test/orm/query.py
test/orm/relationships.py
test/orm/session.py
test/orm/unitofwork.py
test/profiling/zoomark_orm.py

diff --git a/CHANGES b/CHANGES
index e950ca0a1fa8203c2fce6952c661528d6f0935fb..165d2ec07930c8e10180ad49becfa6ef259b6eab 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -4,14 +4,77 @@
 CHANGES
 =======
 
+0.5.4p1
+=======
+
+- orm
+    - Fixed an attribute error introduced in 0.5.4 which would 
+      occur when merge() was used with an incomplete object.
+    
 0.5.4
 =====
 
 - orm
+    - Significant performance enhancements regarding Sessions/flush()
+      in conjunction with large mapper graphs, large numbers of 
+      objects:
+      
+      - Removed all* O(N) scanning behavior from the flush() process,
+        i.e. operations that were scanning the full session, 
+        including an extremely expensive one that was erroneously
+        assuming primary key values were changing when this 
+        was not the case.
+        
+        * one edge case remains which may invoke a full scan,
+          if an existing primary key attribute is modified
+          to a new value.
+      
+      - The Session's "weak referencing" behavior is now *full* -
+        no strong references whatsoever are made to a mapped object
+        or related items/collections in its __dict__.  Backrefs and 
+        other cycles in objects no longer affect the Session's ability 
+        to lose all references to unmodified objects.  Objects with 
+        pending changes still are maintained strongly until flush.  
+        [ticket:1398]
+        
+        The implementation also improves performance by moving
+        the "resurrection" process of garbage collected items
+        to only be relevant for mappings that map "mutable" 
+        attributes (i.e. PickleType, composite attrs).  This removes
+        overhead from the gc process and simplifies internal 
+        behavior.
+        
+        If a "mutable" attribute change is the sole change on an object 
+        which is then dereferenced, the mapper will not have access to 
+        other attribute state when the UPDATE is issued.  This may present 
+        itself differently to some MapperExtensions.
+        
+        The change also affects the internal attribute API, but not
+        the AttributeExtension interface nor any of the publically
+        documented attribute functions.
+        
+      - The unit of work no longer genererates a graph of "dependency"
+        processors for the full graph of mappers during flush(), instead
+        creating such processors only for those mappers which represent
+        objects with pending changes.  This saves a tremendous number
+        of method calls in the context of a large interconnected 
+        graph of mappers.
+        
+      - Cached a wasteful "table sort" operation that previously
+        occured multiple times per flush, also removing significant
+        method call count from flush().
+        
+      - Other redundant behaviors have been simplified in 
+        mapper._save_obj().
+      
     - Modified query_cls on DynamicAttributeImpl to accept a full
       mixin version of the AppenderQuery, which allows subclassing
       the AppenderMixin.
 
+    - The "polymorphic discriminator" column may be part of a 
+      primary key, and it will be populated with the correct 
+      discriminator value.  [ticket:1300]
+      
     - Fixed the evaluator not being able to evaluate IS NULL clauses.
 
     - Fixed the "set collection" function on "dynamic" relations to
@@ -44,12 +107,20 @@ CHANGES
     - Fixed another location where autoflush was interfering
       with session.merge().  autoflush is disabled completely
       for the duration of merge() now. [ticket:1360]
-
+    
+    - Fixed bug which prevented "mutable primary key" dependency
+      logic from functioning properly on a one-to-one
+      relation().  [ticket:1406]
+      
     - Fixed bug in relation(), introduced in 0.5.3, 
       whereby a self referential relation
       from a base class to a joined-table subclass would 
       not configure correctly.
 
+    - Fixed obscure mapper compilation issue when inheriting
+      mappers are used which would result in un-initialized
+      attributes.
+      
     - Fixed documentation for session weak_identity_map - 
       the default value is True, indicating a weak
       referencing map in use.
@@ -62,6 +133,11 @@ CHANGES
     - Fixed Query.update() and Query.delete() failures with eagerloaded
       relations. [ticket:1378]
 
+    - It is now an error to specify both columns of a binary primaryjoin
+      condition in the foreign_keys or remote_side collection.  Whereas
+      previously it was just nonsensical, but would succeed in a 
+      non-deterministic way.
+      
 - schema
     - Added a quote_schema() method to the IdentifierPreparer class
       so that dialects can override how schemas get handled. This
@@ -69,6 +145,18 @@ CHANGES
       identifiers, such as 'database.owner'. [ticket: 594, 1341]
 
 - sql
+    - Back-ported the "compiler" extension from SQLA 0.6.  This
+      is a standardized interface which allows the creation of custom
+      ClauseElement subclasses and compilers.  In particular it's
+      handy as an alternative to text() when you'd like to 
+      build a construct that has database-specific compilations.
+      See the extension docs for details.
+      
+    - Exception messages are truncated when the list of bound 
+      parameters is larger than 10, preventing enormous
+      multi-page exceptions from filling up screens and logfiles
+      for large executemany() statements. [ticket:1413]
+      
     - ``sqlalchemy.extract()`` is now dialect sensitive and can
       extract components of timestamps idiomatically across the
       supported databases, including SQLite.
@@ -77,6 +165,11 @@ CHANGES
       ForeignKey constructed from __clause_element__() style
       construct (i.e. declarative columns).  [ticket:1353]
 
+- mysql
+    - Reflecting a FOREIGN KEY construct will take into account
+      a dotted schema.tablename combination, if the foreign key
+      references a table in a remote schema. [ticket:1405]
+      
 - mssql
     - Modified how savepoint logic works to prevent it from
       stepping on non-savepoint oriented routines. Savepoint
@@ -90,6 +183,9 @@ CHANGES
       since it is only used by mssql now. [ticket:1343]
 
 - sqlite
+    - Corrected the SLBoolean type so that it properly treats only 1
+      as True. [ticket:1402]
+
     - Corrected the float type so that it correctly maps to a
       SLFloat type when being reflected. [ticket:1273]
 
index 70aaf6dced0d75e30a8afbd0cac82a8e89e4e704..4d54d036bda150fed14b6233d759db17b81b8770 100644 (file)
@@ -260,7 +260,7 @@ Integer indexes work as well:
 
     >>> row = result.fetchone()
     >>> print "name:", row[1], "; fullname:", row[2]
-    name: jack ; fullname: Jack Jones
+    name: wendy ; fullname: Wendy Williams
 
 But another way, whose usefulness will become apparent later on, is to use the ``Column`` objects directly as keys:
 
index 5d15f5d763ec46f11cb9cb3b27bde8856750695a..1b1b96855772312b8bb418f39b19151ffee79cc9 100644 (file)
@@ -109,6 +109,6 @@ from sqlalchemy.engine import create_engine, engine_from_config
 __all__ = sorted(name for name, obj in locals().items()
                  if not (name.startswith('_') or inspect.ismodule(obj)))
                  
-__version__ = '0.5.3'
+__version__ = '0.6beta1'
 
 del inspect, sys
index fd5ba7348e41105cbdf1ca8b36334b8cb4c63502..3bb6536a3ceec329f9cd6defbda035bf0c3b9fcb 100644 (file)
@@ -2507,7 +2507,7 @@ class MySQLTableDefinitionParser(object):
             r'%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +'
             r'FOREIGN KEY +'
             r'\((?P<local>[^\)]+?)\) REFERENCES +'
-            r'(?P<table>%(iq)s[^%(fq)s]+%(fq)s) +'
+            r'(?P<table>%(iq)s[^%(fq)s]+%(fq)s(?:\.%(iq)s[^%(fq)s]+%(fq)s)?) +'
             r'\((?P<foreign>[^\)]+?)\)'
             r'(?: +(?P<match>MATCH \w+))?'
             r'(?: +ON DELETE (?P<ondelete>%(on)s))?'
index 0c7400c2bddb027a0d969ef7bcaa3d2ec5af5542..c6228ca2f3fdafe131403e00be488f413a6b7bd1 100644 (file)
@@ -134,7 +134,7 @@ class SLBoolean(sqltypes.Boolean):
         def process(value):
             if value is None:
                 return None
-            return value and True or False
+            return value == 1
         return process
 
 colspecs = {
index 799abbf0d3dabc5f92feb60ff40f44d19068d036..d9fdd5df924c0aff58da4ce996bdad50abcf7119 100644 (file)
@@ -132,6 +132,11 @@ class DBAPIError(SQLAlchemyError):
         self.connection_invalidated = connection_invalidated
 
     def __str__(self):
+        if len(self.params) > 10:
+            return ' '.join((SQLAlchemyError.__str__(self),
+                             repr(self.statement),
+                             repr(self.params[:2]),
+                             '... and a total of %i bound parameters' % len(self.params)))
         return ' '.join((SQLAlchemyError.__str__(self),
                          repr(self.statement), repr(self.params)))
 
index a6861ee45256336f108039eed397d178a7b514b3..1df37b4e1eeb92e0410a1ff19e3885c63c7c9025 100644 (file)
@@ -20,14 +20,13 @@ import types
 import weakref
 
 from sqlalchemy import util
-from sqlalchemy.util import EMPTY_SET
 from sqlalchemy.orm import interfaces, collections, exc
 import sqlalchemy.exceptions as sa_exc
 
 # lazy imports
 _entity_info = None
 identity_equal = None
-
+state = None
 
 PASSIVE_NORESULT = util.symbol('PASSIVE_NORESULT')
 ATTR_WAS_SET = util.symbol('ATTR_WAS_SET')
@@ -105,7 +104,7 @@ class QueryableAttribute(interfaces.PropComparator):
         self.parententity = parententity
 
     def get_history(self, instance, **kwargs):
-        return self.impl.get_history(instance_state(instance), **kwargs)
+        return self.impl.get_history(instance_state(instance), instance_dict(instance), **kwargs)
 
     def __selectable__(self):
         # TODO: conditionally attach this method based on clause_element ?
@@ -148,15 +147,15 @@ class InstrumentedAttribute(QueryableAttribute):
     """Public-facing descriptor, placed in the mapped class dictionary."""
 
     def __set__(self, instance, value):
-        self.impl.set(instance_state(instance), value, None)
+        self.impl.set(instance_state(instance), instance_dict(instance), value, None)
 
     def __delete__(self, instance):
-        self.impl.delete(instance_state(instance))
+        self.impl.delete(instance_state(instance), instance_dict(instance))
 
     def __get__(self, instance, owner):
         if instance is None:
             return self
-        return self.impl.get(instance_state(instance))
+        return self.impl.get(instance_state(instance), instance_dict(instance))
 
 class _ProxyImpl(object):
     accepts_scalar_loader = False
@@ -335,7 +334,7 @@ class AttributeImpl(object):
         else:
             state.callables[self.key] = callable_
 
-    def get_history(self, state, passive=PASSIVE_OFF):
+    def get_history(self, state, dict_, passive=PASSIVE_OFF):
         raise NotImplementedError()
 
     def _get_callable(self, state):
@@ -346,13 +345,13 @@ class AttributeImpl(object):
         else:
             return None
 
-    def initialize(self, state):
+    def initialize(self, state, dict_):
         """Initialize this attribute on the given object instance with an empty value."""
 
-        state.dict[self.key] = None
+        dict_[self.key] = None
         return None
 
-    def get(self, state, passive=PASSIVE_OFF):
+    def get(self, state, dict_, passive=PASSIVE_OFF):
         """Retrieve a value from the given object.
 
         If a callable is assembled on this object's attribute, and
@@ -361,7 +360,7 @@ class AttributeImpl(object):
         """
 
         try:
-            return state.dict[self.key]
+            return dict_[self.key]
         except KeyError:
             # if no history, check for lazy callables, etc.
             if state.committed_state.get(self.key, NEVER_SET) is NEVER_SET:
@@ -374,25 +373,25 @@ class AttributeImpl(object):
                         return PASSIVE_NORESULT
                     value = callable_()
                     if value is not ATTR_WAS_SET:
-                        return self.set_committed_value(state, value)
+                        return self.set_committed_value(state, dict_, value)
                     else:
-                        if self.key not in state.dict:
-                            return self.get(state, passive=passive)
-                        return state.dict[self.key]
+                        if self.key not in dict_:
+                            return self.get(state, dict_, passive=passive)
+                        return dict_[self.key]
 
             # Return a new, empty value
-            return self.initialize(state)
+            return self.initialize(state, dict_)
 
-    def append(self, state, value, initiator, passive=PASSIVE_OFF):
-        self.set(state, value, initiator)
+    def append(self, state, dict_, value, initiator, passive=PASSIVE_OFF):
+        self.set(state, dict_, value, initiator)
 
-    def remove(self, state, value, initiator, passive=PASSIVE_OFF):
-        self.set(state, None, initiator)
+    def remove(self, state, dict_, value, initiator, passive=PASSIVE_OFF):
+        self.set(state, dict_, None, initiator)
 
-    def set(self, state, value, initiator):
+    def set(self, state, dict_, value, initiator):
         raise NotImplementedError()
 
-    def get_committed_value(self, state, passive=PASSIVE_OFF):
+    def get_committed_value(self, state, dict_, passive=PASSIVE_OFF):
         """return the unchanged value of this attribute"""
 
         if self.key in state.committed_state:
@@ -401,12 +400,12 @@ class AttributeImpl(object):
             else:
                 return state.committed_state.get(self.key)
         else:
-            return self.get(state, passive=passive)
+            return self.get(state, dict_, passive=passive)
 
-    def set_committed_value(self, state, value):
+    def set_committed_value(self, state, dict_, value):
         """set an attribute value on the given instance and 'commit' it."""
 
-        state.commit([self.key])
+        state.commit(dict_, [self.key])
 
         state.callables.pop(self.key, None)
         state.dict[self.key] = value
@@ -419,45 +418,45 @@ class ScalarAttributeImpl(AttributeImpl):
     accepts_scalar_loader = True
     uses_objects = False
 
-    def delete(self, state):
+    def delete(self, state, dict_):
 
         # TODO: catch key errors, convert to attributeerror?
         if self.active_history or self.extensions:
-            old = self.get(state)
+            old = self.get(state, dict_)
         else:
-            old = state.dict.get(self.key, NO_VALUE)
+            old = dict_.get(self.key, NO_VALUE)
 
-        state.modified_event(self, False, old)
+        state.modified_event(dict_, self, False, old)
 
         if self.extensions:
-            self.fire_remove_event(state, old, None)
-        del state.dict[self.key]
+            self.fire_remove_event(state, dict_, old, None)
+        del dict_[self.key]
 
-    def get_history(self, state, passive=PASSIVE_OFF):
+    def get_history(self, state, dict_, passive=PASSIVE_OFF):
         return History.from_attribute(
-            self, state, state.dict.get(self.key, NO_VALUE))
+            self, state, dict_.get(self.key, NO_VALUE))
 
-    def set(self, state, value, initiator):
+    def set(self, state, dict_, value, initiator):
         if initiator is self:
             return
 
         if self.active_history or self.extensions:
-            old = self.get(state)
+            old = self.get(state, dict_)
         else:
-            old = state.dict.get(self.key, NO_VALUE)
+            old = dict_.get(self.key, NO_VALUE)
 
-        state.modified_event(self, False, old)
+        state.modified_event(dict_, self, False, old)
 
         if self.extensions:
-            value = self.fire_replace_event(state, value, old, initiator)
-        state.dict[self.key] = value
+            value = self.fire_replace_event(state, dict_, value, old, initiator)
+        dict_[self.key] = value
 
-    def fire_replace_event(self, state, value, previous, initiator):
+    def fire_replace_event(self, state, dict_, value, previous, initiator):
         for ext in self.extensions:
             value = ext.set(state, value, previous, initiator or self)
         return value
 
-    def fire_remove_event(self, state, value, initiator):
+    def fire_remove_event(self, state, dict_, value, initiator):
         for ext in self.extensions:
             ext.remove(state, value, initiator or self)
 
@@ -483,29 +482,48 @@ class MutableScalarAttributeImpl(ScalarAttributeImpl):
             raise sa_exc.ArgumentError("MutableScalarAttributeImpl requires a copy function")
         self.copy = copy_function
 
-    def get_history(self, state, passive=PASSIVE_OFF):
+    def get_history(self, state, dict_, passive=PASSIVE_OFF):
+        if not dict_:
+            v = state.committed_state.get(self.key, NO_VALUE)
+        else:
+            v = dict_.get(self.key, NO_VALUE)
+            
         return History.from_attribute(
-            self, state, state.dict.get(self.key, NO_VALUE))
+            self, state, v)
 
-    def commit_to_state(self, state, dest):
-        dest[self.key] = self.copy(state.dict[self.key])
+    def commit_to_state(self, state, dict_, dest):
+        dest[self.key] = self.copy(dict_[self.key])
 
-    def check_mutable_modified(self, state):
-        (added, unchanged, deleted) = self.get_history(state, passive=PASSIVE_NO_INITIALIZE)
+    def check_mutable_modified(self, state, dict_):
+        (added, unchanged, deleted) = self.get_history(state, dict_, passive=PASSIVE_NO_INITIALIZE)
         return bool(added or deleted)
 
-    def set(self, state, value, initiator):
+    def get(self, state, dict_, passive=PASSIVE_OFF):
+        if self.key not in state.mutable_dict:
+            ret = ScalarAttributeImpl.get(self, state, dict_, passive=passive)
+            if ret is not PASSIVE_NORESULT:
+                state.mutable_dict[self.key] = ret
+            return ret
+        else:
+            return state.mutable_dict[self.key]
+
+    def delete(self, state, dict_):
+        ScalarAttributeImpl.delete(self, state, dict_)
+        state.mutable_dict.pop(self.key)
+
+    def set(self, state, dict_, value, initiator):
         if initiator is self:
             return
 
-        state.modified_event(self, True, NEVER_SET)
-
+        state.modified_event(dict_, self, True, NEVER_SET)
+        
         if self.extensions:
-            old = self.get(state)
-            value = self.fire_replace_event(state, value, old, initiator)
-            state.dict[self.key] = value
+            old = self.get(state, dict_)
+            value = self.fire_replace_event(state, dict_, value, old, initiator)
+            dict_[self.key] = value
         else:
-            state.dict[self.key] = value
+            dict_[self.key] = value
+        state.mutable_dict[self.key] = value
 
 
 class ScalarObjectAttributeImpl(ScalarAttributeImpl):
@@ -526,22 +544,22 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
         if compare_function is None:
             self.is_equal = identity_equal
 
-    def delete(self, state):
-        old = self.get(state)
-        self.fire_remove_event(state, old, self)
-        del state.dict[self.key]
+    def delete(self, state, dict_):
+        old = self.get(state, dict_)
+        self.fire_remove_event(state, dict_, old, self)
+        del dict_[self.key]
 
-    def get_history(self, state, passive=PASSIVE_OFF):
-        if self.key in state.dict:
-            return History.from_attribute(self, state, state.dict[self.key])
+    def get_history(self, state, dict_, passive=PASSIVE_OFF):
+        if self.key in dict_:
+            return History.from_attribute(self, state, dict_[self.key])
         else:
-            current = self.get(state, passive=passive)
+            current = self.get(state, dict_, passive=passive)
             if current is PASSIVE_NORESULT:
                 return HISTORY_BLANK
             else:
                 return History.from_attribute(self, state, current)
 
-    def set(self, state, value, initiator):
+    def set(self, state, dict_, value, initiator):
         """Set a value on the given InstanceState.
 
         `initiator` is the ``InstrumentedAttribute`` that initiated the
@@ -553,12 +571,12 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
             return
 
         # may want to add options to allow the get() here to be passive
-        old = self.get(state)
-        value = self.fire_replace_event(state, value, old, initiator)
-        state.dict[self.key] = value
+        old = self.get(state, dict_)
+        value = self.fire_replace_event(state, dict_, value, old, initiator)
+        dict_[self.key] = value
 
-    def fire_remove_event(self, state, value, initiator):
-        state.modified_event(self, False, value)
+    def fire_remove_event(self, state, dict_, value, initiator):
+        state.modified_event(dict_, self, False, value)
 
         if self.trackparent and value is not None:
             self.sethasparent(instance_state(value), False)
@@ -566,8 +584,8 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
         for ext in self.extensions:
             ext.remove(state, value, initiator or self)
 
-    def fire_replace_event(self, state, value, previous, initiator):
-        state.modified_event(self, False, previous)
+    def fire_replace_event(self, state, dict_, value, previous, initiator):
+        state.modified_event(dict_, self, False, previous)
 
         if self.trackparent:
             if previous is not value and previous is not None:
@@ -615,15 +633,15 @@ class CollectionAttributeImpl(AttributeImpl):
     def __copy(self, item):
         return [y for y in list(collections.collection_adapter(item))]
 
-    def get_history(self, state, passive=PASSIVE_OFF):
-        current = self.get(state, passive=passive)
+    def get_history(self, state, dict_, passive=PASSIVE_OFF):
+        current = self.get(state, dict_, passive=passive)
         if current is PASSIVE_NORESULT:
             return HISTORY_BLANK
         else:
             return History.from_attribute(self, state, current)
 
-    def fire_append_event(self, state, value, initiator):
-        state.modified_event(self, True, NEVER_SET, passive=PASSIVE_NO_INITIALIZE)
+    def fire_append_event(self, state, dict_, value, initiator):
+        state.modified_event(dict_, self, True, NEVER_SET, passive=PASSIVE_NO_INITIALIZE)
 
         for ext in self.extensions:
             value = ext.append(state, value, initiator or self)
@@ -633,11 +651,11 @@ class CollectionAttributeImpl(AttributeImpl):
 
         return value
 
-    def fire_pre_remove_event(self, state, initiator):
-        state.modified_event(self, True, NEVER_SET, passive=PASSIVE_NO_INITIALIZE)
+    def fire_pre_remove_event(self, state, dict_, initiator):
+        state.modified_event(dict_, self, True, NEVER_SET, passive=PASSIVE_NO_INITIALIZE)
 
-    def fire_remove_event(self, state, value, initiator):
-        state.modified_event(self, True, NEVER_SET, passive=PASSIVE_NO_INITIALIZE)
+    def fire_remove_event(self, state, dict_, value, initiator):
+        state.modified_event(dict_, self, True, NEVER_SET, passive=PASSIVE_NO_INITIALIZE)
 
         if self.trackparent and value is not None:
             self.sethasparent(instance_state(value), False)
@@ -645,51 +663,51 @@ class CollectionAttributeImpl(AttributeImpl):
         for ext in self.extensions:
             ext.remove(state, value, initiator or self)
 
-    def delete(self, state):
-        if self.key not in state.dict:
+    def delete(self, state, dict_):
+        if self.key not in dict_:
             return
 
-        state.modified_event(self, True, NEVER_SET)
+        state.modified_event(dict_, self, True, NEVER_SET)
 
-        collection = self.get_collection(state)
+        collection = self.get_collection(state, state.dict)
         collection.clear_with_event()
         # TODO: catch key errors, convert to attributeerror?
-        del state.dict[self.key]
+        del dict_[self.key]
 
-    def initialize(self, state):
+    def initialize(self, state, dict_):
         """Initialize this attribute with an empty collection."""
 
         _, user_data = self._initialize_collection(state)
-        state.dict[self.key] = user_data
+        dict_[self.key] = user_data
         return user_data
 
     def _initialize_collection(self, state):
         return state.manager.initialize_collection(
             self.key, state, self.collection_factory)
 
-    def append(self, state, value, initiator, passive=PASSIVE_OFF):
+    def append(self, state, dict_, value, initiator, passive=PASSIVE_OFF):
         if initiator is self:
             return
 
-        collection = self.get_collection(state, passive=passive)
+        collection = self.get_collection(state, dict_, passive=passive)
         if collection is PASSIVE_NORESULT:
-            value = self.fire_append_event(state, value, initiator)
+            value = self.fire_append_event(state, dict_, value, initiator)
             state.get_pending(self.key).append(value)
         else:
             collection.append_with_event(value, initiator)
 
-    def remove(self, state, value, initiator, passive=PASSIVE_OFF):
+    def remove(self, state, dict_, value, initiator, passive=PASSIVE_OFF):
         if initiator is self:
             return
 
-        collection = self.get_collection(state, passive=passive)
+        collection = self.get_collection(state, state.dict, passive=passive)
         if collection is PASSIVE_NORESULT:
-            self.fire_remove_event(state, value, initiator)
+            self.fire_remove_event(state, dict_, value, initiator)
             state.get_pending(self.key).remove(value)
         else:
             collection.remove_with_event(value, initiator)
 
-    def set(self, state, value, initiator):
+    def set(self, state, dict_, value, initiator):
         """Set a value on the given object.
 
         `initiator` is the ``InstrumentedAttribute`` that initiated the
@@ -701,10 +719,10 @@ class CollectionAttributeImpl(AttributeImpl):
             return
 
         self._set_iterable(
-            state, value,
+            state, dict_, value,
             lambda adapter, i: adapter.adapt_like_to_iterable(i))
 
-    def _set_iterable(self, state, iterable, adapter=None):
+    def _set_iterable(self, state, dict_, iterable, adapter=None):
         """Set a collection value from an iterable of state-bearers.
 
         ``adapter`` is an optional callable invoked with a CollectionAdapter
@@ -722,24 +740,24 @@ class CollectionAttributeImpl(AttributeImpl):
         else:
             new_values = list(iterable)
 
-        old = self.get(state)
+        old = self.get(state, dict_)
 
         # ignore re-assignment of the current collection, as happens
         # implicitly with in-place operators (foo.collection |= other)
         if old is iterable:
             return
 
-        state.modified_event(self, True, old)
+        state.modified_event(dict_, self, True, old)
 
-        old_collection = self.get_collection(state, old)
+        old_collection = self.get_collection(state, dict_, old)
 
-        state.dict[self.key] = user_data
+        dict_[self.key] = user_data
 
         collections.bulk_replace(new_values, old_collection, new_collection)
         old_collection.unlink(old)
 
 
-    def set_committed_value(self, state, value):
+    def set_committed_value(self, state, dict_, value):
         """Set an attribute value on the given instance and 'commit' it."""
 
         collection, user_data = self._initialize_collection(state)
@@ -751,13 +769,13 @@ class CollectionAttributeImpl(AttributeImpl):
         state.callables.pop(self.key, None)
         state.dict[self.key] = user_data
 
-        state.commit([self.key])
+        state.commit(dict_, [self.key])
 
         if self.key in state.pending:
             
             # pending items exist.  issue a modified event,
             # add/remove new items.
-            state.modified_event(self, True, user_data)
+            state.modified_event(dict_, self, True, user_data)
 
             pending = state.pending.pop(self.key)
             added = pending.added_items
@@ -769,14 +787,14 @@ class CollectionAttributeImpl(AttributeImpl):
 
         return user_data
 
-    def get_collection(self, state, user_data=None, passive=PASSIVE_OFF):
+    def get_collection(self, state, dict_, user_data=None, passive=PASSIVE_OFF):
         """Retrieve the CollectionAdapter associated with the given state.
 
         Creates a new CollectionAdapter if one does not exist.
 
         """
         if user_data is None:
-            user_data = self.get(state, passive=passive)
+            user_data = self.get(state, dict_, passive=passive)
             if user_data is PASSIVE_NORESULT:
                 return user_data
 
@@ -799,327 +817,26 @@ class GenericBackrefExtension(interfaces.AttributeExtension):
         if oldchild is not None:
             # With lazy=None, there's no guarantee that the full collection is
             # present when updating via a backref.
-            old_state = instance_state(oldchild)
+            old_state, old_dict = instance_state(oldchild), instance_dict(oldchild)
             impl = old_state.get_impl(self.key)
             try:
-                impl.remove(old_state, state.obj(), initiator, passive=PASSIVE_NO_CALLABLES)
+                impl.remove(old_state, old_dict, state.obj(), initiator, passive=PASSIVE_NO_CALLABLES)
             except (ValueError, KeyError, IndexError):
                 pass
         if child is not None:
-            new_state = instance_state(child)
-            new_state.get_impl(self.key).append(new_state, state.obj(), initiator, passive=PASSIVE_NO_CALLABLES)
+            new_state,  new_dict = instance_state(child), instance_dict(child)
+            new_state.get_impl(self.key).append(new_state, new_dict, state.obj(), initiator, passive=PASSIVE_NO_CALLABLES)
         return child
 
     def append(self, state, child, initiator):
-        child_state = instance_state(child)
-        child_state.get_impl(self.key).append(child_state, state.obj(), initiator, passive=PASSIVE_NO_CALLABLES)
+        child_state, child_dict = instance_state(child), instance_dict(child)
+        child_state.get_impl(self.key).append(child_state, child_dict, state.obj(), initiator, passive=PASSIVE_NO_CALLABLES)
         return child
 
     def remove(self, state, child, initiator):
         if child is not None:
-            child_state = instance_state(child)
-            child_state.get_impl(self.key).remove(child_state, state.obj(), initiator, passive=PASSIVE_NO_CALLABLES)
-
-
-class InstanceState(object):
-    """tracks state information at the instance level."""
-
-    session_id = None
-    key = None
-    runid = None
-    expired_attributes = EMPTY_SET
-    load_options = EMPTY_SET
-    load_path = ()
-    insert_order = None
-    
-    def __init__(self, obj, manager):
-        self.class_ = obj.__class__
-        self.manager = manager
-        
-        self.obj = weakref.ref(obj, self._cleanup)
-        self.dict = obj.__dict__
-        self.modified = False
-        self.callables = {}
-        self.expired = False
-        self.committed_state = {}
-        self.pending = {}
-        self.parents = {}
-        
-    def detach(self):
-        if self.session_id:
-            try:
-                del self.session_id
-            except AttributeError:
-                pass
-
-    def dispose(self):
-        if self.session_id:
-            try:
-                del self.session_id
-            except AttributeError:
-                pass
-        del self.obj
-        del self.dict
-    
-    def _cleanup(self, ref):
-        self.dispose()
-    
-    def obj(self):
-        return None
-    
-    @util.memoized_property
-    def dict(self):
-        # return a blank dict
-        # if none is available, so that asynchronous gc
-        # doesn't blow up expiration operations in progress
-        # (usually expire_attributes)
-        return {}
-    
-    @property
-    def sort_key(self):
-        return self.key and self.key[1] or (self.insert_order, )
-
-    def check_modified(self):
-        if self.modified:
-            return True
-        else:
-            for key in self.manager.mutable_attributes:
-                if self.manager[key].impl.check_mutable_modified(self):
-                    return True
-            else:
-                return False
-
-    def initialize_instance(*mixed, **kwargs):
-        self, instance, args = mixed[0], mixed[1], mixed[2:]
-        manager = self.manager
-
-        for fn in manager.events.on_init:
-            fn(self, instance, args, kwargs)
-        try:
-            return manager.events.original_init(*mixed[1:], **kwargs)
-        except:
-            for fn in manager.events.on_init_failure:
-                fn(self, instance, args, kwargs)
-            raise
-
-    def get_history(self, key, **kwargs):
-        return self.manager.get_impl(key).get_history(self, **kwargs)
-
-    def get_impl(self, key):
-        return self.manager.get_impl(key)
-
-    def get_pending(self, key):
-        if key not in self.pending:
-            self.pending[key] = PendingCollection()
-        return self.pending[key]
-
-    def value_as_iterable(self, key, passive=PASSIVE_OFF):
-        """return an InstanceState attribute as a list,
-        regardless of it being a scalar or collection-based
-        attribute.
-
-        returns None if passive is not PASSIVE_OFF and the getter returns
-        PASSIVE_NORESULT.
-        """
-
-        impl = self.get_impl(key)
-        x = impl.get(self, passive=passive)
-        if x is PASSIVE_NORESULT:
-
-            return None
-        elif hasattr(impl, 'get_collection'):
-            return impl.get_collection(self, x, passive=passive)
-        elif isinstance(x, list):
-            return x
-        else:
-            return [x]
-
-    def _run_on_load(self, instance=None):
-        if instance is None:
-            instance = self.obj()
-        self.manager.events.run('on_load', instance)
-
-    def __getstate__(self):
-        return {'key': self.key,
-                'committed_state': self.committed_state,
-                'pending': self.pending,
-                'parents': self.parents,
-                'modified': self.modified,
-                'expired':self.expired,
-                'load_options':self.load_options,
-                'load_path':interfaces.serialize_path(self.load_path),
-                'instance': self.obj(),
-                'expired_attributes':self.expired_attributes,
-                'callables': self.callables}
-
-    def __setstate__(self, state):
-        self.committed_state = state['committed_state']
-        self.parents = state['parents']
-        self.key = state['key']
-        self.session_id = None
-        self.pending = state['pending']
-        self.modified = state['modified']
-        self.obj = weakref.ref(state['instance'])
-        self.load_options = state['load_options'] or EMPTY_SET
-        self.load_path = interfaces.deserialize_path(state['load_path'])
-        self.class_ = self.obj().__class__
-        self.manager = manager_of_class(self.class_)
-        self.dict = self.obj().__dict__
-        self.callables = state['callables']
-        self.runid = None
-        self.expired = state['expired']
-        self.expired_attributes = state['expired_attributes']
-
-    def initialize(self, key):
-        self.manager.get_impl(key).initialize(self)
-
-    def set_callable(self, key, callable_):
-        self.dict.pop(key, None)
-        self.callables[key] = callable_
-
-    def __call__(self):
-        """__call__ allows the InstanceState to act as a deferred
-        callable for loading expired attributes, which is also
-        serializable (picklable).
-
-        """
-        unmodified = self.unmodified
-        class_manager = self.manager
-        class_manager.deferred_scalar_loader(self, [
-            attr.impl.key for attr in class_manager.attributes if
-                attr.impl.accepts_scalar_loader and
-                attr.impl.key in self.expired_attributes and
-                attr.impl.key in unmodified
-            ])
-        for k in self.expired_attributes:
-            self.callables.pop(k, None)
-        del self.expired_attributes
-        return ATTR_WAS_SET
-
-    @property
-    def unmodified(self):
-        """a set of keys which have no uncommitted changes"""
-
-        return set(
-            key for key in self.manager.iterkeys()
-            if (key not in self.committed_state or
-                (key in self.manager.mutable_attributes and
-                 not self.manager[key].impl.check_mutable_modified(self))))
-
-    @property
-    def unloaded(self):
-        """a set of keys which do not have a loaded value.
-
-        This includes expired attributes and any other attribute that
-        was never populated or modified.
-
-        """
-        return set(
-            key for key in self.manager.iterkeys()
-            if key not in self.committed_state and key not in self.dict)
-
-    def expire_attributes(self, attribute_names):
-        self.expired_attributes = set(self.expired_attributes)
-
-        if attribute_names is None:
-            attribute_names = self.manager.keys()
-            self.expired = True
-            self.modified = False
-            filter_deferred = True
-        else:
-            filter_deferred = False
-        for key in attribute_names:
-            impl = self.manager[key].impl
-            if not filter_deferred or \
-                not impl.dont_expire_missing or \
-                key in self.dict:
-                self.expired_attributes.add(key)
-                if impl.accepts_scalar_loader:
-                    self.callables[key] = self
-            self.dict.pop(key, None)
-            self.pending.pop(key, None)
-            self.committed_state.pop(key, None)
-
-    def reset(self, key):
-        """remove the given attribute and any callables associated with it."""
-
-        self.dict.pop(key, None)
-        self.callables.pop(key, None)
-
-    def modified_event(self, attr, should_copy, previous, passive=PASSIVE_OFF):
-        needs_committed = attr.key not in self.committed_state
-
-        if needs_committed:
-            if previous is NEVER_SET:
-                if passive:
-                    if attr.key in self.dict:
-                        previous = self.dict[attr.key]
-                else:
-                    previous = attr.get(self)
-
-            if should_copy and previous not in (None, NO_VALUE, NEVER_SET):
-                previous = attr.copy(previous)
-
-            if needs_committed:
-                self.committed_state[attr.key] = previous
-
-        self.modified = True
-
-    def commit(self, keys):
-        """Commit attributes.
-
-        This is used by a partial-attribute load operation to mark committed
-        those attributes which were refreshed from the database.
-
-        Attributes marked as "expired" can potentially remain "expired" after
-        this step if a value was not populated in state.dict.
-
-        """
-        class_manager = self.manager
-        for key in keys:
-            if key in self.dict and key in class_manager.mutable_attributes:
-                class_manager[key].impl.commit_to_state(self, self.committed_state)
-            else:
-                self.committed_state.pop(key, None)
-
-        self.expired = False
-        # unexpire attributes which have loaded
-        for key in self.expired_attributes.intersection(keys):
-            if key in self.dict:
-                self.expired_attributes.remove(key)
-                self.callables.pop(key, None)
-
-    def commit_all(self):
-        """commit all attributes unconditionally.
-
-        This is used after a flush() or a full load/refresh
-        to remove all pending state from the instance.
-
-         - all attributes are marked as "committed"
-         - the "strong dirty reference" is removed
-         - the "modified" flag is set to False
-         - any "expired" markers/callables are removed.
-
-        Attributes marked as "expired" can potentially remain "expired" after this step
-        if a value was not populated in state.dict.
-
-        """
-        
-        self.committed_state = {}
-        self.pending = {}
-        
-        # unexpire attributes which have loaded
-        if self.expired_attributes:
-            for key in self.expired_attributes.intersection(self.dict):
-                self.callables.pop(key, None)
-            self.expired_attributes.difference_update(self.dict)
-
-        for key in self.manager.mutable_attributes:
-            if key in self.dict:
-                self.manager[key].impl.commit_to_state(self, self.committed_state)
-
-        self.modified = self.expired = False
-        self._strong_obj = None
+            child_state, child_dict = instance_state(child), instance_dict(child)
+            child_state.get_impl(self.key).remove(child_state, child_dict, state.obj(), initiator, passive=PASSIVE_NO_CALLABLES)
 
 
 class Events(object):
@@ -1128,6 +845,7 @@ class Events(object):
         self.on_init = ()
         self.on_init_failure = ()
         self.on_load = ()
+        self.on_resurrect = ()
 
     def run(self, event, *args, **kwargs):
         for fn in getattr(self, event):
@@ -1153,7 +871,6 @@ class ClassManager(dict):
     STATE_ATTR = '_sa_instance_state'
 
     event_registry_factory = Events
-    instance_state_factory = InstanceState
     deferred_scalar_loader = None
     
     def __init__(self, class_):
@@ -1177,7 +894,6 @@ class ClassManager(dict):
     
     def _configure_create_arguments(self, 
                             _source=None, 
-                            instance_state_factory=None, 
                             deferred_scalar_loader=None):
         """Accept extra **kw arguments passed to create_manager_for_cls.
         
@@ -1192,11 +908,8 @@ class ClassManager(dict):
         
         """
         if _source:
-            instance_state_factory = _source.instance_state_factory
             deferred_scalar_loader = _source.deferred_scalar_loader
 
-        if instance_state_factory:
-            self.instance_state_factory = instance_state_factory
         if deferred_scalar_loader:
             self.deferred_scalar_loader = deferred_scalar_loader
     
@@ -1229,7 +942,16 @@ class ClassManager(dict):
         if self.new_init:
             self.uninstall_member('__init__')
             self.new_init = None
-
+    
+    def _create_instance_state(self, instance):
+        global state
+        if state is None:
+            from sqlalchemy.orm import state
+        if self.mutable_attributes:
+            return state.MutableAttrInstanceState(instance, self)
+        else:
+            return state.InstanceState(instance, self)
+        
     def manage(self):
         """Mark this instance as the manager for its class."""
         
@@ -1337,11 +1059,11 @@ class ClassManager(dict):
 
     def new_instance(self, state=None):
         instance = self.class_.__new__(self.class_)
-        setattr(instance, self.STATE_ATTR, state or self.instance_state_factory(instance, self))
+        setattr(instance, self.STATE_ATTR, state or self._create_instance_state(instance))
         return instance
 
     def setup_instance(self, instance, state=None):
-        setattr(instance, self.STATE_ATTR, state or self.instance_state_factory(instance, self))
+        setattr(instance, self.STATE_ATTR, state or self._create_instance_state(instance))
     
     def teardown_instance(self, instance):
         delattr(instance, self.STATE_ATTR)
@@ -1355,13 +1077,10 @@ class ClassManager(dict):
         if hasattr(instance, self.STATE_ATTR):
             return False
         else:
-            state = self.instance_state_factory(instance, self)
+            state = self._create_instance_state(instance)
             setattr(instance, self.STATE_ATTR, state)
             return state
     
-    def state_of(self, instance):
-        return getattr(instance, self.STATE_ATTR)
-        
     def state_getter(self):
         """Return a (instance) -> InstanceState callable.
 
@@ -1372,6 +1091,9 @@ class ClassManager(dict):
 
         return attrgetter(self.STATE_ATTR)
     
+    def dict_getter(self):
+        return attrgetter('__dict__')
+        
     def has_state(self, instance):
         return hasattr(instance, self.STATE_ATTR)
         
@@ -1392,6 +1114,9 @@ class _ClassInstrumentationAdapter(ClassManager):
 
     def __init__(self, class_, override, **kw):
         self._adapted = override
+        self._get_state = self._adapted.state_getter(class_)
+        self._get_dict = self._adapted.dict_getter(class_)
+        
         ClassManager.__init__(self, class_, **kw)
 
     def manage(self):
@@ -1453,36 +1178,27 @@ class _ClassInstrumentationAdapter(ClassManager):
         self._adapted.initialize_instance_dict(self.class_, instance)
         
         if state is None:
-            state = self.instance_state_factory(instance, self)
+            state = self._create_instance_state(instance)
             
         # the given instance is assumed to have no state
         self._adapted.install_state(self.class_, instance, state)
-        state.dict = self._adapted.get_instance_dict(self.class_, instance)
         return state
 
     def teardown_instance(self, instance):
         self._adapted.remove_state(self.class_, instance)
 
-    def state_of(self, instance):
-        if hasattr(self._adapted, 'state_of'):
-            return self._adapted.state_of(self.class_, instance)
-        else:
-            getter = self._adapted.state_getter(self.class_)
-            return getter(instance)
-
     def has_state(self, instance):
-        if hasattr(self._adapted, 'has_state'):
-            return self._adapted.has_state(self.class_, instance)
-        else:
-            try:
-                state = self.state_of(instance)
-                return True
-            except exc.NO_STATE:
-                return False
+        try:
+            state = self._get_state(instance)
+            return True
+        except exc.NO_STATE:
+            return False
 
     def state_getter(self):
-        return self._adapted.state_getter(self.class_)
+        return self._get_state
 
+    def dict_getter(self):
+        return self._get_dict
 
 class History(tuple):
     """A 3-tuple of added, unchanged and deleted values.
@@ -1527,7 +1243,7 @@ class History(tuple):
         original = state.committed_state.get(attribute.key, NEVER_SET)
 
         if hasattr(attribute, 'get_collection'):
-            current = attribute.get_collection(state, current)
+            current = attribute.get_collection(state, state.dict, current)
             if original is NO_VALUE:
                 return cls(list(current), (), ())
             elif original is NEVER_SET:
@@ -1564,30 +1280,8 @@ class History(tuple):
 
 HISTORY_BLANK = History(None, None, None)
 
-class PendingCollection(object):
-    """A writable placeholder for an unloaded collection.
-
-    Stores items appended to and removed from a collection that has not yet
-    been loaded. When the collection is loaded, the changes stored in
-    PendingCollection are applied to it to produce the final result.
-
-    """
-    def __init__(self):
-        self.deleted_items = util.IdentitySet()
-        self.added_items = util.OrderedIdentitySet()
-
-    def append(self, value):
-        if value in self.deleted_items:
-            self.deleted_items.remove(value)
-        self.added_items.add(value)
-
-    def remove(self, value):
-        if value in self.added_items:
-            self.added_items.remove(value)
-        self.deleted_items.add(value)
-
 def _conditional_instance_state(obj):
-    if not isinstance(obj, InstanceState):
+    if not isinstance(obj, state.InstanceState):
         obj = instance_state(obj)
     return obj
         
@@ -1697,15 +1391,16 @@ def init_collection(obj, key):
     this usage is deprecated.
     
     """
-
-    return init_state_collection(_conditional_instance_state(obj), key)
+    state = _conditional_instance_state(obj)
+    dict_ = state.dict
+    return init_state_collection(state, dict_, key)
     
-def init_state_collection(state, key):
+def init_state_collection(state, dict_, key):
     """Initialize a collection attribute and return the collection adapter."""
     
     attr = state.get_impl(key)
-    user_data = attr.initialize(state)
-    return attr.get_collection(state, user_data)
+    user_data = attr.initialize(state, dict_)
+    return attr.get_collection(state, dict_, user_data)
 
 def set_committed_value(instance, key, value):
     """Set the value of an attribute with no history events.
@@ -1722,8 +1417,8 @@ def set_committed_value(instance, key, value):
     as though it were part of its original loaded state.
     
     """
-    state = instance_state(instance)
-    state.get_impl(key).set_committed_value(instance, key, value)
+    state, dict_ = instance_state(instance), instance_dict(instance)
+    state.get_impl(key).set_committed_value(state, dict_, key, value)
     
 def set_attribute(instance, key, value):
     """Set the value of an attribute, firing history events.
@@ -1735,8 +1430,8 @@ def set_attribute(instance, key, value):
     by SQLAlchemy.
     
     """
-    state = instance_state(instance)
-    state.get_impl(key).set(state, value, None)
+    state, dict_ = instance_state(instance), instance_dict(instance)
+    state.get_impl(key).set(state, dict_, value, None)
 
 def get_attribute(instance, key):
     """Get the value of an attribute, firing any callables required.
@@ -1748,8 +1443,8 @@ def get_attribute(instance, key):
     by SQLAlchemy.
     
     """
-    state = instance_state(instance)
-    return state.get_impl(key).get(state)
+    state, dict_ = instance_state(instance), instance_dict(instance)
+    return state.get_impl(key).get(state, dict_)
 
 def del_attribute(instance, key):
     """Delete the value of an attribute, firing history events.
@@ -1761,8 +1456,8 @@ def del_attribute(instance, key):
     by SQLAlchemy.
     
     """
-    state = instance_state(instance)
-    state.get_impl(key).delete(state)
+    state, dict_ = instance_state(instance), instance_dict(instance)
+    state.get_impl(key).delete(state, dict_)
 
 def is_instrumented(instance, key):
     """Return True if the given attribute on the given instance is instrumented
@@ -1779,6 +1474,7 @@ class InstrumentationRegistry(object):
 
     _manager_finders = weakref.WeakKeyDictionary()
     _state_finders = util.WeakIdentityMapping()
+    _dict_finders = util.WeakIdentityMapping()
     _extended = False
 
     def create_manager_for_cls(self, class_, **kw):
@@ -1813,6 +1509,7 @@ class InstrumentationRegistry(object):
         manager.factory = factory
         self._manager_finders[class_] = manager.manager_getter()
         self._state_finders[class_] = manager.state_getter()
+        self._dict_finders[class_] = manager.dict_getter()
         return manager
 
     def _collect_management_factories_for(self, cls):
@@ -1852,6 +1549,7 @@ class InstrumentationRegistry(object):
             return finder(cls)
 
     def state_of(self, instance):
+        # this is only called when alternate instrumentation has been established
         if instance is None:
             raise AttributeError("None has no persistent state.")
         try:
@@ -1859,21 +1557,15 @@ class InstrumentationRegistry(object):
         except KeyError:
             raise AttributeError("%r is not instrumented" % instance.__class__)
 
-    def state_or_default(self, instance, default=None):
+    def dict_of(self, instance):
+        # this is only called when alternate instrumentation has been established
         if instance is None:
-            return default
+            raise AttributeError("None has no persistent state.")
         try:
-            finder = self._state_finders[instance.__class__]
+            return self._dict_finders[instance.__class__](instance)
         except KeyError:
-            return default
-        else:
-            try:
-                return finder(instance)
-            except exc.NO_STATE:
-                return default
-            except:
-                raise
-
+            raise AttributeError("%r is not instrumented" % instance.__class__)
+        
     def unregister(self, class_):
         if class_ in self._manager_finders:
             manager = self.manager_of_class(class_)
@@ -1881,6 +1573,7 @@ class InstrumentationRegistry(object):
             manager.dispose()
             del self._manager_finders[class_]
             del self._state_finders[class_]
+            del self._dict_finders[class_]
 
 instrumentation_registry = InstrumentationRegistry()
 
@@ -1894,12 +1587,14 @@ def _install_lookup_strategy(implementation):
     and unit tests specific to this behavior.
     
     """
-    global instance_state
+    global instance_state, instance_dict
     if implementation is util.symbol('native'):
         instance_state = attrgetter(ClassManager.STATE_ATTR)
+        instance_dict = attrgetter("__dict__")
     else:
         instance_state = instrumentation_registry.state_of
-    
+        instance_dict = instrumentation_registry.dict_of
+        
 manager_of_class = instrumentation_registry.manager_of_class
 _create_manager_for_cls = instrumentation_registry.create_manager_for_cls
 _install_lookup_strategy(util.symbol('native'))
index 5903d349276ffe4aba0addc932b64c00ac0b019a..b865c11f46a967cc398330258c529cfe44b6f646 100644 (file)
@@ -472,6 +472,7 @@ class CollectionAdapter(object):
     """
     def __init__(self, attr, owner_state, data):
         self.attr = attr
+        # TODO: figure out what this being a weakref buys us
         self._data = weakref.ref(data)
         self.owner_state = owner_state
         self.link_to_self(data)
@@ -578,7 +579,7 @@ class CollectionAdapter(object):
 
         """
         if initiator is not False and item is not None:
-            return self.attr.fire_append_event(self.owner_state, item, initiator)
+            return self.attr.fire_append_event(self.owner_state, self.owner_state.dict, item, initiator)
         else:
             return item
 
@@ -591,7 +592,7 @@ class CollectionAdapter(object):
 
         """
         if initiator is not False and item is not None:
-            self.attr.fire_remove_event(self.owner_state, item, initiator)
+            self.attr.fire_remove_event(self.owner_state, self.owner_state.dict, item, initiator)
 
     def fire_pre_remove_event(self, initiator=None):
         """Notify that an entity is about to be removed from the collection.
@@ -600,7 +601,7 @@ class CollectionAdapter(object):
         fire_remove_event().
 
         """
-        self.attr.fire_pre_remove_event(self.owner_state, initiator=initiator)
+        self.attr.fire_pre_remove_event(self.owner_state, self.owner_state.dict, initiator=initiator)
 
     def __getstate__(self):
         return {'key': self.attr.key,
index c4ba7852f9b003766e310522cba5d01cbc6bf2d4..f3820eb7cdae0b113584aa00f584d227ef1d2488 100644 (file)
@@ -64,17 +64,21 @@ class DependencyProcessor(object):
     def register_dependencies(self, uowcommit):
         """Tell a ``UOWTransaction`` what mappers are dependent on
         which, with regards to the two or three mappers handled by
-        this ``PropertyLoader``.
+        this ``DependencyProcessor``.
 
-        Also register itself as a *processor* for one of its mappers,
-        which will be executed after that mapper's objects have been
-        saved or before they've been deleted.  The process operation
-        manages attributes and dependent operations upon the objects
-        of one of the involved mappers.
         """
 
         raise NotImplementedError()
 
+    def register_processors(self, uowcommit):
+        """Tell a ``UOWTransaction`` about this object as a processor,
+        which will be executed after that mapper's objects have been
+        saved or before they've been deleted.  The process operation
+        manages attributes and dependent operations between two mappers.
+        
+        """
+        raise NotImplementedError()
+        
     def whose_dependent_on_who(self, state1, state2):
         """Given an object pair assuming `obj2` is a child of `obj1`,
         return a tuple with the dependent object second, or None if
@@ -181,9 +185,13 @@ class OneToManyDP(DependencyProcessor):
         if self.post_update:
             uowcommit.register_dependency(self.mapper, self.dependency_marker)
             uowcommit.register_dependency(self.parent, self.dependency_marker)
-            uowcommit.register_processor(self.dependency_marker, self, self.parent)
         else:
             uowcommit.register_dependency(self.parent, self.mapper)
+
+    def register_processors(self, uowcommit):
+        if self.post_update:
+            uowcommit.register_processor(self.dependency_marker, self, self.parent)
+        else:
             uowcommit.register_processor(self.parent, self, self.parent)
 
     def process_dependencies(self, task, deplist, uowcommit, delete = False):
@@ -257,11 +265,13 @@ class OneToManyDP(DependencyProcessor):
                                 uowcommit.register_object(
                                     attributes.instance_state(c),
                                     isdelete=True)
-                if not self.passive_updates and self._pks_changed(uowcommit, state):
+                if self._pks_changed(uowcommit, state):
                     if not history:
-                        history = uowcommit.get_attribute_history(state, self.key, passive=False)
-                    for child in history.unchanged:
-                        uowcommit.register_object(child)
+                        history = uowcommit.get_attribute_history(state, self.key, passive=self.passive_updates)
+                    if history:
+                        for child in history.unchanged:
+                            if child is not None:
+                                uowcommit.register_object(child)
 
     def _synchronize(self, state, child, associationrow, clearkeys, uowcommit):
         source = state
@@ -275,7 +285,7 @@ class OneToManyDP(DependencyProcessor):
             sync.populate(source, self.parent, dest, self.mapper, self.prop.synchronize_pairs)
 
     def _pks_changed(self, uowcommit, state):
-        return sync.source_changes(uowcommit, state, self.parent, self.prop.synchronize_pairs)
+        return sync.source_modified(uowcommit, state, self.parent, self.prop.synchronize_pairs)
 
 class DetectKeySwitch(DependencyProcessor):
     """a special DP that works for many-to-one relations, fires off for
@@ -284,6 +294,9 @@ class DetectKeySwitch(DependencyProcessor):
     no_dependencies = True
 
     def register_dependencies(self, uowcommit):
+        pass
+
+    def register_processors(self, uowcommit):
         uowcommit.register_processor(self.parent, self, self.mapper)
 
     def preprocess_dependencies(self, task, deplist, uowcommit, delete=False):
@@ -314,11 +327,11 @@ class DetectKeySwitch(DependencyProcessor):
                     elem.dict[self.key] is not None and 
                     attributes.instance_state(elem.dict[self.key]) in switchers
                 ]:
-                uowcommit.register_object(s, listonly=self.passive_updates)
+                uowcommit.register_object(s)
                 sync.populate(attributes.instance_state(s.dict[self.key]), self.mapper, s, self.parent, self.prop.synchronize_pairs)
 
     def _pks_changed(self, uowcommit, state):
-        return sync.source_changes(uowcommit, state, self.mapper, self.prop.synchronize_pairs)
+        return sync.source_modified(uowcommit, state, self.mapper, self.prop.synchronize_pairs)
 
 class ManyToOneDP(DependencyProcessor):
     def __init__(self, prop):
@@ -329,12 +342,15 @@ class ManyToOneDP(DependencyProcessor):
         if self.post_update:
             uowcommit.register_dependency(self.mapper, self.dependency_marker)
             uowcommit.register_dependency(self.parent, self.dependency_marker)
-            uowcommit.register_processor(self.dependency_marker, self, self.parent)
         else:
             uowcommit.register_dependency(self.mapper, self.parent)
+    
+    def register_processors(self, uowcommit):
+        if self.post_update:
+            uowcommit.register_processor(self.dependency_marker, self, self.parent)
+        else:
             uowcommit.register_processor(self.mapper, self, self.parent)
 
-
     def process_dependencies(self, task, deplist, uowcommit, delete=False):
         if delete:
             if self.post_update and not self.cascade.delete_orphan and not self.passive_deletes == 'all':
@@ -407,8 +423,10 @@ class ManyToManyDP(DependencyProcessor):
 
         uowcommit.register_dependency(self.parent, self.dependency_marker)
         uowcommit.register_dependency(self.mapper, self.dependency_marker)
-        uowcommit.register_processor(self.dependency_marker, self, self.parent)
 
+    def register_processors(self, uowcommit):
+        uowcommit.register_processor(self.dependency_marker, self, self.parent)
+        
     def process_dependencies(self, task, deplist, uowcommit, delete = False):
         connection = uowcommit.transaction.connection(self.mapper)
         secondary_delete = []
@@ -502,7 +520,7 @@ class ManyToManyDP(DependencyProcessor):
         sync.populate_dict(child, self.mapper, associationrow, self.prop.secondary_synchronize_pairs)
 
     def _pks_changed(self, uowcommit, state):
-        return sync.source_changes(uowcommit, state, self.parent, self.prop.synchronize_pairs)
+        return sync.source_modified(uowcommit, state, self.parent, self.prop.synchronize_pairs)
 
 class MapperStub(object):
     """Represent a many-to-many dependency within a flush 
@@ -526,6 +544,9 @@ class MapperStub(object):
     def _register_dependencies(self, uowcommit):
         pass
 
+    def _register_procesors(self, uowcommit):
+        pass
+
     def _save_obj(self, *args, **kwargs):
         pass
 
index 3d31a686a2e5f75f3cb039682804b414e92d6007..70243291dc3e279bdae383bbe02b7ae1b157f213 100644 (file)
@@ -55,21 +55,21 @@ class DynamicAttributeImpl(attributes.AttributeImpl):
         else:
             self.query_class = mixin_user_query(query_class)
 
-    def get(self, state, passive=False):
+    def get(self, state, dict_, passive=False):
         if passive:
             return self._get_collection_history(state, passive=True).added_items
         else:
             return self.query_class(self, state)
 
-    def get_collection(self, state, user_data=None, passive=True):
+    def get_collection(self, state, dict_, user_data=None, passive=True):
         if passive:
             return self._get_collection_history(state, passive=passive).added_items
         else:
             history = self._get_collection_history(state, passive=passive)
             return history.added_items + history.unchanged_items
 
-    def fire_append_event(self, state, value, initiator):
-        collection_history = self._modified_event(state)
+    def fire_append_event(self, state, dict_, value, initiator):
+        collection_history = self._modified_event(state, dict_)
         collection_history.added_items.append(value)
 
         for ext in self.extensions:
@@ -78,8 +78,8 @@ class DynamicAttributeImpl(attributes.AttributeImpl):
         if self.trackparent and value is not None:
             self.sethasparent(attributes.instance_state(value), True)
 
-    def fire_remove_event(self, state, value, initiator):
-        collection_history = self._modified_event(state)
+    def fire_remove_event(self, state, dict_, value, initiator):
+        collection_history = self._modified_event(state, dict_)
         collection_history.deleted_items.append(value)
 
         if self.trackparent and value is not None:
@@ -88,31 +88,31 @@ class DynamicAttributeImpl(attributes.AttributeImpl):
         for ext in self.extensions:
             ext.remove(state, value, initiator or self)
 
-    def _modified_event(self, state):
+    def _modified_event(self, state, dict_):
 
         if self.key not in state.committed_state:
             state.committed_state[self.key] = CollectionHistory(self, state)
 
-        state.modified_event(self, False, attributes.NEVER_SET, passive=attributes.PASSIVE_NO_INITIALIZE)
+        state.modified_event(dict_, self, False, attributes.NEVER_SET, passive=attributes.PASSIVE_NO_INITIALIZE)
 
         # this is a hack to allow the _base.ComparableEntity fixture
         # to work
-        state.dict[self.key] = True
+        dict_[self.key] = True
         return state.committed_state[self.key]
 
-    def set(self, state, value, initiator):
+    def set(self, state, dict_, value, initiator):
         if initiator is self:
             return
 
-        self._set_iterable(state, value)
+        self._set_iterable(state, dict_, value)
 
-    def _set_iterable(self, state, iterable, adapter=None):
+    def _set_iterable(self, state, dict_, iterable, adapter=None):
 
-        collection_history = self._modified_event(state)
+        collection_history = self._modified_event(state, dict_)
         new_values = list(iterable)
 
         if _state_has_identity(state):
-            old_collection = list(self.get(state))
+            old_collection = list(self.get(state, dict_))
         else:
             old_collection = []
 
@@ -121,7 +121,7 @@ class DynamicAttributeImpl(attributes.AttributeImpl):
     def delete(self, *args, **kwargs):
         raise NotImplementedError()
 
-    def get_history(self, state, passive=False):
+    def get_history(self, state, dict_, passive=False):
         c = self._get_collection_history(state, passive)
         return attributes.History(c.added_items, c.unchanged_items, c.deleted_items)
 
@@ -136,13 +136,13 @@ class DynamicAttributeImpl(attributes.AttributeImpl):
         else:
             return c
 
-    def append(self, state, value, initiator, passive=False):
+    def append(self, state, dict_, value, initiator, passive=False):
         if initiator is not self:
-            self.fire_append_event(state, value, initiator)
+            self.fire_append_event(state, dict_, value, initiator)
 
-    def remove(self, state, value, initiator, passive=False):
+    def remove(self, state, dict_, value, initiator, passive=False):
         if initiator is not self:
-            self.fire_remove_event(state, value, initiator)
+            self.fire_remove_event(state, dict_, value, initiator)
 
 class DynCollectionAdapter(object):
     """the dynamic analogue to orm.collections.CollectionAdapter"""
@@ -156,10 +156,10 @@ class DynCollectionAdapter(object):
         return iter(self.data)
 
     def append_with_event(self, item, initiator=None):
-        self.attr.append(self.state, item, initiator)
+        self.attr.append(self.state, self.state.dict, item, initiator)
 
     def remove_with_event(self, item, initiator=None):
-        self.attr.remove(self.state, item, initiator)
+        self.attr.remove(self.state, self.state.dict, item, initiator)
 
     def append_without_event(self, item):
         pass
@@ -240,10 +240,10 @@ class AppenderMixin(object):
         return query
 
     def append(self, item):
-        self.attr.append(attributes.instance_state(self.instance), item, None)
+        self.attr.append(attributes.instance_state(self.instance), attributes.instance_dict(self.instance), item, None)
 
     def remove(self, item):
-        self.attr.remove(attributes.instance_state(self.instance), item, None)
+        self.attr.remove(attributes.instance_state(self.instance), attributes.instance_dict(self.instance), item, None)
 
 
 class AppenderQuery(AppenderMixin, Query):
index 0829d18015ee9af2e863989fb2bf068ffb0f319b..71527c686dd1046f0fa5da7deefef0f4b3b0901b 100644 (file)
@@ -12,9 +12,12 @@ from sqlalchemy.orm import attributes
 
 class IdentityMap(dict):
     def __init__(self):
-        self._mutable_attrs = {}
-        self.modified = False
+        self._mutable_attrs = set()
+        self._modified = set()
         self._wr = weakref.ref(self)
+
+    def replace(self, state):
+        raise NotImplementedError()
         
     def add(self, state):
         raise NotImplementedError()
@@ -31,28 +34,29 @@ class IdentityMap(dict):
     def _manage_incoming_state(self, state):
         state._instance_dict = self._wr
         
-        if state.modified:  
-            self.modified = True
+        if state.modified:
+            self._modified.add(state)  
         if state.manager.mutable_attributes:
-            self._mutable_attrs[state] = True
+            self._mutable_attrs.add(state)
     
     def _manage_removed_state(self, state):
         del state._instance_dict
+        self._mutable_attrs.discard(state)
+        self._modified.discard(state)
+    
+    def _dirty_states(self):
+        return self._modified.union(s for s in self._mutable_attrs if s.modified)
         
-        if state in self._mutable_attrs:
-            del self._mutable_attrs[state]
-            
     def check_modified(self):
         """return True if any InstanceStates present have been marked as 'modified'."""
         
-        if not self.modified:
-            for state in list(self._mutable_attrs):
-                if state.check_modified():
-                    return True
-            else:
-                return False
-        else:
+        if self._modified:
             return True
+        else:
+            for state in self._mutable_attrs:
+                if state.modified:
+                    return True
+        return False
             
     def has_key(self, key):
         return key in self
@@ -102,6 +106,17 @@ class WeakInstanceDict(IdentityMap):
     def contains_state(self, state):
         return dict.get(self, state.key) is state
         
+    def replace(self, state):
+        if dict.__contains__(self, state.key):
+            existing = dict.__getitem__(self, state.key)
+            if existing is not state:
+                self._manage_removed_state(existing)
+            else:
+                return
+                
+        dict.__setitem__(self, state.key, state)
+        self._manage_incoming_state(state)
+                 
     def add(self, state):
         if state.key in self:
             if dict.__getitem__(self, state.key) is not state:
@@ -176,12 +191,24 @@ class StrongInstanceDict(IdentityMap):
     def contains_state(self, state):
         return state.key in self and attributes.instance_state(self[state.key]) is state
     
+    def replace(self, state):
+        if dict.__contains__(self, state.key):
+            existing = dict.__getitem__(self, state.key)
+            existing = attributes.instance_state(existing)
+            if existing is not state:
+                self._manage_removed_state(existing)
+            else:
+                return
+
+        dict.__setitem__(self, state.key, state.obj())
+        self._manage_incoming_state(state)
+        
     def add(self, state):
         dict.__setitem__(self, state.key, state.obj())
         self._manage_incoming_state(state)
     
     def remove(self, state):
-        if dict.pop(self, state.key) is not state:
+        if attributes.instance_state(dict.pop(self, state.key)) is not state:
             raise AssertionError("State %s is not present in this identity map" % state)
         self._manage_removed_state(state)
     
@@ -191,7 +218,7 @@ class StrongInstanceDict(IdentityMap):
             self._manage_removed_state(state)
             
     def remove_key(self, key):
-        state = dict.__getitem__(self, key)
+        state = attributes.instance_state(dict.__getitem__(self, key))
         self.remove(state)
 
     def prune(self):
@@ -205,62 +232,3 @@ class StrongInstanceDict(IdentityMap):
         self.modified = bool(dirty)
         return ref_count - len(self)
         
-class IdentityManagedState(attributes.InstanceState):
-    def _instance_dict(self):
-        return None
-    
-    def modified_event(self, attr, should_copy, previous, passive=False):
-        attributes.InstanceState.modified_event(self, attr, should_copy, previous, passive)
-        
-        instance_dict = self._instance_dict()
-        if instance_dict:
-            instance_dict.modified = True
-    
-    def _is_really_none(self):
-        """do a check modified/resurrect.
-        
-        This would be called in the extremely rare
-        race condition that the weakref returned None but
-        the cleanup handler had not yet established the 
-        __resurrect callable as its replacement.
-        
-        """
-        if self.check_modified():
-            self.obj = self.__resurrect
-            return self.obj()
-        else:
-            return None
-            
-    def _cleanup(self, ref):
-        """weakref callback.
-        
-        This method may be called by an asynchronous
-        gc.
-        
-        If the state shows pending changes, the weakref
-        is replaced by the __resurrect callable which will
-        re-establish an object reference on next access,
-        else removes this InstanceState from the owning
-        identity map, if any.
-        
-        """
-        if self.check_modified():
-            self.obj = self.__resurrect
-        else:
-            instance_dict = self._instance_dict()
-            if instance_dict:
-                instance_dict.remove(self)
-            self.dispose()
-            
-    def __resurrect(self):
-        """A substitute for the obj() weakref function which resurrects."""
-        
-        # store strong ref'ed version of the object; will revert
-        # to weakref when changes are persisted
-        obj = self.manager.new_instance(state=self)
-        self.obj = weakref.ref(obj, self._cleanup)
-        self._strong_obj = obj
-        obj.__dict__.update(self.dict)
-        self.dict = obj.__dict__
-        self._run_on_load(obj)
-        return obj
index d36f51194e09aec35cd6bd08490998fca4e06f66..0ac771305833137686b1ac126a43cc4023c4a445 100644 (file)
@@ -359,7 +359,7 @@ class MapperProperty(object):
 
         Callables are of the following form::
 
-            def new_execute(state, row, **flags):
+            def new_execute(state, dict_, row, **flags):
                 # process incoming instance state and given row.  the instance is
                 # "new" and was just created upon receipt of this row.
                 # flags is a dictionary containing at least the following
@@ -368,7 +368,7 @@ class MapperProperty(object):
                 #           result of reading this row
                 #   instancekey - identity key of the instance
 
-            def existing_execute(state, row, **flags):
+            def existing_execute(state, dict_, row, **flags):
                 # process incoming instance state and given row.  the instance is
                 # "existing" and was created based on a previous row.
 
@@ -427,13 +427,23 @@ class MapperProperty(object):
     def register_dependencies(self, *args, **kwargs):
         """Called by the ``Mapper`` in response to the UnitOfWork
         calling the ``Mapper``'s register_dependencies operation.
-        Should register with the UnitOfWork all inter-mapper
-        dependencies as well as dependency processors (see UOW docs
-        for more details).
+        Establishes a topological dependency between two mappers
+        which will affect the order in which mappers persist data.
+        
         """
 
         pass
 
+    def register_processors(self, *args, **kwargs):
+        """Called by the ``Mapper`` in response to the UnitOfWork
+        calling the ``Mapper``'s register_processors operation.
+        Establishes a processor object between two mappers which
+        will link data and state between parent/child objects.
+        
+        """
+
+        pass
+        
     def is_primary(self):
         """Return True if this ``MapperProperty``'s mapper is the
         primary mapper for its class.
@@ -939,3 +949,7 @@ class InstrumentationManager(object):
 
     def state_getter(self, class_):
         return lambda instance: getattr(instance, '_default_state')
+
+    def dict_getter(self, class_):
+        return lambda inst: self.get_instance_dict(class_, inst)
+        
\ No newline at end of file
index e5dbb4d03978734a61ff3f93a91a202132b703fb..1502060f02577a14eb24cb2ff95316a708352e61 100644 (file)
@@ -23,7 +23,6 @@ deque = __import__('collections').deque
 from sqlalchemy import sql, util, log, exc as sa_exc
 from sqlalchemy.sql import expression, visitors, operators, util as sqlutil
 from sqlalchemy.orm import attributes, exc, sync
-from sqlalchemy.orm.identity import IdentityManagedState
 from sqlalchemy.orm.interfaces import (
     MapperProperty, EXT_CONTINUE, PropComparator
     )
@@ -255,7 +254,8 @@ class Mapper(object):
 
             for mapper in self.iterate_to_root():
                 util.reset_memoized(mapper, '_equivalent_columns')
-
+                util.reset_memoized(mapper, '_sorted_tables')
+                
             if self.order_by is False and not self.concrete and self.inherits.order_by is not False:
                 self.order_by = self.inherits.order_by
 
@@ -357,7 +357,6 @@ class Mapper(object):
 
         if manager is None:
             manager = attributes.register_class(self.class_, 
-                instance_state_factory = IdentityManagedState,
                 deferred_scalar_loader = _load_scalar_attributes
             )
 
@@ -372,6 +371,8 @@ class Mapper(object):
         event_registry = manager.events
         event_registry.add_listener('on_init', _event_on_init)
         event_registry.add_listener('on_init_failure', _event_on_init_failure)
+        event_registry.add_listener('on_resurrect', _event_on_resurrect)
+        
         for key, method in util.iterate_attributes(self.class_):
             if isinstance(method, types.FunctionType):
                 if hasattr(method, '__sa_reconstructor__'):
@@ -682,7 +683,7 @@ class Mapper(object):
         for key, prop in l:
             self._log("initialize prop " + key)
             
-            if not prop._compile_started:
+            if prop.parent is self and not prop._compile_started:
                 prop.init()
             
             if prop._compile_finished:
@@ -1173,6 +1174,19 @@ class Mapper(object):
 
     # persistence
 
+    @util.memoized_property
+    def _sorted_tables(self):
+        table_to_mapper = {}
+        for mapper in self.base_mapper.polymorphic_iterator():
+            for t in mapper.tables:
+                table_to_mapper[t] = mapper
+        
+        sorted_ = sqlutil.sort_tables(table_to_mapper.iterkeys())
+        ret = util.OrderedDict()
+        for t in sorted_:
+            ret[t] = table_to_mapper[t]
+        return ret
+
     def _save_obj(self, states, uowtransaction, postupdate=False, post_update_cols=None, single=False):
         """Issue ``INSERT`` and/or ``UPDATE`` statements for a list of objects.
 
@@ -1198,16 +1212,37 @@ class Mapper(object):
 
         # if session has a connection callable,
         # organize individual states with the connection to use for insert/update
+        tups = []
         if 'connection_callable' in uowtransaction.mapper_flush_opts:
             connection_callable = uowtransaction.mapper_flush_opts['connection_callable']
-            tups = [(state, _state_mapper(state), connection_callable(self, state.obj()), _state_has_identity(state)) for state in _sort_states(states)]
+            for state in _sort_states(states):
+                m = _state_mapper(state)
+                tups.append(
+                    (
+                        state, 
+                        m, 
+                        connection_callable(self, state.obj()), 
+                        _state_has_identity(state), 
+                        state.key or m._identity_key_from_state(state)
+                    )
+                )
         else:
             connection = uowtransaction.transaction.connection(self)
-            tups = [(state, _state_mapper(state), connection, _state_has_identity(state)) for state in _sort_states(states)]
+            for state in _sort_states(states):
+                m = _state_mapper(state)
+                tups.append(
+                    (
+                        state, 
+                        m, 
+                        connection,
+                        _state_has_identity(state), 
+                        state.key or m._identity_key_from_state(state)
+                    )
+                )
 
         if not postupdate:
             # call before_XXX extensions
-            for state, mapper, connection, has_identity in tups:
+            for state, mapper, connection, has_identity, instance_key in tups:
                 if not has_identity:
                     if 'before_insert' in mapper.extension:
                         mapper.extension.before_insert(mapper, connection, state.obj())
@@ -1215,39 +1250,44 @@ class Mapper(object):
                     if 'before_update' in mapper.extension:
                         mapper.extension.before_update(mapper, connection, state.obj())
 
-        for state, mapper, connection, has_identity in tups:
-            # detect if we have a "pending" instance (i.e. has no instance_key attached to it),
-            # and another instance with the same identity key already exists as persistent.  convert to an
-            # UPDATE if so.
-            instance_key = mapper._identity_key_from_state(state)
-            if not postupdate and not has_identity and instance_key in uowtransaction.session.identity_map:
-                instance = uowtransaction.session.identity_map[instance_key]
-                existing = attributes.instance_state(instance)
-                if not uowtransaction.is_deleted(existing):
-                    raise exc.FlushError("New instance %s with identity key %s conflicts with persistent instance %s" % (state_str(state), str(instance_key), state_str(existing)))
-                if self._should_log_debug:
-                    self._log_debug("detected row switch for identity %s.  will update %s, remove %s from transaction" % (instance_key, state_str(state), state_str(existing)))
-                uowtransaction.set_row_switch(existing)
-
-        table_to_mapper = {}
-        for mapper in self.base_mapper.polymorphic_iterator():
-            for t in mapper.tables:
-                table_to_mapper[t] = mapper
+        row_switches = set()
+        if not postupdate:
+            for state, mapper, connection, has_identity, instance_key in tups:
+                # detect if we have a "pending" instance (i.e. has no instance_key attached to it),
+                # and another instance with the same identity key already exists as persistent.  convert to an
+                # UPDATE if so.
+                if not has_identity and instance_key in uowtransaction.session.identity_map:
+                    instance = uowtransaction.session.identity_map[instance_key]
+                    existing = attributes.instance_state(instance)
+                    if not uowtransaction.is_deleted(existing):
+                        raise exc.FlushError(
+                            "New instance %s with identity key %s conflicts with persistent instance %s" % 
+                            (state_str(state), instance_key, state_str(existing)))
+                    if self._should_log_debug:
+                        self._log_debug(
+                            "detected row switch for identity %s.  will update %s, remove %s from transaction", 
+                            instance_key, state_str(state), state_str(existing))
+                            
+                    # remove the "delete" flag from the existing element
+                    uowtransaction.set_row_switch(existing)
+                    row_switches.add(state)
+        
+        table_to_mapper = self._sorted_tables
 
-        for table in sqlutil.sort_tables(table_to_mapper.iterkeys()):
+        for table in table_to_mapper.iterkeys():
             insert = []
             update = []
 
-            for state, mapper, connection, has_identity in tups:
+            for state, mapper, connection, has_identity, instance_key in tups:
                 if table not in mapper._pks_by_table:
                     continue
+                    
                 pks = mapper._pks_by_table[table]
-                instance_key = mapper._identity_key_from_state(state)
-
+                
                 if self._should_log_debug:
                     self._log_debug("_save_obj() table '%s' instance %s identity %s" % (table.name, state_str(state), str(instance_key)))
 
-                isinsert = not instance_key in uowtransaction.session.identity_map and not postupdate and not has_identity
+                isinsert = not has_identity and not postupdate and state not in row_switches
                 
                 params = {}
                 value_params = {}
@@ -1257,10 +1297,6 @@ class Mapper(object):
                     for col in mapper._cols_by_table[table]:
                         if col is mapper.version_id_col:
                             params[col.key] = 1
-                        elif col in pks:
-                            value = mapper._get_state_attr_by_column(state, col)
-                            if value is not None:
-                                params[col.key] = value
                         elif mapper.polymorphic_on and mapper.polymorphic_on.shares_lineage(col):
                             if self._should_log_debug:
                                 self._log_debug("Using polymorphic identity '%s' for insert column '%s'" % (mapper.polymorphic_identity, col.key))
@@ -1269,6 +1305,10 @@ class Mapper(object):
                                  col.server_default is None) or
                                 value is not None):
                                 params[col.key] = value
+                        elif col in pks:
+                            value = mapper._get_state_attr_by_column(state, col)
+                            if value is not None:
+                                params[col.key] = value
                         else:
                             value = mapper._get_state_attr_by_column(state, col)
                             if ((col.default is None and
@@ -1364,7 +1404,7 @@ class Mapper(object):
                             sync.populate(state, m, state, m, m._inherits_equated_pairs)
 
         if not postupdate:
-            for state, mapper, connection, has_identity in tups:
+            for state, mapper, connection, has_identity, instance_key in tups:
 
                 # expire readonly attributes
                 readonly = state.unmodified.intersection(
@@ -1434,12 +1474,9 @@ class Mapper(object):
             if 'before_delete' in mapper.extension:
                 mapper.extension.before_delete(mapper, connection, state.obj())
 
-        table_to_mapper = {}
-        for mapper in self.base_mapper.polymorphic_iterator():
-            for t in mapper.tables:
-                table_to_mapper[t] = mapper
+        table_to_mapper = self._sorted_tables
 
-        for table in reversed(sqlutil.sort_tables(table_to_mapper.iterkeys())):
+        for table in reversed(table_to_mapper.keys()):
             delete = {}
             for state, mapper, connection in tups:
                 if table not in mapper._pks_by_table:
@@ -1485,6 +1522,10 @@ class Mapper(object):
         for dep in self._props.values() + self._dependency_processors:
             dep.register_dependencies(uowcommit)
 
+    def _register_processors(self, uowcommit):
+        for dep in self._props.values() + self._dependency_processors:
+            dep.register_processors(uowcommit)
+
     # result set conversion
 
     def _instance_processor(self, context, path, adapter, polymorphic_from=None, extension=None, only_load_props=None, refresh_state=None, polymorphic_discriminator=None):
@@ -1514,7 +1555,13 @@ class Mapper(object):
         new_populators = []
         existing_populators = []
 
-        def populate_state(state, row, isnew, only_load_props, **flags):
+        def populate_state(state, dict_, row, isnew, only_load_props, **flags):
+            if isnew:
+                if context.options:
+                    state.load_options = context.options
+                if state.load_options:
+                    state.load_path = context.query._current_path + path
+
             if isnew:
                 if context.options:
                     state.load_options = context.options
@@ -1533,7 +1580,7 @@ class Mapper(object):
                 populators = [p for p in populators if p[0] in only_load_props]
 
             for key, populator in populators:
-                populator(state, row, isnew=isnew, **flags)
+                populator(state, dict_, row, isnew=isnew, **flags)
 
         session_identity_map = context.session.identity_map
 
@@ -1573,9 +1620,11 @@ class Mapper(object):
             if identitykey in session_identity_map:
                 instance = session_identity_map[identitykey]
                 state = attributes.instance_state(instance)
+                dict_ = attributes.instance_dict(instance)
 
                 if self._should_log_debug:
-                    self._log_debug("_instance(): using existing instance %s identity %s" % (instance_str(instance), identitykey))
+                    self._log_debug("_instance(): using existing instance %s identity %s",
+                                        instance_str(instance), identitykey)
 
                 isnew = state.runid != context.runid
                 currentload = not isnew
@@ -1592,12 +1641,13 @@ class Mapper(object):
                 # when eager_defaults is True.
                 state = refresh_state
                 instance = state.obj()
+                dict_ = attributes.instance_dict(instance)
                 isnew = state.runid != context.runid
                 currentload = True
                 loaded_instance = False
             else:
                 if self._should_log_debug:
-                    self._log_debug("_instance(): identity key %s not in session" % str(identitykey))
+                    self._log_debug("_instance(): identity key %s not in session", identitykey)
 
                 if self.allow_null_pks:
                     for x in identitykey[1]:
@@ -1625,8 +1675,10 @@ class Mapper(object):
                     instance = self.class_manager.new_instance()
 
                 if self._should_log_debug:
-                    self._log_debug("_instance(): created new instance %s identity %s" % (instance_str(instance), str(identitykey)))
+                    self._log_debug("_instance(): created new instance %s identity %s",
+                                instance_str(instance), identitykey)
 
+                dict_ = attributes.instance_dict(instance)
                 state = attributes.instance_state(instance)
                 state.key = identitykey
 
@@ -1638,12 +1690,12 @@ class Mapper(object):
             if currentload or populate_existing:
                 if isnew:
                     state.runid = context.runid
-                    context.progress.add(state)
+                    context.progress[state] = dict_
 
                 if not populate_instance or \
                         populate_instance(self, context, row, instance, 
                             only_load_props=only_load_props, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE:
-                    populate_state(state, row, isnew, only_load_props)
+                    populate_state(state, dict_, row, isnew, only_load_props)
 
             else:
                 # populate attributes on non-loading instances which have been expired
@@ -1652,16 +1704,16 @@ class Mapper(object):
 
                     if state in context.partials:
                         isnew = False
-                        attrs = context.partials[state]
+                        (d_, attrs) = context.partials[state]
                     else:
                         isnew = True
                         attrs = state.unloaded
-                        context.partials[state] = attrs  #<-- allow query.instances to commit the subset of attrs
+                        context.partials[state] = (dict_, attrs)  #<-- allow query.instances to commit the subset of attrs
 
                     if not populate_instance or \
                             populate_instance(self, context, row, instance, 
                                 only_load_props=attrs, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE:
-                        populate_state(state, row, isnew, attrs, instancekey=identitykey)
+                        populate_state(state, dict_, row, isnew, attrs, instancekey=identitykey)
 
             if loaded_instance:
                 state._run_on_load(instance)
@@ -1759,6 +1811,14 @@ def _event_on_init_failure(state, instance, args, kwargs):
             instrumenting_mapper, instrumenting_mapper.class_,
             state.manager.events.original_init, instance, args, kwargs)
 
+def _event_on_resurrect(state, instance):
+    # re-populate the primary key elements
+    # of the dict based on the mapping.
+    instrumenting_mapper = state.manager.info[_INSTRUMENTOR]
+    for col, val in zip(instrumenting_mapper.primary_key, state.key[1]):
+        instrumenting_mapper._set_state_attr_by_column(state, col, val)
+    
+    
 def _sort_states(states):
     return sorted(states, key=operator.attrgetter('sort_key'))
 
index 398cbe5d989a8a2c9447189e22d012619b1754bb..5605cdcd1e83fbbf1dc31a4f00a11d32603b068a 100644 (file)
@@ -96,13 +96,13 @@ class ColumnProperty(StrategizedProperty):
         return ColumnProperty(deferred=self.deferred, group=self.group, *self.columns)
 
     def getattr(self, state, column):
-        return state.get_impl(self.key).get(state)
+        return state.get_impl(self.key).get(state, state.dict)
 
     def getcommitted(self, state, column, passive=False):
-        return state.get_impl(self.key).get_committed_value(state, passive=passive)
+        return state.get_impl(self.key).get_committed_value(state, state.dict, passive=passive)
 
     def setattr(self, state, value, column):
-        state.get_impl(self.key).set(state, value, None)
+        state.get_impl(self.key).set(state, state.dict, value, None)
 
     def merge(self, session, source, dest, dont_load, _recursive):
         value = attributes.instance_state(source).value_as_iterable(
@@ -159,7 +159,7 @@ class CompositeProperty(ColumnProperty):
         super(ColumnProperty, self).do_init()
 
     def getattr(self, state, column):
-        obj = state.get_impl(self.key).get(state)
+        obj = state.get_impl(self.key).get(state, state.dict)
         return self.get_col_value(column, obj)
 
     def getcommitted(self, state, column, passive=False):
@@ -168,7 +168,7 @@ class CompositeProperty(ColumnProperty):
 
     def setattr(self, state, value, column):
 
-        obj = state.get_impl(self.key).get(state)
+        obj = state.get_impl(self.key).get(state, state.dict)
         if obj is None:
             obj = self.composite_class(*[None for c in self.columns])
             state.get_impl(self.key).set(state, obj, None)
@@ -635,7 +635,7 @@ class RelationProperty(StrategizedProperty):
                     return
 
         source_state = attributes.instance_state(source)
-        dest_state = attributes.instance_state(dest)
+        dest_state, dest_dict = attributes.instance_state(dest), attributes.instance_dict(dest)
 
         if not "merge" in self.cascade:
             dest_state.expire_attributes([self.key])
@@ -658,7 +658,7 @@ class RelationProperty(StrategizedProperty):
                 for c in dest_list:
                     coll.append_without_event(c)
             else:
-                getattr(dest.__class__, self.key).impl._set_iterable(dest_state, dest_list)
+                getattr(dest.__class__, self.key).impl._set_iterable(dest_state, dest_dict, dest_list)
         else:
             current = instances[0]
             if current is not None:
@@ -839,8 +839,8 @@ class RelationProperty(StrategizedProperty):
                     if self._foreign_keys:
                         raise sa_exc.ArgumentError("Could not determine relation direction for "
                             "primaryjoin condition '%s', on relation %s. "
-                            "Are the columns in 'foreign_keys' present within the given "
-                            "join condition ?" % (self.primaryjoin, self))
+                            "Do the columns in 'foreign_keys' represent only the 'foreign' columns "
+                            "in this join condition ?" % (self.primaryjoin, self))
                     else:
                         raise sa_exc.ArgumentError("Could not determine relation direction for "
                             "primaryjoin condition '%s', on relation %s. "
@@ -1119,6 +1119,10 @@ class RelationProperty(StrategizedProperty):
         if not self.viewonly:
             self._dependency_processor.register_dependencies(uowcommit)
 
+    def register_processors(self, uowcommit):
+        if not self.viewonly:
+            self._dependency_processor.register_processors(uowcommit)
+
 PropertyLoader = RelationProperty
 log.class_logger(RelationProperty)
 
index 533ec9aa522e42dd2b558b90be161903a6cedb2c..be40b08c658325df289d6465c551f7e8c42ad721 100644 (file)
@@ -1330,7 +1330,7 @@ class Query(object):
             rowtuple.keys = labels.keys
 
         while True:
-            context.progress = set()
+            context.progress = {}
             context.partials = {}
 
             if self._yield_per:
@@ -1354,13 +1354,13 @@ class Query(object):
                 rows = filter(rows)
 
             if context.refresh_state and self._only_load_props and context.refresh_state in context.progress:
-                context.refresh_state.commit(self._only_load_props)
-                context.progress.remove(context.refresh_state)
+                context.refresh_state.commit(context.refresh_state.dict, self._only_load_props)
+                context.progress.pop(context.refresh_state)
 
             session._finalize_loaded(context.progress)
 
-            for ii, attrs in context.partials.iteritems():
-                ii.commit(attrs)
+            for ii, (dict_, attrs) in context.partials.iteritems():
+                ii.commit(dict_, attrs)
 
             for row in rows:
                 yield row
@@ -1687,14 +1687,14 @@ class Query(object):
                 evaluated_keys = value_evaluators.keys()
 
                 if issubclass(cls, target_cls) and eval_condition(obj):
-                    state = attributes.instance_state(obj)
+                    state, dict_ = attributes.instance_state(obj), attributes.instance_dict(obj)
 
                     # only evaluate unmodified attributes
                     to_evaluate = state.unmodified.intersection(evaluated_keys)
                     for key in to_evaluate:
-                        state.dict[key] = value_evaluators[key](obj)
+                        dict_[key] = value_evaluators[key](obj)
 
-                    state.commit(list(to_evaluate))
+                    state.commit(dict_, list(to_evaluate))
 
                     # expire attributes with pending changes (there was no autoflush, so they are overwritten)
                     state.expire_attributes(set(evaluated_keys).difference(to_evaluate))
index 1e3a750d950fb51864d34f6b4c770fd8178986b4..cbfb0c1d643a3a2d670b3a29b4968a7dbc38c7c8 100644 (file)
@@ -12,7 +12,7 @@ import sqlalchemy.exceptions as sa_exc
 from sqlalchemy import util, sql, engine, log
 from sqlalchemy.sql import util as sql_util, expression
 from sqlalchemy.orm import (
-    SessionExtension, attributes, exc, query, unitofwork, util as mapperutil,
+    SessionExtension, attributes, exc, query, unitofwork, util as mapperutil, state
     )
 from sqlalchemy.orm.util import object_mapper as _object_mapper
 from sqlalchemy.orm.util import class_mapper as _class_mapper
@@ -299,14 +299,14 @@ class SessionTransaction(object):
             self.session._expunge_state(s)
 
         for s in self.session.identity_map.all_states():
-            _expire_state(s, None)
+            _expire_state(s, None, instance_dict=self.session.identity_map)
 
     def _remove_snapshot(self):
         assert self._is_transaction_boundary
 
         if not self.nested and self.session.expire_on_commit:
             for s in self.session.identity_map.all_states():
-                _expire_state(s, None)
+                _expire_state(s, None, instance_dict=self.session.identity_map)
 
     def _connection_for_bind(self, bind):
         self._assert_is_active()
@@ -899,8 +899,8 @@ class Session(object):
             self.flush()
 
     def _finalize_loaded(self, states):
-        for state in states:
-            state.commit_all()
+        for state, dict_ in states.items():
+            state.commit_all(dict_, self.identity_map)
 
     def refresh(self, instance, attribute_names=None):
         """Refresh the attributes on the given instance.
@@ -935,7 +935,7 @@ class Session(object):
         """Expires all persistent instances within this Session."""
 
         for state in self.identity_map.all_states():
-            _expire_state(state, None)
+            _expire_state(state, None, instance_dict=self.identity_map)
 
     def expire(self, instance, attribute_names=None):
         """Expire the attributes on an instance.
@@ -956,14 +956,14 @@ class Session(object):
             raise exc.UnmappedInstanceError(instance)
         self._validate_persistent(state)
         if attribute_names:
-            _expire_state(state, attribute_names=attribute_names)
+            _expire_state(state, attribute_names=attribute_names, instance_dict=self.identity_map)
         else:
             # pre-fetch the full cascade since the expire is going to
             # remove associations
             cascaded = list(_cascade_state_iterator('refresh-expire', state))
-            _expire_state(state, None)
+            _expire_state(state, None, instance_dict=self.identity_map)
             for (state, m, o) in cascaded:
-                _expire_state(state, None)
+                _expire_state(state, None, instance_dict=self.identity_map)
 
     def prune(self):
         """Remove unreferenced instances cached in the identity map.
@@ -1020,12 +1020,10 @@ class Session(object):
                 # primary key switch
                 self.identity_map.remove(state)
                 state.key = instance_key
-
-            if state.key in self.identity_map and not self.identity_map.contains_state(state):
-                self.identity_map.remove_key(state.key)
-            self.identity_map.add(state)
-            state.commit_all()
-
+            
+            self.identity_map.replace(state)
+            state.commit_all(state.dict, self.identity_map)
+            
         # remove from new last, might be the last strong ref
         if state in self._new:
             if self._enable_transaction_accounting and self.transaction:
@@ -1213,7 +1211,7 @@ class Session(object):
             prop.merge(self, instance, merged, dont_load, _recursive)
 
         if dont_load:
-            attributes.instance_state(merged).commit_all()  # remove any history
+            attributes.instance_state(merged).commit_all(attributes.instance_dict(merged), self.identity_map)  # remove any history
 
         if new_instance:
             merged_state._run_on_load(merged)
@@ -1362,13 +1360,12 @@ class Session(object):
             not self._deleted and not self._new):
             return
 
-        
         dirty = self._dirty_states
         if not dirty and not self._deleted and not self._new:
-            self.identity_map.modified = False
+            self.identity_map._modified.clear()
             return
 
-        flush_context   = UOWTransaction(self)
+        flush_context = UOWTransaction(self)
 
         if self.extensions:
             for ext in self.extensions:
@@ -1391,15 +1388,19 @@ class Session(object):
                     raise exc.UnmappedInstanceError(o)
                 objset.add(state)
         else:
-            # or just everything
-            objset = set(self.identity_map.all_states()).union(new)
+            objset = None
 
         # store objects whose fate has been decided
         processed = set()
 
         # put all saves/updates into the flush context.  detect top-level
         # orphans and throw them into deleted.
-        for state in new.union(dirty).intersection(objset).difference(deleted):
+        if objset:
+            proc = new.union(dirty).intersection(objset).difference(deleted)
+        else:
+            proc = new.union(dirty).difference(deleted)
+            
+        for state in proc:
             is_orphan = _state_mapper(state)._is_orphan(state)
             if is_orphan and not _state_has_identity(state):
                 path = ", nor ".join(
@@ -1415,7 +1416,11 @@ class Session(object):
             processed.add(state)
 
         # put all remaining deletes into the flush context.
-        for state in deleted.intersection(objset).difference(processed):
+        if objset:
+            proc = deleted.intersection(objset).difference(processed)
+        else:
+            proc = deleted.difference(processed)
+        for state in proc:
             flush_context.register_object(state, isdelete=True)
 
         if len(flush_context.tasks) == 0:
@@ -1435,9 +1440,13 @@ class Session(object):
         
         flush_context.finalize_flush_changes()
 
-        if not objects:
-            self.identity_map.modified = False
-
+        # useful assertions:
+        #if not objects:
+        #    assert not self.identity_map._modified
+        #else:
+        #    assert self.identity_map._modified == self.identity_map._modified.difference(objects)
+        #self.identity_map._modified.clear()
+        
         for ext in self.extensions:
             ext.after_flush_postexec(self, flush_context)
 
@@ -1486,10 +1495,7 @@ class Session(object):
         those that were possibly deleted.
 
         """
-        return util.IdentitySet(
-            [state
-             for state in self.identity_map.all_states()
-             if state.check_modified()])
+        return self.identity_map._dirty_states()
 
     @property
     def dirty(self):
@@ -1528,7 +1534,7 @@ class Session(object):
 
         return util.IdentitySet(self._new.values())
 
-_expire_state = attributes.InstanceState.expire_attributes
+_expire_state = state.InstanceState.expire_attributes
     
 UOWEventHandler = unitofwork.UOWEventHandler
 
@@ -1548,16 +1554,19 @@ def _cascade_unknown_state_iterator(cascade, state, **kwargs):
         yield _state_for_unknown_persistence_instance(o), m
 
 def _state_for_unsaved_instance(instance, create=False):
-    manager = attributes.manager_of_class(instance.__class__)
-    if manager is None:
+    try:
+        state = attributes.instance_state(instance)
+    except AttributeError:
         raise exc.UnmappedInstanceError(instance)
-    if manager.has_state(instance):
-        state = manager.state_of(instance)
+    if state:
         if state.key is not None:
             raise sa_exc.InvalidRequestError(
                 "Instance '%s' is already persistent" %
                 mapperutil.state_str(state))
     elif create:
+        manager = attributes.manager_of_class(instance.__class__)
+        if manager is None:
+            raise exc.UnmappedInstanceError(instance)
         state = manager.setup_instance(instance)
     else:
         raise exc.UnmappedInstanceError(instance)
diff --git a/lib/sqlalchemy/orm/state.py b/lib/sqlalchemy/orm/state.py
new file mode 100644 (file)
index 0000000..1b73a1b
--- /dev/null
@@ -0,0 +1,441 @@
+from sqlalchemy.util import EMPTY_SET
+import weakref
+from sqlalchemy import util
+from sqlalchemy.orm.attributes import PASSIVE_NORESULT, PASSIVE_OFF, NEVER_SET, NO_VALUE, manager_of_class, ATTR_WAS_SET
+from sqlalchemy.orm import attributes
+from sqlalchemy.orm import interfaces
+
+class InstanceState(object):
+    """tracks state information at the instance level."""
+
+    session_id = None
+    key = None
+    runid = None
+    expired_attributes = EMPTY_SET
+    load_options = EMPTY_SET
+    load_path = ()
+    insert_order = None
+    mutable_dict = None
+    
+    def __init__(self, obj, manager):
+        self.class_ = obj.__class__
+        self.manager = manager
+        self.obj = weakref.ref(obj, self._cleanup)
+        self.modified = False
+        self.callables = {}
+        self.expired = False
+        self.committed_state = {}
+        self.pending = {}
+        self.parents = {}
+        
+    def detach(self):
+        if self.session_id:
+            del self.session_id
+
+    def dispose(self):
+        if self.session_id:
+            del self.session_id
+        del self.obj
+    
+    def _cleanup(self, ref):
+        instance_dict = self._instance_dict()
+        if instance_dict:
+            instance_dict.remove(self)
+        self.dispose()
+    
+    def obj(self):
+        return None
+    
+    @property
+    def dict(self):
+        o = self.obj()
+        if o is not None:
+            return attributes.instance_dict(o)
+        else:
+            return {}
+        
+    @property
+    def sort_key(self):
+        return self.key and self.key[1] or (self.insert_order, )
+
+    def check_modified(self):
+        # TODO: deprecate
+        return self.modified
+
+    def initialize_instance(*mixed, **kwargs):
+        self, instance, args = mixed[0], mixed[1], mixed[2:]
+        manager = self.manager
+
+        for fn in manager.events.on_init:
+            fn(self, instance, args, kwargs)
+            
+        # LESSTHANIDEAL:
+        # adjust for the case where the InstanceState was created before
+        # mapper compilation, and this actually needs to be a MutableAttrInstanceState
+        if manager.mutable_attributes and self.__class__ is not MutableAttrInstanceState:
+            self.__class__ = MutableAttrInstanceState
+            self.obj = weakref.ref(self.obj(), self._cleanup)
+            self.mutable_dict = {}
+            
+        try:
+            return manager.events.original_init(*mixed[1:], **kwargs)
+        except:
+            for fn in manager.events.on_init_failure:
+                fn(self, instance, args, kwargs)
+            raise
+
+    def get_history(self, key, **kwargs):
+        return self.manager.get_impl(key).get_history(self, self.dict, **kwargs)
+
+    def get_impl(self, key):
+        return self.manager.get_impl(key)
+
+    def get_pending(self, key):
+        if key not in self.pending:
+            self.pending[key] = PendingCollection()
+        return self.pending[key]
+
+    def value_as_iterable(self, key, passive=PASSIVE_OFF):
+        """return an InstanceState attribute as a list,
+        regardless of it being a scalar or collection-based
+        attribute.
+
+        returns None if passive is not PASSIVE_OFF and the getter returns
+        PASSIVE_NORESULT.
+        """
+
+        impl = self.get_impl(key)
+        dict_ = self.dict
+        x = impl.get(self, dict_, passive=passive)
+        if x is PASSIVE_NORESULT:
+            return None
+        elif hasattr(impl, 'get_collection'):
+            return impl.get_collection(self, dict_, x, passive=passive)
+        elif isinstance(x, list):
+            return x
+        else:
+            return [x]
+
+    def _run_on_load(self, instance):
+        self.manager.events.run('on_load', instance)
+
+    def __getstate__(self):
+        return {'key': self.key,
+                'committed_state': self.committed_state,
+                'pending': self.pending,
+                'parents': self.parents,
+                'modified': self.modified,
+                'expired':self.expired,
+                'load_options':self.load_options,
+                'load_path':interfaces.serialize_path(self.load_path),
+                'instance': self.obj(),
+                'expired_attributes':self.expired_attributes,
+                'callables': self.callables}
+
+    def __setstate__(self, state):
+        self.committed_state = state['committed_state']
+        self.parents = state['parents']
+        self.key = state['key']
+        self.session_id = None
+        self.pending = state['pending']
+        self.modified = state['modified']
+        self.obj = weakref.ref(state['instance'])
+        self.load_options = state['load_options'] or EMPTY_SET
+        self.load_path = interfaces.deserialize_path(state['load_path'])
+        self.class_ = self.obj().__class__
+        self.manager = manager_of_class(self.class_)
+        self.callables = state['callables']
+        self.runid = None
+        self.expired = state['expired']
+        self.expired_attributes = state['expired_attributes']
+
+    def initialize(self, key):
+        self.manager.get_impl(key).initialize(self, self.dict)
+
+    def set_callable(self, key, callable_):
+        self.dict.pop(key, None)
+        self.callables[key] = callable_
+
+    def __call__(self):
+        """__call__ allows the InstanceState to act as a deferred
+        callable for loading expired attributes, which is also
+        serializable (picklable).
+
+        """
+        unmodified = self.unmodified
+        class_manager = self.manager
+        class_manager.deferred_scalar_loader(self, [
+            attr.impl.key for attr in class_manager.attributes if
+                attr.impl.accepts_scalar_loader and
+                attr.impl.key in self.expired_attributes and
+                attr.impl.key in unmodified
+            ])
+        for k in self.expired_attributes:
+            self.callables.pop(k, None)
+        del self.expired_attributes
+        return ATTR_WAS_SET
+
+    @property
+    def unmodified(self):
+        """a set of keys which have no uncommitted changes"""
+        
+        return set(self.manager).difference(self.committed_state)
+
+    @property
+    def unloaded(self):
+        """a set of keys which do not have a loaded value.
+
+        This includes expired attributes and any other attribute that
+        was never populated or modified.
+
+        """
+        return set(
+            key for key in self.manager.iterkeys()
+            if key not in self.committed_state and key not in self.dict)
+
+    def expire_attributes(self, attribute_names, instance_dict=None):
+        self.expired_attributes = set(self.expired_attributes)
+
+        if attribute_names is None:
+            attribute_names = self.manager.keys()
+            self.expired = True
+            if self.modified:
+                if not instance_dict:
+                    instance_dict = self._instance_dict()
+                    if instance_dict:
+                        instance_dict._modified.discard(self)
+                else:
+                    instance_dict._modified.discard(self)
+                    
+            self.modified = False
+            filter_deferred = True
+        else:
+            filter_deferred = False
+        dict_ = self.dict
+        
+        for key in attribute_names:
+            impl = self.manager[key].impl
+            if not filter_deferred or \
+                not impl.dont_expire_missing or \
+                key in dict_:
+                self.expired_attributes.add(key)
+                if impl.accepts_scalar_loader:
+                    self.callables[key] = self
+            dict_.pop(key, None)
+            self.pending.pop(key, None)
+            self.committed_state.pop(key, None)
+            if self.mutable_dict:
+                self.mutable_dict.pop(key, None)
+                
+    def reset(self, key, dict_):
+        """remove the given attribute and any callables associated with it."""
+
+        dict_.pop(key, None)
+        self.callables.pop(key, None)
+
+    def _instance_dict(self):
+        return None
+
+    def _is_really_none(self):
+        return self.obj()
+        
+    def modified_event(self, dict_, attr, should_copy, previous, passive=PASSIVE_OFF):
+        needs_committed = attr.key not in self.committed_state
+
+        if needs_committed:
+            if previous is NEVER_SET:
+                if passive:
+                    if attr.key in dict_:
+                        previous = dict_[attr.key]
+                else:
+                    previous = attr.get(self, dict_)
+
+            if should_copy and previous not in (None, NO_VALUE, NEVER_SET):
+                previous = attr.copy(previous)
+
+            if needs_committed:
+                self.committed_state[attr.key] = previous
+
+        if not self.modified:
+            instance_dict = self._instance_dict()
+            if instance_dict:
+                instance_dict._modified.add(self)
+
+        self.modified = True
+        self._strong_obj = self.obj()
+
+    def commit(self, dict_, keys):
+        """Commit attributes.
+
+        This is used by a partial-attribute load operation to mark committed
+        those attributes which were refreshed from the database.
+
+        Attributes marked as "expired" can potentially remain "expired" after
+        this step if a value was not populated in state.dict.
+
+        """
+        class_manager = self.manager
+        for key in keys:
+            if key in dict_ and key in class_manager.mutable_attributes:
+                class_manager[key].impl.commit_to_state(self, dict_, self.committed_state)
+            else:
+                self.committed_state.pop(key, None)
+
+        self.expired = False
+        # unexpire attributes which have loaded
+        for key in self.expired_attributes.intersection(keys):
+            if key in dict_:
+                self.expired_attributes.remove(key)
+                self.callables.pop(key, None)
+
+    def commit_all(self, dict_, instance_dict=None):
+        """commit all attributes unconditionally.
+
+        This is used after a flush() or a full load/refresh
+        to remove all pending state from the instance.
+
+         - all attributes are marked as "committed"
+         - the "strong dirty reference" is removed
+         - the "modified" flag is set to False
+         - any "expired" markers/callables are removed.
+
+        Attributes marked as "expired" can potentially remain "expired" after this step
+        if a value was not populated in state.dict.
+
+        """
+        
+        self.committed_state = {}
+        self.pending = {}
+        
+        # unexpire attributes which have loaded
+        if self.expired_attributes:
+            for key in self.expired_attributes.intersection(dict_):
+                self.callables.pop(key, None)
+            self.expired_attributes.difference_update(dict_)
+
+        for key in self.manager.mutable_attributes:
+            if key in dict_:
+                self.manager[key].impl.commit_to_state(self, dict_, self.committed_state)
+
+        if instance_dict and self.modified:
+            instance_dict._modified.discard(self)
+
+        self.modified = self.expired = False
+        self._strong_obj = None
+
+class MutableAttrInstanceState(InstanceState):
+    def __init__(self, obj, manager):
+        self.mutable_dict = {}
+        InstanceState.__init__(self, obj, manager)
+        
+    def _get_modified(self, dict_=None):
+        if self.__dict__.get('modified', False):
+            return True
+        else:
+            if dict_ is None:
+                dict_ = self.dict
+            for key in self.manager.mutable_attributes:
+                if self.manager[key].impl.check_mutable_modified(self, dict_):
+                    return True
+            else:
+                return False
+    
+    def _set_modified(self, value):
+        self.__dict__['modified'] = value
+        
+    modified = property(_get_modified, _set_modified)
+    
+    @property
+    def unmodified(self):
+        """a set of keys which have no uncommitted changes"""
+
+        dict_ = self.dict
+        return set(
+            key for key in self.manager.iterkeys()
+            if (key not in self.committed_state or
+                (key in self.manager.mutable_attributes and
+                 not self.manager[key].impl.check_mutable_modified(self, dict_))))
+
+    def _is_really_none(self):
+        """do a check modified/resurrect.
+        
+        This would be called in the extremely rare
+        race condition that the weakref returned None but
+        the cleanup handler had not yet established the 
+        __resurrect callable as its replacement.
+        
+        """
+        if self.modified:
+            self.obj = self.__resurrect
+            return self.obj()
+        else:
+            return None
+
+    def reset(self, key, dict_):
+        self.mutable_dict.pop(key, None)
+        InstanceState.reset(self, key, dict_)
+    
+    def _cleanup(self, ref):
+        """weakref callback.
+        
+        This method may be called by an asynchronous
+        gc.
+        
+        If the state shows pending changes, the weakref
+        is replaced by the __resurrect callable which will
+        re-establish an object reference on next access,
+        else removes this InstanceState from the owning
+        identity map, if any.
+        
+        """
+        if self._get_modified(self.mutable_dict):
+            self.obj = self.__resurrect
+        else:
+            instance_dict = self._instance_dict()
+            if instance_dict:
+                instance_dict.remove(self)
+            self.dispose()
+            
+    def __resurrect(self):
+        """A substitute for the obj() weakref function which resurrects."""
+        
+        # store strong ref'ed version of the object; will revert
+        # to weakref when changes are persisted
+        
+        obj = self.manager.new_instance(state=self)
+        self.obj = weakref.ref(obj, self._cleanup)
+        self._strong_obj = obj
+        obj.__dict__.update(self.mutable_dict)
+
+        # re-establishes identity attributes from the key
+        self.manager.events.run('on_resurrect', self, obj)
+        
+        # TODO: don't really think we should run this here.
+        # resurrect is only meant to preserve the minimal state needed to
+        # do an UPDATE, not to produce a fully usable object
+        self._run_on_load(obj)
+        
+        return obj
+
+class PendingCollection(object):
+    """A writable placeholder for an unloaded collection.
+
+    Stores items appended to and removed from a collection that has not yet
+    been loaded. When the collection is loaded, the changes stored in
+    PendingCollection are applied to it to produce the final result.
+
+    """
+    def __init__(self):
+        self.deleted_items = util.IdentitySet()
+        self.added_items = util.OrderedIdentitySet()
+
+    def append(self, value):
+        if value in self.deleted_items:
+            self.deleted_items.remove(value)
+        self.added_items.add(value)
+
+    def remove(self, value):
+        if value in self.added_items:
+            self.added_items.remove(value)
+        self.deleted_items.add(value)
+
index 1aeb311e1c99e573fcd2711918b0fc9377ba5859..20cbb8f4dcdeb09300203d25cbd6779626a297f2 100644 (file)
@@ -115,8 +115,8 @@ class ColumnLoader(LoaderStrategy):
         if adapter:
             col = adapter.columns[col]
         if col in row:
-            def new_execute(state, row, **flags):
-                state.dict[key] = row[col]
+            def new_execute(state, dict_, row, **flags):
+                dict_[key] = row[col]
                 
             if self._should_log_debug:
                 new_execute = self.debug_callable(new_execute, self.logger,
@@ -125,7 +125,7 @@ class ColumnLoader(LoaderStrategy):
                 )
             return (new_execute, None)
         else:
-            def new_execute(state, row, isnew, **flags):
+            def new_execute(state, dict_, row, isnew, **flags):
                 if isnew:
                     state.expire_attributes([key])
             if self._should_log_debug:
@@ -171,15 +171,15 @@ class CompositeColumnLoader(ColumnLoader):
             columns = [adapter.columns[c] for c in columns]
         for c in columns:
             if c not in row:
-                def new_execute(state, row, isnew, **flags):
+                def new_execute(state, dict_, row, isnew, **flags):
                     if isnew:
                         state.expire_attributes([key])
                 if self._should_log_debug:
                     self.logger.debug("%s deferring load" % self)
                 return (new_execute, None)
         else:
-            def new_execute(state, row, **flags):
-                state.dict[key] = composite_class(*[row[c] for c in columns])
+            def new_execute(state, dict_, row, **flags):
+                dict_[key] = composite_class(*[row[c] for c in columns])
 
             if self._should_log_debug:
                 new_execute = self.debug_callable(new_execute, self.logger,
@@ -202,13 +202,13 @@ class DeferredColumnLoader(LoaderStrategy):
             return self.parent_property._get_strategy(ColumnLoader).create_row_processor(selectcontext, path, mapper, row, adapter)
 
         elif not self.is_class_level:
-            def new_execute(state, row, **flags):
+            def new_execute(state, dict_, row, **flags):
                 state.set_callable(self.key, LoadDeferredColumns(state, self.key))
         else:
-            def new_execute(state, row, **flags):
+            def new_execute(state, dict_, row, **flags):
                 # reset state on the key so that deferred callables
                 # fire off on next access.
-                state.reset(self.key)
+                state.reset(self.key, dict_)
 
         if self._should_log_debug:
             new_execute = self.debug_callable(new_execute, self.logger, None,
@@ -340,7 +340,7 @@ class NoLoader(AbstractRelationLoader):
         )
 
     def create_row_processor(self, selectcontext, path, mapper, row, adapter):
-        def new_execute(state, row, **flags):
+        def new_execute(state, dict_, row, **flags):
             self._init_instance_attribute(state)
 
         if self._should_log_debug:
@@ -437,7 +437,7 @@ class LazyLoader(AbstractRelationLoader):
 
     def create_row_processor(self, selectcontext, path, mapper, row, adapter):
         if not self.is_class_level:
-            def new_execute(state, row, **flags):
+            def new_execute(state, dict_, row, **flags):
                 # we are not the primary manager for this attribute on this class - set up a per-instance lazyloader,
                 # which will override the class-level behavior.
                 # this currently only happens when using a "lazyload" option on a "no load" attribute -
@@ -451,11 +451,11 @@ class LazyLoader(AbstractRelationLoader):
 
             return (new_execute, None)
         else:
-            def new_execute(state, row, **flags):
+            def new_execute(state, dict_, row, **flags):
                 # we are the primary manager for this attribute on this class - reset its per-instance attribute state, 
                 # so that the class-level lazy loader is executed when next referenced on this instance.
                 # this is needed in populate_existing() types of scenarios to reset any existing state.
-                state.reset(self.key)
+                state.reset(self.key, dict_)
 
             if self._should_log_debug:
                 new_execute = self.debug_callable(new_execute, self.logger, None,
@@ -735,24 +735,24 @@ class EagerLoader(AbstractRelationLoader):
             _instance = self.mapper._instance_processor(context, path + (self.mapper.base_mapper,), eager_adapter)
             
             if not self.uselist:
-                def execute(state, row, isnew, **flags):
+                def execute(state, dict_, row, isnew, **flags):
                     if isnew:
                         # set a scalar object instance directly on the
                         # parent object, bypassing InstrumentedAttribute
                         # event handlers.
-                        state.dict[key] = _instance(row, None)
+                        dict_[key] = _instance(row, None)
                     else:
                         # call _instance on the row, even though the object has been created,
                         # so that we further descend into properties
                         _instance(row, None)
             else:
-                def execute(state, row, isnew, **flags):
+                def execute(state, dict_, row, isnew, **flags):
                     if isnew or (state, key) not in context.attributes:
                         # appender_key can be absent from context.attributes with isnew=False
                         # when self-referential eager loading is used; the same instance may be present
                         # in two distinct sets of result columns
 
-                        collection = attributes.init_state_collection(state, key)
+                        collection = attributes.init_state_collection(state, dict_, key)
                         appender = util.UniqueAppender(collection, 'append_without_event')
 
                         context.attributes[(state, key)] = appender
index dd979e1a808e7bc698cf903a217ba47131abb1e2..c12f17aff5e9c5f971ae7e33d490e702904fa137 100644 (file)
@@ -50,26 +50,18 @@ def populate_dict(source, source_mapper, dict_, synchronize_pairs):
 
         dict_[r.key] = value
 
-def source_changes(uowcommit, source, source_mapper, synchronize_pairs):
+def source_modified(uowcommit, source, source_mapper, synchronize_pairs):
+    """return true if the source object has changes from an old to a new value on the given
+    synchronize pairs
+    
+    """
     for l, r in synchronize_pairs:
         try:
             prop = source_mapper._get_col_to_prop(l)
         except exc.UnmappedColumnError:
             _raise_col_to_prop(False, source_mapper, l, None, r)
         history = uowcommit.get_attribute_history(source, prop.key, passive=True)
-        if history.has_changes():
-            return True
-    else:
-        return False
-
-def dest_changes(uowcommit, dest, dest_mapper, synchronize_pairs):
-    for l, r in synchronize_pairs:
-        try:
-            prop = dest_mapper._get_col_to_prop(r)
-        except exc.UnmappedColumnError:
-            _raise_col_to_prop(True, None, l, dest_mapper, r)
-        history = uowcommit.get_attribute_history(dest, prop.key, passive=True)
-        if history.has_changes():
+        if len(history.deleted):
             return True
     else:
         return False
index 4ac9c765e03bb9643f2e1c09f9e7b26f7c8d02c3..da26c8d7b38f464a5017eb09e790167b9a7cd6e2 100644 (file)
@@ -96,6 +96,8 @@ class UOWTransaction(object):
         # information.
         self.attributes = {}
         
+        self.processors = set()
+        
     def get_attribute_history(self, state, key, passive=True):
         hashkey = ("history", state, key)
 
@@ -119,6 +121,7 @@ class UOWTransaction(object):
             return history.as_state()
 
     def register_object(self, state, isdelete=False, listonly=False, postupdate=False, post_update_cols=None):
+        
         # if object is not in the overall session, do nothing
         if not self.session._contains_state(state):
             if self._should_log_debug:
@@ -136,6 +139,16 @@ class UOWTransaction(object):
         else:
             task.append(state, listonly=listonly, isdelete=isdelete)
 
+        # ensure the mapper for this object has had its 
+        # DependencyProcessors added.
+        if mapper not in self.processors:
+            mapper._register_processors(self)
+            self.processors.add(mapper)
+
+            if mapper.base_mapper not in self.processors:
+                mapper.base_mapper._register_processors(self)
+                self.processors.add(mapper.base_mapper)
+            
     def set_row_switch(self, state):
         """mark a deleted object as a 'row switch'.
 
@@ -147,7 +160,7 @@ class UOWTransaction(object):
         task = self.get_task_by_mapper(mapper)
         taskelement = task._objects[state]
         taskelement.isdelete = "rowswitch"
-
+    
     def is_deleted(self, state):
         """return true if the given state is marked as deleted within this UOWTransaction."""
 
@@ -201,9 +214,9 @@ class UOWTransaction(object):
         self.dependencies.add((mapper, dependency))
 
     def register_processor(self, mapper, processor, mapperfrom):
-        """register a dependency processor, corresponding to dependencies between
-        the two given mappers.
-
+        """register a dependency processor, corresponding to 
+        operations which occur between two mappers.
+        
         """
         # correct for primary mapper
         mapper = mapper.primary_mapper()
index 4ecc7a06786faf008a5d1446202eb56342af2c23..3fd95642e6b1154f678961804524ed410085c0ee 100644 (file)
@@ -1608,7 +1608,7 @@ class ColumnElement(ClauseElement, _CompareMixin):
     def shares_lineage(self, othercolumn):
         """Return True if the given ``ColumnElement`` has a common ancestor to this ``ColumnElement``."""
 
-        return len(self.proxy_set.intersection(othercolumn.proxy_set)) > 0
+        return bool(self.proxy_set.intersection(othercolumn.proxy_set))
 
     def _make_proxy(self, selectable, name=None):
         """Create a new ``ColumnElement`` representing this
index 36357faf505c5a171bb0a3896b6fd23737070c36..f1f329b5e27a31b60b52bd022b2ab69bedb903e2 100644 (file)
@@ -343,14 +343,14 @@ def criterion_as_pairs(expression, consider_as_foreign_keys=None, consider_as_re
             return
 
         if consider_as_foreign_keys:
-            if binary.left in consider_as_foreign_keys:
+            if binary.left in consider_as_foreign_keys and (binary.right is binary.left or binary.right not in consider_as_foreign_keys):
                 pairs.append((binary.right, binary.left))
-            elif binary.right in consider_as_foreign_keys:
+            elif binary.right in consider_as_foreign_keys and (binary.left is binary.right or binary.left not in consider_as_foreign_keys):
                 pairs.append((binary.left, binary.right))
         elif consider_as_referenced_keys:
-            if binary.left in consider_as_referenced_keys:
+            if binary.left in consider_as_referenced_keys and (binary.right is binary.left or binary.right not in consider_as_referenced_keys):
                 pairs.append((binary.left, binary.right))
-            elif binary.right in consider_as_referenced_keys:
+            elif binary.right in consider_as_referenced_keys and (binary.left is binary.right or binary.left not in consider_as_referenced_keys):
                 pairs.append((binary.right, binary.left))
         else:
             if isinstance(binary.left, schema.Column) and isinstance(binary.right, schema.Column):
index f11461949764985bea1069f89ca5607a6d405c18..23c038955001fad16eee9955f9db73c196d6665a 100644 (file)
@@ -11,6 +11,28 @@ from testlib import *
 class TestTypes(TestBase, AssertsExecutionResults):
     __only_on__ = 'sqlite'
 
+    def test_boolean(self):
+        """Test that the boolean only treats 1 as True
+
+        """
+
+        meta = MetaData(testing.db)
+        t = Table('bool_table', meta,
+                  Column('id', Integer, primary_key=True),
+                  Column('boo', sqlite.SLBoolean))
+
+        try:
+            meta.create_all()
+            testing.db.execute("INSERT INTO bool_table (id, boo) VALUES (1, 'false');")
+            testing.db.execute("INSERT INTO bool_table (id, boo) VALUES (2, 'true');")
+            testing.db.execute("INSERT INTO bool_table (id, boo) VALUES (3, '1');")
+            testing.db.execute("INSERT INTO bool_table (id, boo) VALUES (4, '0');")
+            testing.db.execute("INSERT INTO bool_table (id, boo) VALUES (5, 1);")
+            testing.db.execute("INSERT INTO bool_table (id, boo) VALUES (6, 0);")
+            assert t.select(t.c.boo).order_by(t.c.id).execute().fetchall() == [(3, True,), (5, True,)]
+        finally:
+            meta.drop_all()
+
     def test_string_dates_raise(self):
         self.assertRaises(TypeError, testing.db.execute, select([1]).where(bindparam("date", type_=Date)), date=str(datetime.date(2007, 10, 30)))
     
index 46d944cbc3a4c2e1a42ed93bdd0e33d8d0d7dcf5..0f15d5136f77d24a4e5d0c44a9eebf05b7288dea 100644 (file)
@@ -38,7 +38,7 @@ class AttributesTest(_base.ORMTest):
         u.email_address = 'lala@123.com'
 
         self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com')
-        attributes.instance_state(u).commit_all()
+        attributes.instance_state(u).commit_all(attributes.instance_dict(u))
         self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com')
 
         u.user_name = 'heythere'
@@ -158,7 +158,7 @@ class AttributesTest(_base.ORMTest):
         eq_(f.a, None)
         eq_(f.b, 12)
 
-        attributes.instance_state(f).commit_all()
+        attributes.instance_state(f).commit_all(attributes.instance_dict(f))
         eq_(f.a, None)
         eq_(f.b, 12)
 
@@ -205,7 +205,7 @@ class AttributesTest(_base.ORMTest):
         u.addresses.append(a)
 
         self.assert_(u.user_id == 7 and u.user_name == 'john' and u.addresses[0].email_address == 'lala@123.com')
-        u, attributes.instance_state(a).commit_all()
+        u, attributes.instance_state(a).commit_all(attributes.instance_dict(a))
         self.assert_(u.user_id == 7 and u.user_name == 'john' and u.addresses[0].email_address == 'lala@123.com')
 
         u.user_name = 'heythere'
@@ -272,7 +272,7 @@ class AttributesTest(_base.ORMTest):
         p1 = Post()
         attributes.instance_state(b).set_callable('posts', lambda:[p1])
         attributes.instance_state(p1).set_callable('blog', lambda:b)
-        p1, attributes.instance_state(b).commit_all()
+        p1, attributes.instance_state(b).commit_all(attributes.instance_dict(b))
 
         # no orphans (called before the lazy loaders fire off)
         assert attributes.has_parent(Blog, p1, 'posts', optimistic=True)
@@ -353,7 +353,7 @@ class AttributesTest(_base.ORMTest):
         x = Bar()
         x.element = el
         eq_(attributes.get_history(attributes.instance_state(x), 'element'), ([el], (), ()))
-        attributes.instance_state(x).commit_all()
+        attributes.instance_state(x).commit_all(attributes.instance_dict(x))
 
         (added, unchanged, deleted) = attributes.get_history(attributes.instance_state(x), 'element')
         assert added == ()
@@ -381,7 +381,7 @@ class AttributesTest(_base.ORMTest):
         attributes.register_attribute(Bar, 'id', uselist=False, useobject=True)
 
         x = Foo()
-        attributes.instance_state(x).commit_all()
+        attributes.instance_state(x).commit_all(attributes.instance_dict(x))
         x.col2.append(bar4)
         eq_(attributes.get_history(attributes.instance_state(x), 'col2'), ([bar4], [bar1, bar2, bar3], []))
 
@@ -427,7 +427,7 @@ class AttributesTest(_base.ORMTest):
         attributes.register_attribute(Foo, 'element', uselist=False, copy_function=lambda x:[y for y in x], mutable_scalars=True, useobject=False)
         x = Foo()
         x.element = ['one', 'two', 'three']
-        attributes.instance_state(x).commit_all()
+        attributes.instance_state(x).commit_all(attributes.instance_dict(x))
         x.element[1] = 'five'
         assert attributes.instance_state(x).check_modified()
 
@@ -437,7 +437,7 @@ class AttributesTest(_base.ORMTest):
         attributes.register_attribute(Foo, 'element', uselist=False, useobject=False)
         x = Foo()
         x.element = ['one', 'two', 'three']
-        attributes.instance_state(x).commit_all()
+        attributes.instance_state(x).commit_all(attributes.instance_dict(x))
         x.element[1] = 'five'
         assert not attributes.instance_state(x).check_modified()
 
@@ -699,8 +699,8 @@ class PendingBackrefTest(_base.ORMTest):
 
         b = Blog("blog 1")
         p1.blog = b
-        attributes.instance_state(b).commit_all()
-        attributes.instance_state(p1).commit_all()
+        attributes.instance_state(b).commit_all(attributes.instance_dict(b))
+        attributes.instance_state(p1).commit_all(attributes.instance_dict(p1))
         assert b.posts == [Post("post 1")]
 
 class HistoryTest(_base.ORMTest):
@@ -713,17 +713,17 @@ class HistoryTest(_base.ORMTest):
         attributes.register_attribute(Foo, 'someattr', uselist=False, useobject=False)
 
         f = Foo()
-        eq_(Foo.someattr.impl.get_committed_value(attributes.instance_state(f)), None)
+        eq_(Foo.someattr.impl.get_committed_value(attributes.instance_state(f), attributes.instance_dict(f)), None)
 
         f.someattr = 3
-        eq_(Foo.someattr.impl.get_committed_value(attributes.instance_state(f)), None)
+        eq_(Foo.someattr.impl.get_committed_value(attributes.instance_state(f), attributes.instance_dict(f)), None)
 
         f = Foo()
         f.someattr = 3
-        eq_(Foo.someattr.impl.get_committed_value(attributes.instance_state(f)), None)
+        eq_(Foo.someattr.impl.get_committed_value(attributes.instance_state(f), attributes.instance_dict(f)), None)
         
-        attributes.instance_state(f).commit(['someattr'])
-        eq_(Foo.someattr.impl.get_committed_value(attributes.instance_state(f)), 3)
+        attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr'])
+        eq_(Foo.someattr.impl.get_committed_value(attributes.instance_state(f), attributes.instance_dict(f)), 3)
 
     def test_scalar(self):
         class Foo(_base.BasicEntity):
@@ -739,13 +739,13 @@ class HistoryTest(_base.ORMTest):
         f.someattr = "hi"
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), (['hi'], (), ()))
 
-        attributes.instance_state(f).commit(['someattr'])
+        attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr'])
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), ['hi'], ()))
 
         f.someattr = 'there'
 
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), (['there'], (), ['hi']))
-        attributes.instance_state(f).commit(['someattr'])
+        attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr'])
 
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), ['there'], ()))
 
@@ -760,7 +760,7 @@ class HistoryTest(_base.ORMTest):
         f.someattr = 'old'
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), (['old'], (), ['new']))
 
-        attributes.instance_state(f).commit(['someattr'])
+        attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr'])
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), ['old'], ()))
 
         # setting None on uninitialized is currently a change for a scalar attribute
@@ -778,7 +778,7 @@ class HistoryTest(_base.ORMTest):
 
         # set same value twice
         f = Foo()
-        attributes.instance_state(f).commit(['someattr'])
+        attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr'])
         f.someattr = 'one'
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), (['one'], (), ()))
         f.someattr = 'two'
@@ -799,7 +799,7 @@ class HistoryTest(_base.ORMTest):
         f.someattr = {'foo':'hi'}
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([{'foo':'hi'}], (), ()))
 
-        attributes.instance_state(f).commit(['someattr'])
+        attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr'])
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [{'foo':'hi'}], ()))
         eq_(attributes.instance_state(f).committed_state['someattr'], {'foo':'hi'})
 
@@ -807,7 +807,7 @@ class HistoryTest(_base.ORMTest):
         eq_(attributes.instance_state(f).committed_state['someattr'], {'foo':'hi'})
 
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([{'foo':'there'}], (), [{'foo':'hi'}]))
-        attributes.instance_state(f).commit(['someattr'])
+        attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr'])
 
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [{'foo':'there'}], ()))
 
@@ -819,7 +819,7 @@ class HistoryTest(_base.ORMTest):
         f.someattr = {'foo':'old'}
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([{'foo':'old'}], (), [{'foo':'new'}]))
 
-        attributes.instance_state(f).commit(['someattr'])
+        attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr'])
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [{'foo':'old'}], ()))
 
 
@@ -847,13 +847,13 @@ class HistoryTest(_base.ORMTest):
         f.someattr = hi
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([hi], (), ()))
 
-        attributes.instance_state(f).commit(['someattr'])
+        attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr'])
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [hi], ()))
 
         f.someattr = there
 
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([there], (), [hi]))
-        attributes.instance_state(f).commit(['someattr'])
+        attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr'])
 
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [there], ()))
 
@@ -868,7 +868,7 @@ class HistoryTest(_base.ORMTest):
         f.someattr = old
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([old], (), ['new']))
 
-        attributes.instance_state(f).commit(['someattr'])
+        attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr'])
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [old], ()))
 
         # setting None on uninitialized is currently not a change for an object attribute
@@ -887,7 +887,7 @@ class HistoryTest(_base.ORMTest):
 
         # set same value twice
         f = Foo()
-        attributes.instance_state(f).commit(['someattr'])
+        attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr'])
         f.someattr = 'one'
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), (['one'], (), ()))
         f.someattr = 'two'
@@ -915,13 +915,13 @@ class HistoryTest(_base.ORMTest):
         f.someattr = [hi]
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([hi], [], []))
 
-        attributes.instance_state(f).commit(['someattr'])
+        attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr'])
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [hi], ()))
 
         f.someattr = [there]
 
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([there], [], [hi]))
-        attributes.instance_state(f).commit(['someattr'])
+        attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr'])
 
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [there], ()))
 
@@ -935,13 +935,13 @@ class HistoryTest(_base.ORMTest):
         f = Foo()
         collection = attributes.init_collection(attributes.instance_state(f), 'someattr')
         collection.append_without_event(new)
-        attributes.instance_state(f).commit_all()
+        attributes.instance_state(f).commit_all(attributes.instance_dict(f))
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [new], ()))
 
         f.someattr = [old]
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([old], [], [new]))
 
-        attributes.instance_state(f).commit(['someattr'])
+        attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr'])
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [old], ()))
 
     def test_dict_collections(self):
@@ -969,7 +969,7 @@ class HistoryTest(_base.ORMTest):
         f.someattr['there'] = there
         eq_(tuple([set(x) for x in attributes.get_history(attributes.instance_state(f), 'someattr')]), (set([hi, there]), set(), set()))
 
-        attributes.instance_state(f).commit(['someattr'])
+        attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr'])
         eq_(tuple([set(x) for x in attributes.get_history(attributes.instance_state(f), 'someattr')]), (set(), set([hi, there]), set()))
 
     def test_object_collections_mutate(self):
@@ -994,13 +994,13 @@ class HistoryTest(_base.ORMTest):
         f.someattr.append(hi)
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([hi], [], []))
 
-        attributes.instance_state(f).commit(['someattr'])
+        attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr'])
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [hi], ()))
 
         f.someattr.append(there)
 
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([there], [hi], []))
-        attributes.instance_state(f).commit(['someattr'])
+        attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr'])
 
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [hi, there], ()))
 
@@ -1010,7 +1010,7 @@ class HistoryTest(_base.ORMTest):
         f.someattr.append(old)
         f.someattr.append(new)
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([old, new], [hi], [there]))
-        attributes.instance_state(f).commit(['someattr'])
+        attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr'])
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [hi, old, new], ()))
 
         f.someattr.pop(0)
@@ -1021,19 +1021,19 @@ class HistoryTest(_base.ORMTest):
         f.__dict__['id'] = 1
         collection = attributes.init_collection(attributes.instance_state(f), 'someattr')
         collection.append_without_event(new)
-        attributes.instance_state(f).commit_all()
+        attributes.instance_state(f).commit_all(attributes.instance_dict(f))
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [new], ()))
 
         f.someattr.append(old)
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([old], [new], []))
 
-        attributes.instance_state(f).commit(['someattr'])
+        attributes.instance_state(f).commit(attributes.instance_dict(f), ['someattr'])
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [new, old], ()))
 
         f = Foo()
         collection = attributes.init_collection(attributes.instance_state(f), 'someattr')
         collection.append_without_event(new)
-        attributes.instance_state(f).commit_all()
+        attributes.instance_state(f).commit_all(attributes.instance_dict(f))
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [new], ()))
 
         f.id = 1
@@ -1056,7 +1056,7 @@ class HistoryTest(_base.ORMTest):
         f.someattr.append(hi)
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ([hi, there, hi], [], []))
 
-        attributes.instance_state(f).commit_all()
+        attributes.instance_state(f).commit_all(attributes.instance_dict(f))
         eq_(attributes.get_history(attributes.instance_state(f), 'someattr'), ((), [hi, there, hi], ()))
         
         f.someattr = []
index 69164ebafb41f638ca6b60fa4884c918b31e840e..aec6c181f26c6af1c72828b566f6268b33478892 100644 (file)
@@ -117,7 +117,7 @@ class UserDefinedExtensionTest(_base.ORMTest):
         u.user_id = 7
         u.user_name = 'john'
         u.email_address = 'lala@123.com'
-        self.assert_(u.__dict__ == {'_my_state':u._my_state, '_goofy_dict':{'user_id':7, 'user_name':'john', 'email_address':'lala@123.com'}})
+        self.assert_(u.__dict__ == {'_my_state':u._my_state, '_goofy_dict':{'user_id':7, 'user_name':'john', 'email_address':'lala@123.com'}}, u.__dict__)
         
     def test_basic(self):
         for base in (object, MyBaseClass, MyClass):
@@ -135,7 +135,7 @@ class UserDefinedExtensionTest(_base.ORMTest):
             u.email_address = 'lala@123.com'
 
             self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com')
-            attributes.instance_state(u).commit_all()
+            attributes.instance_state(u).commit_all(attributes.instance_dict(u))
             self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com')
 
             u.user_name = 'heythere'
@@ -182,7 +182,7 @@ class UserDefinedExtensionTest(_base.ORMTest):
             self.assertEquals(f.a, None)
             self.assertEquals(f.b, 12)
 
-            attributes.instance_state(f).commit_all()
+            attributes.instance_state(f).commit_all(attributes.instance_dict(f))
             self.assertEquals(f.a, None)
             self.assertEquals(f.b, 12)
 
@@ -272,8 +272,8 @@ class UserDefinedExtensionTest(_base.ORMTest):
             f1.bars.append(b1)
             self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'bars'), ([b1], [], []))
 
-            attributes.instance_state(f1).commit_all()
-            attributes.instance_state(b1).commit_all()
+            attributes.instance_state(f1).commit_all(attributes.instance_dict(f1))
+            attributes.instance_state(b1).commit_all(attributes.instance_dict(b1))
 
             self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'name'), ((), ['f1'], ()))
             self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'bars'), ((), [b1], ()))
index ddb4fa4ba5f74f361c98983a06f9e712ccdb29c5..d7f19a2cc0c55caece6a2a93e9dfae60a96de836 100644 (file)
@@ -5,6 +5,7 @@ from sqlalchemy.orm import *
 from sqlalchemy.orm import exc as orm_exc
 from testlib import *
 from testlib import fixtures
+from orm import _base, _fixtures
 
 class O2MTest(ORMTest):
     """deals with inheritance and one-to-many relationships"""
@@ -924,6 +925,49 @@ class OptimizedLoadTest(ORMTest):
         # the optimized load needs to return "None" so regular full-row loading proceeds
         s1 = sess.query(Base).get(s1.id)
         assert s1.sub == 's1sub'
+
+class PKDiscriminatorTest(_base.MappedTest):
+    def define_tables(self, metadata):
+        parents = Table('parents', metadata,
+                           Column('id', Integer, primary_key=True),
+                           Column('name', String(60)))
+                           
+        children = Table('children', metadata,
+                        Column('id', Integer, ForeignKey('parents.id'), primary_key=True),
+                        Column('type', Integer,primary_key=True),
+                        Column('name', String(60)))
+
+    @testing.resolve_artifact_names
+    def test_pk_as_discriminator(self):
+        class Parent(object):
+                def __init__(self, name=None):
+                    self.name = name
+
+        class Child(object):
+            def __init__(self, name=None):
+                self.name = name
+
+        class A(Child):
+            pass
+            
+        mapper(Parent, parents, properties={
+            'children': relation(Child, backref='parent'),
+        })
+        mapper(Child, children, polymorphic_on=children.c.type,
+            polymorphic_identity=1)
+            
+        mapper(A, inherits=Child, polymorphic_identity=2)
+
+        s = create_session()
+        p = Parent('p1')
+        a = A('a1')
+        p.children.append(a)
+        s.add(p)
+        s.flush()
+
+        assert a.id
+        assert a.type == 2
+        
         
 class DeleteOrphanTest(ORMTest):
     def define_tables(self, metadata):
index 081c46cdd886540738d916e335e71fc7462f6653..fd15420d0ad8b20dd929dde1072bcba7aba9de1b 100644 (file)
@@ -1,8 +1,8 @@
 import testenv; testenv.configure_for_tests()
 
 from testlib import sa
-from testlib.sa import MetaData, Table, Column, Integer, ForeignKey
-from testlib.sa.orm import mapper, relation, create_session, attributes, class_mapper
+from testlib.sa import MetaData, Table, Column, Integer, ForeignKey, util
+from testlib.sa.orm import mapper, relation, create_session, attributes, class_mapper, clear_mappers
 from testlib.testing import eq_, ne_
 from testlib.compat import _function_named
 from orm import _base
@@ -458,25 +458,9 @@ class MapperInitTest(_base.ORMTest):
 
         m = mapper(A, self.fixture())
 
-        a = attributes.instance_state(A())
-        assert isinstance(a, attributes.InstanceState)
-        assert type(a) is not attributes.InstanceState
-
-        b = attributes.instance_state(B())
-        assert isinstance(b, attributes.InstanceState)
-        assert type(b) is not attributes.InstanceState
-
         # B is not mapped in the current implementation
         self.assertRaises(sa.orm.exc.UnmappedClassError, class_mapper, B)
 
-        # the constructor of C is decorated too.  
-        # we don't support unmapped subclasses in any case,
-        # users should not be expecting any particular behavior
-        # from this scenario.
-        c = attributes.instance_state(C(3))
-        assert isinstance(c, attributes.InstanceState)
-        assert type(c) is not attributes.InstanceState
-
         # C is not mapped in the current implementation
         self.assertRaises(sa.orm.exc.UnmappedClassError, class_mapper, C)
 
@@ -573,6 +557,10 @@ class OnLoadTest(_base.ORMTest):
         finally:
             del A
 
+    def tearDownAll(self):
+        clear_mappers()
+        attributes._install_lookup_strategy(util.symbol('native'))
+
 
 class ExtendedEventsTest(_base.ORMTest):
     """Allow custom Events implementations."""
@@ -593,6 +581,7 @@ class ExtendedEventsTest(_base.ORMTest):
         assert isinstance(manager.events, MyEvents)
 
 
+
 class NativeInstrumentationTest(_base.ORMTest):
     @with_lookup_strategy(sa.util.symbol('native'))
     def test_register_reserved_attribute(self):
index 8192b195aed36798e498813eab9af0eecf647fe2..26a76301f29f92bbe600b1132a9419c43c7c2bfa 100644 (file)
@@ -1754,9 +1754,9 @@ class CompositeTypesTest(_base.MappedTest):
                 return [self.x, self.y]
             __hash__ = None
             def __eq__(self, other):
-                return other.x == self.x and other.y == self.y
+                return isinstance(other, Point) and other.x == self.x and other.y == self.y
             def __ne__(self, other):
-                return not self.__eq__(other)
+                return not isinstance(other, Point) or not self.__eq__(other)
 
         class Graph(object):
             pass
@@ -1822,6 +1822,12 @@ class CompositeTypesTest(_base.MappedTest):
         # query by columns
         eq_(sess.query(Edge.start, Edge.end).all(), [(3, 4, 5, 6), (14, 5, 19, 5)])
 
+        e = g.edges[1]
+        e.end.x = e.end.y = None
+        sess.flush()
+        eq_(sess.query(Edge.start, Edge.end).all(), [(3, 4, 5, 6), (14, 5, None, None)])
+
+
     @testing.resolve_artifact_names
     def test_pk(self):
         """Using a composite type as a primary key"""
index 02f8563c18f1d15910287c23d2373a08ec61b64c..3f832e33bb2e1ff3aacc86f2b4c542470ddecd36 100644 (file)
@@ -221,6 +221,15 @@ class MergeTest(_fixtures.FixtureTest):
             Address(email_address='hoho@bar.com')]))
         eq_(on_load.called, 6)
 
+    @testing.resolve_artifact_names
+    def test_merge_empty_attributes(self):
+        mapper(User, dingalings)
+        u1 = User(id=1)
+        sess = create_session()
+        sess.merge(u1)
+        sess.flush()
+        assert u1.address_id is u1.data is None
+        
     @testing.resolve_artifact_names
     def test_attribute_cascade(self):
         """Merge of a persistent entity with two child persistent entities."""
index 980165fc0b51effa84bc208d408211289fa65347..8efce660c37a777f8e4828409bd850fb856526ec 100644 (file)
@@ -14,20 +14,23 @@ class NaturalPKTest(_base.MappedTest):
     def define_tables(self, metadata):
         users = Table('users', metadata,
             Column('username', String(50), primary_key=True),
-            Column('fullname', String(100)))
+            Column('fullname', String(100)),
+            test_needs_fk=True)
 
         addresses = Table('addresses', metadata,
             Column('email', String(50), primary_key=True),
-            Column('username', String(50), ForeignKey('users.username', onupdate="cascade")))
+            Column('username', String(50), ForeignKey('users.username', onupdate="cascade")),
+            test_needs_fk=True)
 
         items = Table('items', metadata,
             Column('itemname', String(50), primary_key=True),
-            Column('description', String(100)))
+            Column('description', String(100)), 
+            test_needs_fk=True)
 
         users_to_items = Table('users_to_items', metadata,
             Column('username', String(50), ForeignKey('users.username', onupdate='cascade'), primary_key=True),
             Column('itemname', String(50), ForeignKey('items.itemname', onupdate='cascade'), primary_key=True),
-        )
+            test_needs_fk=True)
 
     def setup_classes(self):
         class User(_base.ComparableEntity):
@@ -101,8 +104,7 @@ class NaturalPKTest(_base.MappedTest):
         assert sess.query(User).get('ed').fullname == 'jack'
         
 
-    @testing.fails_on('mysql', 'FIXME: unknown')
-    @testing.fails_on('sqlite', 'FIXME: unknown')
+    @testing.fails_on('sqlite', 'sqlite doesnt support ON UPDATE CASCADE')
     def test_onetomany_passive(self):
         self._test_onetomany(True)
 
@@ -153,8 +155,7 @@ class NaturalPKTest(_base.MappedTest):
         self.assertEquals(User(username='fred', fullname='jack'), u1)
         
 
-    @testing.fails_on('sqlite', 'FIXME: unknown')
-    @testing.fails_on('mysql', 'FIXME: unknown')
+    @testing.fails_on('sqlite', 'sqlite doesnt support ON UPDATE CASCADE')
     def test_manytoone_passive(self):
         self._test_manytoone(True)
 
@@ -181,8 +182,6 @@ class NaturalPKTest(_base.MappedTest):
 
         u1.username = 'ed'
 
-        print id(a1), id(a2), id(u1)
-        print sa.orm.attributes.instance_state(u1).parents
         def go():
             sess.flush()
         if passive_updates:
@@ -198,8 +197,48 @@ class NaturalPKTest(_base.MappedTest):
         sess.expunge_all()
         self.assertEquals([Address(username='ed'), Address(username='ed')], sess.query(Address).all())
 
-    @testing.fails_on('sqlite', 'FIXME: unknown')
-    @testing.fails_on('mysql', 'FIXME: unknown')
+    @testing.fails_on('sqlite', 'sqlite doesnt support ON UPDATE CASCADE')
+    def test_onetoone_passive(self):
+        self._test_onetoone(True)
+
+    def test_onetoone_nonpassive(self):
+        self._test_onetoone(False)
+
+    @testing.resolve_artifact_names
+    def _test_onetoone(self, passive_updates):
+        mapper(User, users, properties={
+            "address":relation(Address, passive_updates=passive_updates, uselist=False)
+        })
+        mapper(Address, addresses)
+
+        sess = create_session()
+        u1 = User(username='jack', fullname='jack')
+        sess.add(u1)
+        sess.flush()
+        
+        a1 = Address(email='jack1')
+        u1.address = a1
+        sess.add(a1)
+        sess.flush()
+
+        u1.username = 'ed'
+
+        def go():
+            sess.flush()
+        if passive_updates:
+            sess.expire(u1, ['address'])
+            self.assert_sql_count(testing.db, go, 1)
+        else:
+            self.assert_sql_count(testing.db, go, 2)
+
+        def go():
+            sess.flush()
+        self.assert_sql_count(testing.db, go, 0)
+
+        sess.expunge_all()
+        self.assertEquals([Address(username='ed')], sess.query(Address).all())
+        
+    @testing.fails_on('sqlite', 'sqlite doesnt support ON UPDATE CASCADE')
     def test_bidirectional_passive(self):
         self._test_bidirectional(True)
 
@@ -230,6 +269,7 @@ class NaturalPKTest(_base.MappedTest):
         def go():
             sess.flush()
         if passive_updates:
+            sess.expire(u1, ['addresses'])
             self.assert_sql_count(testing.db, go, 1)
         else:
             self.assert_sql_count(testing.db, go, 3)
@@ -240,11 +280,11 @@ class NaturalPKTest(_base.MappedTest):
         u1 = sess.query(User).get('ed')
         assert len(u1.addresses) == 2    # load addresses
         u1.username = 'fred'
-        print "--------------------------------"
         def go():
             sess.flush()
         # check that the passive_updates is on on the other side
         if passive_updates:
+            sess.expire(u1, ['addresses'])
             self.assert_sql_count(testing.db, go, 1)
         else:
             self.assert_sql_count(testing.db, go, 3)
@@ -252,11 +292,11 @@ class NaturalPKTest(_base.MappedTest):
         self.assertEquals([Address(username='fred'), Address(username='fred')], sess.query(Address).all())
 
 
-    @testing.fails_on('sqlite', 'FIXME: unknown')
-    @testing.fails_on('mysql', 'FIXME: unknown')
+    @testing.fails_on('sqlite', 'sqlite doesnt support ON UPDATE CASCADE')
     def test_manytomany_passive(self):
         self._test_manytomany(True)
 
+    @testing.fails_on('mysql', 'the executemany() of the association table fails to report the correct row count')
     def test_manytomany_nonpassive(self):
         self._test_manytomany(False)
 
@@ -355,13 +395,16 @@ class NonPKCascadeTest(_base.MappedTest):
         Table('users', metadata,
             Column('id', Integer, primary_key=True),
             Column('username', String(50), unique=True),
-            Column('fullname', String(100)))
+            Column('fullname', String(100)),
+            test_needs_fk=True)
 
         Table('addresses', metadata,
               Column('id', Integer, primary_key=True),
               Column('email', String(50)),
               Column('username', String(50),
-                     ForeignKey('users.username', onupdate="cascade")))
+                     ForeignKey('users.username', onupdate="cascade")),
+                     test_needs_fk=True
+                     )
 
     def setup_classes(self):
         class User(_base.ComparableEntity):
@@ -369,8 +412,7 @@ class NonPKCascadeTest(_base.MappedTest):
         class Address(_base.ComparableEntity):
             pass
 
-    @testing.fails_on('sqlite', 'FIXME: unknown')
-    @testing.fails_on('mysql', 'FIXME: unknown')
+    @testing.fails_on('sqlite', 'sqlite doesnt support ON UPDATE CASCADE')
     def test_onetomany_passive(self):
         self._test_onetomany(True)
 
index 1ed3dcc61948f73de98574a6be4b618a73e1f08f..be0375e48b7628a5b25df76d8ecd216ae1ff9bb1 100644 (file)
@@ -1,7 +1,7 @@
 import testenv; testenv.configure_for_tests()
 from testlib import sa, testing
 from testlib.sa import Table, Column, Integer, String, ForeignKey
-from testlib.sa.orm import mapper, relation
+from testlib.sa.orm import mapper, relation, create_session
 from orm import _base
 
 
@@ -19,50 +19,56 @@ class O2OTest(_base.MappedTest):
               Column('description', String(100)),
               Column('jack_id', Integer, ForeignKey("jack.id")))
 
+    @testing.resolve_artifact_names
     def setup_mappers(self):
         class Jack(_base.BasicEntity):
             pass
         class Port(_base.BasicEntity):
             pass
 
-    @testing.resolve_artifact_names
-    def test_1(self):
-        ctx = sa.orm.scoped_session(sa.orm.create_session)
 
-        mapper(Port, port, extension=ctx.extension)
+    @testing.resolve_artifact_names
+    def test_basic(self):
+        mapper(Port, port)
         mapper(Jack, jack,
                order_by=[jack.c.number],
                properties=dict(
                    port=relation(Port, backref='jack',
-                                 uselist=False, lazy=True)),
-               extension=ctx.extension)
+                                 uselist=False,
+                                 )),
+               )
+
+        session = create_session()
 
         j = Jack(number='101')
+        session.add(j)
         p = Port(name='fa0/1')
+        session.add(p)
+        
         j.port=p
-        ctx.flush()
+        session.flush()
         jid = j.id
         pid = p.id
 
-        j=ctx.query(Jack).get(jid)
-        p=ctx.query(Port).get(pid)
+        j=session.query(Jack).get(jid)
+        p=session.query(Port).get(pid)
         assert p.jack is not None
         assert p.jack is  j
         assert j.port is not None
         p.jack = None
         assert j.port is None
 
-        ctx.expunge_all()
+        session.expunge_all()
 
-        j = ctx.query(Jack).get(jid)
-        p = ctx.query(Port).get(pid)
+        j = session.query(Jack).get(jid)
+        p = session.query(Port).get(pid)
 
         j.port=None
         self.assert_(p.jack is None)
-        ctx.flush()
+        session.flush()
 
-        ctx.delete(j)
-        ctx.flush()
+        session.delete(j)
+        session.flush()
 
 if __name__ == "__main__":
     testenv.main()
index 6531b234c6da8c8329752e2f5750d1d3dfd39e22..07705c925638893ce70e4e8fc02f9a3b33309d48 100644 (file)
@@ -366,7 +366,7 @@ class OperatorTest(QueryTest, AssertsCompiledSQL):
                     )
 
         u7 = User(id=7)
-        attributes.instance_state(u7).commit_all()
+        attributes.instance_state(u7).commit_all(attributes.instance_dict(u7))
         
         self._test(Address.user == u7, ":param_1 = addresses.user_id")
 
index 88f132eae25d31add9eaacf138e821f4a6b2c01d..a0a8900b2c07a2d25072bb278555eba00e191f16 100644 (file)
@@ -835,7 +835,7 @@ class JoinConditionErrorTest(testing.TestBase):
         mapper(C2, t3)
         
         self.assertRaises(sa.exc.NoReferencedColumnError, compile_mappers)
-
+    
     def test_join_error_raised(self):
         m = MetaData()
         t1 = Table('t1', m, 
@@ -1640,6 +1640,53 @@ class InvalidRelationEscalationTest(_base.MappedTest):
             "Could not locate any equated, locally mapped column pairs "
             "for primaryjoin condition", sa.orm.compile_mappers)
 
+    @testing.resolve_artifact_names
+    def test_ambiguous_fks(self):
+        mapper(Foo, foos, properties={
+            'bars':relation(Bar,
+                            primaryjoin=foos.c.id==bars.c.fid,
+                            foreign_keys=[foos.c.id, bars.c.fid])})
+        mapper(Bar, bars)
+
+        self.assertRaisesMessage(
+            sa.exc.ArgumentError, 
+                "Do the columns in 'foreign_keys' represent only the "
+                "'foreign' columns in this join condition ?", 
+                sa.orm.compile_mappers)
+
+    @testing.resolve_artifact_names
+    def test_ambiguous_remoteside_o2m(self):
+        mapper(Foo, foos, properties={
+            'bars':relation(Bar,
+                            primaryjoin=foos.c.id==bars.c.fid,
+                            foreign_keys=[bars.c.fid],
+                            remote_side=[foos.c.id, bars.c.fid],
+                            viewonly=True
+                            )})
+        mapper(Bar, bars)
+
+        self.assertRaisesMessage(
+            sa.exc.ArgumentError, 
+                "could not determine any local/remote column pairs",
+                sa.orm.compile_mappers)
+
+    @testing.resolve_artifact_names
+    def test_ambiguous_remoteside_m2o(self):
+        mapper(Foo, foos, properties={
+            'bars':relation(Bar,
+                            primaryjoin=foos.c.id==bars.c.fid,
+                            foreign_keys=[foos.c.id],
+                            remote_side=[foos.c.id, bars.c.fid],
+                            viewonly=True
+                            )})
+        mapper(Bar, bars)
+
+        self.assertRaisesMessage(
+            sa.exc.ArgumentError, 
+                "could not determine any local/remote column pairs",
+                sa.orm.compile_mappers)
+        
+    
     @testing.resolve_artifact_names
     def test_no_equated_self_ref(self):
         mapper(Foo, foos, properties={
index 5a2229b16c42b50f802ffda493aee5b01ff7a983..41c3fe755225783e3be619ba017348ab7956dcbd 100644 (file)
@@ -5,7 +5,7 @@ from sqlalchemy.orm import create_session, sessionmaker, attributes
 from testlib import engines, sa, testing, config
 from testlib.compat import gc_collect
 from testlib.sa import Table, Column, Integer, String, Sequence
-from testlib.sa.orm import mapper, relation, backref
+from testlib.sa.orm import mapper, relation, backref, eagerload
 from testlib.testing import eq_
 from engine import _base as engine_base
 from orm import _base, _fixtures
@@ -776,7 +776,66 @@ class SessionTest(_fixtures.FixtureTest):
         user = s.query(User).one()
         assert user.name == 'fred'
         assert s.identity_map
+    
+    @testing.resolve_artifact_names
+    def test_weakref_with_cycles_o2m(self):
+        s = sessionmaker()()
+        mapper(User, users, properties={
+            "addresses":relation(Address, backref="user")
+        })
+        mapper(Address, addresses)
+        s.add(User(name="ed", addresses=[Address(email_address="ed1")]))
+        s.commit()
+        
+        user = s.query(User).options(eagerload(User.addresses)).one()
+        user.addresses[0].user # lazyload
+        eq_(user, User(name="ed", addresses=[Address(email_address="ed1")]))
+        
+        del user
+        gc_collect()
+        assert len(s.identity_map) == 0
 
+        user = s.query(User).options(eagerload(User.addresses)).one()
+        user.addresses[0].email_address='ed2'
+        user.addresses[0].user # lazyload
+        del user
+        gc_collect()
+        assert len(s.identity_map) == 2
+        
+        s.commit()
+        user = s.query(User).options(eagerload(User.addresses)).one()
+        eq_(user, User(name="ed", addresses=[Address(email_address="ed2")]))
+        
+    @testing.resolve_artifact_names
+    def test_weakref_with_cycles_o2o(self):
+        s = sessionmaker()()
+        mapper(User, users, properties={
+            "address":relation(Address, backref="user", uselist=False)
+        })
+        mapper(Address, addresses)
+        s.add(User(name="ed", address=Address(email_address="ed1")))
+        s.commit()
+
+        user = s.query(User).options(eagerload(User.address)).one()
+        user.address.user
+        eq_(user, User(name="ed", address=Address(email_address="ed1")))
+
+        del user
+        gc_collect()
+        assert len(s.identity_map) == 0
+
+        user = s.query(User).options(eagerload(User.address)).one()
+        user.address.email_address='ed2'
+        user.address.user # lazyload
+
+        del user
+        gc_collect()
+        assert len(s.identity_map) == 2
+        
+        s.commit()
+        user = s.query(User).options(eagerload(User.address)).one()
+        eq_(user, User(name="ed", address=Address(email_address="ed2")))
+    
     @testing.resolve_artifact_names
     def test_strong_ref(self):
         s = create_session(weak_identity_map=False)
@@ -792,9 +851,9 @@ class SessionTest(_fixtures.FixtureTest):
         assert len(s.identity_map) == 1
 
         user = s.query(User).one()
-        assert not s.identity_map.modified
+        assert not s.identity_map._modified
         user.name = 'u2'
-        assert s.identity_map.modified
+        assert s.identity_map._modified
         s.flush()
         eq_(users.select().execute().fetchall(), [(user.id, 'u2')])
         
index dd1b9b766ce5d6e804a8f52a8b017b165e62e921..f1b912313558126149f7985c55d8f5a2d6ef9c9c 100644 (file)
@@ -14,6 +14,7 @@ from orm import _base, _fixtures
 from engine import _base as engine_base
 import pickleable
 from testlib.assertsql import AllOf, CompiledSQL
+import gc
 
 class UnitOfWorkTest(object):
     pass
@@ -366,6 +367,26 @@ class MutableTypesTest(_base.MappedTest):
              "WHERE mutable_t.id = :mutable_t_id",
              {'mutable_t_id': f1.id, 'val': u'hi', 'data':f1.data})])
 
+    @testing.resolve_artifact_names
+    def test_resurrect(self):
+        f1 = Foo()
+        f1.data = pickleable.Bar(4,5)
+        f1.val = u'hi'
+
+        session = create_session(autocommit=False)
+        session.add(f1)
+        session.commit()
+
+        f1.data.y = 19
+        del f1
+
+        gc.collect()
+        assert len(session.identity_map) == 1
+
+        session.commit()
+
+        assert session.query(Foo).one().data == pickleable.Bar(4, 19)
+
     @testing.resolve_artifact_names
     def test_unicode(self):
         """Equivalent Unicode values are not flagged as changed."""
index 7a189f87313ed027268cbd93f28dbe25d4c11b8b..5d7192261d61867845f362eca3786ac0f85eb9ec 100644 (file)
@@ -290,11 +290,11 @@ class ZooMarkTest(TestBase):
     def test_profile_1_create_tables(self):
         self.test_baseline_1_create_tables()
 
-    @profiling.function_call_count(12925, {'2.4':12478})
+    @profiling.function_call_count(12178, {'2.4':12178})
     def test_profile_1a_populate(self):
         self.test_baseline_1a_populate()
 
-    @profiling.function_call_count(1185, {'2.4':1184})
+    @profiling.function_call_count(903, {'2.4':903})
     def test_profile_2_insert(self):
         self.test_baseline_2_insert()
 
@@ -310,7 +310,7 @@ class ZooMarkTest(TestBase):
     def test_profile_5_aggregates(self):
         self.test_baseline_5_aggregates()
 
-    @profiling.function_call_count(3545)
+    @profiling.function_call_count(3343)
     def test_profile_6_editing(self):
         self.test_baseline_6_editing()